Skip to content

Commit bc52fc9

Browse files
fit_curve update time (#167)
fit_curve with dimension convertion
1 parent 7ddf6c2 commit bc52fc9

File tree

2 files changed

+40
-2
lines changed

2 files changed

+40
-2
lines changed

openeo_processes_dask/process_implementations/ml/curve_fitting.py

+27-2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,23 @@ def fit_curve(
2626
f"Provided dimension ({dimension}) not found in data.dims: {data.dims}"
2727
)
2828

29+
try:
30+
# Try parsing as datetime first
31+
dates = data[dimension].values
32+
dates = np.asarray(dates, dtype=np.datetime64)
33+
except ValueError:
34+
dates = np.asarray(data[dimension].values)
35+
36+
if np.issubdtype(dates.dtype, np.datetime64):
37+
timestep = [
38+
(
39+
(np.datetime64(x) - np.datetime64("1970-01-01", "s"))
40+
/ np.timedelta64(1, "s")
41+
)
42+
for x in dates
43+
]
44+
data[dimension] = np.array(timestep)
45+
2946
dims_before = list(data.dims)
3047

3148
# In the spec, parameters is a list, but xr.curvefit requires names for them,
@@ -87,8 +104,16 @@ def predict_curve(
87104
labels = np.asarray(labels)
88105

89106
if np.issubdtype(labels.dtype, np.datetime64):
90-
labels = labels.astype(int)
91107
labels_were_datetime = True
108+
initial_labels = labels
109+
timestep = [
110+
(
111+
(np.datetime64(x) - np.datetime64("1970-01-01", "s"))
112+
/ np.timedelta64(1, "s")
113+
)
114+
for x in labels
115+
]
116+
labels = np.array(timestep)
92117

93118
# This is necessary to pipe the arguments correctly through @process
94119
def wrapper(f):
@@ -122,6 +147,6 @@ def _wrap(*args, **kwargs):
122147
predictions = predictions.assign_coords({dimension: labels.data})
123148

124149
if labels_were_datetime:
125-
predictions[dimension] = pd.DatetimeIndex(predictions[dimension].values)
150+
predictions[dimension] = initial_labels
126151

127152
return predictions

tests/test_ml.py

+13
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,19 @@ def fitFunction(x, parameters):
8484
assert len(result.coords["param"]) == len(parameters)
8585

8686
labels = dimension_labels(origin_cube, origin_cube.openeo.temporal_dims[0])
87+
labels = [float(l) for l in labels]
88+
predictions = predict_curve(
89+
result,
90+
_process,
91+
origin_cube.openeo.temporal_dims[0],
92+
labels=labels,
93+
).compute()
94+
95+
assert len(predictions.coords[origin_cube.openeo.temporal_dims[0]]) == len(labels)
96+
assert "param" not in predictions.dims
97+
assert result.rio.crs == predictions.rio.crs
98+
99+
labels = ["2020-02-02", "2020-03-02", "2020-04-02", "2020-05-02"]
87100
predictions = predict_curve(
88101
result,
89102
_process,

0 commit comments

Comments
 (0)