-
Notifications
You must be signed in to change notification settings - Fork 294
[NF4] Support nf4 tensor shard and gather #2449
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
b9a55b5
23b9ebf
aba9245
a54089f
f781ddb
8ccd94f
c84f24f
9fb66c4
a27c4f9
b2a52d6
3ec669e
2b2f768
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,51 @@ | |
c10d_functional = torch.ops.c10d_functional | ||
|
||
|
||
NF4_OPS_TABLE: Dict[Any, Any] = {} | ||
def nf4_all_gather_into_tensor(func, *args, **kwargs): | ||
assert len(args) > 1, "Expected valid input" | ||
assert len(args[0]) == 3, "Expected 3 input args" | ||
nf4tensor = args[0][0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we assert len(args) and len(args[0]) before accessing them? |
||
group_size = args[0][1] | ||
name = args[0][2] | ||
updated_attrs = {} | ||
for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: | ||
updated_attrs[attr] = func(getattr(nf4tensor, attr), group_size, name) | ||
updated_attrs.update( | ||
{ | ||
"size": torch.Size((nf4tensor.size()[0] * group_size, nf4tensor.size()[1])), | ||
} | ||
) | ||
updatedNF4Tensor = NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) | ||
return updatedNF4Tensor | ||
|
||
|
||
def scatter_nf4tensor(func, *args, **kwargs): | ||
assert len(args) > 1, "Expected valid input" | ||
assert len(args[0][0]) == 1, "Expected 1 output tensor" | ||
output_tensor = args[0][0][0] | ||
input_tensors = args[0][1] | ||
new_attr, update_work = [], [] | ||
for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: | ||
input_attrs = [] | ||
if input_tensors: | ||
for input_tensor in input_tensors[0]: | ||
assert input_tensor.size() == output_tensor.size(), ( | ||
"Input tensor size must match output tensor size, tensors are not evenly divided." | ||
) | ||
if hasattr(input_tensor, attr): | ||
input_attrs.append(getattr(input_tensor, attr)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what happens when tensor are not evenly divisible? or is there is possibility for uneven sharding? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if sharded unevenly(did not go through the split dispatch), we will compare the input and output sizes here |
||
input_attrs = [input_attrs] | ||
new_attr, update_work = func( | ||
[getattr(output_tensor, attr)], input_attrs, *args[0][2:] | ||
) | ||
# there are 3 works, return one of them, same as the tensor to fit the required output format | ||
return new_attr, update_work | ||
|
||
|
||
NF4_OPS_TABLE: Dict[Any, Any] = { | ||
torch.ops._c10d_functional.all_gather_into_tensor.default: nf4_all_gather_into_tensor, | ||
torch.ops.c10d.scatter_.default: scatter_nf4tensor, | ||
} | ||
|
||
|
||
_INNER_TENSOR_NAMES_FOR_SHARDING = [ | ||
|
@@ -233,7 +277,6 @@ def nf4_split(aten_op, args, kwargs=None): | |
def nf4_new_zeros(aten_op, args, kwargs=None): | ||
nf4tensor = args[0] | ||
new_size = tuple(args[1]) | ||
|
||
if nf4tensor.numel() % math.prod(new_size) != 0: | ||
raise NotImplementedError(f"aten.new_zeros(NF4Tensor) with new size {new_size}") | ||
ratio = nf4tensor.numel() // math.prod(new_size) | ||
|
@@ -273,19 +316,37 @@ def nf4_slice(aten_op, args, kwargs=None): | |
aten.view.default, | ||
] | ||
) | ||
@expect_args_len_at_k(1, CompareOp.EQ, 1, "aten.view(NF4Tensor) with len(size)=") | ||
@expect_args_len_at_k(1, CompareOp.LT, 3, "aten.view(NF4Tensor) with len(size)=") | ||
def nf4_view(aten_op, args, kwargs=None): | ||
nf4tensor = args[0] | ||
size = args[1] | ||
if size[0] != -1: | ||
raise NotImplementedError(f"aten.view(NF4Tensor) with size={size}") | ||
updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs) | ||
updated_attrs.update( | ||
{ | ||
"size": [nf4tensor.numel()], | ||
"stride": (1,), | ||
} | ||
) | ||
if len(size) == 1: | ||
if size[0] != -1: | ||
raise NotImplementedError(f"aten.view(NF4Tensor) with size={size}") | ||
else: | ||
updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs) | ||
updated_attrs.update( | ||
{ | ||
"size": [nf4tensor.numel()], | ||
"stride": (1,), | ||
} | ||
) | ||
elif len(size) == 2: | ||
if nf4tensor.numel() != size[0] * size[1]: | ||
raise NotImplementedError("NF4Tensor size does not match view size.") | ||
updated_attrs = {} | ||
for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: | ||
attr_size = [getattr(nf4tensor, attr).size()] | ||
updated_attrs[attr] = aten_op( | ||
getattr(nf4tensor, attr), *attr_size, **kwargs | ||
) | ||
updated_attrs.update( | ||
{ | ||
"stride": (size[1], 1), | ||
} | ||
) | ||
else: | ||
raise NotImplementedError("aten.view(NF4Tensor) with empty size") | ||
return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) | ||
|
||
|
||
|
@@ -457,6 +518,20 @@ def nf4_cat(aten_op: torch._ops.OpOverload, args, kwargs=None): | |
return tensors | ||
|
||
|
||
@implements( | ||
[ | ||
torch.ops._c10d_functional.wait_tensor.default, | ||
] | ||
) | ||
def wait_tensor(func, *args, **kwargs): | ||
nf4tensor = args[0][0] | ||
updated_attrs = {} | ||
for attr in _INNER_TENSOR_NAMES_FOR_SHARDING: | ||
updated_attrs[attr] = func(getattr(nf4tensor, attr)) | ||
updatedNF4Tensor = NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) | ||
return updatedNF4Tensor | ||
|
||
|
||
@dataclass(frozen=True) | ||
class SubclassTensorArgs: | ||
original_shape: torch.Size | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why removing (512, 512) ?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tensor=torch.randn(512,512) and tensor.view(512,512) is now valid after changes at nf4_view, move it to
test_tensor_2d_view_valid