|
| 1 | +from random import random |
| 2 | + |
| 3 | +from .dataset_provider import DatasetProvider, dataset_definition |
| 4 | + |
| 5 | + |
| 6 | +@dataset_definition(name="basic/stock_ticker", |
| 7 | + summary="Stock ticker dataset", |
| 8 | + autoRegister=True, |
| 9 | + supportsStreaming=True) |
| 10 | +class BasicStockTickerProvider(DatasetProvider.NoAssociatedDatasetsMixin, DatasetProvider): |
| 11 | + """ |
| 12 | + Basic Stock Ticker Dataset |
| 13 | + ======================== |
| 14 | +
|
| 15 | + This is a basic stock ticker dataset with time-series `symbol`, `open`, `close`, `high`, `low`, |
| 16 | + `adj_close`, and `volume` values. |
| 17 | +
|
| 18 | + It takes the following options when retrieving the table: |
| 19 | + - rows : number of rows to generate |
| 20 | + - partitions: number of partitions to use |
| 21 | + - numSymbols: number of unique stock ticker symbols |
| 22 | + - startDate: first date for stock ticker data |
| 23 | + - endDate: last date for stock ticker data |
| 24 | + |
| 25 | + As the data specification is a DataGenerator object, you can add further columns to the data set and |
| 26 | + add constraints (when the feature is available) |
| 27 | +
|
| 28 | + Note that this dataset does not use any features that would prevent it from being used as a source for a |
| 29 | + streaming dataframe, and so the flag `supportsStreaming` is set to True. |
| 30 | +
|
| 31 | + """ |
| 32 | + DEFAULT_NUM_SYMBOLS = 100 |
| 33 | + DEFAULT_START_DATE = "2024-10-01" |
| 34 | + COLUMN_COUNT = 8 |
| 35 | + ALLOWED_OPTIONS = [ |
| 36 | + "numSymbols", |
| 37 | + "startDate" |
| 38 | + ] |
| 39 | + |
| 40 | + @DatasetProvider.allowed_options(options=ALLOWED_OPTIONS) |
| 41 | + def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions=-1, **options): |
| 42 | + import dbldatagen as dg |
| 43 | + |
| 44 | + numSymbols = options.get("numSymbols", self.DEFAULT_NUM_SYMBOLS) |
| 45 | + startDate = options.get("startDate", self.DEFAULT_START_DATE) |
| 46 | + |
| 47 | + assert tableName is None or tableName == DatasetProvider.DEFAULT_TABLE_NAME, "Invalid table name" |
| 48 | + if rows is None or rows < 0: |
| 49 | + rows = DatasetProvider.DEFAULT_ROWS |
| 50 | + if partitions is None or partitions < 0: |
| 51 | + partitions = self.autoComputePartitions(rows, self.COLUMN_COUNT) |
| 52 | + if numSymbols <= 0: |
| 53 | + raise ValueError("'numSymbols' must be > 0") |
| 54 | + |
| 55 | + df_spec = ( |
| 56 | + dg.DataGenerator(sparkSession=sparkSession, rows=rows, |
| 57 | + partitions=partitions, randomSeedMethod="hash_fieldname") |
| 58 | + .withColumn("symbol_id", "long", minValue=676, maxValue=676 + numSymbols - 1) |
| 59 | + .withColumn("rand_value", "float", minValue=0.0, maxValue=1.0, step=0.1, |
| 60 | + baseColumn="symbol_id", omit=True) |
| 61 | + .withColumn("symbol", "string", |
| 62 | + expr="""concat_ws('', transform(split(conv(symbol_id, 10, 26), ''), |
| 63 | + x -> case when x < 10 then char(ascii(x) - 48 + 65) else char(ascii(x) + 10) end))""") |
| 64 | + .withColumn("days_from_start_date", "int", expr=f"floor(id / {numSymbols})", omit=True) |
| 65 | + .withColumn("post_date", "date", expr=f"date_add(cast('{startDate}' as date), days_from_start_date)") |
| 66 | + .withColumn("start_value", "decimal(11,2)", |
| 67 | + values=[1.0 + 199.0 * random() for _ in range(int(numSymbols / 10))], omit=True) |
| 68 | + .withColumn("growth_rate", "float", values=[-0.1 + 0.35 * random() for _ in range(int(numSymbols / 10))], |
| 69 | + baseColumn="symbol_id") |
| 70 | + .withColumn("volatility", "float", values=[0.0075 * random() for _ in range(int(numSymbols / 10))], |
| 71 | + baseColumn="symbol_id", omit=True) |
| 72 | + .withColumn("prev_modifier_sign", "float", |
| 73 | + expr=f"case when sin((id - {numSymbols}) % 17) > 0 then -1.0 else 1.0 end""", |
| 74 | + omit=True) |
| 75 | + .withColumn("modifier_sign", "float", |
| 76 | + expr="case when sin(id % 17) > 0 then -1.0 else 1.0 end", |
| 77 | + omit=True) |
| 78 | + .withColumn("open_base", "decimal(11,2)", |
| 79 | + expr=f"""start_value |
| 80 | + + (volatility * prev_modifier_sign * start_value * sin((id - {numSymbols}) % 17)) |
| 81 | + + (growth_rate * start_value * (days_from_start_date - 1) / 365)""", |
| 82 | + omit=True) |
| 83 | + .withColumn("close_base", "decimal(11,2)", |
| 84 | + expr="""start_value |
| 85 | + + (volatility * start_value * sin(id % 17)) |
| 86 | + + (growth_rate * start_value * days_from_start_date / 365)""", |
| 87 | + omit=True) |
| 88 | + .withColumn("high_base", "decimal(11,2)", |
| 89 | + expr="greatest(open_base, close_base) + rand() * volatility * open_base", |
| 90 | + omit=True) |
| 91 | + .withColumn("low_base", "decimal(11,2)", |
| 92 | + expr="least(open_base, close_base) - rand() * volatility * open_base", |
| 93 | + omit=True) |
| 94 | + .withColumn("open", "decimal(11,2)", expr="greatest(open_base, 0.0)") |
| 95 | + .withColumn("close", "decimal(11,2)", expr="greatest(close_base, 0.0)") |
| 96 | + .withColumn("high", "decimal(11,2)", expr="greatest(high_base, 0.0)") |
| 97 | + .withColumn("low", "decimal(11,2)", expr="greatest(low_base, 0.0)") |
| 98 | + .withColumn("dividend", "decimal(4,2)", expr="0.05 * rand_value * close", omit=True) |
| 99 | + .withColumn("adj_close", "decimal(11,2)", expr="greatest(close - dividend, 0.0)") |
| 100 | + .withColumn("volume", "long", minValue=100000, maxValue=5000000, random=True) |
| 101 | + ) |
| 102 | + |
| 103 | + return df_spec |
0 commit comments