Skip to content

Conversation

farawayboat
Copy link
Contributor

@farawayboat farawayboat commented May 21, 2025

What this PR does / why we need it?

  • This PR introduces changes to adapt the pos_encoding_kernels.cpp, utils.h, attention.py, layernorm.py, platform.py, and utils.py files to support Ascend 310P devices.
  • Specifically, it adjusts the loadSize constant based on the Ascend AI Core version and adds conditional compilation directives for bfloat16_t support.
  • It also includes modifications to handle specific behaviors required for the Ascend 310P, such as tensor alignment and format casting.
  • The purpose of these changes is to ensure compatibility and optimal performance on Ascend 310P devices.

Does this PR introduce any user-facing change?

  • Yes, this PR introduces changes that affect the behavior of the library when running on Ascend 310P devices.
  • Users of Ascend 310P will see improved performance and compatibility due to the added support and optimizations.

How was this patch tested?

The patch has been tested locally on Ascend 310P hardware to ensure that the changes do not break existing functionality and that the new features work as intended.

ENV information

npu-smi info
+--------------------------------------------------------------------------------------------------------+
| npu-smi 24.1.0.1                                 Version: 24.1.0.1                                     |
+-------------------------------+-----------------+------------------------------------------------------+
| NPU     Name                  | Health          | Power(W)     Temp(C)           Hugepages-Usage(page) |
| Chip    Device                | Bus-Id          | AICore(%)    Memory-Usage(MB)                        |
+===============================+=================+======================================================+
| 1536    310P3                 | OK              | NA           65                0     / 0             |
| 0       0                     | 0000:06:00.0    | 0            1524 / 44280                            |
+-------------------------------+-----------------+------------------------------------------------------+
| 1536    310P3                 | OK              | NA           64                17452 / 17452         |
| 1       1                     | 0000:06:00.0    | 0            36314/ 43693                            |
+===============================+=================+======================================================+
| 1792    310P3                 | OK              | NA           72                17636 / 17636         |
| 0       2                     | 0000:07:00.0    | 95           36982/ 44280                            |
+-------------------------------+-----------------+------------------------------------------------------+
| 1792    310P3                 | OK              | NA           68                0     / 0             |
| 1       3                     | 0000:07:00.0    | 0            1216 / 43693                            |
+===============================+=================+======================================================+
| 2048    310P3                 | OK              | NA           60                0     / 0             |
| 0       4                     | 0000:08:00.0    | 0            1411 / 44280                            |
+-------------------------------+-----------------+------------------------------------------------------+
| 2048    310P3                 | OK              | NA           57                0     / 0             |
| 1       5                     | 0000:08:00.0    | 0            1494 / 43693                            |
+===============================+=================+======================================================+
| 2304    310P3                 | OK              | NA           53                18200 / 18200         |
| 0       6                     | 0000:09:00.0    | 0            37812/ 44280                            |
+-------------------------------+-----------------+------------------------------------------------------+
| 2304    310P3                 | OK              | NA           49                18194 / 18194         |
| 1       7                     | 0000:09:00.0    | 0            37958/ 43693                            |
+===============================+=================+======================================================+

CANN, NNAL version: 8.1.RC1

Important

Because the current PTA 2.5.1 version cannot pass parameters in the NZ format as required when calling NNAL operators on 310P, we used a temporary debugging version provided by the PTA team for testing.

Code example

Build vllm-ascend from source code
# download source code as vllm-ascend
cd vllm-ascend
export SOC_VERSION=Ascend310P3
pip install -v -e .
cd ..
Run offline inference
from vllm import LLM, SamplingParams
prompts = ["水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。",
           "水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。"]

# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=10)
# Create an LLM.
llm = LLM(
    model="Qwen/Qwen2.5-7B-Instruct",
    max_model_len=4096,
    max_num_seqs=4,
    dtype="float16", # IMPORTANT cause some ATB ops cannot support bf16 on 310P
    disable_custom_all_reduce=True,
    trust_remote_code=True,
    tensor_parallel_size=2,
    compilation_config={"custom_ops":['none', "+rms_norm", "+rotary_embedding"]},
)

# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

return new_tensor


def communication_adaptation_310p():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move patch func to vllm_ascend/patch module

@farawayboat farawayboat force-pushed the feat-atlas-310p branch 4 times, most recently from 7d0b0f6 to fcdb6fc Compare June 3, 2025 06:30
@wangxiyuan wangxiyuan mentioned this pull request Jun 4, 2025
76 tasks
Copy link

github-actions bot commented Jun 4, 2025

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@leo-pony
Copy link
Contributor

leo-pony commented Jun 5, 2025

I test on 310i duo, with cann8.1.RC1_ubuntu22.04_py11, 2卡:
docker image:
docker pull m.daocloud.io/quay.io/ascend/cann:8.1.rc1-310p-ubuntu22.04-py3.11
export DEVICE=
export IMAGE=
docker run
--name vllm_test
--device $DEVICE
--device /dev/davinci_manager
--device /dev/devmm_svm
--device /dev/hisi_hdc
-v /usr/local/dcmi:/usr/local/dcmi
-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi
-v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/
-v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info
-v /etc/ascend_install.info:/etc/ascend_install.info
-v /var/run/docker.sock:/var/run/docker.sock
-it $IMAGE bash

torch-npu package link:
https://pytorch-package.obs.cn-north-4.myhuaweicloud.com/pta/Daily/v2.5.1/20250528.3/pytorch_v2.5.1_py311.tar.gz

Detail as following:
Test pass model: qwen2.5-7b-instruct, qwen2.5-0.5b, qwen3-0.6b, qwen3-4b, qwen3-8B

Test pass snapshot:
qwen2.5-7b-instruct:
image
qwen2.5-0.5b:
image
qwen3-0.6B:
image
qwen3-4B:
image
qwen3-8B:
image

Copy link

github-actions bot commented Jun 9, 2025

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@130B848
Copy link

130B848 commented Jun 13, 2025

I tested this PR on a 310P3 device following steps mentioned by @leo-pony , and found that Qwen series work well using V0 engine.

However, when I tried to use V1 engine (export VLLM_USE_V1=1), an unimplemented exception torch._dynamo.exc.Unsupported: call_id not supported for sourceless TensorVariable is raised.

I am wondering if I missed any dependencies or configurations for V1 engine support on 310P?

My testbed setup:

docker image: docker pull m.daocloud.io/quay.io/ascend/cann:8.1.rc1-310p-ubuntu22.04-py3.11

torch_npu: https://pytorch-package.obs.cn-north-4.myhuaweicloud.com/pta/Daily/v2.5.1/20250528.3/pytorch_v2.5.1_py311.tar.gz

vllm: pip install vllm==0.9.0

vllm-ascend: pip install -v -e . (cloned from https://github.com/farawayboat/vllm-ascend/tree/feat-atlas-310p)

Full Exception Backtrace

Process EngineCore_0:                                                                                                                                              [467/1930]
Traceback (most recent call last):
  File "/usr/local/python3.11.12/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/local/python3.11.12/lib/python3.11/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/vllm/v1/engine/core.py", line 504, in run_engine_core
    raise e
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/vllm/v1/engine/core.py", line 491, in run_engine_core
    engine_core = EngineCoreProc(*args, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/vllm/v1/engine/core.py", line 390, in __init__
    super().__init__(vllm_config, executor_class, log_stats,
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/vllm/v1/engine/core.py", line 78, in __init__
    self._initialize_kv_caches(vllm_config)
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/vllm/v1/engine/core.py", line 137, in _initialize_kv_caches
    available_gpu_memory = self.model_executor.determine_available_memory()
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/vllm/v1/executor/abstract.py", line 75, in determine_available_memory
    output = self.collective_rpc("determine_available_memory")
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/vllm/executor/uniproc_executor.py", line 56, in collective_rpc
    answer = run_method(self.driver_worker, method, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/vllm/utils.py", line 2605, in run_method
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm-ascend/vllm_ascend/worker/worker_v1.py", line 139, in determine_available_memory
    self.model_runner.profile_run()
  File "/workspace/vllm-ascend/vllm_ascend/worker/model_runner_v1.py", line 1241, in profile_run
    hidden_states = self._dummy_run(self.max_num_tokens)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/vllm-ascend/vllm_ascend/worker/model_runner_v1.py", line 1211, in _dummy_run
    hidden_states = model(
                    ^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl                                                [426/1930]
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/vllm/model_executor/models/qwen2.py", line 481, in forward
    hidden_states = self.model(input_ids, positions, intermediate_tensors,
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/vllm/compilation/decorators.py", line 238, in __call__
    output = self.compiled_callable(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1269, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 526, in __call__
    return _compile(
           ^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 924, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 666, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 699, in _compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 219, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 634, in transform
    tracer.run()
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2796, in run
    super().run()
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
          ^^^^^^^^^^^                                                                                                                                              [385/1930]
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2279, in CALL
    self._call(inst)
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2273, in _call
    self.call_function(fn, args, kwargs)
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/variables/nn_module.py", line 899, in call_function
    return variables.UserFunctionVariable(fn, source=source).call_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/vllm/compilation/decorators.py", line 234, in patched_inline_call
    return inline_call(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_
    tracer.run()
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
          ^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2279, in CALL                                               [344/1930]
    self._call(inst)
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2273, in _call
    self.call_function(fn, args, kwargs)
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/variables/lazy.py", line 156, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/variables/nn_module.py", line 899, in call_function
    return variables.UserFunctionVariable(fn, source=source).call_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/vllm/compilation/decorators.py", line 234, in patched_inline_call
    return inline_call(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_
    tracer.run()
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
          ^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2279, in CALL
    self._call(inst)
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2273, in _call
    self.call_function(fn, args, kwargs)                                                                                                                           [303/1930]
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/variables/lazy.py", line 156, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/variables/nn_module.py", line 899, in call_function
    return variables.UserFunctionVariable(fn, source=source).call_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/vllm/compilation/decorators.py", line 234, in patched_inline_call
    return inline_call(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_
    tracer.run()
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
          ^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2279, in CALL
    self._call(inst)
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2273, in _call
    self.call_function(fn, args, kwargs)
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 385, in call_function                                    [261/1930]
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/vllm/compilation/decorators.py", line 234, in patched_inline_call
    return inline_call(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_
    tracer.run()
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
          ^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2279, in CALL
    self._call(inst)
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2273, in _call
    self.call_function(fn, args, kwargs)
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/variables/builtin.py", line 967, in call_function
    return handler(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/variables/builtin.py", line 839, in builtin_dispatch
    rv = handler(tx, args, kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/variables/builtin.py", line 766, in call_self_handler
    result = self_handler(tx, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/variables/builtin.py", line 1973, in call_id
    return tensor_variable.call_id(tx)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py", line 449, in call_id
    unimplemented("call_id not supported for sourceless TensorVariable")
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/torch/_dynamo/exc.py", line 297, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: call_id not supported for sourceless TensorVariable

from user code:
   File "/usr/local/python3.11.12/lib/python3.11/site-packages/vllm/model_executor/models/qwen2.py", line 358, in forward
    hidden_states, residual = layer(
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/vllm/model_executor/models/qwen2.py", line 257, in forward
    hidden_states = self.self_attn(
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/vllm/model_executor/models/qwen2.py", line 187, in forward
    attn_output = self.attn(q, k, v)
  File "/usr/local/python3.11.12/lib/python3.11/site-packages/vllm/attention/layer.py", line 216, in forward
    self.impl.forward(self,
  File "/workspace/vllm-ascend/vllm_ascend/attention/attention_v1.py", line 410, in forward
    if not id(ori_output) == id(output):


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

@Yikun Yikun mentioned this pull request Jun 20, 2025
15 tasks
@Yikun Yikun changed the title [WIP][Platform] Add support for Ascend 310P [Platform] Add support for Ascend 310P Jun 20, 2025


if is_310p():
logger.info("patch torch communication while using Ascend 310P")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to add the log for patch

if (loopCnt)
AscendC::Copy(dst, src, loadSize, loopCnt, {1, 1, 8, 8});
AscendC::Copy(dst[loopCnt * loadSize], src[loopCnt * loadSize], tailSize, 1, {1, 1, 8, 8});
__aicore__ inline void local_mem_copy(AscendC::LocalTensor<scalar_t> dst,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated change. You can add this file to ignore list https://github.com/vllm-project/vllm-ascend/blob/main/format.sh#L275

@farawayboat farawayboat force-pushed the feat-atlas-310p branch 2 times, most recently from 07aacec to 6507f6d Compare June 20, 2025 08:31
@Yikun Yikun mentioned this pull request Jun 20, 2025
29 tasks
Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Signed-off-by: Vincent Yuan <farawayboat@gmail.com>
Signed-off-by: Vincent Yuan <farawayboat@gmail.com>
Signed-off-by: Vincent Yuan <farawayboat@gmail.com>
Signed-off-by: Vincent Yuan <farawayboat@gmail.com>
Signed-off-by: Vincent Yuan <farawayboat@gmail.com>
Signed-off-by: Vincent Yuan <farawayboat@gmail.com>
Yikun added a commit that referenced this pull request Jun 21, 2025
…1333)

### What this PR does / why we need it?
Add initial experimental support for Ascend 310P, this patch squash
below PR into one to help validation:

- #914
- #1318
- #1327


### Does this PR introduce _any_ user-facing change?
User can run vLLM on Altlas 300I DUO series

### How was this patch tested?
CI passed with:
- E2E image build for 310P
- CI test on A2 with e2e test and longterm test
- Unit test missing because need a real 310P image to have the test,
will add in a separate PR later.
- Manually e2e test:
- Qwen2.5-7b-instruct, Qwen2.5-0.5b, Qwen3-0.6B, Qwen3-4B, Qwen3-8B:
#914 (comment)
  - Pangu MGoE 72B


The patch has been tested locally on Ascend 310P hardware to ensure that
the changes do not break existing functionality and that the new
features work as intended.

#### ENV information

CANN, NNAL version: 8.1.RC1
> [!IMPORTANT]  
> PTA 2.5.1 version >= torch_npu-2.5.1.post1.dev20250528 to support NZ
format and calling NNAL operators on 310P

#### Code example

##### Build vllm-ascend from source code

```shell
# download source code as vllm-ascend
cd vllm-ascend
export SOC_VERSION=Ascend310P3
pip install -v -e .
cd ..
```

##### Run offline inference

```python
from vllm import LLM, SamplingParams
prompts = ["水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。",
           "水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。"]

# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=10)
# Create an LLM.
llm = LLM(
    model="Qwen/Qwen2.5-7B-Instruct",
    max_model_len=4096,
    max_num_seqs=4,
    dtype="float16", # IMPORTANT cause some ATB ops cannot support bf16 on 310P
    disable_custom_all_reduce=True,
    trust_remote_code=True,
    tensor_parallel_size=2,
    compilation_config={"custom_ops":['none', "+rms_norm", "+rotary_embedding"]},
)

# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

```

---------

Signed-off-by: Vincent Yuan <farawayboat@gmail.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: Vincent Yuan <farawayboat@gmail.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: shen-shanshan <467638484@qq.com>
@Yikun
Copy link
Collaborator

Yikun commented Jun 21, 2025

Many thanks for your contributions, I squashed all 310P related commits with your co-author: #1333

@Yikun Yikun closed this Jun 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants