@@ -26,6 +26,23 @@ def fit_curve(
26
26
f"Provided dimension ({ dimension } ) not found in data.dims: { data .dims } "
27
27
)
28
28
29
+ try :
30
+ # Try parsing as datetime first
31
+ dates = data [dimension ].values
32
+ dates = np .asarray (dates , dtype = np .datetime64 )
33
+ except ValueError :
34
+ dates = np .asarray (data [dimension ].values )
35
+
36
+ if np .issubdtype (dates .dtype , np .datetime64 ):
37
+ timestep = [
38
+ (
39
+ (np .datetime64 (x ) - np .datetime64 ("1970-01-01" , "s" ))
40
+ / np .timedelta64 (1 , "s" )
41
+ )
42
+ for x in dates
43
+ ]
44
+ data [dimension ] = np .array (timestep )
45
+
29
46
dims_before = list (data .dims )
30
47
31
48
# In the spec, parameters is a list, but xr.curvefit requires names for them,
@@ -87,8 +104,16 @@ def predict_curve(
87
104
labels = np .asarray (labels )
88
105
89
106
if np .issubdtype (labels .dtype , np .datetime64 ):
90
- labels = labels .astype (int )
91
107
labels_were_datetime = True
108
+ initial_labels = labels
109
+ timestep = [
110
+ (
111
+ (np .datetime64 (x ) - np .datetime64 ("1970-01-01" , "s" ))
112
+ / np .timedelta64 (1 , "s" )
113
+ )
114
+ for x in labels
115
+ ]
116
+ labels = np .array (timestep )
92
117
93
118
# This is necessary to pipe the arguments correctly through @process
94
119
def wrapper (f ):
@@ -122,6 +147,6 @@ def _wrap(*args, **kwargs):
122
147
predictions = predictions .assign_coords ({dimension : labels .data })
123
148
124
149
if labels_were_datetime :
125
- predictions [dimension ] = pd . DatetimeIndex ( predictions [ dimension ]. values )
150
+ predictions [dimension ] = initial_labels
126
151
127
152
return predictions
0 commit comments