-
Notifications
You must be signed in to change notification settings - Fork 580
RotaryEncoding
does not work for training on wgpu backend
#3011
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Can you share a MWE? Looks like an operation might not be respecting the device limits somehow. |
@laggui Here's a MWE. main.rs use burn::{
backend::Wgpu, nn::{loss::CrossEntropyLossConfig, RotaryEncoding, RotaryEncodingConfig}, prelude::*
};
fn main() {
type B = Wgpu;
let device = <B as Backend>::Device::default();
let context_length = 512;
let d_model = 256;
let n_heads = 4;
let batch_size = 128;
let d_k = d_model / n_heads;
let rope: RotaryEncoding<B> = RotaryEncodingConfig::new(context_length * 2, d_model / n_heads)
.with_theta(0.001)
.init(&device);
let data_test_x = Tensor::<_, 4, Float>::from_data(
TensorData::new(vec![0.0; batch_size * n_heads * context_length * d_k], Shape::new([batch_size, n_heads, context_length, d_k])),
&device,
);
let data_test_y = Tensor::<_, 1, Int>::from_data(
TensorData::new(vec![0; batch_size * context_length], Shape::new([batch_size * context_length])),
&device,
);
let x = rope.forward(data_test_x);
let x = x.swap_dims(1, 2).reshape([batch_size * context_length, n_heads * d_k]);
CrossEntropyLossConfig::new()
.init(&x.device())
.forward(x, data_test_y);
} Cargo.toml [package]
name = "gpt-burn"
version = "0.1.0"
edition = "2024"
[dependencies]
burn = { git="https://github.com/Tracel-AI/burn", version = "0.17", features = [
"fusion",
"ndarray",
"train",
"vision",
"wgpu",
"metrics",
] } Error
|
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Describe the bug
When trying to use
RotarryEncoding
with the wgpu backend in training I crash with an error every time. It fails on relatively small input sizes(context length=512, d_model=256), I did not test on very small inputs. Depending on build situation/build case the error changes but here's an example on latest main without autotune:To Reproduce
Try to use
RotaryEncoding
on wgpu autodiff backend as part of a backwards step.Desktop (please complete the following information):
The text was updated successfully, but these errors were encountered: