@@ -84,12 +84,7 @@ async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[
84
84
break # this message means "done sending"
85
85
86
86
def step (
87
- self ,
88
- inputs : torch .Tensor ,
89
- prompts : Optional [torch .Tensor ] = None ,
90
- hypo_ids : Optional [torch .Tensor ] = None ,
91
- * ,
92
- step_id : str ,
87
+ self , inputs : torch .Tensor , prompts : torch .Tensor , hypo_ids : torch .LongTensor , * , step_id : str
93
88
) -> torch .Tensor :
94
89
"""
95
90
Inference step: send a chunk of input tensors and receive a chunk of outputs
@@ -114,21 +109,6 @@ def step(
114
109
else :
115
110
inputs = inputs [:, - n_input_tokens :] # No need to pass prefix further
116
111
117
- if prompts is None or is_dummy (prompts ):
118
- prompts = DUMMY
119
- else :
120
- assert prompts .ndim == 4 , "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]"
121
- assert prompts .shape [0 ] == self .num_blocks
122
- assert prompts .shape [1 ] in (inputs .shape [0 ], 1 )
123
- assert prompts .shape [2 ] <= inputs .shape [1 ]
124
- assert prompts .shape [3 ] == inputs .shape [2 ]
125
-
126
- if hypo_ids is None or is_dummy (hypo_ids ):
127
- hypo_ids = DUMMY_INT64
128
- else :
129
- assert len (hypo_ids ) == len (inputs )
130
- assert hypo_ids .dtype == torch .int64
131
-
132
112
# serialize inputs and put them into the queue
133
113
input_tensors , args_structure = pack_args_kwargs (inputs , prompts , hypo_ids )
134
114
@@ -275,7 +255,9 @@ def __enter__(self) -> "InferenceSession":
275
255
assert not self ._closed and not self ._server_sessions
276
256
return self
277
257
278
- def step (self , inputs : torch .Tensor , prompts : Optional [torch .Tensor ] = None , ** kwargs ) -> torch .Tensor :
258
+ def step (
259
+ self , inputs : torch .Tensor , prompts : Optional [torch .Tensor ] = None , hypo_ids : Optional [torch .Tensor ] = None
260
+ ) -> torch .Tensor :
279
261
assert not self ._closed
280
262
if torch .is_grad_enabled ():
281
263
logger .warning ("Running inference session with grad enabled. Gradients will *not* be propagated correctly." )
@@ -285,11 +267,21 @@ def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **k
285
267
else :
286
268
assert prompts .ndim == 4 , "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]"
287
269
assert prompts .shape [0 ] == self .num_blocks
270
+ assert prompts .shape [1 ] in (inputs .shape [0 ], 1 )
271
+ assert prompts .shape [2 ] <= inputs .shape [1 ]
272
+ assert prompts .shape [3 ] == inputs .shape [2 ]
273
+
274
+ if hypo_ids is None or is_dummy (hypo_ids ):
275
+ hypo_ids = DUMMY_INT64
276
+ else :
277
+ assert len (hypo_ids ) == len (inputs )
278
+ assert hypo_ids .dtype == torch .int64
288
279
289
280
inputs_device = inputs .device
290
281
inputs_dtype = inputs .dtype
291
282
inputs = inputs .cpu ()
292
283
prompts = prompts .cpu ()
284
+ hypo_ids = hypo_ids .cpu ()
293
285
step_id = str (uuid .uuid4 ())
294
286
295
287
n_input_tokens = inputs .shape [1 ]
@@ -310,7 +302,7 @@ def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **k
310
302
311
303
server_session = self ._server_sessions [server_idx ]
312
304
inputs = server_session .step (
313
- inputs , prompts [server_session .span .start : server_session .span .end ], step_id = step_id , ** kwargs
305
+ inputs , prompts [server_session .span .start : server_session .span .end ], hypo_ids , step_id = step_id
314
306
)
315
307
316
308
server_idx += 1
0 commit comments