Skip to content

Commit 3f6aa11

Browse files
Merge branch 'master' into fix_typo_using_standard_datasets
2 parents e7f2b74 + b4370c0 commit 3f6aa11

File tree

8 files changed

+993
-59
lines changed

8 files changed

+993
-59
lines changed

dbldatagen/datasets/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
from .dataset_provider import DatasetProvider, dataset_definition
22
from .basic_geometries import BasicGeometriesProvider
33
from .basic_process_historian import BasicProcessHistorianProvider
4+
from .basic_stock_ticker import BasicStockTickerProvider
45
from .basic_telematics import BasicTelematicsProvider
56
from .basic_user import BasicUserProvider
67
from .benchmark_groupby import BenchmarkGroupByProvider
8+
from .multi_table_sales_order_provider import MultiTableSalesOrderProvider
79
from .multi_table_telephony_provider import MultiTableTelephonyProvider
810

911
__all__ = ["dataset_provider",
1012
"basic_geometries",
1113
"basic_process_historian",
14+
"basic_stock_ticker",
1215
"basic_telematics",
1316
"basic_user",
1417
"benchmark_groupby",
18+
"multi_table_sales_order_provider",
1519
"multi_table_telephony_provider"
1620
]

dbldatagen/datasets/basic_geometries.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,18 @@ class BasicGeometriesProvider(DatasetProvider.NoAssociatedDatasetsMixin, Dataset
2929
"""
3030
MIN_LOCATION_ID = 1000000
3131
MAX_LOCATION_ID = 9223372036854775807
32+
DEFAULT_MIN_LAT = -90.0
33+
DEFAULT_MAX_LAT = 90.0
34+
DEFAULT_MIN_LON = -180.0
35+
DEFAULT_MAX_LON = 180.0
3236
COLUMN_COUNT = 2
3337
ALLOWED_OPTIONS = [
3438
"geometryType",
3539
"maxVertices",
40+
"minLatitude",
41+
"maxLatitude",
42+
"minLongitude",
43+
"maxLongitude",
3644
"random"
3745
]
3846

@@ -45,6 +53,10 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions
4553
generateRandom = options.get("random", False)
4654
geometryType = options.get("geometryType", "point")
4755
maxVertices = options.get("maxVertices", 1 if geometryType == "point" else 3)
56+
minLatitude = options.get("minLatitude", self.DEFAULT_MIN_LAT)
57+
maxLatitude = options.get("maxLatitude", self.DEFAULT_MAX_LAT)
58+
minLongitude = options.get("minLongitude", self.DEFAULT_MIN_LON)
59+
maxLongitude = options.get("maxLongitude", self.DEFAULT_MAX_LON)
4860

4961
assert tableName is None or tableName == DatasetProvider.DEFAULT_TABLE_NAME, "Invalid table name"
5062
if rows is None or rows < 0:
@@ -62,9 +74,9 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions
6274
if maxVertices > 1:
6375
w.warn('Ignoring property maxVertices for point geometries')
6476
df_spec = (
65-
df_spec.withColumn("lat", "float", minValue=-90.0, maxValue=90.0,
77+
df_spec.withColumn("lat", "float", minValue=minLatitude, maxValue=maxLatitude,
6678
step=1e-5, random=generateRandom, omit=True)
67-
.withColumn("lon", "float", minValue=-180.0, maxValue=180.0,
79+
.withColumn("lon", "float", minValue=minLongitude, maxValue=maxLongitude,
6880
step=1e-5, random=generateRandom, omit=True)
6981
.withColumn("wkt", "string", expr="concat('POINT(', lon, ' ', lat, ')')")
7082
)
@@ -75,9 +87,9 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions
7587
j = 0
7688
while j < maxVertices:
7789
df_spec = (
78-
df_spec.withColumn(f"lat_{j}", "float", minValue=-90.0, maxValue=90.0,
90+
df_spec.withColumn(f"lat_{j}", "float", minValue=minLatitude, maxValue=maxLatitude,
7991
step=1e-5, random=generateRandom, omit=True)
80-
.withColumn(f"lon_{j}", "float", minValue=-180.0, maxValue=180.0,
92+
.withColumn(f"lon_{j}", "float", minValue=minLongitude, maxValue=maxLongitude,
8193
step=1e-5, random=generateRandom, omit=True)
8294
)
8395
j = j + 1
@@ -93,9 +105,9 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions
93105
j = 0
94106
while j < maxVertices:
95107
df_spec = (
96-
df_spec.withColumn(f"lat_{j}", "float", minValue=-90.0, maxValue=90.0,
108+
df_spec.withColumn(f"lat_{j}", "float", minValue=minLatitude, maxValue=maxLatitude,
97109
step=1e-5, random=generateRandom, omit=True)
98-
.withColumn(f"lon_{j}", "float", minValue=-180.0, maxValue=180.0,
110+
.withColumn(f"lon_{j}", "float", minValue=minLongitude, maxValue=maxLongitude,
99111
step=1e-5, random=generateRandom, omit=True)
100112
)
101113
j = j + 1
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from random import random
2+
3+
from .dataset_provider import DatasetProvider, dataset_definition
4+
5+
6+
@dataset_definition(name="basic/stock_ticker",
7+
summary="Stock ticker dataset",
8+
autoRegister=True,
9+
supportsStreaming=True)
10+
class BasicStockTickerProvider(DatasetProvider.NoAssociatedDatasetsMixin, DatasetProvider):
11+
"""
12+
Basic Stock Ticker Dataset
13+
========================
14+
15+
This is a basic stock ticker dataset with time-series `symbol`, `open`, `close`, `high`, `low`,
16+
`adj_close`, and `volume` values.
17+
18+
It takes the following options when retrieving the table:
19+
- rows : number of rows to generate
20+
- partitions: number of partitions to use
21+
- numSymbols: number of unique stock ticker symbols
22+
- startDate: first date for stock ticker data
23+
- endDate: last date for stock ticker data
24+
25+
As the data specification is a DataGenerator object, you can add further columns to the data set and
26+
add constraints (when the feature is available)
27+
28+
Note that this dataset does not use any features that would prevent it from being used as a source for a
29+
streaming dataframe, and so the flag `supportsStreaming` is set to True.
30+
31+
"""
32+
DEFAULT_NUM_SYMBOLS = 100
33+
DEFAULT_START_DATE = "2024-10-01"
34+
COLUMN_COUNT = 8
35+
ALLOWED_OPTIONS = [
36+
"numSymbols",
37+
"startDate"
38+
]
39+
40+
@DatasetProvider.allowed_options(options=ALLOWED_OPTIONS)
41+
def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions=-1, **options):
42+
import dbldatagen as dg
43+
44+
numSymbols = options.get("numSymbols", self.DEFAULT_NUM_SYMBOLS)
45+
startDate = options.get("startDate", self.DEFAULT_START_DATE)
46+
47+
assert tableName is None or tableName == DatasetProvider.DEFAULT_TABLE_NAME, "Invalid table name"
48+
if rows is None or rows < 0:
49+
rows = DatasetProvider.DEFAULT_ROWS
50+
if partitions is None or partitions < 0:
51+
partitions = self.autoComputePartitions(rows, self.COLUMN_COUNT)
52+
if numSymbols <= 0:
53+
raise ValueError("'numSymbols' must be > 0")
54+
55+
df_spec = (
56+
dg.DataGenerator(sparkSession=sparkSession, rows=rows,
57+
partitions=partitions, randomSeedMethod="hash_fieldname")
58+
.withColumn("symbol_id", "long", minValue=676, maxValue=676 + numSymbols - 1)
59+
.withColumn("rand_value", "float", minValue=0.0, maxValue=1.0, step=0.1,
60+
baseColumn="symbol_id", omit=True)
61+
.withColumn("symbol", "string",
62+
expr="""concat_ws('', transform(split(conv(symbol_id, 10, 26), ''),
63+
x -> case when x < 10 then char(ascii(x) - 48 + 65) else char(ascii(x) + 10) end))""")
64+
.withColumn("days_from_start_date", "int", expr=f"floor(id / {numSymbols})", omit=True)
65+
.withColumn("post_date", "date", expr=f"date_add(cast('{startDate}' as date), days_from_start_date)")
66+
.withColumn("start_value", "decimal(11,2)",
67+
values=[1.0 + 199.0 * random() for _ in range(int(numSymbols / 10))], omit=True)
68+
.withColumn("growth_rate", "float", values=[-0.1 + 0.35 * random() for _ in range(int(numSymbols / 10))],
69+
baseColumn="symbol_id")
70+
.withColumn("volatility", "float", values=[0.0075 * random() for _ in range(int(numSymbols / 10))],
71+
baseColumn="symbol_id", omit=True)
72+
.withColumn("prev_modifier_sign", "float",
73+
expr=f"case when sin((id - {numSymbols}) % 17) > 0 then -1.0 else 1.0 end""",
74+
omit=True)
75+
.withColumn("modifier_sign", "float",
76+
expr="case when sin(id % 17) > 0 then -1.0 else 1.0 end",
77+
omit=True)
78+
.withColumn("open_base", "decimal(11,2)",
79+
expr=f"""start_value
80+
+ (volatility * prev_modifier_sign * start_value * sin((id - {numSymbols}) % 17))
81+
+ (growth_rate * start_value * (days_from_start_date - 1) / 365)""",
82+
omit=True)
83+
.withColumn("close_base", "decimal(11,2)",
84+
expr="""start_value
85+
+ (volatility * start_value * sin(id % 17))
86+
+ (growth_rate * start_value * days_from_start_date / 365)""",
87+
omit=True)
88+
.withColumn("high_base", "decimal(11,2)",
89+
expr="greatest(open_base, close_base) + rand() * volatility * open_base",
90+
omit=True)
91+
.withColumn("low_base", "decimal(11,2)",
92+
expr="least(open_base, close_base) - rand() * volatility * open_base",
93+
omit=True)
94+
.withColumn("open", "decimal(11,2)", expr="greatest(open_base, 0.0)")
95+
.withColumn("close", "decimal(11,2)", expr="greatest(close_base, 0.0)")
96+
.withColumn("high", "decimal(11,2)", expr="greatest(high_base, 0.0)")
97+
.withColumn("low", "decimal(11,2)", expr="greatest(low_base, 0.0)")
98+
.withColumn("dividend", "decimal(4,2)", expr="0.05 * rand_value * close", omit=True)
99+
.withColumn("adj_close", "decimal(11,2)", expr="greatest(close - dividend, 0.0)")
100+
.withColumn("volume", "long", minValue=100000, maxValue=5000000, random=True)
101+
)
102+
103+
return df_spec

0 commit comments

Comments
 (0)