diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 61bf29224d66..f73fcdea5a6c 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -307,8 +307,7 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, parameter = parameter.contiguous() if self.use_dtensor: parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False) - requires_grad = True if parameter.is_floating_point() else False - return nn.Parameter(parameter, requires_grad=requires_grad) + return nn.Parameter(parameter, requires_grad=parameter.is_floating_point()) @staticmethod def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): @@ -330,8 +329,7 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, parameter = parameter.contiguous() if self.use_dtensor: parameter = DTensor.from_local(parameter, device_mesh, [Shard(-2)], run_check=False) - requires_grad = True if parameter.is_floating_point() else False - return nn.Parameter(parameter, requires_grad=requires_grad) + return nn.Parameter(parameter, requires_grad=parameter.is_floating_point()) class RowwiseParallel(TensorParallelLayer): @@ -383,8 +381,7 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, parameter = parameter.contiguous() if self.use_dtensor: parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False) - requires_grad = True if parameter.is_floating_point() else False - return nn.Parameter(parameter, requires_grad=requires_grad) + return nn.Parameter(parameter, requires_grad=parameter.is_floating_point()) @staticmethod def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): @@ -446,8 +443,7 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, parameter = parameter.contiguous() if self.use_dtensor: parameter = DTensor.from_local(parameter, device_mesh, [Shard(-1)], run_check=False) - requires_grad = True if parameter.is_floating_point() else False - return nn.Parameter(parameter, requires_grad=requires_grad) + return nn.Parameter(parameter, requires_grad=parameter.is_floating_point()) class SequenceParallel(TensorParallelLayer): @@ -531,8 +527,7 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, parameter = parameter.contiguous() if self.use_dtensor: parameter = DTensor.from_local(parameter, device_mesh, [Replicate()], run_check=False) - requires_grad = True if parameter.is_floating_point() else False - return nn.Parameter(parameter, requires_grad=requires_grad) + return nn.Parameter(parameter, requires_grad=parameter.is_floating_point()) SUPPORTED_TP_STYLES = { @@ -671,7 +666,7 @@ def shard_and_distribute_module( # SUPER IMPORTANT we have to use setattr # otherwise loading is crazy slow if not isinstance(param, torch.nn.Parameter): - param = torch.nn.Parameter(param) + param = torch.nn.Parameter(param, requires_grad=param.is_floating_point()) setattr(module_to_tp, param_type, param) # module_to_tp.load_state_dict({param_type: param}, strict=False, assign=True) return param