File tree Expand file tree Collapse file tree 1 file changed +8
-2
lines changed Expand file tree Collapse file tree 1 file changed +8
-2
lines changed Original file line number Diff line number Diff line change @@ -91,8 +91,6 @@ def apply_impl(self, input_):
91
91
# Replace layer.forward to customize the layer computation process.
92
92
def apply (self , input_ ):
93
93
output , output_bias = self .apply_impl (input_ )
94
- if dense_optim_enable ():
95
- torch .ops .vllm .maybe_prefetch_mlp_gate_up_proj (output , self .prefix )
96
94
if not self .return_bias :
97
95
return output
98
96
return output , output_bias
@@ -123,6 +121,14 @@ def update_attrs(self):
123
121
self .reduce_results = self .layer .reduce_results
124
122
self .input_size_per_partition = self .layer .input_size_per_partition
125
123
124
+ def apply (self , input_ ):
125
+ output , output_bias = self .apply_impl (input_ )
126
+ if dense_optim_enable ():
127
+ torch .ops .vllm .maybe_prefetch_mlp_gate_up_proj (output , self .prefix )
128
+ if not self .return_bias :
129
+ return output
130
+ return output , output_bias
131
+
126
132
127
133
class MLPColumnParallelOp (CustomColumnParallelOp ):
128
134
You can’t perform that action at this time.
0 commit comments