Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ numpy
accelerate>=0.20.3
transformers>=4.34.1
torch
aim==3.17.5
aim==3.18.1
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since there's a desire to implement multiple trackers, did we want to make the dependency (and imports) optional, just used when available and configured?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can be done..but then do we still list these in the requirements.txt?

or do we throw an error and ask user to install the required tracker before importing it in the code.

sentencepiece
tokenizers>=0.13.3
tqdm
Expand Down
4 changes: 4 additions & 0 deletions tuning/config/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,7 @@ class TrainingArguments(transformers.TrainingArguments):
default=False,
metadata={"help": "Packing to be enabled in SFT Trainer, default is False"},
)
tracker: str.lower = field(
default=None,
metadata={"help": "Experiment tracker to use. Requires additional configs, see tuning.configs/tracker_configs.py"}
)
14 changes: 14 additions & 0 deletions tuning/config/tracker_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from dataclasses import dataclass

@dataclass
class AimConfig:
# Name of the experiment
experiment: str = None
# 'repo' can point to a locally accessible directory (e.g., '~/.aim') or a remote repository hosted on a server.
# When 'remote_server_ip' or 'remote_server_port' is set, it designates a remote aim repo.
# Otherwise, 'repo' specifies the directory, with a default of None representing '.aim'.
aim_repo: str = None
aim_remote_server_ip: str = None
aim_remote_server_port: int = None
# Location of where run_hash is exported
aim_run_hash_export_path: str = None
64 changes: 47 additions & 17 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from datetime import datetime
from typing import Optional, Union
import json
import os
import os, time

# Third Party
from peft.utils.other import fsdp_auto_wrap_policy
Expand All @@ -22,11 +22,12 @@
import transformers

# Local
from tuning.aim_loader import get_aimstack_callback
from tuning.config import configs, peft_config
from tuning.config import configs, peft_config, tracker_configs
from tuning.data import tokenizer_data_utils
from tuning.utils.config_utils import get_hf_peft_config
from tuning.utils.data_type_utils import get_torch_dtype
from tuning.tracker.tracker import Tracker
from tuning.tracker.aimstack_tracker import AimStackTracker


class PeftSavingCallback(TrainerCallback):
Expand All @@ -39,7 +40,6 @@ def on_save(self, args, state, control, **kwargs):
if "pytorch_model.bin" in os.listdir(checkpoint_path):
os.remove(os.path.join(checkpoint_path, "pytorch_model.bin"))


class FileLoggingCallback(TrainerCallback):
"""Exports metrics, e.g., training loss to a file in the checkpoint directory."""

Expand Down Expand Up @@ -84,6 +84,7 @@ def train(
peft_config: Optional[
Union[peft_config.LoraConfig, peft_config.PromptTuningConfig]
] = None,
tracker_config: Optional[Union[tracker_configs.AimConfig]] = None
):
"""Call the SFTTrainer

Expand All @@ -97,7 +98,6 @@ def train(
The peft configuration to pass to trainer
"""
run_distributed = int(os.environ.get("WORLD_SIZE", "1")) > 1

logger = logging.get_logger("sft_trainer")

# Validate parameters
Expand All @@ -115,13 +115,28 @@ def train(
train_args.fsdp = ""
train_args.fsdp_config = {"xla": False}

# Initialize the tracker early so we can calculate custom metrics like model_load_time.
tracker_name = train_args.tracker
if tracker_name == 'aim':
if tracker_config is not None:
tracker = AimStackTracker(tracker_config)
else:
logger.error("Tracker name is set to "+tracker_name+" but config is None.")
else:
logger.info('No tracker set so just set a dummy API which does nothing')
tracker = Tracker()

task_type = "CAUSAL_LM"

model_load_time = time.time()
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=train_args.cache_dir,
torch_dtype=get_torch_dtype(model_args.torch_dtype),
use_flash_attention_2=model_args.use_flash_attn,
)
model_load_time = time.time() - model_load_time
tracker.track(metric=model_load_time, name='model_load_time')

peft_config = get_hf_peft_config(task_type, peft_config)

Expand Down Expand Up @@ -212,11 +227,6 @@ def train(
formatted_validation_dataset = json_dataset["validation"].map(format_dataset)
logger.info(f"Validation dataset length is {len(formatted_validation_dataset)}")

aim_callback = get_aimstack_callback()
file_logger_callback = FileLoggingCallback(logger)
peft_saving_callback = PeftSavingCallback()
callbacks = [aim_callback, peft_saving_callback, file_logger_callback]

if train_args.packing:
logger.info("Packing is set to True")
data_collator = None
Expand All @@ -242,6 +252,15 @@ def train(
)
packing = False

# club and register callbacks
file_logger_callback = FileLoggingCallback(logger)
peft_saving_callback = PeftSavingCallback()
callbacks = [peft_saving_callback, file_logger_callback]

tracker_callback = tracker.get_hf_callback()
if tracker_callback is not None:
callbacks.append(tracker_callback)

trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
Expand Down Expand Up @@ -271,6 +290,7 @@ def main(**kwargs):
configs.TrainingArguments,
peft_config.LoraConfig,
peft_config.PromptTuningConfig,
tracker_configs.AimConfig,
)
)
parser.add_argument(
Expand All @@ -285,16 +305,26 @@ def main(**kwargs):
training_args,
lora_config,
prompt_tuning_config,
peft_method,
aim_config,
additional,
_,
) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
if peft_method.peft_method == "lora":
tune_config = lora_config
elif peft_method.peft_method == "pt":
tune_config = prompt_tuning_config

peft_method = additional.peft_method
if peft_method =="lora":
tune_config=lora_config
elif peft_method =="pt":
tune_config=prompt_tuning_config
else:
tune_config=None

tracker_name = training_args.tracker
if tracker_name == "aim":
tracker_config=aim_config
else:
tune_config = None
train(model_args, data_args, training_args, tune_config)
tracker_config=None

train(model_args, data_args, training_args, tune_config, tracker_config)


if __name__ == "__main__":
Expand Down
Empty file added tuning/tracker/__init__.py
Empty file.
47 changes: 47 additions & 0 deletions tuning/tracker/aimstack_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Standard
import os

from tuning.tracker.tracker import Tracker

# Third Party
from aim.hugging_face import AimCallback

class AimStackTracker(Tracker):

def __init__(self, tracker_config):
super().__init__(tracker_config)
c = self.config
exp = c.experiment
ip = c.aim_remote_server_ip
port = c.aim_remote_server_port
repo = c.aim_repo
hash_export_path = c.aim_run_hash_export_path

if (ip is not None and port is not None):
aim_callback = AimCallback(
repo="aim://" + ip +":"+ port + "/",
experiment=exp
)
if repo:
aim_callback = AimCallback(repo=repo, experiment=exp)
else:
aim_callback = AimCallback(experiment=exp)

run = aim_callback.experiment # Initialize Aim run
run_hash = run.hash # Extract the hash

# store the run hash
if hash_export_path:
with open(hash_export_path, 'w') as f:
f.write(str(run_hash)+'\n')

# Save Internal State
self.hf_callback = aim_callback
self.run = run

def get_hf_callback(self):
return self.hf_callback

def track(self, metric, name, stage='additional_metrics'):
context={'subset' : stage}
self.run.track(metric, name=name, context=context)
11 changes: 11 additions & 0 deletions tuning/tracker/tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Generic Tracker API

class Tracker:
def __init__(self, tracker_config) -> None:
self.config = tracker_config

def get_hf_callback():
return None

def track(self, metric, name, stage):
pass