Skip to content

[TORCH] Add Kullback-Leibler divergence loss support #4204

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

zahidwx
Copy link
Contributor

@zahidwx zahidwx commented May 24, 2025

This PR takes care of #4203.

  • e2e support of aten.kl_div op supporting all reduction modes (mean, sum, batchmean, none)
  • reduction: batchmean requires special handling by calling op with sum and then dividing it by input batch_size.

Following tests are failing and are marked either in expected failures or crashing set

  • KlDivLossModule_batchmean_reduction_basic
    • config=linalg | RuntimeError: attribute lookup is not defined on builtin | LINALG_XFAIL_SET
  • KlDivLossModule_default_basic | KlDivLossModule_reduction_is_none_basic | KlDivLossModule_mean_reduction_basic | KlDivLossModule_sum_reduction_basic | KlDivLossModule_batchmean_reduction_basic
    • config=torchdynamo | error: failed to legalize operation 'torch.aten.xlogy.Tensor' | TORCHDYNAMO_CRASHING_SET
  • KlDivLossModule_batchmean_reduction_basic
    • config=onnx | Error: RuntimeError: aten::div() Expected a value of type 'number' for argument 'other' but instead found type 'Tensor' Position: 1
      Value: tensor(1)
      Declaration: aten::div.Scalar(Tensor self, Scalar other) -> Tensor
      Cast error details: Cannot cast tensor(1) to number | ONNX_XFAIL_SET
  • KlDivLossModule_default_basic | KlDivLossModule_reduction_is_none_basic | KlDivLossModule_reduction_is_none_log_target_is_true_basic | KlDivLossModule_mean_reduction_basic | KlDivLossModule_sum_reduction_basic | KlDivLossModule_batchmean_reduction_basic
    • config=onnx_tosa | Error: error: failed to legalize operation 'torch.aten.size.int' that was explicitly marked illegal | ONNX_TOSA_XFAIL_SET

@zahidwx zahidwx marked this pull request as ready for review May 25, 2025 11:11
@zahidwx
Copy link
Contributor Author

zahidwx commented May 26, 2025

Can someone review this please? @vivekkhandelwal1 @penguin-wwy @AmosLewis @rsuderman

zahidwx added 2 commits May 28, 2025 13:37
Signed-off-by: Zahid Wakeel <zahid.wakeel@multicorewareinc.com>
Signed-off-by: Zahid Wakeel <zahid.wakeel@multicorewareinc.com>
@zahidwx zahidwx force-pushed the feature/kl_div_loss branch from fb50c06 to 7fb518d Compare May 28, 2025 08:08
Signed-off-by: Zahid Wakeel <zahid.wakeel@multicorewareinc.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant