Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions mlx_lm/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
"lr_schedule": None,
"lora_parameters": {"rank": 8, "dropout": 0.0, "scale": 20.0},
"mask_prompt": False,
"report_accuracy": False,
"report_to": None,
"project_name": None,
}
Expand Down Expand Up @@ -196,6 +197,12 @@ def build_parser():
default=None,
help="Services to report logs to ('wandb', 'swanlab', or 'wandb,swanlab').",
)
parser.add_argument(
"--report-accuracy",
action="store_true",
help="Display token-level accuracy metrics during training/validation/test reporting",
default=None,
)
parser.add_argument(
"--project-name",
type=str,
Expand Down Expand Up @@ -262,6 +269,7 @@ def train_model(
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint,
grad_accumulation_steps=args.grad_accumulation_steps,
report_accuracy=args.report_accuracy,
)

# Initialize the selected optimizer
Expand Down Expand Up @@ -296,17 +304,28 @@ def train_model(


def evaluate_model(args, model: nn.Module, test_set):
test_loss = evaluate(
result = evaluate(
model=model,
dataset=CacheDataset(test_set),
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
return_accuracy=args.report_accuracy,
)
if args.report_accuracy:
test_loss, test_acc = result
else:
test_loss = result
test_acc = None

test_ppl = math.exp(test_loss)

print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
if test_acc is not None:
print(
f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}, Test acc {(test_acc * 100):.3f}%."
)
else:
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")


def run(args, training_callback: TrainingCallback = None):
Expand Down
132 changes: 107 additions & 25 deletions mlx_lm/tuner/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,34 @@ class TrainingArgs:
"help": "Number of steps to accumulate gradients before applying an optimizer update."
},
)
report_accuracy: bool = field(
default=False,
metadata={"help": "Display token-level accuracy during reporting steps."},
)


def _unpack_loss_output(out, need_accuracy: bool):
"""
Normalize various loss outputs to a common form (loss, ntoks, correct).
Supported outputs:
- (loss, ntoks, logits, targets, mask): preferred; computes accuracy only if requested.
- (loss, ntoks): no accuracy available; returns correct=0.
Any other shape falls back to (loss, 0, 0).
"""
if isinstance(out, tuple):
if len(out) >= 5:
losses, toks, logits, targets, mask = out[:5]
if need_accuracy:
preds = mx.argmax(logits, axis=-1)
correct = ((preds == targets) * mask).sum()
else:
correct = mx.array(0)
return losses, toks, correct
elif len(out) == 2:
losses, toks = out
return losses, toks, mx.array(0)
# Fallback
return out, mx.array(0), mx.array(0)


def default_loss(model, batch, lengths):
Expand All @@ -85,7 +113,8 @@ def default_loss(model, batch, lengths):
ntoks = mask.sum()
ce = ce.astype(mx.float32).sum() / ntoks

return ce, ntoks
# Return intermediates so accuracy can be computed conditionally outside
return ce, ntoks, logits, targets, mask


def iterate_batches(
Expand Down Expand Up @@ -170,10 +199,12 @@ def evaluate(
max_seq_length=2048,
loss: callable = default_loss,
iterate_batches: callable = iterate_batches,
return_accuracy: bool = False,
):
model.eval()
all_losses = mx.array(0.0)
ntokens = mx.array(0)
all_correct = mx.array(0)

index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)

Expand All @@ -190,15 +221,25 @@ def evaluate(
desc="Calculating loss...",
total=min(len(dataset) // batch_size, num_batches),
):
losses, toks = loss(model, *batch)
out = loss(model, *batch)
losses, toks, correct = _unpack_loss_output(out, return_accuracy)
all_losses += losses * toks
ntokens += toks
mx.eval(all_losses, ntokens)
if return_accuracy:
all_correct += correct
mx.eval(all_losses, ntokens, all_correct)
else:
mx.eval(all_losses, ntokens)

all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)

return (all_losses / ntokens).item()
if return_accuracy:
all_correct = mx.distributed.all_sum(all_correct, stream=mx.cpu)
avg_loss = (all_losses / ntokens).item()
acc = (all_correct / ntokens).item()
return avg_loss, acc
else:
return (all_losses / ntokens).item()


def train(
Expand Down Expand Up @@ -233,7 +274,8 @@ def train(

@partial(mx.compile, inputs=state, outputs=state)
def step(batch, prev_grad, do_update):
(lvalue, toks), grad = loss_value_and_grad(model, *batch)
out, grad = loss_value_and_grad(model, *batch)
lvalue, toks, correct = _unpack_loss_output(out, args.report_accuracy)

if prev_grad is not None:
grad = tree_map(lambda x, y: x + y, grad, prev_grad)
Expand All @@ -245,7 +287,7 @@ def step(batch, prev_grad, do_update):
optimizer.update(model, grad)
grad = None

return lvalue, toks, grad
return lvalue, toks, correct, grad

model.train()
losses = 0
Expand All @@ -254,6 +296,7 @@ def step(batch, prev_grad, do_update):
trained_tokens = 0
train_time = 0
grad_accum = None
correct_tokens = mx.array(0)

# Main training loop
for it, batch in zip(
Expand All @@ -271,72 +314,110 @@ def step(batch, prev_grad, do_update):
# is always measured before any training.
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
tic = time.perf_counter()
val_loss = evaluate(
val_result = evaluate(
model=model,
dataset=val_dataset,
loss=loss,
batch_size=args.batch_size,
num_batches=args.val_batches,
max_seq_length=args.max_seq_length,
iterate_batches=iterate_batches,
return_accuracy=args.report_accuracy,
)
if args.report_accuracy:
val_loss, val_acc = val_result
else:
val_loss = val_result
model.train()
val_time = time.perf_counter() - tic
if rank == 0:
print(
f"Iter {it}: "
f"Val loss {val_loss:.3f}, "
f"Val took {val_time:.3f}s",
flush=True,
)
if args.report_accuracy:
print(
f"Iter {it}: "
f"Val loss {val_loss:.3f}, "
f"Val acc {(val_acc * 100):.3f}%, "
f"Val took {val_time:.3f}s",
flush=True,
)
else:
print(
f"Iter {it}: "
f"Val loss {val_loss:.3f}, "
f"Val took {val_time:.3f}s",
flush=True,
)

if training_callback is not None:
val_info = {
"iteration": it - 1,
"val_loss": val_loss,
"val_time": val_time,
}
if args.report_accuracy:
val_info["val_acc"] = val_acc
training_callback.on_val_loss_report(val_info)

tic = time.perf_counter()

lvalue, toks, grad_accum = step(
lvalue, toks, correct, grad_accum = step(
batch,
grad_accum,
it % grad_accum_steps == 0,
)

losses += lvalue
n_tokens += toks
correct_tokens += correct
steps += 1
mx.eval(state, losses, n_tokens, grad_accum)
if args.report_accuracy:
mx.eval(state, losses, n_tokens, correct_tokens, grad_accum)
else:
mx.eval(state, losses, n_tokens, grad_accum)
train_time += time.perf_counter() - tic

# Report training loss if needed
if it % args.steps_per_report == 0 or it == args.iters:
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
train_loss /= steps * world_size
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
if args.report_accuracy:
correct_sum = mx.distributed.all_sum(
correct_tokens, stream=mx.cpu
).item()
train_acc = (correct_sum / n_tokens) if n_tokens > 0 else 0.0
learning_rate = optimizer.learning_rate.item()
it_sec = args.steps_per_report / train_time
tokens_sec = float(n_tokens) / train_time
trained_tokens += n_tokens
peak_mem = mx.get_peak_memory() / 1e9
if rank == 0:
print(
f"Iter {it}: Train loss {train_loss:.3f}, "
f"Learning Rate {learning_rate:.3e}, "
f"It/sec {it_sec:.3f}, "
f"Tokens/sec {tokens_sec:.3f}, "
f"Trained Tokens {trained_tokens}, "
f"Peak mem {peak_mem:.3f} GB",
flush=True,
)
if args.report_accuracy:
print(
f"Iter {it}: Train loss {train_loss:.3f}, "
f"Train acc {(train_acc * 100):.3f}%, "
f"Learning Rate {learning_rate:.3e}, "
f"It/sec {it_sec:.3f}, "
f"Tokens/sec {tokens_sec:.3f}, "
f"Trained Tokens {trained_tokens}, "
f"Peak mem {peak_mem:.3f} GB",
flush=True,
)
else:
print(
f"Iter {it}: Train loss {train_loss:.3f}, "
f"Learning Rate {learning_rate:.3e}, "
f"It/sec {it_sec:.3f}, "
f"Tokens/sec {tokens_sec:.3f}, "
f"Trained Tokens {trained_tokens}, "
f"Peak mem {peak_mem:.3f} GB",
flush=True,
)

if training_callback is not None:
train_info = {
"iteration": it,
"train_loss": train_loss,
**({"train_acc": train_acc} if args.report_accuracy else {}),
"learning_rate": learning_rate,
"iterations_per_second": it_sec,
"tokens_per_second": tokens_sec,
Expand All @@ -349,6 +430,7 @@ def step(batch, prev_grad, do_update):
n_tokens = 0
steps = 0
train_time = 0
correct_tokens = mx.array(0)

# Save adapter weights
if it % args.steps_per_save == 0 and rank == 0:
Expand Down