Skip to content

Conversation

liqiangxl
Copy link

@liqiangxl liqiangxl commented Aug 21, 2025

What does this PR do?

Extend nv_cross_entropy_fwd -> cross_entropy_fwd to support fp16, bfloat16 input and none reduction

@liqiangxl liqiangxl marked this pull request as draft August 21, 2025 00:48
nv_ignore_index = getnv(ignore_index, fd, lc_to_nv_map)

zero_scalar = fd.define_scalar(0, dtype=lcdtype_to_nvdtype(a.dtype))
if a.dtype in (dtypes.float16, dtypes.bfloat16):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit-picking

Suggested change
if a.dtype in (dtypes.float16, dtypes.bfloat16):
if requires_cast := a.dtype in (dtypes.float16, dtypes.bfloat16):

sum_3 = fd.ops.sum(where_1)
div = fd.ops.div(sum_3, sum_2_cvt)

if a.dtype in (dtypes.float16, dtypes.bfloat16):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if a.dtype in (dtypes.float16, dtypes.bfloat16):
if requires_cast:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants