Skip to content

Conversation

@Met4physics
Copy link

@Met4physics Met4physics commented Oct 5, 2025

  1. Torch 2.8.0 has been released as a stable version, and the current setup commands will raise errors.
  2. FA3 is incorrectly imported after commit Initial AMD MI300X Support via. AITER #10 . This fix also answers issue flash_attn_3 not imported #19 .
  3. There is a dimensional bug will cause error:
Traceback (most recent call last):
  File "/storage/lmh/reward/flux-fast/run_benchmark.py", line 85, in <module>
    main(args)
    ~~~~^^^^^^
  File "/storage/lmh/reward/flux-fast/run_benchmark.py", line 34, in main
    image = pipeline(
            ~~~~~~~~^
        prompt=args.prompt,
        ^^^^^^^^^^^^^^^^^^^
    ...<2 lines>...
        **_determine_pipe_call_kwargs(args)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    ).images[0]
    ^
  File "/storage/lmh/miniconda3/envs/fastflux/lib/python3.13/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
  File "/storage/lmh/miniconda3/envs/fastflux/lib/python3.13/site-packages/diffusers/pipelines/flux/pipeline_flux.py", line 919, in __call__
    noise_pred = self.transformer(
                 ~~~~~~~~~~~~~~~~^
        hidden_states=latents,
        ^^^^^^^^^^^^^^^^^^^^^^
    ...<7 lines>...
        return_dict=False,
        ^^^^^^^^^^^^^^^^^^
    )[0]
    ^
  File "/storage/lmh/miniconda3/envs/fastflux/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/storage/lmh/miniconda3/envs/fastflux/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "/storage/lmh/miniconda3/envs/fastflux/lib/python3.13/site-packages/diffusers/models/transformers/transformer_flux.py", line 733, in forward
    encoder_hidden_states, hidden_states = block(
                                           ~~~~~^
        hidden_states=hidden_states,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    ...<3 lines>...
        joint_attention_kwargs=joint_attention_kwargs,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/storage/lmh/miniconda3/envs/fastflux/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/storage/lmh/miniconda3/envs/fastflux/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "/storage/lmh/miniconda3/envs/fastflux/lib/python3.13/site-packages/diffusers/models/transformers/transformer_flux.py", line 456, in forward
    attention_outputs = self.attn(
        hidden_states=norm_hidden_states,
    ...<2 lines>...
        **joint_attention_kwargs,
    )
  File "/storage/lmh/miniconda3/envs/fastflux/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/storage/lmh/miniconda3/envs/fastflux/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "/storage/lmh/miniconda3/envs/fastflux/lib/python3.13/site-packages/diffusers/models/transformers/transformer_flux.py", line 343, in forward
    return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
           ~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/storage/lmh/reward/flux-fast/utils/pipeline_utils.py", line 175, in __call__
    value.transpose(1, 2))[0].transpose(1, 2)
                              ~~~~~~~~~^^^^^^
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
  1. Current hard-coded operation will cause dimentional error when using compile+fa3.
    return outputs[0]
Traceback (most recent call last):
  File "/storage/lmh/reward/flux-fast/run_benchmark.py", line 85, in <module>
    main(args)
    ~~~~^^^^^^
  File "/storage/lmh/reward/flux-fast/run_benchmark.py", line 27, in main
    pipeline = load_pipeline(args)
  File "/storage/lmh/reward/flux-fast/utils/pipeline_utils.py", line 507, in load_pipeline
    pipeline = optimize(pipeline, args)
  File "/storage/lmh/reward/flux-fast/utils/pipeline_utils.py", line 479, in optimize
    pipeline = use_compile(pipeline)
  File "/storage/lmh/reward/flux-fast/utils/pipeline_utils.py", line 292, in use_compile
    pipeline(**input_kwargs).images[0]
    ~~~~~~~~^^^^^^^^^^^^^^^^
  File "/storage/lmh/miniconda3/envs/fastflux/lib/python3.13/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
  File "/storage/lmh/miniconda3/envs/fastflux/lib/python3.13/site-packages/diffusers/pipelines/flux/pipeline_flux.py", line 919, in __call__
    noise_pred = self.transformer(
                 ~~~~~~~~~~~~~~~~^
        hidden_states=latents,
        ^^^^^^^^^^^^^^^^^^^^^^
    ...<7 lines>...
        return_dict=False,
        ^^^^^^^^^^^^^^^^^^
    )[0]
    ^
  File "/storage/lmh/miniconda3/envs/fastflux/lib/python3.13/site-packages/torch/_dynamo/eval_frame.py", line 375, in __call__
    return super().__call__(*args, **kwargs)
           ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/storage/lmh/miniconda3/envs/fastflux/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/storage/lmh/miniconda3/envs/fastflux/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "/storage/lmh/miniconda3/envs/fastflux/lib/python3.13/site-packages/torch/_dynamo/eval_frame.py", line 736, in compile_wrapper
    return fn(*args, **kwargs)
  File "/storage/lmh/miniconda3/envs/fastflux/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/storage/lmh/miniconda3/envs/fastflux/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "/storage/lmh/miniconda3/envs/fastflux/lib/python3.13/site-packages/diffusers/models/transformers/transformer_flux.py", line 631, in forward
    def forward(
  File "/storage/lmh/miniconda3/envs/fastflux/lib/python3.13/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
    return fn(*args, **kwargs)
  File "/storage/lmh/miniconda3/envs/fastflux/lib/python3.13/site-packages/torch/_functorch/aot_autograd.py", line 1241, in forward
    return compiled_fn(full_args)
  File "/storage/lmh/miniconda3/envs/fastflux/lib/python3.13/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 384, in runtime_wrapper
    all_outs = call_func_at_runtime_with_args(
        compiled_fn, args, disable_amp=disable_amp, steal_args=True
    )
  File "/storage/lmh/miniconda3/envs/fastflux/lib/python3.13/site-packages/torch/_functorch/_aot_autograd/utils.py", line 126, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
                            ~^^^^^^
  File "/storage/lmh/miniconda3/envs/fastflux/lib/python3.13/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 750, in inner_fn
    outs = compiled_fn(args)
  File "/storage/lmh/miniconda3/envs/fastflux/lib/python3.13/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 556, in wrapper
    return compiled_fn(runtime_args)
  File "/storage/lmh/miniconda3/envs/fastflux/lib/python3.13/site-packages/torch/_inductor/output_code.py", line 584, in __call__
    return self.current_callable(inputs)
           ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^
  File "/storage/lmh/miniconda3/envs/fastflux/lib/python3.13/site-packages/torch/_inductor/utils.py", line 2716, in run
    out = model(new_inputs)
  File "/tmp/torchinductor_lmh/de/cde2egsropmjrue2ojtrrndkfnmoyj5jx2fi5g3lefdapzq2xeo4.py", line 4638, in call
    assert_size_stride(buf43, (1, 4608, 24, 128), (14155776, 3072, 128, 1), 'torch.ops.flash.flash_attn_func.default')
    ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: wrong number of dimensions3 for op: torch.ops.flash.flash_attn_func.default

@Met4physics Met4physics changed the title Update PyTorch installation instructions for NVIDIA and AMD Update PyTorch installation instructions; Fix import and dimentional bugs Oct 6, 2025
@Met4physics Met4physics changed the title Update PyTorch installation instructions; Fix import and dimentional bugs Update PyTorch installation instructions; Fix import and dimensional bugs Oct 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant