-
Notifications
You must be signed in to change notification settings - Fork 107
[torch.library.custom_op
] Add _register_nvfuser_translator
#2481
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
c772160
to
e84169d
Compare
86ab3da
to
6de2392
Compare
the failure seems caused by the "polluted" translation_map & torch->symbol map |
6de2392
to
d6a6b29
Compare
d6a6b29
to
c3e7eb3
Compare
c3e7eb3
to
72ba46c
Compare
thunder.executors.custom_op_ex._override_custom_op_forward
torch.library.custom_op
] Add _register_nvfuser_translator
d4a59a4
to
6ce7b86
Compare
which allows overriding forward of `torch.library.custom_op`'s forward Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
…ranslator` Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
3e0fa38
to
0f88252
Compare
): | ||
if any(sub_bsym.sym.id == _symbol.id for sub_bsym in bsym.subsymbols): | ||
nvfuser_def_for_custom_op_found = True | ||
assert nvfuser_def_for_custom_op_found |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can assert the number of custom op we expect to find in fuser definiton and backward trace
# Register the custom_op of `mul` with :func:`thunder.torch._register_custom_op` | ||
_symbol = _register_custom_op(mul) | ||
# Register custom nvfuser definition for the already registered custom_op of mul | ||
_register_nvfuser_translator(_symbol, mul_translator) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the final API usage proposal (symbol creation and translation registration)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes
from thunder.executors.torchex import _always_executable | ||
|
||
register_supported(symbol, translator_for_nvfuser, checker or _always_executable) | ||
register_supported(symbol.id, translator_for_nvfuser, checker or _always_executable) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you clarify why both calls are needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for most cases the first one should be enough. I don't quite remember the exact reason why I have these two here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! Only some curiosity from my side
What does this PR do?
custom_op_ex
to the default executor list whenthunder.torch.custom_op._register_custom_op
is usedregister_executor(custom_op_ex)
tothunder/executors/custom_op_ex.py
_register_nvfuser_translator
that lets us register nvfuser translator rule for an already registeredtorch.library.custom_op
close #2605
cc @Borda @lantiga