Skip to content

Commit e989a23

Browse files
committed
Move batch to correct device at eval
Signed-off-by: Andrea Fasoli <andrea.fasoli@ibm.com>
1 parent fb3f224 commit e989a23

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

aiu_fms_testing_utils/utils/encoders_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,14 @@ def wrap_encoder(model: nn.Module) -> HFModelArchitecture:
4545
model.config.linear_config.pop("linear_type", None)
4646
return to_hf_api(model, task_specific_params=None)
4747

48+
def move_to_device(batch: dict, device: torch.device) -> dict:
49+
"""Move batch to selected device."""
50+
51+
batch_on_device = {}
52+
for k, v in batch.items():
53+
batch_on_device[k] = v.to(device)
54+
return batch_on_device
55+
4856

4957
class EncoderQAInfer():
5058
"""Run QuestionAnswering task with encoder models."""
@@ -587,6 +595,7 @@ def run_evaluation(self) -> None:
587595
with torch.no_grad():
588596
dprint(f"Step {step + 1} / {len(eval_dataloader)}")
589597
batch = self.convert_batch_to_fms_style(batch)
598+
batch = move_to_device(batch, args.device)
590599
start_logits, end_logits = self.model(**batch)
591600
all_start_logits.append(start_logits.cpu().numpy())
592601
all_end_logits.append(end_logits.cpu().numpy())

scripts/run_encoders.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
# Main model setup
3838
default_dtype, device, dist_strat = setup_model(args)
39+
args.device = device
3940

4041
# Retrieve linear configuration (quantized or not) to instantiate FMS model
4142
linear_config = get_linear_config(args)

0 commit comments

Comments
 (0)