Skip to content

Conversation

Datta0
Copy link
Collaborator

@Datta0 Datta0 commented Oct 6, 2025

del W_deq
return grad_X, None, None

@torch.compile
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check if torch.compile(fullgraph = True, dynamic = True) works better.

Also try using:

from unsloth_zoo.temporary_patches.common import torch_compile_options, torch_compile

@torch_compile
def ...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See if perf changes

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed no performance difference between the three when trying out Qwen3-8B between any of the 3

if weight_fake_quantizer is not None:
W = weight_fake_quantizer(W)

W_quant = next((x for x in [getattr(W, "quant_state", None), getattr(base_layer, "weight_scale_inv", None), getattr(base_layer, "weight_scale", None)] if x is not None), None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very smart

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tbh best to make an if elif to make it faster

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My only worry is someone mistakenly changing when I add if..else cuz if tensor would fail when tensor exists.
one needs to explicitly do if tensor is not None or something like that
I thought this is a safer way to let people continue this/avoid that

But can change if you feel its better that way

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok its fine

Copy link
Contributor

@danielhanchen danielhanchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work

if weight_fake_quantizer is not None:
W = weight_fake_quantizer(W)

W_quant = next((x for x in [getattr(W, "quant_state", None), getattr(base_layer, "weight_scale_inv", None), getattr(base_layer, "weight_scale", None)] if x is not None), None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok its fine

@Datta0 Datta0 changed the title vLLM FP8-E4M3 block quantized support vLLM FP8 quantized support for SFT/GRPO Oct 15, 2025
Copy link
Contributor

@danielhanchen danielhanchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some changes left

@danielhanchen danielhanchen merged commit 092418f into unslothai:main Oct 16, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants