-
Notifications
You must be signed in to change notification settings - Fork 595
Open
Description
File "/home/kojoe/miniconda3/envs/vllm/lib/python3.12/site-packages/gemma/gm/text/_sampler.py", line 311, in sample
init_state = _prefill.prefill(
^^^^^^^^^^^^^^^^^
File "/home/kojoe/miniconda3/envs/vllm/lib/python3.12/site-packages/gemma/gm/text/_prefill.py", line 110, in prefill
out = model.apply(
^^^^^^^^^^^^
File "/home/kojoe/miniconda3/envs/vllm/lib/python3.12/site-packages/kauldron/utils/train_property.py", line 141, in decorated
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/kojoe/miniconda3/envs/vllm/lib/python3.12/site-packages/gemma/gm/utils/_jax_utils.py", line 96, in decorated
output = fn(*bound_args.args, **bound_args.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/kojoe/miniconda3/envs/vllm/lib/python3.12/site-packages/kauldron/typing/type_check.py", line 270, in _reraise_with_shape_info
retval = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/kojoe/miniconda3/envs/vllm/lib/python3.12/site-packages/gemma/gm/nn/_transformer.py", line 247, in __call__
x, new_cache = self._apply_attention(inputs, cache)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/kojoe/miniconda3/envs/vllm/lib/python3.12/site-packages/gemma/gm/nn/_transformer.py", line 292, in _apply_attention
layer_cache, x = block(
^^^^^^
File "/home/kojoe/miniconda3/envs/vllm/lib/python3.12/site-packages/gemma/gm/nn/_modules.py", line 467, in __call__
cache, attn_output = self.attn(
^^^^^^^^^^
File "/home/kojoe/miniconda3/envs/vllm/lib/python3.12/site-packages/gemma/gm/nn/_modules.py", line 277, in __call__
padded_logits = jnp.where((jnp.expand_dims(attn_mask, -2)), logits, K_MASK)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/kojoe/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 2821, in where
return util._where(condition, x, y)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/kojoe/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/_src/numpy/util.py", line 311, in _where
condition, x_arr, y_arr = _broadcast_arrays(condition, x, y)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/kojoe/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/_src/numpy/util.py", line 264, in _broadcast_arrays
result_shape = lax.broadcast_shapes(*shapes)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: Incompatible shapes for broadcasting: shapes=[(1, 1447, 1, 5234), (1, 1447, 8, 4096), ()]
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.```
Metadata
Metadata
Assignees
Labels
No labels