Skip to content

Add TRL example notebook #207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 213 additions & 0 deletions integrations/model-training/trl/notebooks/Comet_with_trl.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<img src=\"https://cdn.comet.ml/img/notebook_logo.png\">"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[Comet](https://www.comet.com/site/products/ml-experiment-tracking/?utm_campaign=ray_train&utm_medium=colab) is an MLOps Platform that is designed to help Data Scientists and Teams build better models faster! Comet provides tooling to track, Explain, Manage, and Monitor your models in a single place! It works with Jupyter Notebooks and Scripts and most importantly it's 100% free to get started!\n",
"\n",
"[TRL](https://github.com/huggingface/trl) is a cutting-edge library designed for post-training foundation models using advanced techniques like Supervised Fine-Tuning (SFT), Proximal Policy Optimization (PPO), and Direct Preference Optimization (DPO).\n",
"\n",
"Instrument your runs with Comet to start managing experiments, create dataset versions and track hyperparameters for faster and easier reproducibility and collaboration.\n",
"\n",
"[Find more information about our integration with TRL](https://www.comet.ml/docs/v2/integrations/ml-frameworks/trl/)\n",
"\n",
"Get a preview for what's to come. Check out a completed experiment created from this notebook [here](TODO).\n",
"\n",
"This example is based on the [following Ray Train Lightning example](https://docs.ray.io/en/latest/train/getting-started-pytorch-lightning.html)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZYchV5RWwdv5"
},
"source": [
"# Install Dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "DJnmqphuY2eI"
},
"outputs": [],
"source": [
"%pip install \"comet_ml>=3.47.1\" \"trl>=0.13.0\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "crOcPHobwhGL"
},
"source": [
"# Initialize Comet"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HNQRM0U3caiY"
},
"outputs": [],
"source": [
"import comet_ml\n",
"\n",
"comet_ml.login()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cgqwGSwtzVWD"
},
"source": [
"# Import Dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "e-5rRYaUw5AF"
},
"outputs": [],
"source": [
"import torch\n",
"from datasets import load_dataset\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"\n",
"from trl import (\n",
" DPOConfig,\n",
" DPOTrainer,\n",
" ModelConfig,\n",
" ScriptArguments,\n",
" TrlParser,\n",
" get_kbit_device_map,\n",
" get_peft_config,\n",
" get_quantization_config,\n",
")\n",
"from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TJuThf1TxP_G"
},
"source": [
"# Load your dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dataset = load_dataset(\"trl-lib/ultrafeedback_binarized\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Train the model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = AutoModelForCausalLM.from_pretrained(\n",
" \"Qwen/Qwen2-0.5B-Instruct\",\n",
")\n",
"ref_model = AutoModelForCausalLM.from_pretrained(\n",
" \"Qwen/Qwen2-0.5B-Instruct\",\n",
")\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\n",
" \"Qwen/Qwen2-0.5B-Instruct\",\n",
")\n",
"if tokenizer.pad_token is None:\n",
" tokenizer.pad_token = tokenizer.eos_token\n",
"if tokenizer.chat_template is None:\n",
" tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE\n",
"\n",
"training_args = DPOConfig(\n",
" output_dir=\"/tmp\",\n",
" learning_rate=5.0e-7,\n",
" max_steps=10,\n",
" per_device_train_batch_size=1,\n",
" gradient_accumulation_steps=8,\n",
" logging_steps=1,\n",
" eval_strategy=\"steps\",\n",
" eval_steps=5,\n",
" report_to=[\"comet_ml\"],\n",
")\n",
"trainer = DPOTrainer(\n",
" model,\n",
" ref_model,\n",
" args=training_args,\n",
" train_dataset=dataset[\"train\"],\n",
" eval_dataset=dataset[\"test\"],\n",
" processing_class=tokenizer,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"comet_ml.end()"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Loading