|  | 
| 25 | 25 | 
 | 
| 26 | 26 | from .base import BaseChronosPipeline, ForecastType | 
| 27 | 27 | 
 | 
| 28 |  | - | 
| 29 | 28 | logger = logging.getLogger(__file__) | 
| 30 | 29 | 
 | 
| 31 | 30 | 
 | 
| @@ -240,13 +239,11 @@ def _init_weights(self, module): | 
| 240 | 239 |             ): | 
| 241 | 240 |                 module.output_layer.bias.data.zero_() | 
| 242 | 241 | 
 | 
| 243 |  | -    def forward( | 
| 244 |  | -        self, | 
| 245 |  | -        context: torch.Tensor, | 
| 246 |  | -        mask: Optional[torch.Tensor] = None, | 
| 247 |  | -        target: Optional[torch.Tensor] = None, | 
| 248 |  | -        target_mask: Optional[torch.Tensor] = None, | 
| 249 |  | -    ) -> ChronosBoltOutput: | 
|  | 242 | +    def encode( | 
|  | 243 | +        self, context: torch.Tensor, mask: Optional[torch.Tensor] = None | 
|  | 244 | +    ) -> Tuple[ | 
|  | 245 | +        torch.Tensor, Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor | 
|  | 246 | +    ]: | 
| 250 | 247 |         mask = ( | 
| 251 | 248 |             mask.to(context.dtype) | 
| 252 | 249 |             if mask is not None | 
| @@ -301,8 +298,21 @@ def forward( | 
| 301 | 298 |             attention_mask=attention_mask, | 
| 302 | 299 |             inputs_embeds=input_embeds, | 
| 303 | 300 |         ) | 
| 304 |  | -        hidden_states = encoder_outputs[0] | 
| 305 | 301 | 
 | 
|  | 302 | +        return encoder_outputs[0], loc_scale, input_embeds, attention_mask | 
|  | 303 | + | 
|  | 304 | +    def forward( | 
|  | 305 | +        self, | 
|  | 306 | +        context: torch.Tensor, | 
|  | 307 | +        mask: Optional[torch.Tensor] = None, | 
|  | 308 | +        target: Optional[torch.Tensor] = None, | 
|  | 309 | +        target_mask: Optional[torch.Tensor] = None, | 
|  | 310 | +    ) -> ChronosBoltOutput: | 
|  | 311 | +        batch_size = context.size(0) | 
|  | 312 | + | 
|  | 313 | +        hidden_states, loc_scale, input_embeds, attention_mask = self.encode( | 
|  | 314 | +            context=context, mask=mask | 
|  | 315 | +        ) | 
| 306 | 316 |         sequence_output = self.decode(input_embeds, attention_mask, hidden_states) | 
| 307 | 317 | 
 | 
| 308 | 318 |         quantile_preds_shape = ( | 
| @@ -426,6 +436,46 @@ def __init__(self, model: ChronosBoltModelForForecasting): | 
| 426 | 436 |     def quantiles(self) -> List[float]: | 
| 427 | 437 |         return self.model.config.chronos_config["quantiles"] | 
| 428 | 438 | 
 | 
|  | 439 | +    @torch.no_grad() | 
|  | 440 | +    def embed( | 
|  | 441 | +        self, context: Union[torch.Tensor, List[torch.Tensor]] | 
|  | 442 | +    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | 
|  | 443 | +        """ | 
|  | 444 | +        Get encoder embeddings for the given time series. | 
|  | 445 | +
 | 
|  | 446 | +        Parameters | 
|  | 447 | +        ---------- | 
|  | 448 | +        context | 
|  | 449 | +            Input series. This is either a 1D tensor, or a list | 
|  | 450 | +            of 1D tensors, or a 2D tensor whose first dimension | 
|  | 451 | +            is batch. In the latter case, use left-padding with | 
|  | 452 | +            ``torch.nan`` to align series of different lengths. | 
|  | 453 | +
 | 
|  | 454 | +        Returns | 
|  | 455 | +        ------- | 
|  | 456 | +        embeddings, loc_scale | 
|  | 457 | +            A tuple of two items: the encoder embeddings and the loc_scale, | 
|  | 458 | +            i.e., the mean and std of the original time series. | 
|  | 459 | +            The encoder embeddings are shaped (batch_size, num_patches + 1, d_model), | 
|  | 460 | +            where num_patches is the number of patches in the time series | 
|  | 461 | +            and the extra 1 is for the [REG] token (if used by the model). | 
|  | 462 | +        """ | 
|  | 463 | +        context_tensor = self._prepare_and_validate_context(context=context) | 
|  | 464 | +        model_context_length = self.model.config.chronos_config["context_length"] | 
|  | 465 | + | 
|  | 466 | +        if context_tensor.shape[-1] > model_context_length: | 
|  | 467 | +            context_tensor = context_tensor[..., -model_context_length:] | 
|  | 468 | + | 
|  | 469 | +        context_tensor = context_tensor.to( | 
|  | 470 | +            device=self.model.device, | 
|  | 471 | +            dtype=torch.float32, | 
|  | 472 | +        ) | 
|  | 473 | +        embeddings, loc_scale, *_ = self.model.encode(context=context_tensor) | 
|  | 474 | +        return embeddings.cpu(), ( | 
|  | 475 | +            loc_scale[0].squeeze(-1).cpu(), | 
|  | 476 | +            loc_scale[1].squeeze(-1).cpu(), | 
|  | 477 | +        ) | 
|  | 478 | + | 
| 429 | 479 |     def predict(  # type: ignore[override] | 
| 430 | 480 |         self, | 
| 431 | 481 |         context: Union[torch.Tensor, List[torch.Tensor]], | 
|  | 
0 commit comments