Skip to content

ValueError: Incompatible shapes for broadcasting #407

@radna0

Description

@radna0
  File "/home/kojoe/miniconda3/envs/vllm/lib/python3.12/site-packages/gemma/gm/text/_sampler.py", line 311, in sample
    init_state = _prefill.prefill(
                 ^^^^^^^^^^^^^^^^^
  File "/home/kojoe/miniconda3/envs/vllm/lib/python3.12/site-packages/gemma/gm/text/_prefill.py", line 110, in prefill
    out = model.apply(
          ^^^^^^^^^^^^
  File "/home/kojoe/miniconda3/envs/vllm/lib/python3.12/site-packages/kauldron/utils/train_property.py", line 141, in decorated
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/kojoe/miniconda3/envs/vllm/lib/python3.12/site-packages/gemma/gm/utils/_jax_utils.py", line 96, in decorated
    output = fn(*bound_args.args, **bound_args.kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/kojoe/miniconda3/envs/vllm/lib/python3.12/site-packages/kauldron/typing/type_check.py", line 270, in _reraise_with_shape_info
    retval = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/kojoe/miniconda3/envs/vllm/lib/python3.12/site-packages/gemma/gm/nn/_transformer.py", line 247, in __call__
    x, new_cache = self._apply_attention(inputs, cache)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/kojoe/miniconda3/envs/vllm/lib/python3.12/site-packages/gemma/gm/nn/_transformer.py", line 292, in _apply_attention
    layer_cache, x = block(
                     ^^^^^^
  File "/home/kojoe/miniconda3/envs/vllm/lib/python3.12/site-packages/gemma/gm/nn/_modules.py", line 467, in __call__
    cache, attn_output = self.attn(
                         ^^^^^^^^^^
  File "/home/kojoe/miniconda3/envs/vllm/lib/python3.12/site-packages/gemma/gm/nn/_modules.py", line 277, in __call__
    padded_logits = jnp.where((jnp.expand_dims(attn_mask, -2)), logits, K_MASK)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/kojoe/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 2821, in where
    return util._where(condition, x, y)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/kojoe/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/_src/numpy/util.py", line 311, in _where
    condition, x_arr, y_arr = _broadcast_arrays(condition, x, y)
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/kojoe/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/_src/numpy/util.py", line 264, in _broadcast_arrays
    result_shape = lax.broadcast_shapes(*shapes)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: Incompatible shapes for broadcasting: shapes=[(1, 1447, 1, 5234), (1, 1447, 8, 4096), ()]
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.```

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions