Skip to content

Commit bfa52d8

Browse files
authored
Import the ccxt backtest data source performance (#293)
1 parent 2307351 commit bfa52d8

File tree

8 files changed

+22431
-67
lines changed

8 files changed

+22431
-67
lines changed

examples/backtest_example/run_backtest.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import time
12
import logging.config
23
from datetime import datetime, timedelta
34

@@ -156,13 +157,17 @@ def apply_strategy(self, algorithm: Algorithm, market_data):
156157

157158
if __name__ == "__main__":
158159
end_date = datetime(2023, 12, 2)
159-
start_date = end_date - timedelta(days=100)
160+
start_date = end_date - timedelta(days=400)
160161
date_range = BacktestDateRange(
161162
start_date=start_date,
162163
end_date=end_date
163164
)
165+
start_time = time.time()
166+
164167
backtest_report = app.run_backtest(
165168
algorithm=algorithm,
166169
backtest_date_range=date_range,
167170
)
168171
pretty_print_backtest(backtest_report)
172+
end_time = time.time()
173+
print(f"Execution Time: {end_time - start_time:.6f} seconds")

investing_algorithm_framework/infrastructure/models/market_data_sources/ccxt.py

+33-34
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import logging
22
import os
33
from datetime import timedelta, datetime, timezone
4-
from dateutil.parser import parse
54
import polars
65
from dateutil import parser
76

@@ -57,6 +56,9 @@ def __init__(
5756
self.data = None
5857
self._start_date_data_source = None
5958
self._end_date_data_source = None
59+
self.backtest_end_index = self.window_size
60+
self.backtest_start_index = 0
61+
self.window_cache = {}
6062

6163
def prepare_data(
6264
self,
@@ -100,8 +102,6 @@ def prepare_data(
100102

101103
self.backtest_data_start_date = backtest_data_start_date\
102104
.replace(microsecond=0)
103-
self.backtest_data_index_date = backtest_data_start_date\
104-
.replace(microsecond=0)
105105
self.backtest_data_end_date = backtest_end_date.replace(microsecond=0)
106106

107107
# Creating the backtest data directory and file
@@ -148,14 +148,30 @@ def prepare_data(
148148
self.write_data_to_file_path(file_path, ohlcv)
149149

150150
self.load_data()
151+
self._precompute_sliding_windows() # Precompute sliding windows!
152+
153+
def _precompute_sliding_windows(self):
154+
"""
155+
Precompute all sliding windows for fast retrieval.
156+
"""
157+
self.window_cache = {}
158+
timestamps = self.data["Datetime"].to_list()
159+
160+
for i in range(len(timestamps) - self.window_size + 1):
161+
# Use last timestamp as key
162+
end_time = timestamps[i + self.window_size - 1]
163+
self.window_cache[end_time] = self.data.slice(i, self.window_size)
151164

152165
def load_data(self):
153166
file_path = self._create_file_path()
154-
self.data = polars.read_csv(file_path)
167+
self.data = polars.read_csv(
168+
file_path, dtypes={"Datetime": polars.Datetime}, low_memory=True
169+
) # Faster parsing
155170
first_row = self.data.head(1)
156171
last_row = self.data.tail(1)
157-
self._start_date_data_source = parse(first_row["Datetime"][0])
158-
self._end_date_data_source = parse(last_row["Datetime"][0])
172+
173+
self._start_date_data_source = first_row["Datetime"][0]
174+
self._end_date_data_source = last_row["Datetime"][0]
159175

160176
def _create_file_path(self):
161177
"""
@@ -190,38 +206,21 @@ def get_data(
190206
source. This implementation will use polars to load and filter the
191207
data.
192208
"""
193-
if self.data is None:
194-
self.load_data()
195-
196-
end_date = date
197209

198-
if end_date is None:
199-
return self.data
210+
data = self.window_cache.get(date)
211+
if data is not None:
212+
return data
200213

201-
start_date = self.create_start_date(
202-
end_date, self.time_frame, self.window_size
203-
)
214+
# Find closest previous timestamp
215+
sorted_timestamps = sorted(self.window_cache.keys())
204216

205-
if start_date < self._start_date_data_source:
206-
raise OperationalException(
207-
f"Start date {start_date} is before the start date "
208-
f"of the data source {self._start_date_data_source}"
209-
)
217+
closest_date = None
218+
for ts in reversed(sorted_timestamps):
219+
if ts < date:
220+
closest_date = ts
221+
break
210222

211-
if end_date > self._end_date_data_source:
212-
raise OperationalException(
213-
f"End date {end_date} is after the end date "
214-
f"of the data source {self._end_date_data_source}"
215-
)
216-
217-
time_frame = TimeFrame.from_string(self.time_frame)
218-
start_date = start_date - \
219-
timedelta(minutes=time_frame.amount_of_minutes)
220-
selection = self.data.filter(
221-
(self.data['Datetime'] >= start_date.strftime(DATETIME_FORMAT))
222-
& (self.data['Datetime'] <= end_date.strftime(DATETIME_FORMAT))
223-
)
224-
return selection
223+
return self.window_cache.get(closest_date) if closest_date else None
225224

226225
def to_backtest_market_data_source(self) -> BacktestMarketDataSource:
227226
# Ignore this method for now

investing_algorithm_framework/infrastructure/models/market_data_sources/csv.py

+24-22
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ def __init__(
4040
self._columns = [
4141
"Datetime", "Open", "High", "Low", "Close", "Volume"
4242
]
43-
df = polars.read_csv(csv_file_path)
43+
44+
df = polars.read_csv(self._csv_file_path)
4445

4546
# Check if all column names are in the csv file
4647
if not all(column in df.columns for column in self._columns):
@@ -53,15 +54,25 @@ def __init__(
5354
f"Missing columns: {missing_columns}"
5455
)
5556

56-
first_row = df.head(1)
57-
last_row = df.tail(1)
58-
self._start_date_data_source = parse(first_row["Datetime"][0])
59-
self._end_date_data_source = parse(last_row["Datetime"][0])
57+
self.data = self._load_data(self.csv_file_path)
58+
first_row = self.data.head(1)
59+
last_row = self.data.tail(1)
60+
self._start_date_data_source = first_row["Datetime"][0]
61+
self._end_date_data_source = last_row["Datetime"][0]
6062

6163
@property
6264
def csv_file_path(self):
6365
return self._csv_file_path
6466

67+
def _load_data(self, file_path):
68+
return polars.read_csv(
69+
file_path, dtypes={"Datetime": polars.Datetime}, low_memory=True
70+
).with_columns(
71+
polars.col("Datetime").cast(
72+
polars.Datetime(time_unit="ms", time_zone="UTC")
73+
)
74+
)
75+
6576
def get_data(
6677
self,
6778
start_date: datetime = None,
@@ -86,9 +97,7 @@ def get_data(
8697
"""
8798

8899
if start_date is None and end_date is None:
89-
return polars.read_csv(
90-
self.csv_file_path, columns=self._columns, separator=","
91-
)
100+
return self.data
92101

93102
if end_date is not None and start_date is not None:
94103

@@ -101,13 +110,10 @@ def get_data(
101110
if start_date > self._end_date_data_source:
102111
return polars.DataFrame()
103112

104-
df = polars.read_csv(
105-
self.csv_file_path, columns=self._columns, separator=","
106-
)
107-
113+
df = self.data
108114
df = df.filter(
109-
(df['Datetime'] >= start_date.strftime(DATETIME_FORMAT))
110-
& (df['Datetime'] <= end_date.strftime(DATETIME_FORMAT))
115+
(df['Datetime'] >= start_date)
116+
& (df['Datetime'] <= end_date)
111117
)
112118
return df
113119

@@ -119,11 +125,9 @@ def get_data(
119125
if start_date > self._end_date_data_source:
120126
return polars.DataFrame()
121127

122-
df = polars.read_csv(
123-
self.csv_file_path, columns=self._columns, separator=","
124-
)
128+
df = self.data
125129
df = df.filter(
126-
(df['Datetime'] >= start_date.strftime(DATETIME_FORMAT))
130+
(df['Datetime'] >= start_date)
127131
)
128132
df = df.head(self.window_size)
129133
return df
@@ -136,11 +140,9 @@ def get_data(
136140
if end_date > self._end_date_data_source:
137141
return polars.DataFrame()
138142

139-
df = polars.read_csv(
140-
self.csv_file_path, columns=self._columns, separator=","
141-
)
143+
df = self.data
142144
df = df.filter(
143-
(df['Datetime'] <= end_date.strftime(DATETIME_FORMAT))
145+
(df['Datetime'] <= end_date)
144146
)
145147
df = df.tail(self.window_size)
146148
return df

investing_algorithm_framework/services/trade_service/trade_service.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import logging
22
from queue import PriorityQueue
3-
from dateutil import parser
43

54
from investing_algorithm_framework.domain import OrderStatus, \
65
TradeStatus, Trade, OperationalException, MarketDataType
@@ -248,9 +247,7 @@ def update_trades_with_market_data(self, market_data):
248247
last_row = data.tail(1)
249248
update_data = {
250249
"last_reported_price": last_row["Close"][0],
251-
"updated_at": parser.parse(
252-
last_row["Datetime"][0]
253-
)
250+
"updated_at": last_row["Datetime"][0]
254251
}
255252
price = last_row["Close"][0]
256253

tests/infrastructure/market_data_sources/test_csv_ohlcv_market_data_source.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from datetime import datetime, timedelta
2+
from datetime import datetime, timedelta, timezone
33
from unittest import TestCase
44

55
from dateutil import parser
@@ -42,7 +42,7 @@ def test_right_columns(self):
4242
f"{file_name}",
4343
window_size=10,
4444
)
45-
date = datetime(2023, 8, 7, 8, 0, tzinfo=tzutc())
45+
date = datetime(2023, 8, 7, 8, 0, tzinfo=timezone.utc)
4646
df = data_source.get_data(start_date=date)
4747
self.assertEqual(
4848
["Datetime", "Open", "High", "Low", "Close", "Volume"], df.columns
@@ -61,7 +61,7 @@ def test_throw_exception_when_missing_column_names_columns(self):
6161
)
6262

6363
def test_start_date(self):
64-
start_date = datetime(2023, 8, 7, 8, 0, tzinfo=tzutc())
64+
start_date = datetime(2023, 8, 7, 8, 0, tzinfo=timezone.utc)
6565
file_name = "OHLCV_BTC-EUR_BINANCE" \
6666
"_2h_2023-08-07-07-59_2023-12-02-00-00.csv"
6767
csv_ohlcv_market_data_source = CSVOHLCVMarketDataSource(
@@ -78,7 +78,7 @@ def test_start_date(self):
7878

7979
def test_start_date_with_window_size(self):
8080
start_date = datetime(
81-
year=2023, month=8, day=7, hour=10, minute=0, tzinfo=tzutc()
81+
year=2023, month=8, day=7, hour=10, minute=0, tzinfo=timezone.utc
8282
)
8383
file_name = "OHLCV_BTC-EUR_BINANCE" \
8484
"_2h_2023-08-07-07-59_2023-12-02-00-00.csv"
@@ -92,7 +92,7 @@ def test_start_date_with_window_size(self):
9292
start_date=start_date
9393
)
9494
self.assertEqual(12, len(data))
95-
first_date = parser.parse(data["Datetime"][0])
95+
first_date = data["Datetime"][0]
9696
self.assertEqual(
9797
start_date.strftime(DATETIME_FORMAT),
9898
first_date.strftime(DATETIME_FORMAT)
@@ -122,7 +122,7 @@ def test_empty(self):
122122
f"{file_name}",
123123
window_size=10,
124124
)
125-
start_date = datetime(2023, 12, 2, 0, 0, tzinfo=tzutc())
125+
start_date = datetime(2023, 12, 2, 0, 0, tzinfo=timezone.utc)
126126
self.assertFalse(data_source.empty(start_date))
127127

128128
def test_get_data(self):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import os
2+
from unittest import TestCase
3+
from datetime import datetime, timedelta
4+
from investing_algorithm_framework.infrastructure.models\
5+
.market_data_sources.ccxt import CCXTOHLCVBacktestMarketDataSource
6+
from investing_algorithm_framework.domain import RESOURCE_DIRECTORY, \
7+
BACKTEST_DATA_DIRECTORY_NAME
8+
9+
10+
class TestCCXTOHLCVBacktestDataSource(TestCase):
11+
12+
def setUp(self):
13+
self.resource_dir = os.path.abspath(
14+
os.path.join(
15+
os.path.join(
16+
os.path.join(
17+
os.path.join(
18+
os.path.join(
19+
os.path.realpath(__file__),
20+
os.pardir
21+
),
22+
os.pardir
23+
),
24+
os.pardir
25+
),
26+
os.pardir
27+
),
28+
"resources"
29+
)
30+
)
31+
self.backtest_data_dir = "market_data_sources_for_testing"
32+
33+
def test_prepare_data(self):
34+
pass
35+
36+
def test_get_data(self):
37+
data_source = CCXTOHLCVBacktestMarketDataSource(
38+
identifier="bitvavo",
39+
market="BITVAVO",
40+
symbol="BTC/EUR",
41+
time_frame="2h",
42+
window_size=200,
43+
)
44+
config = {
45+
RESOURCE_DIRECTORY: self.resource_dir,
46+
BACKTEST_DATA_DIRECTORY_NAME: self.backtest_data_dir
47+
}
48+
data_source.prepare_data(
49+
config=config,
50+
backtest_start_date=datetime(2021, 1, 1), backtest_end_date=datetime(2025, 1, 1)
51+
)
52+
number_of_data_retrievals = 0
53+
backtest_start_date = datetime(2021, 1, 1)
54+
backtest_end_date = datetime(2022, 1, 1)
55+
interval = timedelta(hours=2) # Define the 2-hour interval
56+
current_date = backtest_start_date
57+
delta = backtest_end_date - backtest_start_date
58+
runs = (delta.total_seconds() / 7200) + 1
59+
60+
while current_date <= backtest_end_date:
61+
data = data_source.get_data(date=current_date)
62+
63+
if data is not None:
64+
number_of_data_retrievals += 1
65+
self.assertTrue(abs(200 - len(data)) <= 4)
66+
67+
current_date += interval # Increment by 2 hours
68+
69+
self.assertEqual(runs, number_of_data_retrievals)

0 commit comments

Comments
 (0)