Trainers

Base Trainer

Native training utilities.

This module provides a native PyTorch training framework for genomic models, including automatic mixed precision training, early stopping, metric tracking, and model checkpointing.

class omnigenbench.src.trainer.trainer.Trainer(model: Module, device: str | device | None = None, **kwargs)[source]

Bases: BaseTrainer

Native PyTorch trainer for genomic models.

This trainer provides a complete training framework with automatic mixed precision, early stopping, metric tracking, and model checkpointing using native PyTorch without distributed training dependencies.

Variables:
  • device – Device to run training on (CPU or GPU)

  • fast_dtype – Data type for mixed precision training

  • scaler – Gradient scaler for mixed precision training

Example

>>> trainer = Trainer(
...     model=model,
...     train_dataset=train_dataset,
...     eval_dataset=eval_dataset,
...     epochs=10,
...     batch_size=32,
...     optimizer=optimizer
... )
>>> metrics = trainer.train()
evaluate() Dict[str, Any][source]

Evaluate the model on the validation dataset.

Returns:

Dict[str, Any] – Dictionary containing evaluation metrics

save_model(path_to_save: str, overwrite: bool = False, **kwargs) None[source]

Save the trained model.

Parameters:
  • path_to_save (str) – Path to save the model

  • overwrite (bool) – Whether to overwrite existing files (default: False)

  • **kwargs – Additional keyword arguments

test() Dict[str, Any][source]

Test the model on the test dataset.

Returns:

Dict[str, Any] – Dictionary containing test metrics

unwrap_model(model: Module | None = None) Module[source]

Unwrap the model from any distributed training wrappers.

Parameters:

model (Optional[torch.nn.Module]) – Model to unwrap (default: None, uses self.model)

Returns:

torch.nn.Module – The unwrapped model

HuggingFace Trainer

HuggingFace trainer integration for genomic models.

This module provides HuggingFace trainer wrappers for genomic models, enabling seamless integration with the HuggingFace training ecosystem while maintaining OmniGenome-specific functionality.

class omnigenbench.src.trainer.hf_trainer.HFTrainer(model: Module, training_args: TrainingArguments | None = None, **kwargs)[source]

Bases: BaseTrainer

HuggingFace trainer wrapper for genomic models.

This class extends the OmniGenome BaseTrainer to integrate with HuggingFace Trainer while maintaining OmniGenome-specific metadata and functionality. It provides seamless integration with the HuggingFace training ecosystem.

Variables:
  • hf_trainer – The underlying HuggingFace Trainer instance

  • training_args – HuggingFace TrainingArguments instance

  • metadata – Dictionary containing OmniGenome library information

Example

>>> from transformers import TrainingArguments
>>> training_args = TrainingArguments(
...     output_dir="./output",
...     num_train_epochs=3,
...     per_device_train_batch_size=16,
... )
>>> trainer = HFTrainer(
...     model=model,
...     train_dataset=train_dataset,
...     eval_dataset=eval_dataset,
...     training_args=training_args
... )
>>> metrics = trainer.train()
evaluate() Dict[str, Any][source]

Evaluate the model on the validation dataset.

Returns:

Dict[str, Any] – Dictionary containing evaluation metrics

get_model(**kwargs) Module[source]

Get the trained model.

Parameters:

**kwargs – Additional keyword arguments

Returns:

torch.nn.Module – The trained model

save_model(path_to_save: str, overwrite: bool = False, **kwargs) None[source]

Save the trained model.

Parameters:
  • path_to_save (str) – Path to save the model

  • overwrite (bool) – Whether to overwrite existing files (default: False)

  • **kwargs – Additional keyword arguments

test() Dict[str, Any][source]

Test the model on the test dataset.

Returns:

Dict[str, Any] – Dictionary containing test metrics

train(path_to_save: str | None = None, **kwargs) Dict[str, Any][source]

Train the model using HuggingFace Trainer.

Parameters:
  • path_to_save (Optional[str]) – Path to save the trained model

  • **kwargs – Additional keyword arguments

Returns:

Dict[str, Any] – Training metrics and results

class omnigenbench.src.trainer.hf_trainer.HFTrainingArguments(*args, **kwargs)[source]

Bases: TrainingArguments

HuggingFace training arguments wrapper for genomic models.

This class extends the HuggingFace TrainingArguments to include OmniGenome-specific metadata while maintaining full compatibility with the HuggingFace training ecosystem.

Variables:

metadata – Dictionary containing OmniGenome library information

Example

>>> training_args = HFTrainingArguments(
...     output_dir="./output",
...     num_train_epochs=3,
...     per_device_train_batch_size=16,
... )
>>> trainer = HFTrainer(model=model, training_args=training_args)

Accelerate Trainer

Accelerate-based distributed training utilities.

This module provides HuggingFace Accelerate-based distributed training framework for genomic models, including automatic mixed precision training, distributed training support, early stopping, and model checkpointing.

class omnigenbench.src.trainer.accelerate_trainer.AccelerateTrainer(model: Module, **kwargs)[source]

Bases: BaseTrainer

HuggingFace Accelerate-based distributed trainer for genomic models.

This trainer provides distributed training capabilities with automatic mixed precision, gradient accumulation, and early stopping. It supports both single and multi-GPU training with seamless integration with HuggingFace Accelerate.

Variables:
  • accelerator – HuggingFace Accelerate instance for distributed training

  • early_stop_flag – Tensor for coordinating early stopping across processes

Example

>>> trainer = AccelerateTrainer(
...     model=model,
...     train_dataset=train_dataset,
...     eval_dataset=eval_dataset,
...     epochs=10,
...     batch_size=32,
...     optimizer=optimizer
... )
>>> metrics = trainer.train()
evaluate() Dict[str, Any][source]

Evaluate the model on the validation dataset.

This method runs the model in evaluation mode and computes metrics on the validation dataset. It handles distributed evaluation and gathers results from all processes.

Returns:

Dict[str, Any] – Dictionary containing evaluation metrics

save_model(path_to_save: str, overwrite: bool = False, **kwargs) None[source]

Save the trained model.

Parameters:
  • path_to_save (str) – Path to save the model

  • overwrite (bool) – Whether to overwrite existing files (default: False)

  • **kwargs – Additional keyword arguments for model saving

test() Dict[str, Any][source]

Test the model on the test dataset.

This method runs the model in evaluation mode and computes metrics on the test dataset. It handles distributed testing and gathers results from all processes.

Returns:

Dict[str, Any] – Dictionary containing test metrics

train(path_to_save: str | None = None, **kwargs) Dict[str, Any][source]

Train the model using distributed training.

This method performs the complete training loop with validation, early stopping, and model checkpointing. It handles distributed training across multiple GPUs and processes.

Parameters:
  • path_to_save (Optional[str]) – Path to save the trained model

  • **kwargs – Additional keyword arguments for model saving

Returns:

Dict[str, Any] – Dictionary containing training metrics