Skip to content

Commit e6ea96f

Browse files
committed
Variable attrs
1 parent 32485ae commit e6ea96f

File tree

2 files changed

+26
-9
lines changed

2 files changed

+26
-9
lines changed

modelskill/comparison/_comparison.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1265,7 +1265,15 @@ def save(self, filename: Union[str, Path]) -> None:
12651265
con = duckdb.connect(filename)
12661266
# TODO figure out how to save the x, y, z coordinates and other attributes later
12671267
df = ds.to_dataframe().drop(columns=["x", "y", "z"]).reset_index() # noqa
1268-
duckdb.sql("CREATE TABLE data AS SELECT * FROM df", connection=con)
1268+
duckdb.sql("CREATE TABLE matched_data AS SELECT * FROM df", connection=con)
1269+
1270+
attr_dict = {key: str(ds[key].attrs) for key in ds.data_vars}
1271+
attr_df = pd.DataFrame(attr_dict.items(), columns=["key", "value"]) # noqa
1272+
1273+
# attr_df["global", "key"] = str(ds.attrs)
1274+
1275+
duckdb.sql("CREATE TABLE attrs AS SELECT * FROM attr_df", connection=con)
1276+
12691277
con.close()
12701278
elif ext == ".nc":
12711279
if self.gtype == "point":
@@ -1304,18 +1312,20 @@ def load(filename: Union[str, Path]) -> "Comparer":
13041312
import duckdb
13051313

13061314
con = duckdb.connect(filename)
1307-
df = duckdb.sql("SELECT * FROM data", connection=con).df().set_index("time")
1315+
df = (
1316+
duckdb.sql("SELECT * FROM matched_data", connection=con)
1317+
.df()
1318+
.set_index("time")
1319+
)
13081320

13091321
# convert pandas dataframe to xarray dataset
13101322
ds = xr.Dataset.from_dataframe(df)
13111323

1312-
# set observation attribute
1313-
ds.Observation.attrs["kind"] = "observation"
1314-
1315-
# set model attributes
1316-
for key in ds.data_vars:
1317-
if key != "Observation":
1318-
ds[key].attrs["kind"] = "model"
1324+
attrs = duckdb.sql("SELECT * FROM attrs", connection=con).df()
1325+
for row in attrs.iterrows():
1326+
key = row[1]["key"]
1327+
value = row[1]["value"]
1328+
ds[key].attrs = eval(value)
13191329

13201330
# TODO figure out aux variables
13211331

tests/test_comparer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -962,3 +962,10 @@ def test_save_load(pc, tmp_path) -> None:
962962
assert "m1" in pc2.mod_names
963963
assert "m2" in pc2.mod_names
964964
assert pc2.n_points == 5
965+
assert pc2.data.m1.attrs["kind"] == "model"
966+
assert pc2.data.m2.attrs["kind"] == "model"
967+
assert pc2.data.Observation.attrs["kind"] == "observation"
968+
969+
# TODO global attrs
970+
971+
# TODO raw_mod_data

0 commit comments

Comments
 (0)