diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 55b13ee9b0..82c70670d1 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -312,15 +312,15 @@ defmodule EXLA.Defn do ) do [initial_arg, _arg, pred, body] = args + {initial, cache} = recur_composite(initial_arg, state, cache) + initial = if token = get_token(cache) do - {token, initial_arg} + [token | initial] else - initial_arg + initial end - {initial, cache} = recur_composite(initial, state, cache) - {pred_computation, cache} = mlir_while_computation(pred, initial, {:pred, 8}, state, cache) {body_computation, cache} = mlir_while_computation(body, initial, :with_token, state, cache)