Skip to content

Commit ad410c9

Browse files
authored
Add pipeline.embed support for Chronos-Bolt (#247)
1 parent 28e7b32 commit ad410c9

File tree

3 files changed

+106
-17
lines changed

3 files changed

+106
-17
lines changed

.github/workflows/eval-model.yml

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ on:
1212
- labeled # When a label is added to the PR
1313

1414
jobs:
15-
evaluate-and-post:
15+
evaluate-and-print:
1616
if: contains(github.event.pull_request.labels.*.name, 'run-eval') # Only run if 'run-eval' label is added
1717
runs-on: ubuntu-latest
1818
env:
@@ -33,10 +33,5 @@ jobs:
3333
- name: Run Eval Script
3434
run: python scripts/evaluation/evaluate.py ci/evaluate/backtest_config.yaml $RESULTS_CSV --chronos-model-id=amazon/chronos-bolt-small --device=cpu --torch-dtype=float32
3535

36-
- name: Upload CSV
37-
uses: actions/upload-artifact@v4
38-
with:
39-
name: eval-metrics
40-
path: ${{ env.RESULTS_CSV }}
41-
retention-days: 1
42-
overwrite: true
36+
- name: Print CSV
37+
run: cat $RESULTS_CSV

src/chronos/chronos_bolt.py

Lines changed: 59 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
from .base import BaseChronosPipeline, ForecastType
2727

28-
2928
logger = logging.getLogger(__file__)
3029

3130

@@ -240,13 +239,11 @@ def _init_weights(self, module):
240239
):
241240
module.output_layer.bias.data.zero_()
242241

243-
def forward(
244-
self,
245-
context: torch.Tensor,
246-
mask: Optional[torch.Tensor] = None,
247-
target: Optional[torch.Tensor] = None,
248-
target_mask: Optional[torch.Tensor] = None,
249-
) -> ChronosBoltOutput:
242+
def encode(
243+
self, context: torch.Tensor, mask: Optional[torch.Tensor] = None
244+
) -> Tuple[
245+
torch.Tensor, Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor
246+
]:
250247
mask = (
251248
mask.to(context.dtype)
252249
if mask is not None
@@ -301,8 +298,21 @@ def forward(
301298
attention_mask=attention_mask,
302299
inputs_embeds=input_embeds,
303300
)
304-
hidden_states = encoder_outputs[0]
305301

302+
return encoder_outputs[0], loc_scale, input_embeds, attention_mask
303+
304+
def forward(
305+
self,
306+
context: torch.Tensor,
307+
mask: Optional[torch.Tensor] = None,
308+
target: Optional[torch.Tensor] = None,
309+
target_mask: Optional[torch.Tensor] = None,
310+
) -> ChronosBoltOutput:
311+
batch_size = context.size(0)
312+
313+
hidden_states, loc_scale, input_embeds, attention_mask = self.encode(
314+
context=context, mask=mask
315+
)
306316
sequence_output = self.decode(input_embeds, attention_mask, hidden_states)
307317

308318
quantile_preds_shape = (
@@ -426,6 +436,46 @@ def __init__(self, model: ChronosBoltModelForForecasting):
426436
def quantiles(self) -> List[float]:
427437
return self.model.config.chronos_config["quantiles"]
428438

439+
@torch.no_grad()
440+
def embed(
441+
self, context: Union[torch.Tensor, List[torch.Tensor]]
442+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
443+
"""
444+
Get encoder embeddings for the given time series.
445+
446+
Parameters
447+
----------
448+
context
449+
Input series. This is either a 1D tensor, or a list
450+
of 1D tensors, or a 2D tensor whose first dimension
451+
is batch. In the latter case, use left-padding with
452+
``torch.nan`` to align series of different lengths.
453+
454+
Returns
455+
-------
456+
embeddings, loc_scale
457+
A tuple of two items: the encoder embeddings and the loc_scale,
458+
i.e., the mean and std of the original time series.
459+
The encoder embeddings are shaped (batch_size, num_patches + 1, d_model),
460+
where num_patches is the number of patches in the time series
461+
and the extra 1 is for the [REG] token (if used by the model).
462+
"""
463+
context_tensor = self._prepare_and_validate_context(context=context)
464+
model_context_length = self.model.config.chronos_config["context_length"]
465+
466+
if context_tensor.shape[-1] > model_context_length:
467+
context_tensor = context_tensor[..., -model_context_length:]
468+
469+
context_tensor = context_tensor.to(
470+
device=self.model.device,
471+
dtype=torch.float32,
472+
)
473+
embeddings, loc_scale, *_ = self.model.encode(context=context_tensor)
474+
return embeddings.cpu(), (
475+
loc_scale[0].squeeze(-1).cpu(),
476+
loc_scale[1].squeeze(-1).cpu(),
477+
)
478+
429479
def predict( # type: ignore[override]
430480
self,
431481
context: Union[torch.Tensor, List[torch.Tensor]],

test/test_chronos_bolt.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,50 @@ def test_pipeline_predict_quantiles(
132132
validate_tensor(mean, (1, prediction_length), dtype=torch.float32)
133133

134134

135+
@pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16])
136+
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16, torch.int64])
137+
def test_pipeline_embed(model_dtype: torch.dtype, input_dtype: torch.dtype):
138+
pipeline = ChronosBoltPipeline.from_pretrained(
139+
Path(__file__).parent / "dummy-chronos-bolt-model",
140+
device_map="cpu",
141+
torch_dtype=model_dtype,
142+
)
143+
d_model = pipeline.model.config.d_model
144+
context = 10 * torch.rand(size=(4, 16)) + 10
145+
context = context.to(dtype=input_dtype)
146+
147+
# the patch size of dummy model is 16, so only 1 patch is created
148+
expected_embed_length = 1 + (
149+
1 if pipeline.model.config.chronos_config["use_reg_token"] else 0
150+
)
151+
152+
# input: tensor of shape (batch_size, context_length)
153+
154+
embedding, loc_scale = pipeline.embed(context)
155+
validate_tensor(
156+
embedding, shape=(4, expected_embed_length, d_model), dtype=model_dtype
157+
)
158+
validate_tensor(loc_scale[0], shape=(4,), dtype=torch.float32)
159+
validate_tensor(loc_scale[1], shape=(4,), dtype=torch.float32)
160+
161+
# input: batch_size-long list of tensors of shape (context_length,)
162+
163+
embedding, loc_scale = pipeline.embed(list(context))
164+
validate_tensor(
165+
embedding, shape=(4, expected_embed_length, d_model), dtype=model_dtype
166+
)
167+
validate_tensor(loc_scale[0], shape=(4,), dtype=torch.float32)
168+
validate_tensor(loc_scale[1], shape=(4,), dtype=torch.float32)
169+
170+
# input: tensor of shape (context_length,)
171+
embedding, loc_scale = pipeline.embed(context[0, ...])
172+
validate_tensor(
173+
embedding, shape=(1, expected_embed_length, d_model), dtype=model_dtype
174+
)
175+
validate_tensor(loc_scale[0], shape=(1,), dtype=torch.float32)
176+
validate_tensor(loc_scale[1], shape=(1,), dtype=torch.float32)
177+
178+
135179
# The following tests have been taken from
136180
# https://github.com/autogluon/autogluon/blob/f57beb26cb769c6e0d484a6af2b89eab8aee73a8/timeseries/tests/unittests/models/chronos/pipeline/test_chronos_bolt.py
137181
# Author: Caner Turkmen <atturkm@amazon.com>

0 commit comments

Comments
 (0)