Skip to content

Commit 95cc0bf

Browse files
committed
Add rounding service
1 parent 6aa127d commit 95cc0bf

File tree

6 files changed

+65
-42
lines changed

6 files changed

+65
-42
lines changed

investing_algorithm_framework/app/algorithm.py

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
import decimal
1+
import inspect
22
import logging
33
from typing import List
4-
import inspect
54

65
from investing_algorithm_framework.domain import OrderStatus, OrderFee, \
76
Position, Order, Portfolio, OrderType, OrderSide, \
87
BACKTESTING_FLAG, BACKTESTING_INDEX_DATETIME, MarketService, TimeUnit, \
9-
OperationalException, random_string
8+
OperationalException, random_string, RoundingService
109
from investing_algorithm_framework.services import MarketCredentialService, \
1110
MarketDataSourceService, PortfolioService, PositionService, TradeService, \
1211
OrderService, ConfigurationService, StrategyOrchestratorService, \
@@ -218,7 +217,7 @@ def create_limit_order(
218217
amount = position.get_amount() * (percentage_of_position / 100)
219218

220219
if precision is not None:
221-
amount = self.round_down(amount, precision)
220+
amount = RoundingService.round_down(amount, precision)
222221

223222
order_data = {
224223
"target_symbol": target_symbol,
@@ -594,7 +593,9 @@ def get_position_percentage_of_portfolio_by_net_size(
594593
net_size = portfolio.get_net_size()
595594
return (position.cost / net_size) * 100
596595

597-
def close_position(self, symbol, market=None, identifier=None):
596+
def close_position(
597+
self, symbol, market=None, identifier=None, precision=None
598+
):
598599
portfolio = self.portfolio_service.find(
599600
{"market": market, "identifier": identifier}
600601
)
@@ -623,6 +624,7 @@ def close_position(self, symbol, market=None, identifier=None):
623624
amount=position.get_amount(),
624625
order_side=OrderSide.SELL.value,
625626
price=ticker["bid"],
627+
precision=precision,
626628
)
627629

628630
def add_strategies(self, strategies):
@@ -886,28 +888,13 @@ def get_trades(self, market=None):
886888
def get_closed_trades(self):
887889
return self.trade_service.get_closed_trades()
888890

889-
def round_down(self, value, amount_of_decimals):
890-
891-
if self.count_decimals(value) <= amount_of_decimals:
892-
return value
893-
894-
with decimal.localcontext() as ctx:
895-
d = decimal.Decimal(value)
896-
ctx.rounding = decimal.ROUND_DOWN
897-
return float(round(d, amount_of_decimals))
898-
899-
def count_decimals(self, number):
900-
decimal_str = str(number)
901-
if '.' in decimal_str:
902-
return len(decimal_str.split('.')[1])
903-
else:
904-
return 0
905-
906891
def get_open_trades(self, target_symbol=None, market=None):
907892
return self.trade_service.get_open_trades(target_symbol, market)
908893

909-
def close_trade(self, trade, market=None):
910-
self.trade_service.close_trade(trade, market)
894+
def close_trade(self, trade, market=None, precision=None) -> None:
895+
self.trade_service.close_trade(
896+
trade=trade, market=market, precision=precision
897+
)
911898

912899
def get_number_of_positions(self):
913900
"""

investing_algorithm_framework/domain/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
from .decimal_parsing import parse_decimal_to_string, parse_string_to_decimal
2727
from .services import TickerMarketDataSource, OrderBookMarketDataSource, \
2828
OHLCVMarketDataSource, BacktestMarketDataSource, MarketDataSource, \
29-
MarketService, MarketCredentialService, AbstractPortfolioSyncService
29+
MarketService, MarketCredentialService, AbstractPortfolioSyncService, \
30+
RoundingService
3031
from .data_structures import PeekableQueue
3132

3233
__all__ = [
@@ -109,5 +110,6 @@
109110
"RESERVED_BALANCES",
110111
"AbstractPortfolioSyncService",
111112
"APP_MODE",
112-
"AppMode"
113+
"AppMode",
114+
"RoundingService",
113115
]

investing_algorithm_framework/domain/services/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .market_service import MarketService
44
from .market_credential_service import MarketCredentialService
55
from .portfolios import AbstractPortfolioSyncService
6+
from .rounding_service import RoundingService
67

78
__all__ = [
89
"MarketDataSource",
@@ -12,5 +13,6 @@
1213
"BacktestMarketDataSource",
1314
"MarketService",
1415
"MarketCredentialService",
15-
"AbstractPortfolioSyncService"
16+
"AbstractPortfolioSyncService",
17+
"RoundingService",
1618
]
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import decimal
2+
3+
4+
class RoundingService:
5+
"""
6+
Service to round numbers to a certain amount of decimals.
7+
It will always round down.
8+
"""
9+
10+
@staticmethod
11+
def round_down(value, amount_of_decimals):
12+
13+
if RoundingService.count_decimals(value) <= amount_of_decimals:
14+
return value
15+
16+
with decimal.localcontext() as ctx:
17+
d = decimal.Decimal(value)
18+
ctx.rounding = decimal.ROUND_DOWN
19+
return float(round(d, amount_of_decimals))
20+
21+
@staticmethod
22+
def count_decimals(number):
23+
decimal_str = str(number)
24+
if '.' in decimal_str:
25+
return len(decimal_str.split('.')[1])
26+
else:
27+
return 0

investing_algorithm_framework/services/trade_service/trade_service.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import logging
2-
from typing import List
32
from queue import PriorityQueue
3+
from typing import List
44

55
from investing_algorithm_framework.domain import OrderStatus, OrderSide, \
66
Trade, PeekableQueue, OrderType, TradeStatus, \
7-
OperationalException, Order
8-
from investing_algorithm_framework.services.position_service import \
9-
PositionService
7+
OperationalException, Order, RoundingService
108
from investing_algorithm_framework.services.market_data_source_service import \
119
MarketDataSourceService
10+
from investing_algorithm_framework.services.position_service import \
11+
PositionService
1212

1313
logger = logging.getLogger(__name__)
1414

@@ -199,7 +199,7 @@ def get_closed_trades(self, portfolio_id=None) -> List[Trade]:
199199
if order.get_trade_closed_at() is not None
200200
]
201201

202-
def close_trade(self, trade, market=None) -> None:
202+
def close_trade(self, trade, market=None, precision=None) -> None:
203203
"""
204204
Close trade method
205205
@@ -210,6 +210,7 @@ def close_trade(self, trade, market=None) -> None:
210210
211211
return: None
212212
"""
213+
213214
if trade.closed_at is not None:
214215
raise OperationalException("Trade already closed.")
215216

@@ -227,6 +228,9 @@ def close_trade(self, trade, market=None) -> None:
227228
)
228229
amount = order.get_amount()
229230

231+
if precision is not None:
232+
amount = RoundingService.round_down(amount, precision)
233+
230234
if position.get_amount() < amount:
231235
logger.warning(
232236
f"Order amount {amount} is larger then amount "

tests/app/algorithm/test_round_down.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from investing_algorithm_framework import create_app, RESOURCE_DIRECTORY, \
44
PortfolioConfiguration, Algorithm, MarketCredential
5+
from investing_algorithm_framework.domain import RoundingService
56
from tests.resources import TestBase, MarketServiceStub
67

78

@@ -49,52 +50,52 @@ def setUp(self) -> None:
4950
self.app.initialize()
5051

5152
def test_round_down(self):
52-
new_value = self.app.algorithm.round_down(1, 3)
53+
new_value = RoundingService.round_down(1, 3)
5354
self.assertEqual(
5455
0, self.count_decimals(new_value)
5556
)
5657
self.assertEqual(1, new_value)
57-
new_value = self.app.algorithm.round_down(1.23456789, 2)
58+
new_value = RoundingService.round_down(1.23456789, 2)
5859
self.assertEqual(
5960
2, self.count_decimals(new_value)
6061
)
6162
self.assertEqual(1.23, new_value)
62-
new_value = self.app.algorithm.round_down(1.987654321, 3)
63+
new_value = RoundingService.round_down(1.987654321, 3)
6364
self.assertEqual(
6465
3, self.count_decimals(new_value)
6566
)
6667
self.assertEqual(1.987, new_value)
67-
new_value = self.app.algorithm.round_down(1.987654321, 4)
68+
new_value = RoundingService.round_down(1.987654321, 4)
6869
self.assertEqual(
6970
4, self.count_decimals(new_value)
7071
)
7172
self.assertEqual(1.9876, new_value)
72-
new_value = self.app.algorithm.round_down(1.987654321, 5)
73+
new_value = RoundingService.round_down(1.987654321, 5)
7374
self.assertEqual(
7475
5, self.count_decimals(new_value)
7576
)
7677
self.assertEqual(1.98765, new_value)
77-
new_value = self.app.algorithm.round_down(1.987654321, 6)
78+
new_value = RoundingService.round_down(1.987654321, 6)
7879
self.assertEqual(
7980
6, self.count_decimals(new_value)
8081
)
8182
self.assertEqual(1.987654, new_value)
82-
new_value = self.app.algorithm.round_down(1.987654321, 7)
83+
new_value = RoundingService.round_down(1.987654321, 7)
8384
self.assertEqual(
8485
7, self.count_decimals(new_value)
8586
)
8687
self.assertEqual(1.9876543, new_value)
87-
new_value = self.app.algorithm.round_down(1.987654321, 8)
88+
new_value = RoundingService.round_down(1.987654321, 8)
8889
self.assertEqual(
8990
8, self.count_decimals(new_value)
9091
)
9192
self.assertEqual(1.98765432, new_value)
92-
new_value = self.app.algorithm.round_down(1.987654321, 9)
93+
new_value = RoundingService.round_down(1.987654321, 9)
9394
self.assertEqual(
9495
9, self.count_decimals(new_value)
9596
)
9697
self.assertEqual(1.987654321, new_value)
97-
new_value = self.app.algorithm.round_down(1.987654321, 10)
98+
new_value = RoundingService.round_down(1.987654321, 10)
9899
self.assertEqual(
99100
9, self.count_decimals(new_value)
100101
)

0 commit comments

Comments
 (0)