@@ -364,6 +364,7 @@ def predict_dist(self,
364364 - "quantile" calculates the quantiles from the predicted distribution.
365365 - "parameters" returns the predicted distributional parameters.
366366 - "expectiles" returns the predicted expectiles.
367+ - "properties" returns the mean, variance, and (if implemented) mode.
367368 n_samples : int
368369 Number of samples to draw from the predicted distribution.
369370 quantiles : List[float]
@@ -402,18 +403,35 @@ def predict_dist(self,
402403 dist_params_predt = pd .DataFrame (dist_params_predt )
403404 dist_params_predt .columns = self .param_dict .keys ()
404405
405- # Draw samples from predicted response distribution
406- pred_samples_df = self .draw_samples (predt_params = dist_params_predt ,
407- n_samples = n_samples ,
408- seed = seed )
409-
410406 if pred_type == "parameters" :
411407 return dist_params_predt
412408
413409 elif pred_type == "expectiles" :
414410 return dist_params_predt
411+
412+ elif pred_type == "properties" :
413+ if self .tau is None :
414+ pred_params = torch .tensor (dist_params_predt .values )
415+ dist_kwargs = {arg_name : param for arg_name , param in zip (self .distribution_arg_names , pred_params .T )}
416+ dist_pred = self .distribution (** dist_kwargs )
417+ pred_props = pd .DataFrame ({"mean" : dist_pred .mean .detach ().numpy (),
418+ "variance" : dist_pred .variance .detach ().numpy ()})
419+ try :
420+ dist_pred .mode
421+ except NotImplementedError :
422+ pass
423+ else :
424+ pred_props ["mode" ] = dist_pred .mode .detach ().numpy ()
425+ return pred_props
426+ else :
427+ raise ValueError ("Invalid prediction type." )
428+
429+ # Draw samples from predicted response distribution
430+ pred_samples_df = self .draw_samples (predt_params = dist_params_predt ,
431+ n_samples = n_samples ,
432+ seed = seed )
415433
416- elif pred_type == "samples" :
434+ if pred_type == "samples" :
417435 return pred_samples_df
418436
419437 elif pred_type == "quantiles" :
0 commit comments