Trainers¶
Base Trainer¶
Native training utilities.
This module provides a native PyTorch training framework for genomic models, including automatic mixed precision training, early stopping, metric tracking, and model checkpointing.
- class omnigenbench.src.trainer.trainer.Trainer(model: Module, device: str | device | None = None, **kwargs)[source]
Bases:
BaseTrainerNative PyTorch trainer for genomic models.
This trainer provides a complete training framework with automatic mixed precision, early stopping, metric tracking, and model checkpointing using native PyTorch without distributed training dependencies.
- Variables:
device – Device to run training on (CPU or GPU)
fast_dtype – Data type for mixed precision training
scaler – Gradient scaler for mixed precision training
Example
>>> trainer = Trainer( ... model=model, ... train_dataset=train_dataset, ... eval_dataset=eval_dataset, ... epochs=10, ... batch_size=32, ... optimizer=optimizer ... ) >>> metrics = trainer.train()
- evaluate() Dict[str, Any][source]
Evaluate the model on the validation dataset.
- Returns:
Dict[str, Any] – Dictionary containing evaluation metrics
- save_model(path_to_save: str, overwrite: bool = False, **kwargs) None[source]
Save the trained model.
- Parameters:
path_to_save (str) – Path to save the model
overwrite (bool) – Whether to overwrite existing files (default: False)
**kwargs – Additional keyword arguments
- test() Dict[str, Any][source]
Test the model on the test dataset.
- Returns:
Dict[str, Any] – Dictionary containing test metrics
- unwrap_model(model: Module | None = None) Module[source]
Unwrap the model from any distributed training wrappers.
- Parameters:
model (Optional[torch.nn.Module]) – Model to unwrap (default: None, uses self.model)
- Returns:
torch.nn.Module – The unwrapped model
HuggingFace Trainer¶
HuggingFace trainer integration for genomic models.
This module provides HuggingFace trainer wrappers for genomic models, enabling seamless integration with the HuggingFace training ecosystem while maintaining OmniGenome-specific functionality.
- class omnigenbench.src.trainer.hf_trainer.HFTrainer(model: Module, training_args: TrainingArguments | None = None, **kwargs)[source]
Bases:
BaseTrainerHuggingFace trainer wrapper for genomic models.
This class extends the OmniGenome BaseTrainer to integrate with HuggingFace Trainer while maintaining OmniGenome-specific metadata and functionality. It provides seamless integration with the HuggingFace training ecosystem.
- Variables:
hf_trainer – The underlying HuggingFace Trainer instance
training_args – HuggingFace TrainingArguments instance
metadata – Dictionary containing OmniGenome library information
Example
>>> from transformers import TrainingArguments >>> training_args = TrainingArguments( ... output_dir="./output", ... num_train_epochs=3, ... per_device_train_batch_size=16, ... ) >>> trainer = HFTrainer( ... model=model, ... train_dataset=train_dataset, ... eval_dataset=eval_dataset, ... training_args=training_args ... ) >>> metrics = trainer.train()
- evaluate() Dict[str, Any][source]
Evaluate the model on the validation dataset.
- Returns:
Dict[str, Any] – Dictionary containing evaluation metrics
- get_model(**kwargs) Module[source]
Get the trained model.
- Parameters:
**kwargs – Additional keyword arguments
- Returns:
torch.nn.Module – The trained model
- save_model(path_to_save: str, overwrite: bool = False, **kwargs) None[source]
Save the trained model.
- Parameters:
path_to_save (str) – Path to save the model
overwrite (bool) – Whether to overwrite existing files (default: False)
**kwargs – Additional keyword arguments
- test() Dict[str, Any][source]
Test the model on the test dataset.
- Returns:
Dict[str, Any] – Dictionary containing test metrics
- train(path_to_save: str | None = None, **kwargs) Dict[str, Any][source]
Train the model using HuggingFace Trainer.
- Parameters:
path_to_save (Optional[str]) – Path to save the trained model
**kwargs – Additional keyword arguments
- Returns:
Dict[str, Any] – Training metrics and results
- class omnigenbench.src.trainer.hf_trainer.HFTrainingArguments(*args, **kwargs)[source]
Bases:
TrainingArgumentsHuggingFace training arguments wrapper for genomic models.
This class extends the HuggingFace TrainingArguments to include OmniGenome-specific metadata while maintaining full compatibility with the HuggingFace training ecosystem.
- Variables:
metadata – Dictionary containing OmniGenome library information
Example
>>> training_args = HFTrainingArguments( ... output_dir="./output", ... num_train_epochs=3, ... per_device_train_batch_size=16, ... ) >>> trainer = HFTrainer(model=model, training_args=training_args)
Accelerate Trainer¶
Accelerate-based distributed training utilities.
This module provides HuggingFace Accelerate-based distributed training framework for genomic models, including automatic mixed precision training, distributed training support, early stopping, and model checkpointing.
- class omnigenbench.src.trainer.accelerate_trainer.AccelerateTrainer(model: Module, **kwargs)[source]
Bases:
BaseTrainerHuggingFace Accelerate-based distributed trainer for genomic models.
This trainer provides distributed training capabilities with automatic mixed precision, gradient accumulation, and early stopping. It supports both single and multi-GPU training with seamless integration with HuggingFace Accelerate.
- Variables:
accelerator – HuggingFace Accelerate instance for distributed training
early_stop_flag – Tensor for coordinating early stopping across processes
Example
>>> trainer = AccelerateTrainer( ... model=model, ... train_dataset=train_dataset, ... eval_dataset=eval_dataset, ... epochs=10, ... batch_size=32, ... optimizer=optimizer ... ) >>> metrics = trainer.train()
- evaluate() Dict[str, Any][source]
Evaluate the model on the validation dataset.
This method runs the model in evaluation mode and computes metrics on the validation dataset. It handles distributed evaluation and gathers results from all processes.
- Returns:
Dict[str, Any] – Dictionary containing evaluation metrics
- save_model(path_to_save: str, overwrite: bool = False, **kwargs) None[source]
Save the trained model.
- Parameters:
path_to_save (str) – Path to save the model
overwrite (bool) – Whether to overwrite existing files (default: False)
**kwargs – Additional keyword arguments for model saving
- test() Dict[str, Any][source]
Test the model on the test dataset.
This method runs the model in evaluation mode and computes metrics on the test dataset. It handles distributed testing and gathers results from all processes.
- Returns:
Dict[str, Any] – Dictionary containing test metrics
- train(path_to_save: str | None = None, **kwargs) Dict[str, Any][source]
Train the model using distributed training.
This method performs the complete training loop with validation, early stopping, and model checkpointing. It handles distributed training across multiple GPUs and processes.
- Parameters:
path_to_save (Optional[str]) – Path to save the trained model
**kwargs – Additional keyword arguments for model saving
- Returns:
Dict[str, Any] – Dictionary containing training metrics