Skip to content

Commit 0bf9342

Browse files
committed
Add prediction type to return the mean, variance, and (if implemented) mode
Additionally, this change avoids unnecessary sampling if the prediction type doesn't need it.
1 parent a99ae90 commit 0bf9342

File tree

1 file changed

+24
-6
lines changed

1 file changed

+24
-6
lines changed

lightgbmlss/distributions/distribution_utils.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,7 @@ def predict_dist(self,
364364
- "quantile" calculates the quantiles from the predicted distribution.
365365
- "parameters" returns the predicted distributional parameters.
366366
- "expectiles" returns the predicted expectiles.
367+
- "properties" returns the mean, variance, and (if implemented) mode.
367368
n_samples : int
368369
Number of samples to draw from the predicted distribution.
369370
quantiles : List[float]
@@ -402,18 +403,35 @@ def predict_dist(self,
402403
dist_params_predt = pd.DataFrame(dist_params_predt)
403404
dist_params_predt.columns = self.param_dict.keys()
404405

405-
# Draw samples from predicted response distribution
406-
pred_samples_df = self.draw_samples(predt_params=dist_params_predt,
407-
n_samples=n_samples,
408-
seed=seed)
409-
410406
if pred_type == "parameters":
411407
return dist_params_predt
412408

413409
elif pred_type == "expectiles":
414410
return dist_params_predt
411+
412+
elif pred_type == "properties":
413+
if self.tau is None:
414+
pred_params = torch.tensor(dist_params_predt.values)
415+
dist_kwargs = {arg_name: param for arg_name, param in zip(self.distribution_arg_names, pred_params.T)}
416+
dist_pred = self.distribution(**dist_kwargs)
417+
pred_props = pd.DataFrame({"mean": dist_pred.mean.detach().numpy(),
418+
"variance": dist_pred.variance.detach().numpy()})
419+
try:
420+
dist_pred.mode
421+
except NotImplementedError:
422+
pass
423+
else:
424+
pred_props["mode"] = dist_pred.mode.detach().numpy()
425+
return pred_props
426+
else:
427+
raise ValueError("Invalid prediction type.")
428+
429+
# Draw samples from predicted response distribution
430+
pred_samples_df = self.draw_samples(predt_params=dist_params_predt,
431+
n_samples=n_samples,
432+
seed=seed)
415433

416-
elif pred_type == "samples":
434+
if pred_type == "samples":
417435
return pred_samples_df
418436

419437
elif pred_type == "quantiles":

0 commit comments

Comments
 (0)