Skip to content

在n卡上跑example中的extformer_moe_enso会报错(External) CUDA error(719), unspecified launch failure #1218

@FelixFinnDu

Description

@FelixFinnDu

bug 描述 bug description

在windows和linux下面都会报错:
部分报错信息:

Error: /paddle/paddle/phi/kernels/funcs/gather_scatter_functor.cu:268 Assertion `index >= -src_select_dim_size && index < src_select_dim_size` failed. The index is out of bounds, please check whether the index and input's shape meet the requirements. It should be greater or equal to [-10] and less than [-1], but received [10]
Error executing job with overrides: []
Traceback (most recent call last):
  File "/localnvme/application/duff/workspace/PaddleScience/examples/extformer_moe/extformer_moe_enso_train.py", line 195, in main
    train(cfg)
  File "/localnvme/application/duff/workspace/PaddleScience/examples/extformer_moe/extformer_moe_enso_train.py", line 137, in train
    solver.train()
  File "/localnvme/application/duff/workspace/PaddleScience/ppsci/solver/solver.py", line 593, in train
    self.train_epoch_func(self, epoch_id, self.log_freq)
  File "/localnvme/application/duff/workspace/PaddleScience/ppsci/solver/train.py", line 118, in train_epoch_func
    losses_all, losses_constraint = solver.forward_helper.train_forward(
  File "/localnvme/application/duff/workspace/PaddleScience/ppsci/utils/expression.py", line 96, in train_forward
    output_dict = model(input_dicts[i])
  File "/localnvme/application/duff/anaconda3/envs/paddle/lib/python3.10/site-packages/paddle/nn/layer/layers.py", line 1576, in __call__
    return self.forward(*inputs, **kwargs)
  File "/localnvme/application/duff/workspace/PaddleScience/ppsci/arch/extformer_moe_cuboid.py", line 957, in forward
    mem_l = self.encoder(x)
  File "/localnvme/application/duff/anaconda3/envs/paddle/lib/python3.10/site-packages/paddle/nn/layer/layers.py", line 1576, in __call__
    return self.forward(*inputs, **kwargs)
  File "/localnvme/application/duff/workspace/PaddleScience/ppsci/arch/extformer_moe_cuboid_encoder.py", line 1671, in forward
    x = self.blocks[i](x)
  File "/localnvme/application/duff/anaconda3/envs/paddle/lib/python3.10/site-packages/paddle/nn/layer/layers.py", line 1576, in __call__
    return self.forward(*inputs, **kwargs)
  File "/localnvme/application/duff/anaconda3/envs/paddle/lib/python3.10/site-packages/paddle/nn/layer/container.py", line 769, in forward
    input = layer(input)
  File "/localnvme/application/duff/anaconda3/envs/paddle/lib/python3.10/site-packages/paddle/nn/layer/layers.py", line 1576, in __call__
    return self.forward(*inputs, **kwargs)
  File "/localnvme/application/duff/workspace/PaddleScience/ppsci/arch/extformer_moe_cuboid_encoder.py", line 1351, in forward
    x = ffn(x)
  File "/localnvme/application/duff/anaconda3/envs/paddle/lib/python3.10/site-packages/paddle/nn/layer/layers.py", line 1576, in __call__
    return self.forward(*inputs, **kwargs)
  File "/localnvme/application/duff/workspace/PaddleScience/ppsci/arch/extformer_moe_cuboid_encoder.py", line 1839, in forward
    ) = self.gate(
  File "/localnvme/application/duff/anaconda3/envs/paddle/lib/python3.10/site-packages/paddle/nn/layer/layers.py", line 1576, in __call__
    return self.forward(*inputs, **kwargs)
  File "/localnvme/application/duff/workspace/PaddleScience/ppsci/arch/extformer_moe_utils.py", line 156, in forward
    importance_loss = self.importance_loss_all(
  File "/localnvme/application/duff/workspace/PaddleScience/ppsci/arch/extformer_moe_utils.py", line 80, in importance_loss_all
    importance_loss = self.cv_squared(routing_weights.sum(axis=0))
  File "/localnvme/application/duff/workspace/PaddleScience/ppsci/arch/extformer_moe_utils.py", line 31, in cv_squared
    return x.var(axis=-1) / (x.mean(axis=-1) ** 2 + eps)
  File "/localnvme/application/duff/anaconda3/envs/paddle/lib/python3.10/site-packages/paddle/utils/decorator_utils.py", line 50, in wrapper
    return func(*processed_args, **processed_kwargs)
  File "/localnvme/application/duff/anaconda3/envs/paddle/lib/python3.10/site-packages/paddle/tensor/stat.py", line 256, in var
    if paddle.in_dynamic_mode() and paddle.any(corrected_n <= 0):
  File "/localnvme/application/duff/anaconda3/envs/paddle/lib/python3.10/site-packages/paddle/base/dygraph/tensor_patch_methods.py", line 1016, in __bool__
    return self.__nonzero__()
  File "/localnvme/application/duff/anaconda3/envs/paddle/lib/python3.10/site-packages/paddle/base/dygraph/tensor_patch_methods.py", line 1013, in __nonzero__
    return bool(np.array(self) > 0)
  File "/localnvme/application/duff/anaconda3/envs/paddle/lib/python3.10/site-packages/paddle/base/dygraph/tensor_patch_methods.py", line 1045, in __array__
    array = self.numpy(False)
OSError: (External) CUDA error(719), unspecified launch failure. 
  [Hint: 'cudaErrorLaunchFailure'. An exception occurred on the device while executing a kernel. Common causes include dereferencing an invalid device pointerand accessing out of bounds shared memory. Less common cases can be system specific - more information about these cases canbe found in the system specific user guide. This leaves the process in an inconsistent state and any further CUDA work willreturn the same error. To continue using CUDA, the process must be terminated and relaunched.] (at /paddle/paddle/phi/backends/gpu/cuda/cuda_info.cc:293)


Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.

linux下环境如下:

name: paddle
channels:
  - defaults
dependencies:
  - _libgcc_mutex=0.1=main
  - _openmp_mutex=5.1=1_gnu
  - bzip2=1.0.8=h5eee18b_6
  - ca-certificates=2025.9.9=h06a4308_0
  - expat=2.7.1=h6a678d5_0
  - ld_impl_linux-64=2.40=h12ee557_0
  - libffi=3.4.4=h6a678d5_1
  - libgcc-ng=11.2.0=h1234567_1
  - libgomp=11.2.0=h1234567_1
  - libstdcxx-ng=11.2.0=h1234567_1
  - libuuid=1.41.5=h5eee18b_0
  - libxcb=1.17.0=h9b100fa_0
  - libzlib=1.3.1=hb25bd0a_0
  - ncurses=6.5=h7934f7d_0
  - openssl=3.0.17=h5eee18b_0
  - pip=25.2=pyhc872135_0
  - pthread-stubs=0.3=h0ce48e5_1
  - python=3.10.18=h1a3bd86_0
  - readline=8.3=hc2a1206_0
  - setuptools=78.1.1=py310h06a4308_0
  - sqlite=3.50.2=hb25bd0a_1
  - tk=8.6.15=h54e0aa7_0
  - wheel=0.45.1=py310h06a4308_0
  - xorg-libx11=1.8.12=h9b100fa_1
  - xorg-libxau=1.0.12=h9b100fa_0
  - xorg-libxdmcp=1.1.5=h9b100fa_0
  - xorg-xorgproto=2024.1=h5eee18b_1
  - xz=5.6.4=h5eee18b_1
  - zlib=1.3.1=hb25bd0a_0
  - pip:
      - annotated-types==0.7.0
      - antlr4-python3-runtime==4.9.3
      - anyio==4.10.0
      - argon2-cffi==25.1.0
      - argon2-cffi-bindings==25.1.0
      - arrow==1.3.0
      - asttokens==3.0.0
      - async-lru==2.0.5
      - attrs==25.3.0
      - babel==2.17.0
      - beautifulsoup4==4.13.5
      - bleach==6.2.0
      - certifi==2025.8.3
      - cffi==2.0.0
      - charset-normalizer==3.4.3
      - colorlog==6.9.0
      - comm==0.2.3
      - contourpy==1.3.2
      - cycler==0.12.1
      - debugpy==1.8.17
      - decorator==5.2.1
      - defusedxml==0.7.1
      - einops==0.8.1
      - exceptiongroup==1.3.0
      - executing==2.2.1
      - fastjsonschema==2.21.2
      - fonttools==4.60.0
      - fqdn==1.5.1
      - h11==0.16.0
      - h5netcdf==1.6.4
      - h5py==3.14.0
      - httpcore==1.0.9
      - httpx==0.28.1
      - hydra-core==1.3.2
      - idna==3.10
      - imageio==2.37.0
      - ipykernel==6.30.1
      - ipython==8.37.0
      - isoduration==20.11.0
      - jedi==0.19.2
      - jinja2==3.1.6
      - joblib==1.5.2
      - json5==0.12.1
      - jsonpointer==3.0.0
      - jsonschema==4.25.1
      - jsonschema-specifications==2025.9.1
      - jupyter-client==8.6.3
      - jupyter-core==5.8.1
      - jupyter-events==0.12.0
      - jupyter-lsp==2.3.0
      - jupyter-server==2.17.0
      - jupyter-server-terminals==0.5.3
      - jupyterlab==4.4.7
      - jupyterlab-pygments==0.3.0
      - jupyterlab-server==2.27.3
      - kiwisolver==1.4.9
      - lark==1.2.2
      - markdown-it-py==4.0.0
      - markupsafe==3.0.2
      - matplotlib==3.10.6
      - matplotlib-inline==0.1.7
      - mdurl==0.1.2
      - meshio==5.3.4
      - mistune==3.1.4
      - mpmath==1.3.0
      - nbclient==0.10.2
      - nbconvert==7.16.6
      - nbformat==5.10.4
      - nest-asyncio==1.6.0
      - networkx==3.4.2
      - notebook==7.4.5
      - notebook-shim==0.2.4
      - numpy==1.26.4
      - nvidia-cublas-cu12==12.9.0.13
      - nvidia-cuda-cccl-cu12==12.9.27
      - nvidia-cuda-cupti-cu12==12.9.19
      - nvidia-cuda-nvrtc-cu12==12.9.41
      - nvidia-cuda-runtime-cu12==12.9.37
      - nvidia-cudnn-cu12==9.9.0.52
      - nvidia-cufft-cu12==11.4.0.6
      - nvidia-cufile-cu12==1.14.0.30
      - nvidia-curand-cu12==10.3.10.19
      - nvidia-cusolver-cu12==11.7.4.40
      - nvidia-cusparse-cu12==12.5.9.5
      - nvidia-cusparselt-cu12==0.7.1
      - nvidia-ml-py==13.580.82
      - nvidia-nccl-cu12==2.27.3
      - nvidia-nvjitlink-cu12==12.9.41
      - nvidia-nvtx-cu12==12.9.19
      - nvitop==1.5.3
      - omegaconf==2.3.0
      - opt-einsum==3.3.0
      - overrides==7.7.0
      - packaging==25.0
      - paddlepaddle-gpu==3.2.0
      - paddlesci==1.4.0.post1.dev70+g0a3547ed9.d20250918
      - pandas==2.3.2
      - pandocfilters==1.5.1
      - parso==0.8.5
      - pexpect==4.9.0
      - pillow==11.3.0
      - platformdirs==4.4.0
      - prometheus-client==0.22.1
      - prompt-toolkit==3.0.52
      - protobuf==6.32.0
      - psutil==7.1.0
      - ptyprocess==0.7.0
      - pure-eval==0.2.3
      - pyaml==25.7.0
      - pycparser==2.23
      - pydantic==2.11.9
      - pydantic-core==2.33.2
      - pyevtk==1.6.0
      - pygments==2.19.2
      - pyparsing==3.2.4
      - python-dateutil==2.9.0.post0
      - python-json-logger==3.3.0
      - pytz==2025.2
      - pyyaml==6.0.2
      - pyzmq==27.1.0
      - referencing==0.36.2
      - requests==2.32.5
      - rfc3339-validator==0.1.4
      - rfc3986-validator==0.1.1
      - rfc3987-syntax==1.1.0
      - rich==14.1.0
      - rpds-py==0.27.1
      - safetensors==0.6.2
      - scikit-learn==1.4.2
      - scikit-optimize==0.10.2
      - scipy==1.15.3
      - seaborn==0.13.2
      - send2trash==1.8.3
      - six==1.17.0
      - sniffio==1.3.1
      - soupsieve==2.8
      - stack-data==0.6.3
      - sympy==1.14.0
      - terminado==0.18.1
      - threadpoolctl==3.6.0
      - tinycss2==1.4.0
      - tomli==2.2.1
      - tornado==6.5.2
      - tqdm==4.67.1
      - traitlets==5.14.3
      - types-python-dateutil==2.9.0.20250822
      - typing-extensions==4.15.0
      - typing-inspection==0.4.1
      - tzdata==2025.2
      - uri-template==1.3.0
      - urllib3==2.5.0
      - wcwidth==0.2.13
      - webcolors==24.11.1
      - webencodings==0.5.1
      - websocket-client==1.8.0
      - wget==3.2
      - xarray==2024.2.0
prefix: /localnvme/application/duff/anaconda3/envs/paddle

非常感谢

其他补充信息 Additional Supplementary Information

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions