From 6a96993539ff95fea713d310f856e72069e3cce4 Mon Sep 17 00:00:00 2001 From: Liu Nazhou <1171509797@qq.com> Date: Wed, 21 Jan 2026 10:17:26 +0000 Subject: [PATCH 1/4] support vision input for Qwen3-VL --- pyproject.toml | 1 + src/hpcai/cookbook/eval/inspect_evaluators.py | 2 +- src/hpcai/cookbook/image_processing_utils.py | 56 + src/hpcai/cookbook/model_info.py | 5 +- .../cookbook/recipes/vlm_classifier/data.py | 529 ++++++++ .../cookbook/recipes/vlm_classifier/eval.py | 486 +++++++ .../recipes/vlm_classifier/eval_sweep.py | 258 ++++ .../cookbook/recipes/vlm_classifier/sweep.py | 222 +++ .../cookbook/recipes/vlm_classifier/train.py | 152 +++ src/hpcai/cookbook/renderers.py | 1206 ++++++++++++----- src/hpcai/types/__init__.py | 1 + src/hpcai/types/image_chunk.py | 54 + src/hpcai/types/model_input_chunk.py | 3 +- 13 files changed, 2652 insertions(+), 323 deletions(-) create mode 100644 src/hpcai/cookbook/image_processing_utils.py create mode 100644 src/hpcai/cookbook/recipes/vlm_classifier/data.py create mode 100644 src/hpcai/cookbook/recipes/vlm_classifier/eval.py create mode 100644 src/hpcai/cookbook/recipes/vlm_classifier/eval_sweep.py create mode 100644 src/hpcai/cookbook/recipes/vlm_classifier/sweep.py create mode 100644 src/hpcai/cookbook/recipes/vlm_classifier/train.py create mode 100644 src/hpcai/types/image_chunk.py diff --git a/pyproject.toml b/pyproject.toml index 32fe44b..e6e55c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dependencies = [ # AIOHTTP support "aiohttp", "httpx_aiohttp>=0.1.8", + "torchvision", ] requires-python = ">= 3.9" classifiers = [ diff --git a/src/hpcai/cookbook/eval/inspect_evaluators.py b/src/hpcai/cookbook/eval/inspect_evaluators.py index 9befeef..4109596 100644 --- a/src/hpcai/cookbook/eval/inspect_evaluators.py +++ b/src/hpcai/cookbook/eval/inspect_evaluators.py @@ -97,7 +97,7 @@ async def __call__(self, sampling_client: hpcai.SamplingClient) -> dict[str, flo debug_errors=self.config.debug_errors, # Never retry - the hpcai SDK is doing this for us already retry_on_error=0, - # Although Tinker sampling tries very hard to only throw unrecoverable failures, + # Although sampling tries very hard to only throw unrecoverable failures, # the inspect evaluation can still fail if e.g. the parser returns an error for # a given sample. fail_on_error=False, diff --git a/src/hpcai/cookbook/image_processing_utils.py b/src/hpcai/cookbook/image_processing_utils.py new file mode 100644 index 0000000..41ef360 --- /dev/null +++ b/src/hpcai/cookbook/image_processing_utils.py @@ -0,0 +1,56 @@ +# Copyright 2026 Thinking Machines Lab +# +# Licensed under the Apache License, Version 2.0 +# +# Modifications: +# - Adapted for HPC-AI cloud fine-tuning workflow +# Copyright © 2026 HPC-AI.COM + +""" +Utilities for working with image processors. Create new types to avoid needing to import AutoImageProcessor and BaseImageProcessor. +Avoid importing AutoImageProcessor and BaseImageProcessor until runtime, because they're slow imports. +""" + +from __future__ import annotations + +from functools import cache +from typing import TYPE_CHECKING, Any, TypeAlias + +from PIL import Image + +if TYPE_CHECKING: + # this import takes a few seconds, so avoid it on the module import when possible + from transformers.image_processing_utils import BaseImageProcessor + + ImageProcessor: TypeAlias = BaseImageProcessor +else: + # make it importable from other files as a type in runtime + ImageProcessor: TypeAlias = Any + + +@cache +def get_image_processor(model_name: str) -> ImageProcessor: + from transformers.models.auto.image_processing_auto import AutoImageProcessor + + return AutoImageProcessor.from_pretrained(model_name, use_fast=True) + + +def resize_image(image: Image.Image, max_size: int) -> Image.Image: + """ + Resize an image so that its longest side is at most max_size pixels. + Preserves aspect ratio and uses LANCZOS resampling for quality. + Returns the original image if it's already smaller than max_size. + """ + + width, height = image.size + if max(width, height) <= max_size: + return image + + if width > height: + new_width = max_size + new_height = int(height * max_size / width) + else: + new_height = max_size + new_width = int(width * max_size / height) + + return image.resize((new_width, new_height), Image.Resampling.LANCZOS) \ No newline at end of file diff --git a/src/hpcai/cookbook/model_info.py b/src/hpcai/cookbook/model_info.py index 8a9b56a..4306d12 100644 --- a/src/hpcai/cookbook/model_info.py +++ b/src/hpcai/cookbook/model_info.py @@ -55,6 +55,7 @@ def get_qwen_info() -> dict[str, ModelAttributes]: "Qwen3-4B-Instruct-2507": ModelAttributes(org, "3", "4B", True), "Qwen3-30B-A3B-Instruct-2507": ModelAttributes(org, "3", "30B-A3B", True), "Qwen3-235B-A22B-Instruct-2507": ModelAttributes(org, "3", "235B-A22B", True), + "Qwen3-VL-8B-Instruct": ModelAttributes(org, "3", "8B", True), } @@ -101,7 +102,9 @@ def get_recommended_renderer_names(model_name: str) -> list[str]: return ["llama3"] elif attributes.organization == "Qwen": if attributes.version_str == "3": - if "-Instruct" in model_name: + if attributes.is_vl: + return ["qwen3_vl"] + elif "-Instruct" in model_name: return ["qwen3_instruct"] else: return ["qwen3", "qwen3_disable_thinking"] diff --git a/src/hpcai/cookbook/recipes/vlm_classifier/data.py b/src/hpcai/cookbook/recipes/vlm_classifier/data.py new file mode 100644 index 0000000..09eb222 --- /dev/null +++ b/src/hpcai/cookbook/recipes/vlm_classifier/data.py @@ -0,0 +1,529 @@ +""" +Datasets for supervised learning (SFT) that use chat-formatted data, which we +convert to tokens using a Renderer. +""" + +import logging +from typing import Any, cast + +import random +import torch +import math +import io +import chz +import datasets +import hpcai +from PIL import Image +from collections import defaultdict +from hpcai.cookbook.supervised.common import datum_from_model_input_weights +from hpcai.cookbook.supervised.types import SupervisedDatasetBuilder, SupervisedDataset +from hpcai.cookbook.tokenizer_utils import get_tokenizer +from hpcai.cookbook.image_processing_utils import get_image_processor, resize_image +from hpcai.cookbook.renderers import ( + Message, + ContentPart, + ImagePart, + TextPart, + TrainOnWhat, + get_renderer, +) + +logger = logging.getLogger(__name__) + + +@chz.chz +class ClassifierDatasetConfig: + """ + Configuration for a classification dataset. + """ + + dataset: str + dataset_split: str + + image_column_name: str = "image" + label_column_name: str = "label" + + model_name_for_tokenizer: str + renderer_name: str + + num_repeats: float = 1 + batch_size: int = 32 + max_length: int = 8192 + + train_on_what: TrainOnWhat = TrainOnWhat.LAST_ASSISTANT_MESSAGE + + # If set, sample only this many examples per class (for few-shot experiments) + examples_per_class: int | None = None + subset_seed: int = 0 + + max_image_size: int = 480 + hflip_probability: float = 0.5 + + +class ClassifierDataset(SupervisedDataset): + def __init__(self, config: ClassifierDatasetConfig): + """ + Construct a VLM classifier dataset with the provided data config. + """ + + self.config = config + + tokenizer = get_tokenizer(self.config.model_name_for_tokenizer) + image_processor = get_image_processor(self.config.model_name_for_tokenizer) + + self.renderer = get_renderer( + name=self.config.renderer_name, tokenizer=tokenizer, image_processor=image_processor + ) + + dataset = datasets.load_dataset(self.config.dataset) + dataset = cast(datasets.DatasetDict, dataset) + self.dataset = dataset[self.config.dataset_split] + + # If examples_per_class is set, sample N examples per class for few-shot setting + if self.config.examples_per_class is not None: + self.dataset = self._sample_per_class(self.dataset) + + self.class_labels = self.dataset.features[self.config.label_column_name] + self.shuffled_indices = self.get_shuffled_indices() + + def get_shuffled_indices(self, seed: int = 0) -> list[int]: + """ + Get a shuffled set of dataset indices with a target number of num_repeats. + """ + + max_repeat = int(math.ceil(self.config.num_repeats)) + max_examples = int(math.ceil(self.config.num_repeats * len(self.dataset))) + + random_gen = random.Random(seed) + shuffled_indices: list[int] = [] + + for _ in range(max_repeat): + dataset_indices = list(range(len(self.dataset))) + random_gen.shuffle(dataset_indices) + shuffled_indices.extend(dataset_indices) + + return shuffled_indices[:max_examples] + + def _sample_per_class(self, dataset: datasets.Dataset) -> datasets.Dataset: + """ + Sample up to N examples per class from the dataset for few-shot experiments. + Uses self.config.examples_per_class, label_column_name, and subset_seed. + """ + rng = random.Random(self.config.subset_seed) + + # Group indices by class label + class_indices: dict[int, list[int]] = defaultdict(list) + for idx, label in enumerate(dataset[self.config.label_column_name]): + class_indices[label].append(idx) + + # Shuffle and sample up to examples_per_class from each class + selected_indices: list[int] = [] + for label in sorted(class_indices.keys()): + indices = class_indices[label] + rng.shuffle(indices) + + selected_indices.extend(indices[: self.config.examples_per_class]) + + logger.info( + f"Sampled {len(selected_indices)} examples " + f"({self.config.examples_per_class} per class, {len(class_indices)} classes)" + ) + + return dataset.select(selected_indices) + + def get_class_name(self, label: str) -> str: + """ + Helper function to clean up the original class name. + """ + + return label.replace("_", " ").replace(".", " ").replace("-", " ").lower() + + def build_supervised_example( + self, + example: dict[str, Any], + ) -> tuple[hpcai.ModelInput, torch.Tensor]: + """ + Generate an input to prompt the model. + """ + + class_label = example[self.config.label_column_name] + class_label_name = self.get_class_name(self.class_labels.int2str(class_label)) + + image = example[self.config.image_column_name] + pil_image: Image.Image | None = None + + if isinstance(image, dict) and "bytes" in image: + pil_image = Image.open(io.BytesIO(image["bytes"])) + + elif isinstance(image, Image.Image): + pil_image = cast(Image.Image, image) + + # If the dataset cannot be loaded + if pil_image is None: + raise AssertionError(f"Unable to interpret {image} as an image") + + pil_image = resize_image(image=pil_image, max_size=self.config.max_image_size) + + # horizontal flip 50% of the time + if random.random() < self.config.hflip_probability: + pil_image = pil_image.transpose(Image.Transpose.FLIP_LEFT_RIGHT) + + user_parts: list[ContentPart] = [ + ImagePart(type="image", image=pil_image), + TextPart(type="text", text="What is the name of the subject in this photo?"), + ] + + assistant_parts: list[ContentPart] = [ + TextPart(type="text", text=f"The subject in this photo is: {class_label_name}\n"), + ] + + messages = [ + Message(role="user", content=user_parts), + Message(role="assistant", content=assistant_parts), + ] + + return self.renderer.build_supervised_example( + messages=messages, + train_on_what=self.config.train_on_what, + ) + + def get_batch(self, index: int) -> list[hpcai.Datum]: + """ + Load a batch of training examples. + """ + + return [ + datum_from_model_input_weights( + *self.build_supervised_example(self.dataset[self.shuffled_indices[idx]]), + max_length=self.config.max_length, + ) + for idx in range( + self.config.batch_size * index, + min(self.config.batch_size * (index + 1), len(self.shuffled_indices)), + ) + ] + + def __len__(self) -> int: + """ + Number of batches in the dataloader + """ + + return int(math.ceil(len(self.shuffled_indices) / self.config.batch_size)) + + def set_epoch(self, seed: int = 0): + """ + Set the epoch for shuffling the dataloader. + """ + + self.shuffled_indices = self.get_shuffled_indices(seed=seed) + + +@chz.chz +class Caltech101DatasetBuilder(SupervisedDatasetBuilder): + """ + Configuration for a classification dataset. + """ + + model_name_for_tokenizer: str + renderer_name: str + + num_repeats: float = 1 + batch_size: int = 32 + max_length: int = 8192 + + train_on_what: TrainOnWhat | None = None + + # If set, sample only this many examples per class (for few-shot experiments) + examples_per_class: int | None = None + subset_seed: int = 0 + + max_image_size: int = 480 + + run_nll_evaluator: bool = False + + def __call__(self) -> tuple[SupervisedDataset, SupervisedDataset | None]: + default_train_on_what = self.train_on_what or TrainOnWhat.LAST_ASSISTANT_MESSAGE + + train_config = ClassifierDatasetConfig( + dataset="dpdl-benchmark/caltech101", + dataset_split="train", + image_column_name="image", + label_column_name="label", + renderer_name=self.renderer_name, + model_name_for_tokenizer=self.model_name_for_tokenizer, + num_repeats=self.num_repeats, + batch_size=self.batch_size, + max_length=self.max_length, + train_on_what=default_train_on_what, + examples_per_class=self.examples_per_class, + subset_seed=self.subset_seed, + max_image_size=self.max_image_size, + hflip_probability=0.5, + ) + + test_config = ClassifierDatasetConfig( + dataset="dpdl-benchmark/caltech101", + dataset_split="test", + image_column_name="image", + label_column_name="label", + renderer_name=self.renderer_name, + model_name_for_tokenizer=self.model_name_for_tokenizer, + batch_size=self.batch_size, + max_length=self.max_length, + train_on_what=default_train_on_what, + max_image_size=self.max_image_size, + hflip_probability=0.0, # No augmentation for test set + # Note: test set uses full data, no few-shot sampling + ) + + train_dataset = ClassifierDataset(train_config) + + if not self.run_nll_evaluator: + return train_dataset, None + + return train_dataset, ClassifierDataset(test_config) + + +@chz.chz +class Flowers102DatasetBuilder(SupervisedDatasetBuilder): + """ + Configuration for a classification dataset. + """ + + model_name_for_tokenizer: str + renderer_name: str + + num_repeats: float = 1 + batch_size: int = 32 + max_length: int = 8192 + + train_on_what: TrainOnWhat | None = None + + # If set, sample only this many examples per class (for few-shot experiments) + examples_per_class: int | None = None + subset_seed: int = 0 + + max_image_size: int = 480 + + run_nll_evaluator: bool = False + + def __call__(self) -> tuple[SupervisedDataset, SupervisedDataset | None]: + default_train_on_what = self.train_on_what or TrainOnWhat.LAST_ASSISTANT_MESSAGE + + train_config = ClassifierDatasetConfig( + dataset="dpdl-benchmark/oxford_flowers102", + dataset_split="train", + image_column_name="image", + label_column_name="label", + renderer_name=self.renderer_name, + model_name_for_tokenizer=self.model_name_for_tokenizer, + num_repeats=self.num_repeats, + batch_size=self.batch_size, + max_length=self.max_length, + train_on_what=default_train_on_what, + examples_per_class=self.examples_per_class, + subset_seed=self.subset_seed, + max_image_size=self.max_image_size, + hflip_probability=0.5, + ) + + test_config = ClassifierDatasetConfig( + dataset="dpdl-benchmark/oxford_flowers102", + dataset_split="test", + image_column_name="image", + label_column_name="label", + renderer_name=self.renderer_name, + model_name_for_tokenizer=self.model_name_for_tokenizer, + batch_size=self.batch_size, + max_length=self.max_length, + train_on_what=default_train_on_what, + max_image_size=self.max_image_size, + hflip_probability=0.0, + # Note: test set uses full data, no few-shot sampling + ) + + train_dataset = ClassifierDataset(train_config) + + if not self.run_nll_evaluator: + return train_dataset, None + + return train_dataset, ClassifierDataset(test_config) + + +@chz.chz +class OxfordPetsDatasetBuilder(SupervisedDatasetBuilder): + """ + Configuration for a classification dataset. + """ + + model_name_for_tokenizer: str + renderer_name: str + + num_repeats: float = 1 + batch_size: int = 32 + max_length: int = 8192 + + train_on_what: TrainOnWhat | None = None + + # If set, sample only this many examples per class (for few-shot experiments) + examples_per_class: int | None = None + subset_seed: int = 0 + + max_image_size: int = 480 + + run_nll_evaluator: bool = False + + def __call__(self) -> tuple[SupervisedDataset, SupervisedDataset | None]: + default_train_on_what = self.train_on_what or TrainOnWhat.LAST_ASSISTANT_MESSAGE + + train_config = ClassifierDatasetConfig( + dataset="dpdl-benchmark/oxford_iiit_pet", + dataset_split="train", + image_column_name="image", + label_column_name="label", + renderer_name=self.renderer_name, + model_name_for_tokenizer=self.model_name_for_tokenizer, + num_repeats=self.num_repeats, + batch_size=self.batch_size, + max_length=self.max_length, + train_on_what=default_train_on_what, + examples_per_class=self.examples_per_class, + subset_seed=self.subset_seed, + max_image_size=self.max_image_size, + hflip_probability=0.5, + ) + + test_config = ClassifierDatasetConfig( + dataset="dpdl-benchmark/oxford_iiit_pet", + dataset_split="test", + image_column_name="image", + label_column_name="label", + renderer_name=self.renderer_name, + model_name_for_tokenizer=self.model_name_for_tokenizer, + batch_size=self.batch_size, + max_length=self.max_length, + train_on_what=default_train_on_what, + max_image_size=self.max_image_size, + hflip_probability=0.0, + # Note: test set uses full data, no few-shot sampling + ) + + train_dataset = ClassifierDataset(train_config) + + if not self.run_nll_evaluator: + return train_dataset, None + + return train_dataset, ClassifierDataset(test_config) + + +@chz.chz +class StanfordCarsDatasetBuilder(SupervisedDatasetBuilder): + """ + Configuration for a classification dataset. + """ + + model_name_for_tokenizer: str + renderer_name: str + + num_repeats: float = 1 + batch_size: int = 32 + max_length: int = 8192 + + train_on_what: TrainOnWhat | None = None + + # If set, sample only this many examples per class (for few-shot experiments) + examples_per_class: int | None = None + subset_seed: int = 0 + + max_image_size: int = 480 + + run_nll_evaluator: bool = False + + def __call__(self) -> tuple[SupervisedDataset, SupervisedDataset | None]: + default_train_on_what = self.train_on_what or TrainOnWhat.LAST_ASSISTANT_MESSAGE + + train_config = ClassifierDatasetConfig( + dataset="tanganke/stanford_cars", + dataset_split="train", + image_column_name="image", + label_column_name="label", + renderer_name=self.renderer_name, + model_name_for_tokenizer=self.model_name_for_tokenizer, + num_repeats=self.num_repeats, + batch_size=self.batch_size, + max_length=self.max_length, + train_on_what=default_train_on_what, + examples_per_class=self.examples_per_class, + subset_seed=self.subset_seed, + max_image_size=self.max_image_size, + hflip_probability=0.5, + ) + + test_config = ClassifierDatasetConfig( + dataset="tanganke/stanford_cars", + dataset_split="test", + image_column_name="image", + label_column_name="label", + renderer_name=self.renderer_name, + model_name_for_tokenizer=self.model_name_for_tokenizer, + batch_size=self.batch_size, + max_length=self.max_length, + train_on_what=default_train_on_what, + max_image_size=self.max_image_size, + hflip_probability=0.0, + # Note: test set uses full data, no few-shot sampling + ) + + train_dataset = ClassifierDataset(train_config) + + if not self.run_nll_evaluator: + return train_dataset, None + + return train_dataset, ClassifierDataset(test_config) + + +DATASETS = { + "caltech101": Caltech101DatasetBuilder, + "flowers102": Flowers102DatasetBuilder, + "pets": OxfordPetsDatasetBuilder, + "cars": StanfordCarsDatasetBuilder, +} + + +def get_dataset_builder( + dataset: str, + model_name_for_tokenizer: str, + renderer_name: str, + num_repeats: float = 1, + batch_size: int = 32, + max_length: int = 8192, + train_on_what: TrainOnWhat | None = None, + examples_per_class: int | None = None, + subset_seed: int = 0, + max_image_size: int = 480, + run_nll_evaluator: bool = False, +) -> SupervisedDatasetBuilder: + """ + Create a training and test dataset for a vlm classifier. + + Args: + examples_per_class: If set, sample only this many examples per class + from the training set (for few-shot experiments). Test set is + unaffected. + subset_seed: Seed for shuffling before selecting the few-shot subset. + max_image_size: Maximum size for the longest side of images. Images + larger than this will be resized while preserving aspect ratio. + """ + + return DATASETS[dataset]( + model_name_for_tokenizer=model_name_for_tokenizer, + renderer_name=renderer_name, + num_repeats=num_repeats, + batch_size=batch_size, + max_length=max_length, + train_on_what=train_on_what, + examples_per_class=examples_per_class, + subset_seed=subset_seed, + max_image_size=max_image_size, + run_nll_evaluator=run_nll_evaluator, + ) diff --git a/src/hpcai/cookbook/recipes/vlm_classifier/eval.py b/src/hpcai/cookbook/recipes/vlm_classifier/eval.py new file mode 100644 index 0000000..d050952 --- /dev/null +++ b/src/hpcai/cookbook/recipes/vlm_classifier/eval.py @@ -0,0 +1,486 @@ +import asyncio +import logging +from typing import TypedDict, Any, cast +from PIL import Image + +import datasets +import chz +import io + +import numpy as np +import hpcai +from hpcai.cookbook import renderers +from hpcai.cookbook.eval.evaluators import SamplingClientEvaluator, EvaluatorBuilder +from hpcai.cookbook.tokenizer_utils import get_tokenizer +from hpcai.cookbook.image_processing_utils import get_image_processor, resize_image +from hpcai.cookbook.renderers import Message, ImagePart, TextPart, get_text_content +from hpcai.cookbook.utils.misc_utils import timed + + +# Set up logger +logger = logging.getLogger(__name__) + + +@chz.chz +class ClassifierEvaluatorConfig: + """ + Configuration for classifier evaluation. + """ + + dataset: str + dataset_split: str + + image_column_name: str = "image" + label_column_name: str = "label" + + model_name_for_tokenizer: str + renderer_name: str + + temperature: float = 0.0 + max_tokens: int = 128 + top_p: float = 1.0 + top_k: int = -1 + + n_eval: int | None = None + max_parallel_tasks: int = 128 + + max_image_size: int = 480 + + +class ClassifierOutput(TypedDict): + """ + Parsed output from an image classification model. + """ + + predicted_class_name: str + + +class ClassifierEvaluator(SamplingClientEvaluator): + """ + Evaluator that runs image classification evaluation. + """ + + def __init__( + self, + config: ClassifierEvaluatorConfig, + ): + """ + Initialize the CustomEvaluator. + Args: + config: Configuration object containing all evaluation parameters + """ + + self.config = config + + tokenizer = get_tokenizer(self.config.model_name_for_tokenizer) + image_processor = get_image_processor(self.config.model_name_for_tokenizer) + + self.renderer = renderers.get_renderer( + name=self.config.renderer_name, tokenizer=tokenizer, image_processor=image_processor + ) + + dataset = datasets.load_dataset(self.config.dataset) + dataset = cast(datasets.DatasetDict, dataset) + self.dataset = dataset[self.config.dataset_split] + + self.shuffled_dataset = self.dataset.shuffle(seed=0) + self.class_labels = self.dataset.features[self.config.label_column_name] + + def get_class_name(self, label: str) -> str: + """ + Helper function to clean up the original class name. + """ + + return label.replace("_", " ").replace(".", " ").replace("-", " ").lower() + + def build_generation_prompt( + self, + example: dict[str, Any], + ) -> hpcai.ModelInput: + """ + Generate an input to prompt the model. + """ + + image = example[self.config.image_column_name] + pil_image: Image.Image | None = None + + if isinstance(image, dict) and "bytes" in image: + pil_image = Image.open(io.BytesIO(image["bytes"])) + + elif isinstance(image, Image.Image): + pil_image = cast(Image.Image, image) + + # If the dataset cannot be loaded + if pil_image is None: + raise AssertionError(f"Unable to interpret {image} as an image") + + pil_image = resize_image(image=pil_image, max_size=self.config.max_image_size) + + content_parts = [ + ImagePart(type="image", image=pil_image), + TextPart(type="text", text="What is the name of the subject in this photo?"), + ] + + messages = [ + Message(role="user", content=content_parts), + ] + + return self.renderer.build_generation_prompt( + messages=messages, role="assistant", prefill="The subject in this photo is:" + ) + + async def generate_output( + self, + model_input: hpcai.ModelInput, + sampling_client: hpcai.SamplingClient, + sampling_params: hpcai.SamplingParams, + ) -> ClassifierOutput: + """ + Generate a completion and extract the class name from the model. + """ + + # Generate response + r: hpcai.SampleResponse = await sampling_client.sample_async( + prompt=model_input, num_samples=1, sampling_params=sampling_params + ) + tokens: list[int] = r.sequences[0].tokens + response = self.renderer.parse_response(tokens)[0] + + predicted_class_name = get_text_content(response).split(":")[-1].strip().lower() + + return ClassifierOutput(predicted_class_name=predicted_class_name) + + def get_metrics_for_output( + self, example: dict[str, Any], classifier_output: ClassifierOutput + ) -> dict[str, float]: + """ + Score the class name predicted by the model. + """ + + predicted_class_name = classifier_output["predicted_class_name"] + class_label = example[self.config.label_column_name] + class_label_name = self.get_class_name(self.class_labels.int2str(class_label)) + + return {"accuracy": float(predicted_class_name == class_label_name)} + + async def __call__(self, sampling_client: hpcai.SamplingClient) -> dict[str, float]: + """ + Evaluate a vision-language model as an image classifier. + + Args: + sampling_client: The sampling client to evaluate + + Returns: + Dictionary of metrics from evaluation + + """ + + sampling_params = hpcai.SamplingParams( + max_tokens=self.config.max_tokens, + temperature=self.config.temperature, + top_p=self.config.top_p, + top_k=self.config.top_k, + stop=self.renderer.get_stop_sequences(), + ) + + num_examples = min( + len(self.shuffled_dataset), self.config.n_eval or len(self.shuffled_dataset) + ) + + # Limit concurrent sampling tasks + semaphore = asyncio.Semaphore(self.config.max_parallel_tasks) + + async def bounded_generate_output(example: dict[str, Any]) -> ClassifierOutput: + async with semaphore: + return await self.generate_output( + self.build_generation_prompt(example), sampling_client, sampling_params + ) + + # Sample from the model in parallel + async_tasks = [] + + logger.info( + f"Submitting {num_examples} sampling tasks (max {self.config.max_parallel_tasks} parallel)" + ) + for example_id in range(num_examples): + example = self.shuffled_dataset[example_id] + + # Prepare model input for sampling, generate + async_tasks.append(asyncio.create_task(bounded_generate_output(example))) + + # Wait for the hpcai API to return the sampled completions + with timed("sample outputs", {}): + outputs = await asyncio.gather(*async_tasks) + + # Aggregate metrics for each example + metrics_per_example = [] + + logger.info(f"Evaluating {num_examples} sampled responses") + for example_id in range(num_examples): + example = self.shuffled_dataset[example_id] + output = outputs[example_id] + + # Evaluate the model response + metrics = self.get_metrics_for_output(example, output) + metrics_per_example.append(metrics) + + # aggregate the performance metrics + aggregated_metrics = { + key: np.mean([example[key] for example in metrics_per_example]).item() + for key in metrics_per_example[0].keys() + } + + return aggregated_metrics + + +@chz.chz +class Caltech101EvaluatorBuilder: + """ + Configuration for classifier evaluation. + """ + + model_name_for_tokenizer: str + renderer_name: str + + temperature: float = 0.0 + max_tokens: int = 128 + top_p: float = 1.0 + top_k: int = -1 + + n_eval: int | None = None + max_parallel_tasks: int = 128 + + max_image_size: int = 480 + + def __call__(self) -> ClassifierEvaluator: + config = ClassifierEvaluatorConfig( + dataset="dpdl-benchmark/caltech101", + dataset_split="test", + renderer_name=self.renderer_name, + model_name_for_tokenizer=self.model_name_for_tokenizer, + temperature=self.temperature, + max_tokens=self.max_tokens, + top_p=self.top_p, + top_k=self.top_k, + image_column_name="image", + label_column_name="label", + n_eval=self.n_eval, + max_parallel_tasks=self.max_parallel_tasks, + max_image_size=self.max_image_size, + ) + + return ClassifierEvaluator(config) + + +@chz.chz +class Flowers102EvaluatorBuilder: + """ + Configuration for classifier evaluation. + """ + + model_name_for_tokenizer: str + renderer_name: str + + temperature: float = 0.0 + max_tokens: int = 128 + top_p: float = 1.0 + top_k: int = -1 + + n_eval: int | None = None + max_parallel_tasks: int = 128 + + max_image_size: int = 480 + + def __call__(self) -> ClassifierEvaluator: + config = ClassifierEvaluatorConfig( + dataset="dpdl-benchmark/oxford_flowers102", + dataset_split="test", + renderer_name=self.renderer_name, + model_name_for_tokenizer=self.model_name_for_tokenizer, + temperature=self.temperature, + max_tokens=self.max_tokens, + top_p=self.top_p, + top_k=self.top_k, + image_column_name="image", + label_column_name="label", + n_eval=self.n_eval, + max_parallel_tasks=self.max_parallel_tasks, + max_image_size=self.max_image_size, + ) + + return ClassifierEvaluator(config) + + +@chz.chz +class OxfordPetsEvaluatorBuilder: + """ + Configuration for classifier evaluation. + """ + + model_name_for_tokenizer: str + renderer_name: str + + temperature: float = 0.0 + max_tokens: int = 128 + top_p: float = 1.0 + top_k: int = -1 + + n_eval: int | None = None + max_parallel_tasks: int = 128 + + max_image_size: int = 480 + + def __call__(self) -> ClassifierEvaluator: + config = ClassifierEvaluatorConfig( + dataset="dpdl-benchmark/oxford_iiit_pet", + dataset_split="test", + renderer_name=self.renderer_name, + model_name_for_tokenizer=self.model_name_for_tokenizer, + temperature=self.temperature, + max_tokens=self.max_tokens, + top_p=self.top_p, + top_k=self.top_k, + image_column_name="image", + label_column_name="label", + n_eval=self.n_eval, + max_parallel_tasks=self.max_parallel_tasks, + max_image_size=self.max_image_size, + ) + + return ClassifierEvaluator(config) + + +@chz.chz +class StanfordCarsEvaluatorBuilder: + """ + Configuration for classifier evaluation. + """ + + model_name_for_tokenizer: str + renderer_name: str + + temperature: float = 0.0 + max_tokens: int = 128 + top_p: float = 1.0 + top_k: int = -1 + + n_eval: int | None = None + max_parallel_tasks: int = 128 + + max_image_size: int = 480 + + def __call__(self) -> ClassifierEvaluator: + config = ClassifierEvaluatorConfig( + dataset="tanganke/stanford_cars", + dataset_split="test", + renderer_name=self.renderer_name, + model_name_for_tokenizer=self.model_name_for_tokenizer, + temperature=self.temperature, + max_tokens=self.max_tokens, + top_p=self.top_p, + top_k=self.top_k, + image_column_name="image", + label_column_name="label", + n_eval=self.n_eval, + max_parallel_tasks=self.max_parallel_tasks, + max_image_size=self.max_image_size, + ) + + return ClassifierEvaluator(config) + + +EVALUATORS = { + "caltech101": Caltech101EvaluatorBuilder, + "flowers102": Flowers102EvaluatorBuilder, + "pets": OxfordPetsEvaluatorBuilder, + "cars": StanfordCarsEvaluatorBuilder, +} + + +def get_evaluator_builder( + dataset: str, + model_name_for_tokenizer: str, + renderer_name: str, + temperature: float = 0.0, + max_tokens: int = 128, + top_p: float = 1.0, + top_k: int = -1, + n_eval: int | None = None, + max_parallel_tasks: int = 128, + max_image_size: int = 480, +) -> EvaluatorBuilder: + """ + Create a sampling based evaluator for a vlm classifier. + """ + + return EVALUATORS[dataset]( + model_name_for_tokenizer=model_name_for_tokenizer, + renderer_name=renderer_name, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + top_k=top_k, + n_eval=n_eval, + max_parallel_tasks=max_parallel_tasks, + max_image_size=max_image_size, + ) + + +@chz.chz +class EvalConfig: + """ + Config for launching evaluation on a model checkpoint. + """ + + dataset: str + model_path: str + + renderer_name: str = "qwen3_vl" + model_name: str = "Qwen/Qwen3-VL-235B-A22B-Instruct" + + # Infrastructure parameters + base_url: str | None = None + + temperature: float = 0.0 + max_tokens: int = 128 + top_p: float = 1.0 + top_k: int = -1 + + n_eval: int | None = None + max_parallel_tasks: int = 128 + + max_image_size: int = 480 + + +def run_eval(eval_config: EvalConfig): + """ + Launch evaluation on a model checkpoint on an image dataset. + """ + + service_client = hpcai.ServiceClient(base_url=eval_config.base_url) + sampling_client = service_client.create_sampling_client(model_path=eval_config.model_path) + + evaluator_builder = get_evaluator_builder( + dataset=eval_config.dataset, + model_name_for_tokenizer=eval_config.model_name, + renderer_name=eval_config.renderer_name, + temperature=eval_config.temperature, + max_tokens=eval_config.max_tokens, + top_p=eval_config.top_p, + top_k=eval_config.top_k, + n_eval=eval_config.n_eval, + max_parallel_tasks=eval_config.max_parallel_tasks, + max_image_size=eval_config.max_image_size, + ) + + evaluator = evaluator_builder() + + async def main(): + result = await evaluator(sampling_client) # type: ignore[arg-type] + print(f"Metrics = {result}") + + asyncio.run(main()) + + +if __name__ == "__main__": + chz.nested_entrypoint(run_eval) diff --git a/src/hpcai/cookbook/recipes/vlm_classifier/eval_sweep.py b/src/hpcai/cookbook/recipes/vlm_classifier/eval_sweep.py new file mode 100644 index 0000000..2cb3201 --- /dev/null +++ b/src/hpcai/cookbook/recipes/vlm_classifier/eval_sweep.py @@ -0,0 +1,258 @@ +""" + +## VLM Image Classifier + +Launcher for evaluating trained image classifiers. + +""" + +import asyncio +import json +import logging +import os +import re +from typing import Any + +import chz +import hpcai + +from hpcai.cookbook.checkpoint_utils import get_last_checkpoint, load_checkpoints_file +from hpcai.cookbook.recipes.vlm_classifier.eval import get_evaluator_builder + + +# Set up logger +logger = logging.getLogger(__name__) + + +def get_checkpoint_at_step( + log_dir: str, + step: int, + required_key: str = "sampler_path", +) -> dict[str, Any] | None: + """ + Get the checkpoint at a specific step from the checkpoints.jsonl file. + + Args: + log_dir: The directory containing checkpoints.jsonl. + step: The step number to find. + required_key: The key to check for in the checkpoint. + + Returns: + The checkpoint at the specified step, or None if not found. + """ + checkpoints = load_checkpoints_file(log_dir) + for checkpoint in checkpoints: + if checkpoint.get("batch") == step and required_key in checkpoint: + logger.info(f"Found checkpoint at step {step}: {checkpoint}") + return checkpoint + logger.warning(f"No checkpoint found at step {step} with key '{required_key}' in {log_dir}") + return None + + +def parse_hyperparams_from_experiment_name(experiment_name: str) -> dict[str, Any]: + """ + Parse hyperparameters from the experiment directory name. + + Experiment names follow the format from sweep.py: + {dataset}-{model_name}-{lora_rank}rank-{learning_rate}lr-{batch_size}batch-{examples_per_class}shot-seed{subset_seed}-{date} + + Example: caltech101-Qwen-Qwen3-VL-235B-A22B-Instruct-32rank-0.0005lr-32batch-4shot-seed0-2025-11-26 + """ + + hyperparams: dict[str, Any] = {} + + # Parse dataset: first segment before the first dash + hyperparams["dataset"] = experiment_name.split("-")[0] + + # Parse lora_rank: look for pattern like "32rank" + if match := re.search(r"-(\d+)rank-", experiment_name): + hyperparams["lora_rank"] = int(match.group(1)) + + # Parse learning_rate: look for pattern like "0.0005lr" or "5e-4lr" + if match := re.search(r"-([\d.e+-]+)lr-", experiment_name): + hyperparams["learning_rate"] = float(match.group(1)) + + # Parse batch_size: look for pattern like "32batch" + if match := re.search(r"-(\d+)batch", experiment_name): + hyperparams["batch_size"] = int(match.group(1)) + + # Parse examples_per_class (shot): look for pattern like "4shot" + if match := re.search(r"-(\d+)shot-", experiment_name): + hyperparams["examples_per_class"] = int(match.group(1)) + + # Parse subset_seed: look for pattern like "seed0" + if match := re.search(r"-seed(\d+)-", experiment_name): + hyperparams["subset_seed"] = int(match.group(1)) + + # Parse date: look for pattern like "2025-11-26" at the end + if match := re.search(r"-(\d{4}-\d{2}-\d{2})$", experiment_name): + hyperparams["date"] = match.group(1) + + return hyperparams + + +@chz.chz +class EvalConfig: + """ + Config for evaluating all experiments in a sweep directory. + """ + + experiment_dir: str + output_file: str + + renderer_name: str = "qwen3_vl" + model_name: str = "Qwen/Qwen3-VL-235B-A22B-Instruct" + + # Infrastructure parameters + base_url: str | None = None + + temperature: float = 0.0 + max_tokens: int = 128 + top_p: float = 1.0 + top_k: int = -1 + + n_eval: int | None = None + max_parallel_tasks: int = 1024 + max_parallel_evals: int = 5 + + max_image_size: int = 480 + + # Early stopping: map experiment name to the step of the best checkpoint + # If not provided or experiment not in dict, uses the last checkpoint + early_stopping_checkpoints: dict[str, int] | None = None + + +async def evaluate_experiment( + experiment_name: str, + eval_config: EvalConfig, + service_client: hpcai.ServiceClient, +) -> dict[str, Any]: + """ + Evaluate a single few-shot image classifier experiment. + """ + + experiment_path = os.path.join(eval_config.experiment_dir, experiment_name) + assert os.path.isdir(experiment_path), f"Experiment directory does not exist: {experiment_path}" + + # Load checkpoint: use early stopping step if provided, otherwise use last checkpoint + early_stop_step = ( + eval_config.early_stopping_checkpoints.get(experiment_name) + if eval_config.early_stopping_checkpoints + else None + ) + + if early_stop_step is not None: + checkpoint = get_checkpoint_at_step( + experiment_path, early_stop_step, required_key="sampler_path" + ) + assert checkpoint is not None, ( + f"No checkpoint at step {early_stop_step} with sampler_path found in {experiment_path}" + ) + logger.info( + f"Using early stopping checkpoint at step {early_stop_step} for {experiment_name}" + ) + else: + checkpoint = get_last_checkpoint(experiment_path, required_key="sampler_path") + assert checkpoint is not None, f"No checkpoint with sampler_path found in {experiment_path}" + + # Parse hyperparameters (including dataset) from directory name + hyperparams = parse_hyperparams_from_experiment_name(experiment_name) + assert "dataset" in hyperparams, f"Unable to parse the dataset name from {experiment_path}" + + # Create evaluator for this dataset + evaluator_builder = get_evaluator_builder( + dataset=hyperparams["dataset"], + model_name_for_tokenizer=eval_config.model_name, + renderer_name=eval_config.renderer_name, + temperature=eval_config.temperature, + max_tokens=eval_config.max_tokens, + top_p=eval_config.top_p, + top_k=eval_config.top_k, + n_eval=eval_config.n_eval, + max_parallel_tasks=eval_config.max_parallel_tasks, + max_image_size=eval_config.max_image_size, + ) + + sampling_client = service_client.create_sampling_client(model_path=checkpoint["sampler_path"]) + metrics = await evaluator_builder()(sampling_client) # type: ignore[arg-type] + return { + "experiment_name": experiment_name, + "checkpoint_step": checkpoint.get("step"), + **metrics, + **hyperparams, + } + + +async def evaluate_sweep( + eval_config: EvalConfig, + experiment_names: list[str], +) -> dict[str, dict[str, Any]]: + """ + Evaluate all few-shot image classifier experiments in a sweep directory. + """ + + service_client = hpcai.ServiceClient(base_url=eval_config.base_url) + + # Limit concurrent evaluation tasks + semaphore = asyncio.Semaphore(eval_config.max_parallel_evals) + + async def bounded_evaluate_experiment(experiment_name: str) -> dict[str, Any]: + async with semaphore: + return await evaluate_experiment( + experiment_name=experiment_name, + eval_config=eval_config, + service_client=service_client, + ) + + # Evaluate all experiments in parallel (bounded by semaphore) + logger.info( + f"Submitting {len(experiment_names)} eval tasks (max {eval_config.max_parallel_evals} parallel)" + ) + async_tasks = [ + asyncio.create_task(bounded_evaluate_experiment(name)) for name in experiment_names + ] + + results = await asyncio.gather(*async_tasks) + return {metrics["experiment_name"]: metrics for metrics in results} + + +def run_eval_sweep(eval_config: EvalConfig): + """ + Evaluate all few-shot image classifier experiments in a sweep directory. + """ + + logging.basicConfig(level=logging.INFO) + + if not os.path.isdir(eval_config.experiment_dir): + raise ValueError(f"Experiment directory does not exist: {eval_config.experiment_dir}") + + # Find all experiment subdirectories + experiment_names = sorted( + [ + d + for d in os.listdir(eval_config.experiment_dir) + if os.path.isdir(os.path.join(eval_config.experiment_dir, d)) + ] + ) + + logger.info( + f"Found {len(experiment_names)} experiment directories in {eval_config.experiment_dir}" + ) + classifier_results_json = asyncio.run( + evaluate_sweep( + eval_config=eval_config, + experiment_names=experiment_names, + ) + ) + + # Save results to output file + os.makedirs(os.path.dirname(os.path.abspath(eval_config.output_file)), exist_ok=True) + with open(eval_config.output_file, "w") as f: + json.dump(classifier_results_json, f, indent=2) + + logger.info(f"Saved classifier results to {eval_config.output_file}") + print(json.dumps(classifier_results_json, indent=2)) + + +if __name__ == "__main__": + chz.nested_entrypoint(run_eval_sweep) diff --git a/src/hpcai/cookbook/recipes/vlm_classifier/sweep.py b/src/hpcai/cookbook/recipes/vlm_classifier/sweep.py new file mode 100644 index 0000000..152b851 --- /dev/null +++ b/src/hpcai/cookbook/recipes/vlm_classifier/sweep.py @@ -0,0 +1,222 @@ +""" + +## VLM Image Classifier + +Launcher for training image classifiers based on VLMs. + +""" + +import os +import asyncio +from concurrent.futures import ProcessPoolExecutor +from datetime import datetime +from itertools import product + +import chz +from hpcai.cookbook.renderers import TrainOnWhat +from hpcai.cookbook.utils.lr_scheduling import LRSchedule +from hpcai.cookbook import cli_utils +from hpcai.cookbook.recipes.vlm_classifier.data import get_dataset_builder +from hpcai.cookbook.recipes.vlm_classifier.eval import get_evaluator_builder +from hpcai.cookbook.supervised import train + + +@chz.chz +class ExperimentConfig: + """ + Experiments for few-shot image classification with VLMs. + """ + + experiment_dir: str + + dataset: str = "caltech101" + renderer_name: str = "qwen3_vl" + model_name: str = "Qwen/Qwen3-VL-235B-A22B-Instruct" + + # Infrastructure parameters + base_url: str | None = None + + # Training parameters + learning_rate: float = 5e-4 + num_epochs: int = 1 + lr_schedule: LRSchedule = "cosine" + + # Model parameters + lora_rank: int = 32 + + # Checkpointing and evaluation + save_every: int = 50 + eval_every: int = 50 + infrequent_eval_every: int = 100 + + # Logging parameters + wandb_project: str | None = None + + train_on_what: TrainOnWhat = TrainOnWhat.LAST_ASSISTANT_MESSAGE + + num_repeats: float = 10 + batch_size: int = 32 + max_length: int = 8192 + + examples_per_class: int | None = None + subset_seed: int = 0 + + run_nll_evaluator: bool = False + run_sampling_evaluator: bool = True + + temperature: float = 0.0 + max_tokens: int = 128 + top_p: float = 1.0 + top_k: int = -1 + + n_eval: int = 256 + + +def run_experiment(experiment_config: ExperimentConfig): + """ + Run a supervised training experiment for a vlm classifier. + """ + + # build full config + model_name = experiment_config.model_name.replace("/", "-") + date_and_time = datetime.now().strftime("%Y-%m-%d") + + # Include examples_per_class and subset_seed in run name if doing few-shot + shot_suffix = ( + f"-{experiment_config.examples_per_class}shot-seed{experiment_config.subset_seed}" + if experiment_config.examples_per_class + else "" + ) + experiment_name = f"{experiment_config.dataset}-{model_name}-{experiment_config.lora_rank}rank-{experiment_config.learning_rate}lr-{experiment_config.batch_size}batch{shot_suffix}-{date_and_time}" + + experiment_path = os.path.join(experiment_config.experiment_dir, experiment_name) + cli_utils.check_log_dir(experiment_path, behavior_if_exists="delete") + + dataset_builder = get_dataset_builder( + dataset=experiment_config.dataset, + model_name_for_tokenizer=experiment_config.model_name, + renderer_name=experiment_config.renderer_name, + num_repeats=experiment_config.num_repeats, + batch_size=experiment_config.batch_size, + max_length=experiment_config.max_length, + train_on_what=experiment_config.train_on_what, + examples_per_class=experiment_config.examples_per_class, + subset_seed=experiment_config.subset_seed, + run_nll_evaluator=experiment_config.run_nll_evaluator, + ) + + evaluator_builders = [] + if experiment_config.run_sampling_evaluator: + evaluator_builders = [ + get_evaluator_builder( + dataset=experiment_config.dataset, + model_name_for_tokenizer=experiment_config.model_name, + renderer_name=experiment_config.renderer_name, + temperature=experiment_config.temperature, + max_tokens=experiment_config.max_tokens, + top_p=experiment_config.top_p, + top_k=experiment_config.top_k, + n_eval=experiment_config.n_eval, + ) + ] + + config = train.Config( + log_path=experiment_path, + model_name=experiment_config.model_name, + dataset_builder=dataset_builder, + evaluator_builders=evaluator_builders, + infrequent_evaluator_builders=[], + learning_rate=experiment_config.learning_rate, + lr_schedule=experiment_config.lr_schedule, + num_epochs=experiment_config.num_epochs, + base_url=experiment_config.base_url, + wandb_project=experiment_config.wandb_project, + wandb_name=experiment_name, + lora_rank=experiment_config.lora_rank, + save_every=experiment_config.save_every, + eval_every=experiment_config.eval_every, + infrequent_eval_every=experiment_config.infrequent_eval_every, + ) + + asyncio.run(train.main(config)) + + +@chz.chz +class SweepConfig: + """ + Configuration for the sweep. + """ + + experiment_dir: str + + renderer_name: str = "qwen3_vl" + model_name: str = "Qwen/Qwen3-VL-235B-A22B-Instruct" + + datasets: list[str] = chz.field(default_factory=lambda: ["caltech101"]) + examples_per_class: list[int] = chz.field(default_factory=lambda: [1, 2, 4, 8, 16]) + + learning_rate: float = 1e-4 + num_epochs: int = 1 + lr_schedule: LRSchedule = "constant" + + lora_rank: int = 32 + + num_repeats: float = 10 + batch_size: int = 32 + max_length: int = 8192 + + run_nll_evaluator: bool = False + run_sampling_evaluator: bool = True + + base_url: str | None = None + wandb_project: str | None = None + + # Number of experiments to run in parallel + num_parallel: int = 5 + + +# Adjust the number of epochs based on the amount of data +EXAMPLES_TO_MULTIPLIER = {16: 1, 8: 2, 4: 4, 2: 8, 1: 16} + + +def run_sweep(sweep_config: SweepConfig): + """ + Run all experiments in parallel using ProcessPoolExecutor. + """ + + experiment_configs = [ + ExperimentConfig( + experiment_dir=sweep_config.experiment_dir, + model_name=sweep_config.model_name, + renderer_name=sweep_config.renderer_name, + dataset=target_dataset, + learning_rate=sweep_config.learning_rate, + num_epochs=sweep_config.num_epochs, + lr_schedule=sweep_config.lr_schedule, + lora_rank=sweep_config.lora_rank, + num_repeats=EXAMPLES_TO_MULTIPLIER[examples_per_class] * sweep_config.num_repeats, + batch_size=sweep_config.batch_size, + max_length=sweep_config.max_length, + examples_per_class=examples_per_class, + wandb_project=sweep_config.wandb_project, + base_url=sweep_config.base_url, + run_nll_evaluator=sweep_config.run_nll_evaluator, + run_sampling_evaluator=sweep_config.run_sampling_evaluator, + ) + for target_dataset, examples_per_class in product( + sweep_config.datasets, sweep_config.examples_per_class + ) + ] + + print( + f"Running {len(experiment_configs)} experiments with {sweep_config.num_parallel} parallel workers" + ) + + with ProcessPoolExecutor(max_workers=sweep_config.num_parallel) as executor: + futures = [executor.submit(run_experiment, config) for config in experiment_configs] + results = [f.result() for f in futures] + print(f"{len(results)} experiments finished running") + + +if __name__ == "__main__": + chz.nested_entrypoint(run_sweep) diff --git a/src/hpcai/cookbook/recipes/vlm_classifier/train.py b/src/hpcai/cookbook/recipes/vlm_classifier/train.py new file mode 100644 index 0000000..30b8482 --- /dev/null +++ b/src/hpcai/cookbook/recipes/vlm_classifier/train.py @@ -0,0 +1,152 @@ +""" + +## VLM Image Classifier + +Launcher for training image classifiers based on VLMs. + +""" + +import os + +import asyncio +from datetime import datetime +from typing import Literal + +import chz +from hpcai.cookbook.renderers import TrainOnWhat +from hpcai.cookbook.utils.lr_scheduling import LRSchedule +from hpcai.cookbook import cli_utils +# from hpcai.cookbook.recipes.vlm_classifier.eval import get_evaluator_builder +from hpcai.cookbook.recipes.vlm_classifier.data import get_dataset_builder +from hpcai.cookbook.supervised import train + + +@chz.chz +class ExperimentConfig: + """ + Experiments for few-shot image classification with VLMs. + """ + + experiment_dir: str + load_checkpoint_path: str | None = None + + dataset: str = "caltech101" + + renderer_name: str = "qwen3_vl" + model_name: str = "Qwen/Qwen3-VL-8B-Instruct" + + # Infrastructure parameters + base_url: str | None = "http://0.0.0.0:8001" + behavior_if_log_dir_exists: Literal["delete", "resume", "ask", "raise"] = "ask" + + # Training parameters + learning_rate: float = 5e-4 + num_epochs: int = 3 + lr_schedule: LRSchedule = "cosine" + + # Model parameters + lora_rank: int = 32 + + # Checkpointing and evaluation + save_every: int = 50000 + eval_every: int = 50000 + infrequent_eval_every: int = 50000 + + # Logging parameters + wandb_project: str | None = None + wandb_name: str | None = None + + train_on_what: TrainOnWhat = TrainOnWhat.LAST_ASSISTANT_MESSAGE + + num_repeats: float = 1 + batch_size: int = 32 + max_length: int = 8192 + + examples_per_class: int | None = None + subset_seed: int = 0 + + run_nll_evaluator: bool = False + run_sampling_evaluator: bool = False + + temperature: float = 0.0 + max_tokens: int = 128 + top_p: float = 1.0 + top_k: int = -1 + + n_eval: int = 128 + + +def run_experiment(experiment_config: ExperimentConfig): + """ + Launcher for training an image classifier based on a VLM on a custom vision dataset. + """ + + # build full config + model_name = experiment_config.model_name.replace("/", "-") + date_and_time = datetime.now().strftime("%Y-%m-%d") + + # Include examples_per_class and subset_seed in run name if doing few-shot + shot_suffix = ( + f"-{experiment_config.examples_per_class}shot-seed{experiment_config.subset_seed}" + if experiment_config.examples_per_class + else "" + ) + experiment_name = f"{experiment_config.dataset}-{model_name}-{experiment_config.lora_rank}rank-{experiment_config.learning_rate}lr-{experiment_config.batch_size}batch{shot_suffix}-{date_and_time}" + + experiment_path = os.path.join(experiment_config.experiment_dir, experiment_name) + cli_utils.check_log_dir( + experiment_path, behavior_if_exists=experiment_config.behavior_if_log_dir_exists + ) + + dataset_builder = get_dataset_builder( + dataset=experiment_config.dataset, + model_name_for_tokenizer=experiment_config.model_name, + renderer_name=experiment_config.renderer_name, + num_repeats=experiment_config.num_repeats, + batch_size=experiment_config.batch_size, + max_length=experiment_config.max_length, + train_on_what=experiment_config.train_on_what, + examples_per_class=experiment_config.examples_per_class, + subset_seed=experiment_config.subset_seed, + run_nll_evaluator=experiment_config.run_nll_evaluator, + ) + + evaluator_builders = [] + # if experiment_config.run_sampling_evaluator: + # evaluator_builders = [ + # get_evaluator_builder( + # dataset=experiment_config.dataset, + # model_name_for_tokenizer=experiment_config.model_name, + # renderer_name=experiment_config.renderer_name, + # temperature=experiment_config.temperature, + # max_tokens=experiment_config.max_tokens, + # top_p=experiment_config.top_p, + # top_k=experiment_config.top_k, + # n_eval=experiment_config.n_eval, + # ) + # ] + + config = train.Config( + log_path=experiment_path, + model_name=experiment_config.model_name, + load_checkpoint_path=experiment_config.load_checkpoint_path, + dataset_builder=dataset_builder, + evaluator_builders=evaluator_builders, + infrequent_evaluator_builders=[], + learning_rate=experiment_config.learning_rate, + lr_schedule=experiment_config.lr_schedule, + num_epochs=experiment_config.num_epochs, + base_url=experiment_config.base_url, + wandb_project=experiment_config.wandb_project, + wandb_name=experiment_config.wandb_name or experiment_name, + lora_rank=experiment_config.lora_rank, + save_every=experiment_config.save_every, + eval_every=experiment_config.eval_every, + infrequent_eval_every=experiment_config.infrequent_eval_every, + ) + + asyncio.run(train.main(config)) + + +if __name__ == "__main__": + chz.nested_entrypoint(run_experiment) diff --git a/src/hpcai/cookbook/renderers.py b/src/hpcai/cookbook/renderers.py index db26704..f54f0af 100644 --- a/src/hpcai/cookbook/renderers.py +++ b/src/hpcai/cookbook/renderers.py @@ -12,35 +12,261 @@ import json import logging import re +import urllib.request from datetime import datetime from enum import StrEnum -from typing import Callable, NotRequired, TypedDict +from typing import NotRequired, Optional, TypedDict, Literal, Protocol, cast +from PIL import Image import hpcai import torch +import pydantic + +import io from hpcai.cookbook.tokenizer_utils import Tokenizer +from hpcai.cookbook.image_processing_utils import ImageProcessor logger = logging.getLogger(__name__) +# Tool types are based on kosong (https://github.com/MoonshotAI/kosong). + + +class StrictBase(pydantic.BaseModel): + """ + Pydantic base class that's immutable and doesn't silently ignore extra fields. + """ + + model_config = pydantic.ConfigDict(frozen=True, extra="forbid") + + def __str__(self) -> str: + return repr(self) + + +class ToolCall(StrictBase): + """ + Structured tool invocation following OpenAI/kosong format. + + This represents a request to invoke a tool/function. The structure follows + the OpenAI function calling format for compatibility with various LLM APIs. + + Example: + tool_call = ToolCall( + function=ToolCall.FunctionBody( + name="search", + arguments='{"query_list": ["python async", "pydantic validation"]}' + ), + id="call_abc123" + ) + """ + + class FunctionBody(pydantic.BaseModel): + """ + Tool call function body containing the tool name and arguments. + + The arguments field must be a valid JSON string that will be parsed + by the tool implementation. + """ + + name: str + """The name of the tool to be called.""" + arguments: str + """Arguments of the tool call in JSON string format.""" + + type: Literal["function"] = "function" + """Tool call type, must be 'function' for compatibility.""" + + id: str | None = None + """Optional unique identifier for tracking this specific tool call.""" + + function: FunctionBody + """The function body containing tool name and arguments.""" + + +class ToolOk(StrictBase): + """ + Successful tool execution result. + + Used to indicate that a tool call completed successfully, with + the main output and optional metadata fields. + """ -class ToolCall(TypedDict): - name: str - # Each argument is a stringified JSON object - args: dict[str, str] + output: str + """The main output/result from the tool execution.""" + + message: str = "" + """Optional human-readable message about the execution.""" + + brief: str = "" + """Optional brief summary of the result for logging.""" + + +class ToolError(StrictBase): + """ + Tool execution error result. + + Used to indicate that a tool call failed or encountered an error, + with details about what went wrong. + """ + + output: str = "" + """Any partial output that was generated before the error.""" + + message: str = "" + """Error message describing what went wrong.""" + + brief: str = "" + """Brief error summary for logging.""" + + +ToolReturnType = ToolOk | ToolError +"""Union type for tool execution results - either success or error.""" + + +class ToolResult(StrictBase): + """ + Complete tool execution result with tracking ID. + + Wraps the actual result (ToolOk or ToolError) with the corresponding + tool call ID for correlation in multi-tool scenarios. + + Note: This class is defined for future use in handling multiple + concurrent tool calls with result correlation. + """ + + tool_call_id: str | None + """ID of the tool call this result corresponds to.""" + + result: ToolReturnType + """The actual execution result (success or error).""" + + +class TextPart(TypedDict): + """ + Container for a text part in a multimodal message. + + Args: + + type: Literal['text'] + The type of the content part, which must be text in this case. + text: str + The string content of the content part. + """ + + type: Literal["text"] + text: str + + +class ImagePart(TypedDict): + """ + Container for an image part in a multimodal message. + + Args: + + type: Literal['image'] + The type of the content part, which must be image in this case. + image: str | Image.Image + Either a url, data URL, or PIL image. + """ + + type: Literal["image"] + image: str | Image.Image + + +# Container for a part of a multimodal message content +ContentPart = TextPart | ImagePart # NOTE: we use a broad type definition for the role to be flexible # Common roles are "user", "assistant", "system", "tool" Role = str +# Content is a string or a list of parts +Content = str | list[ContentPart] + class Message(TypedDict): + """ + Container for a single turn in a multi-turn conversation. + + Args: + + role: Role + String that denotes the source of the message, typically system, user, assistant, and tool. + content: Content + Content of the message, can be a string, or a list of ContentPart. + tool_calls: NotRequired[list[ToolCall]] + Optional sequence of tool calls generated by the model. + thinking: NotRequired[str] + Optional thinking produced by the model before its final response. + trainable: NotRequired[bool] + Optional indicator whether this message should contribute to the training loss. + + """ + role: Role - content: str + content: Content + tool_calls: NotRequired[list[ToolCall]] thinking: NotRequired[str] trainable: NotRequired[bool] + tool_call_id: NotRequired[str] + name: NotRequired[str] + + +def ensure_text(content: Content) -> str: + """ + Assert that content is text-only and return it as a string. + + Raises ValueError if content contains images or multiple parts. + Use this to validate that message content is text-only before + processing it in code paths that don't support multimodal content. + """ + if isinstance(content, str): + return content + if len(content) == 1 and content[0]["type"] == "text": + return content[0]["text"] + raise ValueError(f"Expected text content, got multimodal content with {len(content)} parts") + +def get_text_content(message: Message) -> str: + """Extract text content from message, stripping thinking parts. + + Use this after parse_response when you only need the text output, + ignoring any thinking/reasoning content. + """ + content = message["content"] + if isinstance(content, str): + return content + return "".join(p["text"] for p in content if p["type"] == "text") + + +def _tool_call_payload(tool_call: ToolCall) -> dict[str, object]: + """Minimal JSON payload for embedding in blocks.""" + # Convert from nested structure to flat format for compatibility + return { + "name": tool_call.function.name, + "args": json.loads(tool_call.function.arguments), + } + + +class RenderedMessage(TypedDict): + """ + Container for parts of a rendered message, for masking. + + Args: + + prefix: NotRequired[hpcai.EncodedTextChunk] + Message header that typically includes the speaker's role in the conversation. + content: list[hpcai.ModelInputChunk] + Inner parts of the message that may include spans of image and text. + suffix: NotRequired[hpcai.EncodedTextChunk] + Message header that typically includes the turn stop token. + + """ + + prefix: NotRequired[hpcai.EncodedTextChunk] + content: list[hpcai.ModelInputChunk] + suffix: NotRequired[hpcai.EncodedTextChunk] class TrainOnWhat(StrEnum): @@ -52,42 +278,156 @@ class TrainOnWhat(StrEnum): CUSTOMIZED = "customized" -class Renderer: +class Renderer(Protocol): + """ + Render a message list into training and sampling prompts for language models. + """ + + tokenizer: Tokenizer + def __init__(self, tokenizer: Tokenizer): self.tokenizer = tokenizer - def build_supervised_example( - self, - messages: list[Message], - train_on_what: TrainOnWhat = TrainOnWhat.LAST_ASSISTANT_MESSAGE, - ) -> tuple[torch.Tensor, torch.Tensor]: - raise NotImplementedError + def _preprocess_message_parts(self, message: Message) -> list[ImagePart | TextPart]: + return ( + message["content"] + if isinstance(message["content"], list) + else [TextPart(type="text", text=message["content"])] + ) - def build_generation_prompt( - self, messages: list[Message], role: Role = "assistant", prefill: str | None = None - ) -> hpcai.ModelInput: - raise NotImplementedError + @property + def _bos_tokens(self) -> list[int]: + return [] def get_stop_sequences(self) -> list[str] | list[int]: raise NotImplementedError + def render_message(self, idx: int, message: Message, is_last: bool = False) -> RenderedMessage: + raise NotImplementedError + def parse_response(self, response: list[int]) -> tuple[Message, bool]: raise NotImplementedError + def build_generation_prompt( + self, messages: list[Message], role: Role = "assistant", prefill: str | None = None + ) -> hpcai.ModelInput: + """ + Generates tokens for sampling from the model. -def ensure_text(content) -> str: - """ - Assert that content is text-only and return it as a string. + Args: + messages: a list of messages to render. + role: the role of the partial message to be completed. + prefill: an optional string to prefill in the model's generation. + """ - Raises ValueError if content contains images or multiple parts. - Use this to validate that message content is text-only before - processing it in code paths that don't support multimodal content. - """ - if isinstance(content, str): - return content - if len(content) == 1 and content[0]["type"] == "text": - return content[0]["text"] - raise ValueError(f"Expected text content, got multimodal content with {len(content)} parts") + chunks: list[hpcai.types.ModelInputChunk] = [] + if self._bos_tokens: + chunks.append(hpcai.types.EncodedTextChunk(tokens=self._bos_tokens)) + for idx, message in enumerate(messages): + rendered_message = self.render_message(idx, message) + ob_chunk = rendered_message.get("prefix") + action_chunks = rendered_message["content"] + if ob_chunk: + chunks.append(ob_chunk) + chunks.extend([x for x in action_chunks if x]) + new_partial_message = Message(role=role, content="") + rendered_message = self.render_message(len(messages), new_partial_message) + ob_chunk = rendered_message.get("prefix") + if ob_chunk: + chunks.append(ob_chunk) + if prefill: + chunks.append( + hpcai.types.EncodedTextChunk( + tokens=self.tokenizer.encode(prefill, add_special_tokens=False) + ) + ) + return hpcai.ModelInput(chunks=chunks) + + def build_supervised_example( + self, + messages: list[Message], + train_on_what: TrainOnWhat = TrainOnWhat.LAST_ASSISTANT_MESSAGE, + ) -> tuple[hpcai.ModelInput, torch.Tensor]: + """ + Generates tokens and weights (for SFT) in the most standard way; by concatenating + together tokens and weights for each message. + + Args: + messages: a list of messages to render. + train_on_what: an enum that controls how the weights are assigned to the tokens. + - TrainOnWhat.LAST_ASSISTANT_MESSAGE: only the last assistant message is used for training + - TrainOnWhat.ALL_ASSISTANT_MESSAGES: all assistant messages are used for training + - TrainOnWhat.ALL_MESSAGES: all messages are used for training + - TrainOnWhat.ALL_TOKENS: all tokens are used for training + - TrainOnWhat.ALL_USER_AND_SYSTEM_MESSAGES: all user and system messages are used for training + - TrainOnWhat.CUSTOMIZED: each message has a trainable field, and the weights are assigned based on the trainable field + + Returns: + A tuple of two tensors: + - model_input: the hpcai ModelInput for your model + - weights: a tensor of weights + """ + + model_input_chunks_weights: list[tuple[hpcai.types.ModelInputChunk, float]] = [] + if self._bos_tokens: + model_input_chunks_weights.append( + (hpcai.types.EncodedTextChunk(tokens=self._bos_tokens), 0.0) + ) + + for idx, message in enumerate(messages): + if train_on_what == TrainOnWhat.CUSTOMIZED: + assert "trainable" in message, ( + "When using CUSTOMIZED train_on_what, each message must have a trainable field: True if loss is applied on this message, False otherwise" + ) + else: + assert "trainable" not in message, ( + "When using non-CUSTOMIZED train_on_what, each message must not have a trainable field. Either change train_on_what to CUSTOMIZED or remove the trainable field from the message" + ) + + is_last_message = idx == len(messages) - 1 + is_assistant = message["role"] == "assistant" + is_user_or_system = message["role"] in ["user", "system"] + + # only apply weight to observation part if train_on_what is ALL_TOKENS + rendered_message = self.render_message(idx, message, is_last=is_last_message) + ob_part = rendered_message.get("prefix") + action_parts = rendered_message.get("content") + action_tail = rendered_message.get("suffix") + + ob_weight = int(train_on_what == TrainOnWhat.ALL_TOKENS) + if ob_part: + model_input_chunks_weights += [(ob_part, ob_weight)] + + match train_on_what: + case TrainOnWhat.LAST_ASSISTANT_MESSAGE: + action_has_weight = is_last_message and is_assistant + case TrainOnWhat.ALL_ASSISTANT_MESSAGES: + action_has_weight = is_assistant + case TrainOnWhat.ALL_MESSAGES: + action_has_weight = True + case TrainOnWhat.ALL_TOKENS: + action_has_weight = True + case TrainOnWhat.ALL_USER_AND_SYSTEM_MESSAGES: + action_has_weight = is_user_or_system + case TrainOnWhat.CUSTOMIZED: + action_has_weight = message.get("trainable", False) + case _: + raise ValueError(f"Unknown train_on_what: {train_on_what}") + + model_input_chunks_weights += [ + (action_part, int(action_has_weight)) for action_part in action_parts if action_part + ] + + # action tail is effectively the stop_token and the start token for the next turn + # e.g. \n\nUser: + if is_last_message and action_tail: + model_input_chunks_weights += [(action_tail, int(action_has_weight))] + + weights_data = [w for chunk, w in model_input_chunks_weights for _ in range(chunk.length)] + weights_tensor = torch.tensor(weights_data) + + model_input_chunks = [chunk for chunk, _ in model_input_chunks_weights] + return hpcai.ModelInput(chunks=model_input_chunks), weights_tensor def tokens_weights_from_strings_weights( @@ -104,85 +444,6 @@ def tokens_weights_from_strings_weights( return tokens, weights -def build_supervised_example( - start_tokens: list[int], - render_message: Callable[[int, Message], tuple[list[int], list[int], list[int]]], - messages: list[Message], - train_on_what: TrainOnWhat = TrainOnWhat.LAST_ASSISTANT_MESSAGE, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Generates tokens and weights (for SFT) in the most standard way; by concatenating - together tokens and weights for each message. - - Args: - start_tokens: a list of tokens that are added at the beginning of the sequence. - render_message: a function that takes an index and a message and returns a tuple of three lists of tokens: - - ob_part: tokens for the observation part of the message - - action_part: tokens for the action part of the message - - action_tail: tokens that are generated by the assistant in this message, which are also - part of the ob part of the next message. (Only relevant for some renderers, such as RoleColonRenderer) - train_on_what: an enum that controls how the weights are assigned to the tokens. - - TrainOnWhat.LAST_ASSISTANT_MESSAGE: only the last assistant message is used for training - - TrainOnWhat.ALL_ASSISTANT_MESSAGES: all assistant messages are used for training - - TrainOnWhat.ALL_MESSAGES: all messages are used for training - - TrainOnWhat.ALL_TOKENS: all tokens are used for training - - TrainOnWhat.ALL_USER_AND_SYSTEM_MESSAGES: all user and system messages are used for training - - TrainOnWhat.CUSTOMIZED: each message has a trainable field, and the weights are assigned based on the trainable field - messages: a list of messages to render. - - Returns: - A tuple of two tensors: - - tokens: a tensor of tokens - - weights: a tensor of weights - """ - tokens_weights = [(token, 0) for token in start_tokens] - for idx, message in enumerate(messages): - if train_on_what == TrainOnWhat.CUSTOMIZED: - assert "trainable" in message, ( - "When using CUSTOMIZED train_on_what, each message must have a trainable field: True if loss is applied on this message, False otherwise" - ) - else: - assert "trainable" not in message, ( - "When using non-CUSTOMIZED train_on_what, each message must not have a trainable field. Either change train_on_what to CUSTOMIZED or remove the trainable field from the message" - ) - - is_last_message = idx == len(messages) - 1 - is_assistant = message["role"] == "assistant" - is_user_or_system = message["role"] in ["user", "system"] - - # only apply weight to observation part if train_on_what is ALL_TOKENS - ob_part, action_part, action_tail = render_message(idx, message) - ob_weight = int(train_on_what == TrainOnWhat.ALL_TOKENS) - tokens_weights += [(token, ob_weight) for token in ob_part] - - action_tokens = action_part - # action tail is effectively the stop_token and the start token for the next turn - # e.g. \n\nUser: - if is_last_message: - action_tokens += action_tail - - match train_on_what: - case TrainOnWhat.LAST_ASSISTANT_MESSAGE: - action_has_weight = is_last_message and is_assistant - case TrainOnWhat.ALL_ASSISTANT_MESSAGES: - action_has_weight = is_assistant - case TrainOnWhat.ALL_MESSAGES: - action_has_weight = True - case TrainOnWhat.ALL_TOKENS: - action_has_weight = True - case TrainOnWhat.ALL_USER_AND_SYSTEM_MESSAGES: - action_has_weight = is_user_or_system - case TrainOnWhat.CUSTOMIZED: - action_has_weight = message.get("trainable", False) - case _: - raise ValueError(f"Unknown train_on_what: {train_on_what}") - - tokens_weights += [(token, int(action_has_weight)) for token in action_tokens] - - tokens, weights = zip(*tokens_weights, strict=True) - return torch.tensor(tokens), torch.tensor(weights) - - def parse_response_for_stop_token( response: list[int], tokenizer: Tokenizer, stop_token: int ) -> tuple[Message, bool]: @@ -218,49 +479,29 @@ class RoleColonRenderer(Renderer): except that they use "Human" instead of "User". """ - def _render_message(self, message: Message) -> tuple[list[int], list[int], list[int]]: + def render_message(self, idx: int, message: Message, is_last: bool = False) -> RenderedMessage: assert message.get("thinking") is None, "Thinking tokens not supported in RoleColonRenderer" + assert isinstance(message["content"], str), ( + "RoleColonRenderer only supports message with string content" + ) ob_str = message["role"].capitalize() + ":" # Observation (prompt) part ac_str = " " + message["content"] + "\n\n" # Action part ac_tail_str = "User:" if message["role"] == "assistant" else "" # Action part that's only included in the last message in SFT - return ( - self.tokenizer.encode(ob_str, add_special_tokens=False), - self.tokenizer.encode(ac_str, add_special_tokens=False), - self.tokenizer.encode(ac_tail_str, add_special_tokens=False), + prefix = hpcai.types.EncodedTextChunk( + tokens=self.tokenizer.encode(ob_str, add_special_tokens=False) ) - - def build_generation_prompt( - self, messages: list[Message], role: Role = "assistant", prefill: str | None = None - ) -> hpcai.ModelInput: - tokens: list[int] = [] - tokens.extend(self._bos_tokens) - for message in messages: - ob_part, action_part, action_tail = self._render_message(message) - tokens.extend(ob_part) - tokens.extend(action_part) - new_partial_message = Message(role=role, content="") - ob_part, _action_part, _action_tail = self._render_message(new_partial_message) - tokens.extend(ob_part) - tokens.extend(self.tokenizer.encode(prefill or "", add_special_tokens=False)) - return hpcai.ModelInput.from_ints(tokens) - - def build_supervised_example( - self, - messages: list[Message], - train_on_what: TrainOnWhat = TrainOnWhat.LAST_ASSISTANT_MESSAGE, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Get tokens and weights for action corresponding to final message - """ - return build_supervised_example( - self._bos_tokens, - lambda _idx, message: self._render_message(message), - messages, - train_on_what, + content: list[hpcai.ModelInputChunk] = [ + hpcai.types.EncodedTextChunk( + tokens=self.tokenizer.encode(ac_str, add_special_tokens=False) + ) + ] + suffix = hpcai.types.EncodedTextChunk( + tokens=self.tokenizer.encode(ac_tail_str, add_special_tokens=False) ) + return RenderedMessage(prefix=prefix, content=content, suffix=suffix) def get_stop_sequences(self) -> list[str]: return ["\n\nUser:"] @@ -300,49 +541,23 @@ class Llama3Renderer(Renderer): """ - def _render_message(self, message: Message) -> tuple[list[int], list[int], list[int]]: + def render_message(self, idx: int, message: Message, is_last: bool = False) -> RenderedMessage: assert message.get("thinking") is None, "CoT tokens not supported in Llama3" + assert isinstance(message["content"], str), ( + "Llama3Renderer only supports message with string content" + ) ob_str = f"<|start_header_id|>{message['role']}<|end_header_id|>\n\n" # Observation (prompt) part ac_str = f"{message['content']}<|eot_id|>" - # Action part - ac_tail_str = "" # No action tail needed for Llama3 format - # Action part that's only included in the last message in SFT - return ( - self.tokenizer.encode(ob_str, add_special_tokens=False), - self.tokenizer.encode(ac_str, add_special_tokens=False), - self.tokenizer.encode(ac_tail_str, add_special_tokens=False), - ) - - def build_generation_prompt( - self, messages: list[Message], role: Role = "assistant", prefill: str | None = None - ) -> hpcai.ModelInput: - tokens: list[int] = [] - tokens.extend(self._bos_tokens) - for message in messages: - ob_part, action_part, action_tail = self._render_message(message) - tokens.extend(ob_part) - tokens.extend(action_part) - new_partial_message = Message(role=role, content="") - ob_part, _action_part, _action_tail = self._render_message(new_partial_message) - tokens.extend(ob_part) - tokens.extend(self.tokenizer.encode(prefill or "", add_special_tokens=False)) - return hpcai.ModelInput.from_ints(tokens) - - def build_supervised_example( - self, - messages: list[Message], - train_on_what: TrainOnWhat = TrainOnWhat.LAST_ASSISTANT_MESSAGE, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Get tokens and weights for action corresponding to final message - """ - return build_supervised_example( - self._bos_tokens, - lambda _idx, message: self._render_message(message), - messages, - train_on_what, + prefix = hpcai.types.EncodedTextChunk( + tokens=self.tokenizer.encode(ob_str, add_special_tokens=False) ) + content: list[hpcai.ModelInputChunk] = [ + hpcai.types.EncodedTextChunk( + tokens=self.tokenizer.encode(ac_str, add_special_tokens=False) + ) + ] + return RenderedMessage(prefix=prefix, content=content) @property def _bos_tokens(self) -> list[int]: @@ -372,16 +587,42 @@ class Qwen3Renderer(Renderer): I can help you with...<|im_end|> - - It is currently missing Qwen 3's functionality for removing thinking spans in multi-turn conversations. """ - def _render_message(self, idx: int, message: Message) -> tuple[list[int], list[int], list[int]]: + def __init__(self, tokenizer: Tokenizer, strip_thinking_from_history: bool = True): + """ + Args: + tokenizer: The tokenizer to use for encoding. + strip_thinking_from_history: When True (default), strips ... blocks + from assistant messages in multi-turn history. This matches how Qwen3 models + were trained - they only see their own thinking during the current turn, not + from previous turns. Set to False to preserve thinking in history (useful for + certain RL scenarios where you want the extension property for efficiency). + + See https://hpcai-docs.thinkingmachines.ai/rl/sequence-extension for details on + how this option affects multi-turn RL compute efficiency. + """ + super().__init__(tokenizer) + self.strip_thinking_from_history = strip_thinking_from_history + + def render_message(self, idx: int, message: Message, is_last: bool = False) -> RenderedMessage: assert message.get("thinking") is None, "TODO: support CoT in Qwen3 renderer" + assert isinstance(message["content"], str), ( + "Qwen3Renderer only supports message with string content" + ) maybe_newline = "\n" if idx > 0 else "" ob_str = f"{maybe_newline}<|im_start|>{message['role']}\n" ac_content = message["content"] - if message["role"] == "assistant" and "" not in ac_content: + if ( + self.strip_thinking_from_history + and message["role"] == "assistant" + and "" in ac_content + ): + # Multi-turn conversation, we remove the thinking section from the assistant message. + # This matches how Qwen3 models were trained - they only see their own thinking + # during the current turn, not from previous turns. + ac_content = ac_content.split("")[1].lstrip() + elif message["role"] == "assistant" and "" not in ac_content: # Matching the paper, we force the assistant to start with . Some SFT datasets include # in the assistant messages, we so don't need to re-add it in those cases. ob_str += "\n" @@ -389,44 +630,21 @@ def _render_message(self, idx: int, message: Message) -> tuple[list[int], list[i if "tool_calls" in message: ac_content += "\n".join( [ - f"\n{json.dumps(tool_call)}\n" + f"\n{json.dumps(_tool_call_payload(tool_call))}\n" for tool_call in message["tool_calls"] ] ) ac_content += "<|im_end|>" # Action part - ac_tail_str = "" # No action tail needed for Qwen format - # Action part that's only included in the last message in SFT - return ( - self.tokenizer.encode(ob_str, add_special_tokens=False), - self.tokenizer.encode(ac_content, add_special_tokens=False), - self.tokenizer.encode(ac_tail_str, add_special_tokens=False), + prefix = hpcai.types.EncodedTextChunk( + tokens=self.tokenizer.encode(ob_str, add_special_tokens=False) ) - - def build_generation_prompt( - self, messages: list[Message], role: Role = "assistant", prefill: str | None = None - ) -> hpcai.ModelInput: - tokens: list[int] = [] # No BOS token for Qwen - for idx, message in enumerate(messages): - ob_part, action_part, _ = self._render_message(idx, message) - tokens.extend(ob_part) - tokens.extend(action_part) - # Add generation prompt - new_partial_message = Message(role=role, content="") - ob_part, _, _ = self._render_message(len(messages), new_partial_message) - tokens.extend(ob_part) - tokens.extend(self.tokenizer.encode(prefill or "", add_special_tokens=False)) - return hpcai.ModelInput.from_ints(tokens) - - def build_supervised_example( - self, - messages: list[Message], - train_on_what: TrainOnWhat = TrainOnWhat.LAST_ASSISTANT_MESSAGE, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Get tokens and weights for action corresponding to final message. - """ - return build_supervised_example([], self._render_message, messages, train_on_what) + content: list[hpcai.ModelInputChunk] = [ + hpcai.types.EncodedTextChunk( + tokens=self.tokenizer.encode(ac_content, add_special_tokens=False) + ) + ] + return RenderedMessage(prefix=prefix, content=content) @property def _end_message_token(self) -> int: @@ -445,15 +663,20 @@ def _parse_tool_call(self, tool_call_str: str) -> list[ToolCall] | None: if not isinstance(tool_call, dict): return None - if ( - "name" not in tool_call - or "args" not in tool_call - or not isinstance(tool_call["name"], str) - or not isinstance(tool_call["args"], dict) - ): + name = tool_call.get("name") + args = tool_call.get("args") + tool_id = tool_call.get("id") + if not isinstance(name, str) or not isinstance(args, dict): return None - - return [ToolCall(**tool_call)] + if tool_id is not None and not isinstance(tool_id, str): + tool_id = None + # Convert to nested structure with arguments as JSON string + return [ + ToolCall( + function=ToolCall.FunctionBody(name=name, arguments=json.dumps(args)), + id=tool_id, + ) + ] def parse_response(self, response: list[int]) -> tuple[Message, bool]: assistant_message, parse_success = parse_response_for_stop_token( @@ -465,6 +688,7 @@ def parse_response(self, response: list[int]) -> tuple[Message, bool]: # Follow Qwen docs and Qwen-Agent's tool calling prompt to use ... tags to wrap the tool call. # - https://qwen.readthedocs.io/en/latest/getting_started/concepts.html#tool-calling # - https://github.com/QwenLM/Qwen-Agent/blob/main/qwen_agent/llm/fncall_prompts/nous_fncall_prompt.py#L279-L282 + assert isinstance(assistant_message["content"], str) match = re.search(r"(.*?)", assistant_message["content"], re.DOTALL) if match: tool_calls = self._parse_tool_call(match.group(1)) @@ -496,8 +720,11 @@ class Qwen3InstructRenderer(Qwen3Renderer): use the tag at all. """ - def _render_message(self, idx: int, message: Message) -> tuple[list[int], list[int], list[int]]: + def render_message(self, idx: int, message: Message, is_last: bool = False) -> RenderedMessage: assert message.get("thinking") is None, "CoT tokens not supported in Qwen3 instruct 2507" + assert isinstance(message["content"], str), ( + "Qwen3InstructRenderer only supports message with string content" + ) maybe_newline = "\n" if idx > 0 else "" ob_str = f"{maybe_newline}<|im_start|>{message['role']}\n" ac_content = message["content"] @@ -505,31 +732,185 @@ def _render_message(self, idx: int, message: Message) -> tuple[list[int], list[i if "tool_calls" in message: ac_content += "\n".join( [ - f"\n{json.dumps(tool_call)}\n" + f"\n{json.dumps(_tool_call_payload(tool_call))}\n" for tool_call in message["tool_calls"] ] ) ac_content += "<|im_end|>" # Action part - ac_tail_str = "" # No action tail needed for Qwen format - # Action part that's only included in the last message in SFT - return ( - self.tokenizer.encode(ob_str, add_special_tokens=False), - self.tokenizer.encode(ac_content, add_special_tokens=False), - self.tokenizer.encode(ac_tail_str, add_special_tokens=False), + prefix = hpcai.types.EncodedTextChunk( + tokens=self.tokenizer.encode(ob_str, add_special_tokens=False) + ) + content: list[hpcai.ModelInputChunk] = [ + hpcai.types.EncodedTextChunk( + tokens=self.tokenizer.encode(ac_content, add_special_tokens=False) + ) + ] + return RenderedMessage(prefix=prefix, content=content) + + +class ImageProcessorProtocol(Protocol): + merge_size: int + patch_size: int + + def get_number_of_image_patches( + self, height: int, width: int, images_kwargs: Optional[dict] = None + ) -> int: + raise NotImplementedError() + + +def image_to_chunk( + image_or_str: Image.Image | str, image_processor: ImageProcessorProtocol +) -> hpcai.types.ImageChunk: + """ + Convert a PIL Image to a hpcai.types.ImageChunk for QwenVL + """ + + # load an image from a data URI or a URL + if isinstance(image_or_str, str): + with urllib.request.urlopen(image_or_str) as response: + pil_image = Image.open(io.BytesIO(response.read())) + + # Otherwise the image is a PIL image and can be loaded directly + elif isinstance(image_or_str, Image.Image): + pil_image = image_or_str + + # Validate the provided data is actually a valid image type + else: + raise ValueError("The provided image must be a PIL.Image.Image, URL, or data URI.") + + # Convert to RGB if needed (JPEG doesn't support RGBA/LA/P modes) + if pil_image.mode in ("RGBA", "LA", "P"): + pil_image = pil_image.convert("RGB") + + img_byte_arr = io.BytesIO() + pil_image.save(img_byte_arr, format="JPEG") + image_data = img_byte_arr.getvalue() + + width, height = pil_image.size + num_image_tokens = ( + image_processor.get_number_of_image_patches(height, width, images_kwargs={}) + // image_processor.merge_size**2 + ) + + return hpcai.types.ImageChunk( + data=image_data, + format="jpeg", + expected_tokens=num_image_tokens, + ) + + +class Qwen3VLRenderer(Qwen3Renderer): + """ + Format like this: + <|im_start|>system + You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|> + <|im_start|>user + What can you help me with?<|im_end|> + <|im_start|>assistant + + + + I can help you with...<|im_end|> + + It is currently missing Qwen 3's functionality for removing thinking spans in multi-turn conversations. + """ + + image_processor: ImageProcessor + + def __init__(self, tokenizer: Tokenizer, image_processor: ImageProcessor): + self.tokenizer = tokenizer + self.image_processor = image_processor + + def _preprocess_message_parts(self, message: Message) -> list[ImagePart | TextPart]: + chunks: list[ImagePart | TextPart] = [] + + for content_chunk in super()._preprocess_message_parts(message): + if content_chunk["type"] == "image": + chunks.append(TextPart(type="text", text="<|vision_start|>")) + + chunks.append(content_chunk) + + if content_chunk["type"] == "image": + chunks.append(TextPart(type="text", text="<|vision_end|>")) + + return chunks + + def render_message(self, idx: int, message: Message, is_last: bool = False) -> RenderedMessage: + assert message.get("thinking") is None, "TODO: support CoT in Qwen3 renderer" + maybe_newline = "\n" if idx > 0 else "" + ob_str = f"{maybe_newline}<|im_start|>{message['role']}\n" + + ac_content_chunks = self._preprocess_message_parts(message) + + contains_think_token = any( + [ + ( + "" in x + if isinstance(x, str) + else "" in x["text"] + if isinstance(x, dict) and x["type"] == "text" + else False + ) + for x in ac_content_chunks + ] + ) + if message["role"] == "assistant" and not contains_think_token: + # Matching the paper, we force the assistant to start with . Some SFT datasets include + # in the assistant messages, we so don't need to re-add it in those cases. + ob_str += "\n" + # Observation (prompt) part + if "tool_calls" in message: + ac_content_chunks += [ + TextPart( + type="text", + text="\n".join( + [ + f"\n{json.dumps(_tool_call_payload(tool_call))}\n" + for tool_call in message["tool_calls"] + ] + ), + ) + ] + ac_content_chunks += [TextPart(type="text", text="<|im_end|>")] + # Action part + + ac_content_chunks_encoded: list[hpcai.ModelInputChunk] = [ + image_to_chunk( + image_or_str=x["image"], + image_processor=cast(ImageProcessorProtocol, self.image_processor), + ) + if x["type"] == "image" + else hpcai.EncodedTextChunk( + tokens=self.tokenizer.encode(x["text"], add_special_tokens=False) + ) + for x in ac_content_chunks + ] + + prefix = hpcai.types.EncodedTextChunk( + tokens=self.tokenizer.encode(ob_str, add_special_tokens=False) ) + return RenderedMessage(prefix=prefix, content=ac_content_chunks_encoded) class DeepSeekV3Renderer(Renderer): """ Format like this (no newlines between messages): - <|begin_of_sentence|><|User|>What can you help me with?<|Assistant|>Thinking...I can help you with...<|end_of_centence|> + <|begin_of_sentence|><|User|>What can you help me with?<|Assistant|>Thinking...I can help you with...<|end_of_sentence|> For no-think, just use <|Assistant|> + Deepseek renderer does not support the system role out of the box. You can set system_role_as_user to True to automatically convert the system role to the user role. """ - def _render_message(self, message: Message) -> tuple[list[int], list[int], list[int]]: + def __init__(self, tokenizer: Tokenizer, system_role_as_user: bool = False): + super().__init__(tokenizer) + self.system_role_as_user = system_role_as_user + + def render_message(self, idx: int, message: Message, is_last: bool = False) -> RenderedMessage: assert message.get("thinking") is None, "TODO: support CoT in DsV3 renderer" - if message["role"] == "user": + assert isinstance(message["content"], str), ( + "DeepSeekV3Renderer only supports message with string content" + ) + if message["role"] == "user" or (self.system_role_as_user and message["role"] == "system"): role_token = self._get_special_token("User") elif message["role"] == "assistant": role_token = self._get_special_token("Assistant") @@ -540,39 +921,10 @@ def _render_message(self, message: Message) -> tuple[list[int], list[int], list[ if message["role"] == "assistant": # end_of_message only for assistant in dsv3 ac.append(self._end_message_token) - # Action part that's only included in the last message in SFT - ac_tail = [] # No action tail needed for DsV3 format - return (ob, ac, ac_tail) - def build_generation_prompt( - self, messages: list[Message], role: Role = "assistant", prefill: str | None = None - ) -> hpcai.ModelInput: - tokens: list[int] = [] - tokens.extend(self._bos_tokens) - for message in messages: - ob_part, action_part, action_tail = self._render_message(message) - tokens.extend(ob_part) - tokens.extend(action_part) - new_partial_message = Message(role=role, content="") - ob_part, _action_part, _action_tail = self._render_message(new_partial_message) - tokens.extend(ob_part) - tokens.extend(self.tokenizer.encode(prefill or "", add_special_tokens=False)) - return hpcai.ModelInput.from_ints(tokens) - - def build_supervised_example( - self, - messages: list[Message], - train_on_what: TrainOnWhat = TrainOnWhat.LAST_ASSISTANT_MESSAGE, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Get tokens and weights for action corresponding to final message - """ - return build_supervised_example( - self._bos_tokens, - lambda _idx, message: self._render_message(message), - messages, - train_on_what, - ) + prefix = hpcai.types.EncodedTextChunk(tokens=ob) + content: list[hpcai.ModelInputChunk] = [hpcai.types.EncodedTextChunk(tokens=ac)] + return RenderedMessage(prefix=prefix, content=content) def _get_special_token(self, name: str) -> int: sep = chr(65372) @@ -601,14 +953,17 @@ class DeepSeekV3DisableThinkingRenderer(DeepSeekV3Renderer): Renderer that disables thinking for DsV3 models """ - def _render_message(self, message: Message) -> tuple[list[int], list[int], list[int]]: + def render_message(self, idx: int, message: Message, is_last: bool = False) -> RenderedMessage: + assert isinstance(message["content"], str), ( + "DeepSeekV3DisableThinkingRenderer only supports message with string content" + ) if ( message["role"] == "assistant" and not message["content"].startswith("") and not message["content"].startswith("") ): message["content"] = "" + message["content"] - return super()._render_message(message) + return super().render_message(idx, message, is_last) def build_generation_prompt( self, messages: list[Message], role: Role = "assistant", prefill: str | None = None @@ -617,6 +972,240 @@ def build_generation_prompt( return super().build_generation_prompt(messages, role, prefill) +class KimiK2Renderer(Renderer): + """ + Format for moonshotai/Kimi-K2-Thinking: + <|im_system|>system<|im_middle|>You are Kimi, an AI assistant created by Moonshot AI.<|im_end|> + <|im_user|>user<|im_middle|>What can you help me with?<|im_end|> + <|im_assistant|>assistant<|im_middle|>reasoningI can help you with...<|im_end|> + + Historical assistant messages use empty blocks, while the final assistant + response preserves reasoning_content in the thinking block. + """ + + def render_message(self, idx: int, message: Message, is_last: bool = False) -> RenderedMessage: + """ + Render a message. For assistant messages, is_last controls whether thinking is preserved + (True) or stripped to empty (False). + """ + assert isinstance(message["content"], str), ( + "KimiK2Renderer only supports message with string content" + ) + role = message["role"] + role_name = message.get("name", role) + + # Build role token based on role type + if role == "user": + ob_str = f"<|im_user|>{role_name}<|im_middle|>" + elif role == "assistant": + ob_str = f"<|im_assistant|>{role_name}<|im_middle|>" + elif role == "system": + ob_str = f"<|im_system|>{role_name}<|im_middle|>" + elif role == "tool": + ob_str = f"<|im_system|>{role_name}<|im_middle|>" + # Tool responses have special formatting + tool_call_id = message.get("tool_call_id", "") + ob_str += f"## Return of {tool_call_id}\n" + else: + ob_str = f"<|im_system|>{role_name}<|im_middle|>" + + # Build action content + ac_str = "" + if role == "assistant": + # For the last assistant message (is_last=True), preserve thinking; otherwise use empty think block + thinking = message.get("thinking", "") + if is_last and thinking: + ac_str = f"{thinking}" + else: + ac_str = "" + ac_str += message["content"] + + # Handle tool calls + if "tool_calls" in message and message["tool_calls"]: + ac_str += "<|tool_calls_section_begin|>" + for tool_call in message["tool_calls"]: + tool_id = tool_call.id or "" + args = tool_call.function.arguments + ac_str += f"<|tool_call_begin|>{tool_id}<|tool_call_argument_begin|>{args}<|tool_call_end|>" + ac_str += "<|tool_calls_section_end|>" + else: + ac_str = message["content"] + + ac_str += "<|im_end|>" + + prefix = hpcai.types.EncodedTextChunk(tokens=self.tokenizer.encode(ob_str)) + content: list[hpcai.ModelInputChunk] = [ + hpcai.types.EncodedTextChunk(tokens=self.tokenizer.encode(ac_str)) + ] + return RenderedMessage(prefix=prefix, content=content) + + def _get_default_system_chunk(self) -> hpcai.types.EncodedTextChunk: + """Returns chunk for the default system message if none is present.""" + system_str = "<|im_system|>system<|im_middle|>You are Kimi, an AI assistant created by Moonshot AI.<|im_end|>" + return hpcai.types.EncodedTextChunk(tokens=self.tokenizer.encode(system_str)) + + def build_generation_prompt( + self, messages: list[Message], role: Role = "assistant", prefill: str | None = None + ) -> hpcai.ModelInput: + chunks: list[hpcai.types.ModelInputChunk] = [] + + # Add default system prompt if no system message present + if len(messages) == 0 or messages[0]["role"] != "system": + chunks.append(self._get_default_system_chunk()) + + for idx, message in enumerate(messages): + # For generation prompt, no message is "last assistant" since we're generating new response + rendered_message = self.render_message(idx, message, is_last=False) + ob_chunk = rendered_message.get("prefix") + action_chunks = rendered_message["content"] + if ob_chunk: + chunks.append(ob_chunk) + chunks.extend([x for x in action_chunks if x]) + + # Add generation prompt for new assistant message + gen_prompt = f"<|im_assistant|>{role}<|im_middle|>" + chunks.append(hpcai.types.EncodedTextChunk(tokens=self.tokenizer.encode(gen_prompt))) + if prefill: + chunks.append(hpcai.types.EncodedTextChunk(tokens=self.tokenizer.encode(prefill))) + return hpcai.ModelInput(chunks=chunks) + + def build_supervised_example( + self, + messages: list[Message], + train_on_what: TrainOnWhat = TrainOnWhat.LAST_ASSISTANT_MESSAGE, + ) -> tuple[hpcai.ModelInput, torch.Tensor]: + """ + Override to properly handle thinking preservation for the last assistant message. + """ + # Find last non-tool-call assistant message index + last_assistant_idx = -1 + for idx in range(len(messages) - 1, -1, -1): + if messages[idx]["role"] == "assistant" and "tool_calls" not in messages[idx]: + last_assistant_idx = idx + break + + model_input_chunks_weights: list[tuple[hpcai.types.ModelInputChunk, float]] = [] + + # Add default system prompt if needed + if len(messages) == 0 or messages[0]["role"] != "system": + model_input_chunks_weights.append((self._get_default_system_chunk(), 0.0)) + + for idx, message in enumerate(messages): + if train_on_what == TrainOnWhat.CUSTOMIZED: + assert "trainable" in message, ( + "When using CUSTOMIZED train_on_what, each message must have a trainable field" + ) + else: + assert "trainable" not in message, ( + "When using non-CUSTOMIZED train_on_what, each message must not have a trainable field" + ) + + is_last_message = idx == len(messages) - 1 + is_assistant = message["role"] == "assistant" + is_user_or_system = message["role"] in ["user", "system"] + + # For Kimi K2, preserve thinking only for last non-tool-call assistant + is_last_assistant = idx >= last_assistant_idx and is_assistant + rendered_message = self.render_message(idx, message, is_last=is_last_assistant) + + ob_part = rendered_message.get("prefix") + action_parts = rendered_message.get("content") + + ob_weight = int(train_on_what == TrainOnWhat.ALL_TOKENS) + if ob_part: + model_input_chunks_weights += [(ob_part, ob_weight)] + + match train_on_what: + case TrainOnWhat.LAST_ASSISTANT_MESSAGE: + action_has_weight = is_last_message and is_assistant + case TrainOnWhat.ALL_ASSISTANT_MESSAGES: + action_has_weight = is_assistant + case TrainOnWhat.ALL_MESSAGES: + action_has_weight = True + case TrainOnWhat.ALL_TOKENS: + action_has_weight = True + case TrainOnWhat.ALL_USER_AND_SYSTEM_MESSAGES: + action_has_weight = is_user_or_system + case TrainOnWhat.CUSTOMIZED: + action_has_weight = message.get("trainable", False) + case _: + raise ValueError(f"Unknown train_on_what: {train_on_what}") + + model_input_chunks_weights += [ + (action_part, int(action_has_weight)) for action_part in action_parts if action_part + ] + + weights_data = [w for chunk, w in model_input_chunks_weights for _ in range(chunk.length)] + weights_tensor = torch.tensor(weights_data) + + model_input_chunks = [chunk for chunk, _ in model_input_chunks_weights] + return hpcai.ModelInput(chunks=model_input_chunks), weights_tensor + + @property + def _end_message_token(self) -> int: + tokens = self.tokenizer.encode("<|im_end|>") + assert len(tokens) == 1, f"Expected single token for <|im_end|>, got {len(tokens)}" + return tokens[0] + + def get_stop_sequences(self) -> list[int]: + return [self._end_message_token] + + def parse_response(self, response: list[int]) -> tuple[Message, bool]: + assistant_message, parse_success = parse_response_for_stop_token( + response, self.tokenizer, self._end_message_token + ) + if not parse_success: + return assistant_message, False + + content = assistant_message["content"] + assert isinstance(content, str) + + # Extract thinking content if present + think_match = re.search(r"(.*?)", content, re.DOTALL) + if think_match: + thinking = think_match.group(1) + # Remove the think block from content + content = content[think_match.end() :].lstrip() + assistant_message["thinking"] = thinking + assistant_message["content"] = content + + # Handle tool calls if present + if "<|tool_calls_section_begin|>" in content: + tool_section_match = re.search( + r"<\|tool_calls_section_begin\|>(.*?)<\|tool_calls_section_end\|>", + content, + re.DOTALL, + ) + if tool_section_match: + tool_section = tool_section_match.group(1) + tool_calls: list[ToolCall] = [] + + # Parse individual tool calls + tool_call_pattern = r"<\|tool_call_begin\|>(.*?)<\|tool_call_argument_begin\|>(.*?)<\|tool_call_end\|>" + for match in re.finditer(tool_call_pattern, tool_section, re.DOTALL): + tool_id = match.group(1) + args_str = match.group(2) + # Try to parse as JSON to validate, but store as string + try: + json.loads(args_str) + tool_calls.append( + ToolCall( + function=ToolCall.FunctionBody(name="", arguments=args_str), + id=tool_id if tool_id else None, + ) + ) + except json.JSONDecodeError: + return assistant_message, False + + if tool_calls: + assistant_message["tool_calls"] = tool_calls + # Remove tool section from content + content = content[: content.find("<|tool_calls_section_begin|>")] + assistant_message["content"] = content + + return assistant_message, True + + class GptOssRenderer(Renderer): """ Format like this (no newlines between messages, last message should end with <|return|> but be replaced by <|end|> when continuing the convo): @@ -646,10 +1235,11 @@ def __init__( "Reasoning effort must be set iff using system prompt" ) - def _render_message( - self, message: Message, is_last: bool = False - ) -> tuple[list[int], list[int], list[int]]: + def render_message(self, idx: int, message: Message, is_last: bool = False) -> RenderedMessage: assert message.get("tool_calls") is None, "TODO: support tools in gpt-oss renderer" + assert isinstance(message["content"], str), ( + "GptOssRenderer only supports message with string content" + ) # Observation (prompt) part ob_str = f"<|start|>{message['role']}" # Action part @@ -659,7 +1249,8 @@ def _render_message( # Assistant channels. See https://cookbook.openai.com/articles/openai-harmony thinking = message.get("thinking") - content = message.get("content", "") + message_content = message.get("content", "") + assert isinstance(message_content, str), "GptOssRenderer only supports string content" # Analysis channel (CoT) if thinking: @@ -668,7 +1259,7 @@ def _render_message( ac_str += f"<|channel|>analysis<|message|>{thinking}<|end|><|start|>assistant" # Final channel (Response Content) - ac_str += f"<|channel|>final<|message|>{content}" + ac_str += f"<|channel|>final<|message|>{message_content}" else: assert message.get("thinking") is None, ( "Thinking is only allowed for assistant messages" @@ -681,13 +1272,15 @@ def _render_message( # <|return|> ends the last-message in harmony (but should be replaced by <|end|> when continuing the convo) ac_str += "<|return|>" - # Action part that's only included in the last message in SFT - ac_tail_str = "" # No action tail needed for gpt-oss format - return ( - self.tokenizer.encode(ob_str, add_special_tokens=False), - self.tokenizer.encode(ac_str, add_special_tokens=False), - self.tokenizer.encode(ac_tail_str, add_special_tokens=False), + prefix = hpcai.types.EncodedTextChunk( + tokens=self.tokenizer.encode(ob_str, add_special_tokens=False) ) + content: list[hpcai.ModelInputChunk] = [ + hpcai.types.EncodedTextChunk( + tokens=self.tokenizer.encode(ac_str, add_special_tokens=False) + ) + ] + return RenderedMessage(prefix=prefix, content=content) def _build_system_prompt(self) -> str: current_date = ( @@ -699,48 +1292,14 @@ def _build_system_prompt(self) -> str: current_date=current_date, reasoning_effort=self.reasoning_effort ) - def build_generation_prompt( - self, messages: list[Message], role: Role = "assistant", prefill: str | None = None - ) -> hpcai.ModelInput: - tokens: list[int] = [] - tokens.extend(self._bos_tokens) + @property + def _bos_tokens(self) -> list[int]: + tokens = [] if self.use_system_prompt: tokens.extend( self.tokenizer.encode(self._build_system_prompt(), add_special_tokens=False) ) - for message in messages: - ob_part, action_part, action_tail = self._render_message(message) - tokens.extend(ob_part) - tokens.extend(action_part) - new_partial_message = Message(role=role, content="") - ob_part, _action_part, _action_tail = self._render_message(new_partial_message) - tokens.extend(ob_part) - tokens.extend(self.tokenizer.encode(prefill or "", add_special_tokens=False)) - return hpcai.ModelInput.from_ints(tokens) - - def build_supervised_example( - self, - messages: list[Message], - train_on_what: TrainOnWhat = TrainOnWhat.LAST_ASSISTANT_MESSAGE, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Get tokens and weights for action corresponding to final message - """ - start_tokens = self._bos_tokens - if self.use_system_prompt: - start_tokens.extend( - self.tokenizer.encode(self._build_system_prompt(), add_special_tokens=False) - ) - return build_supervised_example( - start_tokens, - lambda _idx, message: self._render_message(message, is_last=_idx == len(messages) - 1), - messages, - train_on_what, - ) - - @property - def _bos_tokens(self) -> list[int]: - return [] + return tokens @property def _return_token(self) -> int: @@ -755,13 +1314,18 @@ def parse_response(self, response: list[int]) -> tuple[Message, bool]: return parse_response_for_stop_token(response, self.tokenizer, self._return_token) -def get_renderer(name: str, tokenizer: Tokenizer) -> Renderer: +def get_renderer( + name: str, tokenizer: Tokenizer, image_processor: ImageProcessor | None = None +) -> Renderer: if name == "role_colon": return RoleColonRenderer(tokenizer) elif name == "llama3": return Llama3Renderer(tokenizer) elif name == "qwen3": return Qwen3Renderer(tokenizer) + elif name == "qwen3_vl": + assert image_processor is not None, "qwen3_vl renderer requires an image_processor" + return Qwen3VLRenderer(tokenizer, image_processor) elif name == "qwen3_disable_thinking": return Qwen3DisableThinkingRenderer(tokenizer) elif name == "qwen3_instruct": @@ -770,6 +1334,8 @@ def get_renderer(name: str, tokenizer: Tokenizer) -> Renderer: return DeepSeekV3Renderer(tokenizer) elif name == "deepseekv3_disable_thinking": return DeepSeekV3DisableThinkingRenderer(tokenizer) + elif name == "kimi_k2": + return KimiK2Renderer(tokenizer) elif name == "gpt_oss_no_sysprompt": return GptOssRenderer(tokenizer, use_system_prompt=False) elif name == "gpt_oss_low_reasoning": @@ -779,4 +1345,4 @@ def get_renderer(name: str, tokenizer: Tokenizer) -> Renderer: elif name == "gpt_oss_high_reasoning": return GptOssRenderer(tokenizer, use_system_prompt=True, reasoning_effort="high") else: - raise ValueError(f"Unknown renderer: {name}") + raise ValueError(f"Unknown renderer: {name}") \ No newline at end of file diff --git a/src/hpcai/types/__init__.py b/src/hpcai/types/__init__.py index 68f63be..2fc457c 100644 --- a/src/hpcai/types/__init__.py +++ b/src/hpcai/types/__init__.py @@ -86,6 +86,7 @@ from .future_retrieve_response import FutureRetrieveResponse as FutureRetrieveResponse from .compute_logprobs_response import ComputeLogprobsResponse as ComputeLogprobsResponse from .image_asset_pointer_chunk import ImageAssetPointerChunk as ImageAssetPointerChunk +from .image_chunk import ImageChunk as ImageChunk from .training_optim_step_params import TrainingOptimStepParams as _TrainingOptimStepParams from .weight_save_for_sampler_params import WeightSaveForSamplerParams as _WeightSaveForSamplerParams from .image_asset_pointer_chunk_param import ImageAssetPointerChunkParam as _ImageAssetPointerChunkParam diff --git a/src/hpcai/types/image_chunk.py b/src/hpcai/types/image_chunk.py new file mode 100644 index 0000000..03396a8 --- /dev/null +++ b/src/hpcai/types/image_chunk.py @@ -0,0 +1,54 @@ +# Copyright 2026 Thinking Machines Lab +# +# Licensed under the Apache License, Version 2.0 +# +# Modifications: +# - Adapted for HPC-AI cloud fine-tuning workflow +# Copyright © 2026 HPC-AI.COM + +import base64 +from typing import Union + +from pydantic import field_serializer, field_validator +from typing_extensions import Literal + +from .._models import StrictBase + +__all__ = ["ImageChunk"] + + +class ImageChunk(StrictBase): + data: bytes + """Image data as bytes""" + + format: Literal["png", "jpeg"] + """Image format""" + + expected_tokens: int | None = None + """Expected number of tokens this image represents. + This is only advisory: the hpcai backend will compute the number of tokens + from the image, and we can fail requests quickly if the tokens does not + match expected_tokens.""" + + type: Literal["image"] = "image" + + @field_validator("data", mode="before") + @classmethod + def validate_data(cls, value: Union[bytes, str]) -> bytes: + """Deserialize base64 string to bytes if needed.""" + if isinstance(value, str): + return base64.b64decode(value) + return value + + @field_serializer("data") + def serialize_data(self, value: bytes) -> str: + """Serialize bytes to base64 string for JSON.""" + return base64.b64encode(value).decode("utf-8") + + @property + def length(self) -> int: + if self.expected_tokens is None: + raise ValueError( + "ImageChunk expected_tokens needs to be set in order to compute the length" + ) + return self.expected_tokens \ No newline at end of file diff --git a/src/hpcai/types/model_input_chunk.py b/src/hpcai/types/model_input_chunk.py index 2ca18fd..d2feb8d 100644 --- a/src/hpcai/types/model_input_chunk.py +++ b/src/hpcai/types/model_input_chunk.py @@ -14,9 +14,10 @@ from .._utils import PropertyInfo from .encoded_text_chunk import EncodedTextChunk from .image_asset_pointer_chunk import ImageAssetPointerChunk +from .image_chunk import ImageChunk __all__ = ["ModelInputChunk"] ModelInputChunk: TypeAlias = Annotated[ - Union[EncodedTextChunk, ImageAssetPointerChunk], PropertyInfo(discriminator="type") + Union[EncodedTextChunk, ImageAssetPointerChunk, ImageChunk], PropertyInfo(discriminator="type") ] From 852800baf50ae3b7ba66db1e064d42aba1ff77a7 Mon Sep 17 00:00:00 2001 From: Liu Nazhou <1171509797@qq.com> Date: Tue, 27 Jan 2026 07:38:50 +0000 Subject: [PATCH 2/4] fix bugs --- src/hpcai/cookbook/data.py | 4 ++-- src/hpcai/cookbook/supervised/data.py | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/hpcai/cookbook/data.py b/src/hpcai/cookbook/data.py index ceebf1e..bd7418f 100644 --- a/src/hpcai/cookbook/data.py +++ b/src/hpcai/cookbook/data.py @@ -8,7 +8,7 @@ import torch from hpcai.types import Datum, ModelInput, TensorData from hpcai.cookbook.renderers import Message, Renderer, TrainOnWhat -from hpcai.cookbook.common import datum_from_tokens_weights +from hpcai.cookbook.supervised.common import datum_from_model_input_weights def datum_from_tokens_weights( @@ -49,4 +49,4 @@ def conversation_to_datum( ) -> Datum: """Common function to process a list of messages into a Datum.""" tokens, weights = renderer.build_supervised_example(conversation, train_on_what=train_on_what) - return datum_from_tokens_weights(tokens, weights, max_length) + return datum_from_model_input_weights(tokens, weights, max_length) diff --git a/src/hpcai/cookbook/supervised/data.py b/src/hpcai/cookbook/supervised/data.py index 7993f70..7f3e57f 100644 --- a/src/hpcai/cookbook/supervised/data.py +++ b/src/hpcai/cookbook/supervised/data.py @@ -20,7 +20,6 @@ from hpcai.cookbook.supervised.types import ChatDatasetBuilder, SupervisedDataset from hpcai.types import Datum, ModelInput, TensorData from hpcai.cookbook.renderers import Message, Renderer, TrainOnWhat -from hpcai.cookbook.common import datum_from_tokens_weights def datum_from_tokens_weights( @@ -61,7 +60,7 @@ def conversation_to_datum( ) -> Datum: """Common function to process a list of messages into a Datum.""" tokens, weights = renderer.build_supervised_example(conversation, train_on_what=train_on_what) - return datum_from_tokens_weights(tokens, weights, max_length) + return datum_from_model_input_weights(tokens, weights, max_length) def _one_of(a: Any, b: Any) -> bool: From dd1352589b05461ab27f8eabd9810d631acf2acb Mon Sep 17 00:00:00 2001 From: Liu Nazhou <1171509797@qq.com> Date: Tue, 27 Jan 2026 07:41:06 +0000 Subject: [PATCH 3/4] update gitignore --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 283cc7c..c0291d0 100644 --- a/.gitignore +++ b/.gitignore @@ -143,6 +143,9 @@ ENV/ env.bak/ venv.bak/ +code.diff +metrics.jsonl + # Spyder project settings .spyderproject .spyproject From 5aea5c595ef9f67ae3e51621ceaac1a77febb5ee Mon Sep 17 00:00:00 2001 From: Liu Nazhou <1171509797@qq.com> Date: Thu, 29 Jan 2026 10:10:11 +0000 Subject: [PATCH 4/4] Added image_chunk_param --- src/hpcai/types/__init__.py | 1 + src/hpcai/types/image_chunk_param.py | 28 ++++++++++++++++++++++ src/hpcai/types/model_input_chunk_param.py | 3 ++- 3 files changed, 31 insertions(+), 1 deletion(-) create mode 100644 src/hpcai/types/image_chunk_param.py diff --git a/src/hpcai/types/__init__.py b/src/hpcai/types/__init__.py index 2fc457c..1a0e55d 100644 --- a/src/hpcai/types/__init__.py +++ b/src/hpcai/types/__init__.py @@ -90,6 +90,7 @@ from .training_optim_step_params import TrainingOptimStepParams as _TrainingOptimStepParams from .weight_save_for_sampler_params import WeightSaveForSamplerParams as _WeightSaveForSamplerParams from .image_asset_pointer_chunk_param import ImageAssetPointerChunkParam as _ImageAssetPointerChunkParam +from .image_chunk_param import ImageChunkParam as _ImageChunkParam from .session_end_event_param import SessionEndEventParam as _SessionEndEventParam from .session_start_event_param import SessionStartEventParam as _SessionStartEventParam from .unhandled_exception_event import UnhandledExceptionEvent as UnhandledExceptionEvent diff --git a/src/hpcai/types/image_chunk_param.py b/src/hpcai/types/image_chunk_param.py new file mode 100644 index 0000000..650a9de --- /dev/null +++ b/src/hpcai/types/image_chunk_param.py @@ -0,0 +1,28 @@ +# Copyright 2026 Thinking Machines Lab +# +# Licensed under the Apache License, Version 2.0 +# +# Modifications: +# - Adapted for HPC-AI cloud fine-tuning workflow +# Copyright © 2026 HPC-AI.COM + +from typing_extensions import Literal +from typing_extensions import TypedDict + +__all__ = ["ImageChunkParam"] + + +class ImageChunkParam(TypedDict, total=False): + data: bytes + """Image data as bytes""" + + format: Literal["png", "jpeg"] + """Image format""" + + expected_tokens: int | None = None + """Expected number of tokens this image represents. + This is only advisory: the hpcai backend will compute the number of tokens + from the image, and we can fail requests quickly if the tokens does not + match expected_tokens.""" + + type: Literal["image"] = "image" \ No newline at end of file diff --git a/src/hpcai/types/model_input_chunk_param.py b/src/hpcai/types/model_input_chunk_param.py index 48093cc..e943aee 100644 --- a/src/hpcai/types/model_input_chunk_param.py +++ b/src/hpcai/types/model_input_chunk_param.py @@ -15,7 +15,8 @@ from .encoded_text_chunk_param import EncodedTextChunkParam from .image_asset_pointer_chunk_param import ImageAssetPointerChunkParam +from .image_chunk_param import ImageChunkParam __all__ = ["ModelInputChunkParam"] -ModelInputChunkParam: TypeAlias = Union[EncodedTextChunkParam, ImageAssetPointerChunkParam] +ModelInputChunkParam: TypeAlias = Union[EncodedTextChunkParam, ImageAssetPointerChunkParam, ImageChunkParam]