File tree Expand file tree Collapse file tree 2 files changed +10
-0
lines changed
aiu_fms_testing_utils/utils Expand file tree Collapse file tree 2 files changed +10
-0
lines changed Original file line number Diff line number Diff line change @@ -45,6 +45,14 @@ def wrap_encoder(model: nn.Module) -> HFModelArchitecture:
45
45
model .config .linear_config .pop ("linear_type" , None )
46
46
return to_hf_api (model , task_specific_params = None )
47
47
48
+ def move_to_device (batch : dict , device : torch .device ) -> dict :
49
+ """Move batch to selected device."""
50
+
51
+ batch_on_device = {}
52
+ for k , v in batch .items ():
53
+ batch_on_device [k ] = v .to (device )
54
+ return batch_on_device
55
+
48
56
49
57
class EncoderQAInfer ():
50
58
"""Run QuestionAnswering task with encoder models."""
@@ -587,6 +595,7 @@ def run_evaluation(self) -> None:
587
595
with torch .no_grad ():
588
596
dprint (f"Step { step + 1 } / { len (eval_dataloader )} " )
589
597
batch = self .convert_batch_to_fms_style (batch )
598
+ batch = move_to_device (batch , args .device )
590
599
start_logits , end_logits = self .model (** batch )
591
600
all_start_logits .append (start_logits .cpu ().numpy ())
592
601
all_end_logits .append (end_logits .cpu ().numpy ())
Original file line number Diff line number Diff line change 36
36
37
37
# Main model setup
38
38
default_dtype , device , dist_strat = setup_model (args )
39
+ args .device = device
39
40
40
41
# Retrieve linear configuration (quantized or not) to instantiate FMS model
41
42
linear_config = get_linear_config (args )
You can’t perform that action at this time.
0 commit comments