From 7e84fe1a9f681ce738e22632f27148b00d53c256 Mon Sep 17 00:00:00 2001 From: Josh York Date: Wed, 5 Nov 2025 13:04:05 -0500 Subject: [PATCH 1/2] Added accuracy reporting --- mlx_lm/lora.py | 23 ++++++++- mlx_lm/tuner/trainer.py | 110 +++++++++++++++++++++++++++++++--------- 2 files changed, 106 insertions(+), 27 deletions(-) diff --git a/mlx_lm/lora.py b/mlx_lm/lora.py index b70f2e716..01866515c 100644 --- a/mlx_lm/lora.py +++ b/mlx_lm/lora.py @@ -72,6 +72,7 @@ "lr_schedule": None, "lora_parameters": {"rank": 8, "dropout": 0.0, "scale": 20.0}, "mask_prompt": False, + "report_accuracy": False, "report_to": None, "project_name": None, } @@ -196,6 +197,12 @@ def build_parser(): default=None, help="Services to report logs to ('wandb', 'swanlab', or 'wandb,swanlab').", ) + parser.add_argument( + "--report-accuracy", + action="store_true", + help="Display token-level accuracy metrics during training/validation/test reporting", + default=None, + ) parser.add_argument( "--project-name", type=str, @@ -262,6 +269,7 @@ def train_model( max_seq_length=args.max_seq_length, grad_checkpoint=args.grad_checkpoint, grad_accumulation_steps=args.grad_accumulation_steps, + report_accuracy=args.report_accuracy, ) # Initialize the selected optimizer @@ -296,17 +304,28 @@ def train_model( def evaluate_model(args, model: nn.Module, test_set): - test_loss = evaluate( + result = evaluate( model=model, dataset=CacheDataset(test_set), batch_size=args.batch_size, num_batches=args.test_batches, max_seq_length=args.max_seq_length, + return_accuracy=args.report_accuracy, ) + if args.report_accuracy: + test_loss, test_acc = result + else: + test_loss = result + test_acc = None test_ppl = math.exp(test_loss) - print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.") + if test_acc is not None: + print( + f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}, Test acc {(test_acc * 100):.3f}%." + ) + else: + print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.") def run(args, training_callback: TrainingCallback = None): diff --git a/mlx_lm/tuner/trainer.py b/mlx_lm/tuner/trainer.py index a56301f8c..4ac2ae5a2 100644 --- a/mlx_lm/tuner/trainer.py +++ b/mlx_lm/tuner/trainer.py @@ -70,6 +70,10 @@ class TrainingArgs: "help": "Number of steps to accumulate gradients before applying an optimizer update." }, ) + report_accuracy: bool = field( + default=False, + metadata={"help": "Display token-level accuracy during reporting steps."}, + ) def default_loss(model, batch, lengths): @@ -85,7 +89,11 @@ def default_loss(model, batch, lengths): ntoks = mask.sum() ce = ce.astype(mx.float32).sum() / ntoks - return ce, ntoks + preds = mx.argmax(logits, axis=-1) + correct = (preds == targets) * mask + correct = correct.sum() + + return ce, ntoks, correct def iterate_batches( @@ -163,10 +171,12 @@ def evaluate( max_seq_length=2048, loss: callable = default_loss, iterate_batches: callable = iterate_batches, + return_accuracy: bool = False, ): model.eval() all_losses = mx.array(0.0) ntokens = mx.array(0) + all_correct = mx.array(0) index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) @@ -182,15 +192,25 @@ def evaluate( desc="Calculating loss...", total=min(len(dataset) // batch_size, num_batches), ): - losses, toks = loss(model, *batch) + out = loss(model, *batch) + if isinstance(out, tuple) and len(out) == 3: + losses, toks, correct = out + all_correct += correct + else: + losses, toks = out all_losses += losses * toks ntokens += toks - mx.eval(all_losses, ntokens) + mx.eval(all_losses, ntokens, all_correct) all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu) ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu) - - return (all_losses / ntokens).item() + if return_accuracy: + all_correct = mx.distributed.all_sum(all_correct, stream=mx.cpu) + avg_loss = (all_losses / ntokens).item() + acc = (all_correct / ntokens).item() + return avg_loss, acc + else: + return (all_losses / ntokens).item() def train( @@ -225,7 +245,7 @@ def train( @partial(mx.compile, inputs=state, outputs=state) def step(batch, prev_grad, do_update): - (lvalue, toks), grad = loss_value_and_grad(model, *batch) + (lvalue, toks, correct), grad = loss_value_and_grad(model, *batch) if prev_grad is not None: grad = tree_map(lambda x, y: x + y, grad, prev_grad) @@ -237,7 +257,7 @@ def step(batch, prev_grad, do_update): optimizer.update(model, grad) grad = None - return lvalue, toks, grad + return lvalue, toks, correct, grad model.train() losses = 0 @@ -246,6 +266,7 @@ def step(batch, prev_grad, do_update): trained_tokens = 0 train_time = 0 grad_accum = None + correct_tokens = mx.array(0) # Main training loop for it, batch in zip( @@ -262,7 +283,7 @@ def step(batch, prev_grad, do_update): # is always measured before any training. if it == 1 or it % args.steps_per_eval == 0 or it == args.iters: tic = time.perf_counter() - val_loss = evaluate( + val_result = evaluate( model=model, dataset=val_dataset, loss=loss, @@ -270,16 +291,30 @@ def step(batch, prev_grad, do_update): num_batches=args.val_batches, max_seq_length=args.max_seq_length, iterate_batches=iterate_batches, + return_accuracy=args.report_accuracy, ) + if args.report_accuracy: + val_loss, val_acc = val_result + else: + val_loss = val_result model.train() val_time = time.perf_counter() - tic if rank == 0: - print( - f"Iter {it}: " - f"Val loss {val_loss:.3f}, " - f"Val took {val_time:.3f}s", - flush=True, - ) + if args.report_accuracy: + print( + f"Iter {it}: " + f"Val loss {val_loss:.3f}, " + f"Val acc {(val_acc * 100):.3f}%, " + f"Val took {val_time:.3f}s", + flush=True, + ) + else: + print( + f"Iter {it}: " + f"Val loss {val_loss:.3f}, " + f"Val took {val_time:.3f}s", + flush=True, + ) if training_callback is not None: val_info = { @@ -287,11 +322,13 @@ def step(batch, prev_grad, do_update): "val_loss": val_loss, "val_time": val_time, } + if args.report_accuracy: + val_info["val_acc"] = val_acc training_callback.on_val_loss_report(val_info) tic = time.perf_counter() - lvalue, toks, grad_accum = step( + lvalue, toks, correct, grad_accum = step( batch, grad_accum, it % grad_accum_steps == 0, @@ -299,8 +336,12 @@ def step(batch, prev_grad, do_update): losses += lvalue n_tokens += toks + correct_tokens += correct steps += 1 - mx.eval(state, losses, n_tokens, grad_accum) + if args.report_accuracy: + mx.eval(state, losses, n_tokens, correct_tokens, grad_accum) + else: + mx.eval(state, losses, n_tokens, grad_accum) train_time += time.perf_counter() - tic # Report training loss if needed @@ -308,26 +349,44 @@ def step(batch, prev_grad, do_update): train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item() train_loss /= steps * world_size n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item() + if args.report_accuracy: + correct_sum = mx.distributed.all_sum( + correct_tokens, stream=mx.cpu + ).item() + train_acc = (correct_sum / n_tokens) if n_tokens > 0 else 0.0 learning_rate = optimizer.learning_rate.item() it_sec = args.steps_per_report / train_time tokens_sec = float(n_tokens) / train_time trained_tokens += n_tokens peak_mem = mx.get_peak_memory() / 1e9 if rank == 0: - print( - f"Iter {it}: Train loss {train_loss:.3f}, " - f"Learning Rate {learning_rate:.3e}, " - f"It/sec {it_sec:.3f}, " - f"Tokens/sec {tokens_sec:.3f}, " - f"Trained Tokens {trained_tokens}, " - f"Peak mem {peak_mem:.3f} GB", - flush=True, - ) + if args.report_accuracy: + print( + f"Iter {it}: Train loss {train_loss:.3f}, " + f"Train acc {(train_acc * 100):.3f}%, " + f"Learning Rate {learning_rate:.3e}, " + f"It/sec {it_sec:.3f}, " + f"Tokens/sec {tokens_sec:.3f}, " + f"Trained Tokens {trained_tokens}, " + f"Peak mem {peak_mem:.3f} GB", + flush=True, + ) + else: + print( + f"Iter {it}: Train loss {train_loss:.3f}, " + f"Learning Rate {learning_rate:.3e}, " + f"It/sec {it_sec:.3f}, " + f"Tokens/sec {tokens_sec:.3f}, " + f"Trained Tokens {trained_tokens}, " + f"Peak mem {peak_mem:.3f} GB", + flush=True, + ) if training_callback is not None: train_info = { "iteration": it, "train_loss": train_loss, + **({"train_acc": train_acc} if args.report_accuracy else {}), "learning_rate": learning_rate, "iterations_per_second": it_sec, "tokens_per_second": tokens_sec, @@ -340,6 +399,7 @@ def step(batch, prev_grad, do_update): n_tokens = 0 steps = 0 train_time = 0 + correct_tokens = mx.array(0) # Save adapter weights if it % args.steps_per_save == 0 and rank == 0: From 6fad197d249c9b2235f3750a4f2620cc8188324a Mon Sep 17 00:00:00 2001 From: Josh York Date: Wed, 5 Nov 2025 15:42:35 -0500 Subject: [PATCH 2/2] Removed accuracy-related calculations from the default loss function to avoid needless compute when report_loss is turned off. Added a helper to cleanly unpack the loss output so the logic isn't duplicated between the train and evaluate methods. --- mlx_lm/tuner/trainer.py | 46 ++++++++++++++++++++++++++++++----------- 1 file changed, 34 insertions(+), 12 deletions(-) diff --git a/mlx_lm/tuner/trainer.py b/mlx_lm/tuner/trainer.py index 4ac2ae5a2..f161a45d8 100644 --- a/mlx_lm/tuner/trainer.py +++ b/mlx_lm/tuner/trainer.py @@ -76,6 +76,30 @@ class TrainingArgs: ) +def _unpack_loss_output(out, need_accuracy: bool): + """ + Normalize various loss outputs to a common form (loss, ntoks, correct). + Supported outputs: + - (loss, ntoks, logits, targets, mask): preferred; computes accuracy only if requested. + - (loss, ntoks): no accuracy available; returns correct=0. + Any other shape falls back to (loss, 0, 0). + """ + if isinstance(out, tuple): + if len(out) >= 5: + losses, toks, logits, targets, mask = out[:5] + if need_accuracy: + preds = mx.argmax(logits, axis=-1) + correct = ((preds == targets) * mask).sum() + else: + correct = mx.array(0) + return losses, toks, correct + elif len(out) == 2: + losses, toks = out + return losses, toks, mx.array(0) + # Fallback + return out, mx.array(0), mx.array(0) + + def default_loss(model, batch, lengths): inputs = batch[:, :-1] targets = batch[:, 1:] @@ -89,11 +113,8 @@ def default_loss(model, batch, lengths): ntoks = mask.sum() ce = ce.astype(mx.float32).sum() / ntoks - preds = mx.argmax(logits, axis=-1) - correct = (preds == targets) * mask - correct = correct.sum() - - return ce, ntoks, correct + # Return intermediates so accuracy can be computed conditionally outside + return ce, ntoks, logits, targets, mask def iterate_batches( @@ -193,14 +214,14 @@ def evaluate( total=min(len(dataset) // batch_size, num_batches), ): out = loss(model, *batch) - if isinstance(out, tuple) and len(out) == 3: - losses, toks, correct = out - all_correct += correct - else: - losses, toks = out + losses, toks, correct = _unpack_loss_output(out, return_accuracy) all_losses += losses * toks ntokens += toks - mx.eval(all_losses, ntokens, all_correct) + if return_accuracy: + all_correct += correct + mx.eval(all_losses, ntokens, all_correct) + else: + mx.eval(all_losses, ntokens) all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu) ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu) @@ -245,7 +266,8 @@ def train( @partial(mx.compile, inputs=state, outputs=state) def step(batch, prev_grad, do_update): - (lvalue, toks, correct), grad = loss_value_and_grad(model, *batch) + out, grad = loss_value_and_grad(model, *batch) + lvalue, toks, correct = _unpack_loss_output(out, args.report_accuracy) if prev_grad is not None: grad = tree_map(lambda x, y: x + y, grad, prev_grad)