Skip to content

Commit b732b75

Browse files
committed
Fix schema initialization in process_dataset_chunks
1 parent b6dd397 commit b732b75

File tree

1 file changed

+24
-21
lines changed

1 file changed

+24
-21
lines changed

python-lib/dku_io_utils.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,28 +18,27 @@ def count_records(dataset: dataiku.Dataset) -> int:
1818
1919
Returns:
2020
Number of records
21+
2122
"""
2223
metric_id = "records:COUNT_RECORDS"
2324
partitions = dataset.read_partitions
2425
client = dataiku.api_client()
2526
project = client.get_project(dataset.project_key)
2627
record_count = 0
27-
logging.info("Counting records of dataset: {}...".format(dataset.name))
28+
logging.info(f"Counting records of dataset: {dataset.name}...")
2829
if partitions is None or len(partitions) == 0:
2930
project.get_dataset(dataset.short_name).compute_metrics(metric_ids=[metric_id])
3031
metric = dataset.get_last_metric_values()
3132
record_count = dataiku.ComputedMetrics.get_value_from_data(metric.get_global_data(metric_id=metric_id))
32-
logging.info("Dataset {} contains {:d} records and is not partitioned".format(dataset.name, record_count))
33+
logging.info(f"Dataset {dataset.name} contains {record_count:d} records and is not partitioned")
3334
else:
3435
for partition in partitions:
3536
project.get_dataset(dataset.short_name).compute_metrics(partition=partition, metric_ids=[metric_id])
3637
metric = dataset.get_last_metric_values()
3738
record_count += dataiku.ComputedMetrics.get_value_from_data(
3839
metric.get_partition_data(partition=partition, metric_id=metric_id)
3940
)
40-
logging.info(
41-
"Dataset {} contains {:d} records in partition(s) {}".format(dataset.name, record_count, partitions)
42-
)
41+
logging.info(f"Dataset {dataset.name} contains {record_count:d} records in partition(s) {partitions}")
4342
return record_count
4443

4544

@@ -48,8 +47,8 @@ def process_dataset_chunks(
4847
) -> None:
4948
"""Read a dataset by chunks, process each dataframe chunk with a function and write back to another dataset.
5049
51-
Passes keyword arguments to the function, adds a tqdm progress bar and generic logging.
52-
Directly writes chunks to the output_dataset, so that only one chunk needs to be processed in-memory at a time.
50+
Pass keyword arguments to the function, adds a tqdm progress bar and generic logging.
51+
Directly write chunks to the output_dataset, so that only one chunk needs to be processed in-memory at a time.
5352
5453
Args:
5554
input_dataset: Input dataiku.Dataset instance
@@ -59,45 +58,49 @@ def process_dataset_chunks(
5958
and output another pandas.DataFrame
6059
chunksize: Number of rows of each chunk of pandas.DataFrame fed to `func`
6160
**kwargs: Optional keyword arguments fed to `func`
61+
62+
Raises:
63+
ValueError: If the input dataset is empty or if pandas cannot read it without type inference
64+
6265
"""
6366
input_count_records = count_records(input_dataset)
6467
if input_count_records == 0:
6568
raise ValueError("Input dataset has no records")
66-
logging.info(
67-
"Processing dataset {} of {:d} rows by chunks of {:d}...".format(
68-
input_dataset.name, input_count_records, chunksize
69-
)
70-
)
69+
logging.info(f"Processing dataset {input_dataset.name} of {input_count_records} rows by chunks of {chunksize}...")
7170
start = time()
71+
# First, initialize output schema if not present. Required to show the real error if `iter_dataframes` fails.
72+
if not output_dataset.read_schema(raise_if_empty=False):
73+
df = input_dataset.get_dataframe(limit=5, infer_with_pandas=False)
74+
output_df = func(df=df, **kwargs)
75+
output_dataset.write_schema_from_dataframe(output_df)
7276
with output_dataset.get_writer() as writer:
7377
df_iterator = input_dataset.iter_dataframes(chunksize=chunksize, infer_with_pandas=False)
7478
len_iterator = math.ceil(input_count_records / chunksize)
75-
for i, df in tqdm(enumerate(df_iterator), total=len_iterator):
79+
for i, df in tqdm(enumerate(df_iterator), total=len_iterator, unit="chunk", mininterval=1.0):
7680
output_df = func(df=df, **kwargs)
7781
if i == 0:
7882
output_dataset.write_schema_from_dataframe(
7983
output_df, dropAndCreate=bool(not output_dataset.writePartition)
8084
)
8185
writer.write_dataframe(output_df)
8286
logging.info(
83-
"Processing dataset {} of {:d} rows: Done in {:.2f} seconds.".format(
84-
input_dataset.name, input_count_records, time() - start
85-
)
87+
f"Processing dataset {input_dataset.name} of {input_count_records} rows: Done in {time() - start:.2f} seconds."
8688
)
8789

8890

89-
def set_column_description(
90-
output_dataset: dataiku.Dataset, column_description_dict: Dict, input_dataset: dataiku.Dataset = None
91+
def set_column_descriptions(
92+
output_dataset: dataiku.Dataset, column_descriptions: Dict, input_dataset: dataiku.Dataset = None
9193
) -> None:
9294
"""Set column descriptions of the output dataset based on a dictionary of column descriptions
9395
94-
Retains the column descriptions from the input dataset if the column name matches.
96+
Retain the column descriptions from the input dataset if the column name matches.
9597
9698
Args:
9799
output_dataset: Output dataiku.Dataset instance
98-
column_description_dict: Dictionary holding column descriptions (value) by column name (key)
100+
column_descriptions: Dictionary holding column descriptions (value) by column name (key)
99101
input_dataset: Optional input dataiku.Dataset instance
100102
in case you want to retain input column descriptions
103+
101104
"""
102105
output_dataset_schema = output_dataset.read_schema()
103106
input_dataset_schema = []
@@ -107,7 +110,7 @@ def set_column_description(
107110
input_columns_names = [col["name"] for col in input_dataset_schema]
108111
for output_col_info in output_dataset_schema:
109112
output_col_name = output_col_info.get("name", "")
110-
output_col_info["comment"] = column_description_dict.get(output_col_name)
113+
output_col_info["comment"] = column_descriptions.get(output_col_name)
111114
if output_col_name in input_columns_names:
112115
matched_comment = [
113116
input_col_info.get("comment", "")

0 commit comments

Comments
 (0)