Source code for omnigenbench.auto.auto_train.auto_train

# -*- coding: utf-8 -*-
# file: auto_train.py
# time: 11:54 14/04/2024
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# huggingface: https://huggingface.co/yangheng
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
# Copyright (C) 2019-2024. All Rights Reserved.

import os
import time
import warnings

import findfile
import torch
from metric_visualizer import MetricVisualizer
from transformers import TrainingArguments, Trainer as HFTrainer

from ...auto.config.auto_config import AutoConfig
from ...src.lora.lora_model import OmniLoraModel
from ...src.abc.abstract_tokenizer import OmniTokenizer
from ...src.misc.utils import (
    seed_everything,
    fprint,
    load_module_from_path,
    clean_temp_checkpoint,
)
from ...src.trainer.accelerate_trainer import AccelerateTrainer
from ...src.trainer.trainer import Trainer

autotrain_evaluations = "./autotrain_evaluations"


[docs] class AutoTrain: """ This class provides a comprehensive framework for training genomic models on various datasets with minimal configuration. It handles dataset loading, model initialization, training configuration, and result tracking. AutoTrain supports various training scenarios including: - Single dataset training with multiple seeds - Different trainer backends (native, accelerate, huggingface) - Automatic metric visualization and result tracking - Configurable training parameters Attributes: dataset (str): The name or path of the dataset to use for training. model_name_or_path (str): The name or path of the model to train. tokenizer: The tokenizer to use for training. autocast (str): The autocast precision to use ('fp16', 'bf16', etc.). overwrite (bool): Whether to overwrite existing training results. trainer (str): The trainer to use ('native', 'accelerate', 'hf_trainer'). mv_path (str): Path to the metric visualizer file. mv (MetricVisualizer): The metric visualizer instance. """ def __init__( self, dataset, model_name_or_path, tokenizer=None, **kwargs, ): """ Initialize the AutoTrain instance. Args: dataset (str): The name or path of the dataset to use for training. model_name_or_path (str): The model instance, model name or model path of the model to train. tokenizer: The tokenizer to use. If None, it will be loaded from the model path. **kwargs: Additional keyword arguments. - autocast (str): The autocast precision to use ('fp16', 'bf16', etc.). Defaults to 'fp16'. - overwrite (bool): Whether to overwrite existing training results. Defaults to False. - trainer (str): The trainer to use ('native', 'accelerate', 'hf_trainer'). Defaults to 'accelerate'. Example: >>> # Initialize with a dataset and model >>> trainer = AutoTrain("dataset_name", "model_name") >>> # Initialize with custom settings >>> trainer = AutoTrain("dataset_name", "model_name", ... autocast="bf16", trainer="accelerate") """ self.dataset = dataset.rstrip("/") self.autocast = kwargs.pop("autocast", "fp16") self.overwrite = kwargs.pop("overwrite", False) self.trainer = kwargs.pop("trainer", "accelerate") self.model_name_or_path = model_name_or_path self.tokenizer = tokenizer if isinstance(self.model_name_or_path, str): self.model_name_or_path = self.model_name_or_path.rstrip("/") self.model_name = self.model_name_or_path.split("/")[-1] else: self.model_name = self.model_name_or_path.__class__.__name__ if isinstance(tokenizer, str): self.tokenizer = tokenizer.rstrip("/") os.makedirs(autotrain_evaluations, exist_ok=True) time_str = time.strftime("%Y%m%d_%H%M%S", time.localtime()) mv_name = f"{dataset}-{self.model_name}" self.mv_path = f"{autotrain_evaluations}/{mv_name}-{time_str}.mv" mv_paths = findfile.find_files( autotrain_evaluations, [dataset, self.model_name, ".mv"], ) if mv_paths and not self.overwrite: self.mv = MetricVisualizer.load(mv_paths[-1]) self.mv.summary(round=4) else: self.mv = MetricVisualizer(self.mv_path) self.train_info()
[docs] def train_info(self): """ Print and return information about the current training setup. Returns: str: A string containing training setup information. Example: >>> info = trainer.train_info() >>> print(info) """ info = f"Dataset Root: {self.dataset}\n" info += f"Model Name or Path: {self.model_name}\n" info += f"Tokenizer: {self.tokenizer}\n" info += f"Metric Visualizer Path: {self.mv_path}\n" fprint(info) return info
[docs] def run(self, **kwargs): """ This method loads the dataset configuration, initializes the model and tokenizer, and runs training across multiple seeds. It supports various training backends and automatic result tracking. Args: **kwargs: Additional keyword arguments that will override the default parameters in the dataset configuration. Example: >>> # Run training with default settings >>> trainer.run() >>> # Run with custom parameters >>> trainer.run(learning_rate=1e-4, batch_size=16) """ clean_temp_checkpoint(1) # clean temp checkpoint older than 1 day _kwargs = kwargs.copy() train_config_path = findfile.find_file( self.dataset, f"{self.dataset}.config".split("."), ) config = load_module_from_path("config", train_config_path) train_config = None for attr_name in dir(config): attr = getattr(config, attr_name) if isinstance(attr, AutoConfig): # Check if it is an instance of AutoConfig train_config = attr if train_config is None: raise ValueError( f"Could not find AutoConfig instance in {train_config_path}" ) fprint(f"Loaded config for {self.dataset} from {train_config_path}") fprint(train_config.args) # Init Tokenizer and Model if not self.tokenizer: tokenizer = OmniTokenizer.from_pretrained( self.model_name_or_path, trust_remote_code=True ) else: tokenizer = self.tokenizer for key, value in _kwargs.items(): if key in train_config: fprint("Override", key, "with", value, "according to the input kwargs") train_config.update({key: value}) else: warnings.warn( f"kwarg: {key} not found in train_config while setting {key} = {value}" ) train_config.update({key: value}) for key, value in train_config.items(): if key in train_config and key in _kwargs: _kwargs.pop(key) fprint( f"Autotrain Config for {self.dataset}:", "\n".join([f"{k}: {v}" for k, v in train_config.items()]), ) if not isinstance(train_config["seeds"], list): train_config["seeds"] = [train_config["seeds"]] random_seeds = train_config["seeds"] for seed in random_seeds: batch_size = ( train_config["batch_size"] if "batch_size" in train_config else 8 ) record_name = f"{os.path.basename(self.dataset)}-{self.model_name}".split( "/" )[-1] # check if the record exists if record_name in self.mv.transpose() and len( list(self.mv.transpose()[record_name].values())[0] ) >= len(random_seeds): continue seed_everything(seed) if self.model_name_or_path: model_cls = train_config["model_cls"] model = model_cls( self.model_name_or_path, tokenizer=tokenizer, label2id=train_config.label2id, num_labels=train_config["num_labels"], trust_remote_code=True, ignore_mismatched_sizes=True, ) else: raise ValueError( "model_name_or_path is not specified. Please provide a valid model name or path." ) if kwargs.get("lora_config", None) is not None: fprint("Applying LoRA to the model with config:", kwargs["lora_config"]) model = OmniLoraModel(model, **kwargs.get("lora_config", {})) # Init Trainer dataset_cls = train_config["dataset_cls"] if hasattr(model.config, "max_position_embeddings"): max_length = min( train_config["max_length"], model.config.max_position_embeddings, ) else: max_length = train_config["max_length"] train_set = dataset_cls( data_source=train_config["train_file"], tokenizer=tokenizer, label2id=train_config["label2id"], max_length=max_length, structure_in=train_config.get("structure_in", False), max_examples=train_config.get("max_examples", None), shuffle=train_config.get("shuffle", True), drop_long_seq=train_config.get("drop_long_seq", False), **_kwargs, ) test_set = dataset_cls( data_source=train_config["test_file"], tokenizer=tokenizer, label2id=train_config["label2id"], max_length=max_length, structure_in=train_config.get("structure_in", False), max_examples=train_config.get("max_examples", None), shuffle=False, drop_long_seq=train_config.get("drop_long_seq", False), **_kwargs, ) valid_set = dataset_cls( data_source=train_config["valid_file"], tokenizer=tokenizer, label2id=train_config["label2id"], max_length=max_length, structure_in=train_config.get("structure_in", False), max_examples=train_config.get("max_examples", None), shuffle=False, drop_long_seq=train_config.get("drop_long_seq", False), **_kwargs, ) if self.trainer == "hf_trainer": # Set up HuggingFace Trainer hf_kwargs = { k: v for k, v in kwargs.items() if hasattr(TrainingArguments, k) and k != "output_dir" } training_args = TrainingArguments( output_dir=f"./autotrain_evaluations/{self.model_name}", num_train_epochs=hf_kwargs.pop( "num_train_epochs", train_config["epochs"] ), per_device_train_batch_size=hf_kwargs.pop("batch_size", batch_size), per_device_eval_batch_size=hf_kwargs.pop("batch_size", batch_size), gradient_accumulation_steps=hf_kwargs.pop( "gradient_accumulation_steps", 1 ), learning_rate=hf_kwargs.pop("learning_rate", 2e-5), weight_decay=hf_kwargs.pop("weight_decay", 0), eval_strategy=hf_kwargs.pop("eval_strategy", "epoch"), save_strategy=hf_kwargs.pop("save_strategy", "epoch"), fp16=hf_kwargs.pop("fp16", True), remove_unused_columns=False, label_names=["labels"], **hf_kwargs, ) valid_set = valid_set if len(valid_set) else test_set if len(train_config["compute_metrics"]) > 1: fprint( "Multiple metrics not supported by HFTrainer, using the first one metric only." ) trainer = HFTrainer( model=model, args=training_args, train_dataset=train_set, eval_dataset=valid_set, compute_metrics=( train_config["compute_metrics"][0] if isinstance(train_config["compute_metrics"], list) else train_config["compute_metrics"] ), ) # Train and evaluate eval_result = trainer.evaluate( valid_set if len(valid_set) else test_set ) print(eval_result) train_result = trainer.train() eval_result = trainer.evaluate() test_result = trainer.evaluate(test_set if len(test_set) else valid_set) metrics = { "train": train_result.metrics, "eval": eval_result, "test": test_result, } fprint(metrics) else: optimizer = torch.optim.AdamW( filter(lambda p: p.requires_grad, model.parameters()), lr=( train_config["learning_rate"] if "learning_rate" in train_config else 2e-5 ), weight_decay=( train_config["weight_decay"] if "weight_decay" in train_config else 0 ), ) if self.trainer == "accelerate": trainer_cls = AccelerateTrainer else: trainer_cls = Trainer fprint(f"Using Trainer: {trainer_cls}") trainer = trainer_cls( model=model, train_dataset=train_set, eval_dataset=valid_set, test_dataset=test_set, batch_size=batch_size, patience=( train_config["patience"] if "patience" in train_config else 3 ), epochs=train_config["epochs"], gradient_accumulation_steps=train_config.get( "gradient_accumulation_steps", 1 ), optimizer=optimizer, loss_fn=( train_config["loss_fn"] if "loss_fn" in train_config else None ), compute_metrics=train_config["compute_metrics"], seed=seed, autocast=self.autocast, **_kwargs, ) metrics = trainer.train() print(_kwargs) if _kwargs.get("save_model", True): fprint( f"Saving model to {autotrain_evaluations}/{self.dataset}/{self.model_name}" ) save_path = os.path.join( autotrain_evaluations, self.dataset, self.model_name ) os.makedirs(save_path, exist_ok=True) trainer.save_model(save_path, overwrite=True) if metrics: for key, value in metrics["test"][-1].items(): try: value = float(value) except: pass # ignore non-float values self.mv.log(f"{record_name}", f"{key}", value) # for key, value in metrics['test'][-1].items(): # self.mv.log(f'{record_name}', f'test_{key}', value) # for i, valid_metrics in enumerate(metrics["valid"]): # for key, value in valid_metrics.items(): # self.mv.log(f'{record_name}', f'valid_epoch_{i}_{key}', value) self.mv.summary(round=4) self.mv.dump(self.mv_path) self.mv.to_csv(self.mv_path.replace(".mv", ".csv")) del model, trainer, optimizer torch.cuda.empty_cache()