Skip to content

RotaryEncoding does not work for training on wgpu backend #3011

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
Devon7925 opened this issue Apr 13, 2025 · 2 comments
Open

RotaryEncoding does not work for training on wgpu backend #3011

Devon7925 opened this issue Apr 13, 2025 · 2 comments
Labels
bug Something isn't working wgpu Related to WGPU backend

Comments

@Devon7925
Copy link

Describe the bug
When trying to use RotarryEncoding with the wgpu backend in training I crash with an error every time. It fails on relatively small input sizes(context length=512, d_model=256), I did not test on very small inputs. Depending on build situation/build case the error changes but here's an example on latest main without autotune:

wgpu error: Validation Error

Caused by:
  In ComputePass::end
    In a dispatch command, indirect:false
      Each current dispatch group size dimension ([1, 1, 81920]) must be less or equal to 65535

To Reproduce
Try to use RotaryEncoding on wgpu autodiff backend as part of a backwards step.

Desktop (please complete the following information):

  • Windows 10
  • RTX 4070
@laggui laggui added bug Something isn't working wgpu Related to WGPU backend labels Apr 14, 2025
@laggui
Copy link
Member

laggui commented Apr 14, 2025

Can you share a MWE?

Looks like an operation might not be respecting the device limits somehow.

@Devon7925
Copy link
Author

@laggui Here's a MWE.

main.rs

use burn::{
    backend::Wgpu, nn::{loss::CrossEntropyLossConfig, RotaryEncoding, RotaryEncodingConfig}, prelude::*
};

fn main() {
    type B = Wgpu;
    
    let device = <B as Backend>::Device::default();

    let context_length = 512;
    let d_model = 256;
    let n_heads = 4;
    let batch_size = 128;
    let d_k = d_model / n_heads;

    let rope: RotaryEncoding<B> = RotaryEncodingConfig::new(context_length * 2, d_model / n_heads)
        .with_theta(0.001)
        .init(&device);

    let data_test_x = Tensor::<_, 4, Float>::from_data(
        TensorData::new(vec![0.0; batch_size * n_heads * context_length * d_k], Shape::new([batch_size, n_heads, context_length, d_k])),
        &device,
    );
    let data_test_y = Tensor::<_, 1, Int>::from_data(
        TensorData::new(vec![0; batch_size * context_length], Shape::new([batch_size * context_length])),
        &device,
    );

    let x = rope.forward(data_test_x);
    let x = x.swap_dims(1, 2).reshape([batch_size * context_length, n_heads * d_k]);

    CrossEntropyLossConfig::new()
        .init(&x.device())
        .forward(x, data_test_y);
}

Cargo.toml

[package]
name = "gpt-burn"
version = "0.1.0"
edition = "2024"

[dependencies]
burn = { git="https://github.com/Tracel-AI/burn", version = "0.17", features = [
    "fusion",
    "ndarray",
    "train",
    "vision",
    "wgpu",
    "metrics",
] }

Error

thread 'main' panicked at C:\Users\impor\.cargo\registry\src\index.crates.io-1949cf8c6b5b557f\wgpu-25.0.0\src\backend\wgpu_core.rs:2879:26:
wgpu error: Validation Error

Caused by:
  In ComputePass::end
    In a dispatch command, indirect:false
      Each current dispatch group size dimension ([1, 1, 262144]) must be less or equal to 65535


stack backtrace:
   0: std::panicking::begin_panic_handler
             at /rustc/4d91de4e48198da2e33413efdcd9cd2cc0c46688/library\std\src\panicking.rs:692
   1: core::panicking::panic_fmt
             at /rustc/4d91de4e48198da2e33413efdcd9cd2cc0c46688/library\core\src\panicking.rs:75
   2: wgpu::backend::wgpu_core::ContextWgpuCore::handle_error_inner
   3: <wgpu::backend::wgpu_core::CoreComputePass as wgpu::dispatch::ComputePassInterface>::end
   4: core::ptr::drop_in_place<core::option::Option<wgpu::api::compute_pass::ComputePass>>
   5: cubecl_wgpu::compute::stream::WgpuStream::flush
   6: cubecl_wgpu::compute::stream::WgpuStream::start_profile
   7: cubecl_runtime::client::ComputeClient<Server,Channel>::profile
   8: <alloc::vec::Vec<T> as alloc::vec::spec_from_iter::SpecFromIter<T,I>>::from_iter
   9: cubecl_runtime::tune::tune_benchmark::TuneBenchmark<S,C,In,Out>::profile
  10: cubecl_runtime::tune::local::LocalTuner<AK,ID>::execute
  11: burn_cubecl::kernel::matmul::tune::base::matmul_autotune
  12: ZN11burn_cubecl3ops9float_ops197_$LT$impl$u20$burn_tensor..tensor..ops..tensor..FloatTensorOps$LT$burn_cubecl..backend..CubeBackend$LT$R$C$F$C$I$C$BT$GT$$GT$$u20$for$u20$burn_cubecl..backend..CubeBackend$LT$R$C$F$C$I$C$BT$GT$$GT$12float_matmul17h4b09ad50c3
  13: ZN366_$LT$burn_fusion..ops..float..$LT$impl$u20$burn_tensor..tensor..ops..tensor..FloatTensorOps$LT$burn_fusion..backend..Fusion$LT$B$GT$$GT$$u20$for$u20$burn_fusion..backend..Fusion$LT$B$GT$$GT$..float_matmul..MatmulOps$LT$B$GT$$u20$as$u20$burn_fusion..st
  14: burn_fusion::stream::execution::base::<impl burn_fusion::stream::base::OperationQueue<R>>::execute
  15: burn_fusion::stream::execution::processor::Processor<O>::process
  16: burn_fusion::stream::multi::MultiStream<R>::register
  17: <burn_fusion::client::mutex::MutexFusionClient<R> as burn_fusion::client::base::FusionClient<R>>::register
  18: burn_fusion::ops::float::<impl burn_tensor::tensor::ops::tensor::FloatTensorOps<burn_fusion::backend::Fusion<B>> for burn_fusion::backend::Fusion<B>>::float_reshape
  19: burn_tensor::tensor::api::base::Tensor<B,_,K>::reshape
  20: hashbrown::raw::inner::RawTableInner::drop_inner_table
note: Some details are omitted, run with `RUST_BACKTRACE=full` for a verbose backtrace.
error: process didn't exit successfully: `target\release\gpt-burn.exe` (exit code: 101)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working wgpu Related to WGPU backend
Projects
None yet
Development

No branches or pull requests

2 participants