Abstract Classes¶
OmniGenBench provides a set of abstract base classes that define the core interfaces for datasets, models, metrics, and tokenizers. These abstract classes are designed to be subclassed, allowing users to implement custom logic for new data formats, model architectures, evaluation metrics, or sequence representations.
How to Use Abstract Classes:
Start by exploring the abstract base classes for datasets, models, metrics, and tokenizers.
To add new functionality, subclass the relevant abstract class and implement the required methods.
The package uses these abstract classes as the foundation for all built-in and user-extended components, ensuring consistency and interoperability.
Main Abstract Classes:
OmniDataset: Base class for datasets. Subclass to support new data formats or preprocessing logic.
OmniModel: Base class for models. Subclass to implement custom architectures or tasks.
OmniMetric: Base class for evaluation metrics. Subclass to define new metrics for benchmarking.
OmniTokenizer: Base class for tokenizers. Subclass to support new sequence representations.
Refer to the API documentation below for details on each abstract class, including their methods and usage examples.
OmniModel¶
- class omnigenbench.src.abc.abstract_model.OmniModel(config_or_model, tokenizer, *args, **kwargs)[source]
Bases:
ModuleThis class provides a unified interface for all genomic models in the OmniGenome framework. It handles model initialization, forward passes, loss computation, prediction, inference, and model persistence.
- static from_pretrained(model_name_or_path, tokenizer, *args, **kwargs)[source]
Loads a pre-trained model and tokenizer.
- Parameters:
model_name_or_path – The name or path of the pre-trained model.
tokenizer – The tokenizer to use.
args – Additional positional arguments.
kwargs – Additional keyword arguments.
- Returns:
An instance of OmniModel.
- inference(sequence_or_inputs, **kwargs)[source]
This method takes raw sequences or tokenized inputs and returns processed predictions that are ready for human consumption. It typically includes post-processing steps like converting logits to class labels or probabilities.
- Parameters:
sequence_or_inputs – A sequence (str), list of sequences, or tokenized inputs (dict/tuple).
**kwargs – Additional arguments for tokenization and inference.
- Returns:
dict – A dictionary containing the processed predictions, typically including ‘predictions’, ‘confidence’, and other human-readable outputs.
Example
>>> # Inference on a single sequence >>> results = model.inference("ATCGATCG") >>> print(results['predictions']) # Class labels >>> # Inference on multiple sequences >>> results = model.inference(["ATCGATCG", "GCTAGCTA"])
- last_hidden_state_forward(**inputs)[source]
Performs a forward pass to get the last hidden state from the base model. It also handles compatibility with different model architectures by mapping input parameters appropriately.
- Parameters:
**inputs – The inputs to the model, compatible with the base model’s forward method. Typically includes ‘input_ids’, ‘attention_mask’, and other model-specific parameters.
- Returns:
torch.Tensor – The last hidden state tensor.
Example
>>> inputs = { ... 'input_ids': torch.tensor([[1, 2, 3, 4]]), ... 'attention_mask': torch.tensor([[1, 1, 1, 1]]) ... } >>> hidden_states = model.last_hidden_state_forward(**inputs)
- load(path, **kwargs)[source]
Loads the model, tokenizer, and metadata from a directory.
- Parameters:
path – The directory to load the model from.
**kwargs – Additional arguments.
- Returns:
The loaded model instance.
- loss_function(logits, labels)[source]
Calculates the loss. This method should be implemented by concrete model classes to define how the loss is calculated for their specific task (classification, regression, etc.).
- Parameters:
logits (torch.Tensor) – The model’s output logits.
labels (torch.Tensor) – The ground truth labels.
- Returns:
torch.Tensor – The calculated loss.
- Raises:
NotImplementedError – If the method is not implemented by the subclass.
Example
>>> # In a classification model >>> loss = model.loss_function(logits, labels)
- model_info()[source]
Prints and returns detailed information about the model.
- Returns:
A string containing the model information.
- predict(sequence_or_inputs, **kwargs)[source]
This method takes raw sequences or tokenized inputs and returns the raw model outputs (logits, hidden states, etc.) without post-processing. It’s useful for getting the model’s direct predictions for further processing.
- Parameters:
sequence_or_inputs – A sequence (str), list of sequences, or tokenized inputs (dict/tuple).
**kwargs – Additional arguments for tokenization and inference.
- Returns:
dict – A dictionary containing the raw model outputs, typically including logits, last_hidden_state, and other model-specific outputs.
Example
>>> # Predict on a single sequence >>> outputs = model.predict("ATCGATCG") >>> # Predict on multiple sequences >>> outputs = model.predict(["ATCGATCG", "GCTAGCTA"])
- save(path, overwrite=False, dtype=torch.float16, **kwargs)[source]
Saves the model, tokenizer, and metadata to a directory.
- Parameters:
path – The directory to save the model to.
overwrite – Whether to overwrite the directory if it exists.
dtype – The data type to save the model weights in.
kwargs – Additional arguments.
- set_loss_fn(loss_function)[source]
Sets a custom loss function for the model. The loss function should be compatible with the model’s output format.
- Parameters:
loss_function (callable) – A callable loss function that takes logits and labels as arguments.
Example
>>> import torch.nn as nn >>> model.set_loss_fn(nn.CrossEntropyLoss())
- omnigenbench.src.abc.abstract_model.count_parameters(model)[source]
This function iterates through all parameters of a PyTorch model and counts only those that require gradients (i.e., trainable parameters).
- Parameters:
model (torch.nn.Module) – A PyTorch model.
- Returns:
int – The total number of trainable parameters.
Example
>>> model = OmniModelForSequenceClassification(config, tokenizer) >>> num_params = count_parameters(model) >>> print(f"Model has {num_params} trainable parameters")
OmniDataset¶
- class omnigenbench.src.abc.abstract_dataset.OmniDataset(data_source, tokenizer, max_length=None, **kwargs)[source]
Bases:
DatasetA unified interface for genomic datasets in the OmniGenome framework. It handles data loading, preprocessing, tokenization, and provides a PyTorch-compatible dataset interface.
The class supports various data formats and can handle different types of genomic tasks including classification, regression, and token-level tasks.
- Variables:
tokenizer – The tokenizer to use for processing sequences.
max_length (int) – The maximum sequence length for tokenization.
label2id (dict) – Mapping from labels to integer IDs.
id2label (dict) – Mapping from integer IDs to labels.
shuffle (bool) – Whether to shuffle the data.
structure_in (bool) – Whether to include secondary structure information.
drop_long_seq (bool) – Whether to drop sequences longer than max_length.
metadata (dict) – Metadata about the dataset including version info.
rna2structure (RNA2StructureCache) – Cache for RNA structure predictions.
- get_column(column_name)[source]
Returns all values for a specific column in the dataset.
- Parameters:
column_name (str) – The name of the column.
- Returns:
list – A list of values from the specified column.
- get_inputs_length()[source]
Calculates and returns statistics about sequence and label lengths.
- Returns:
dict – A dictionary with length statistics (min, max, avg).
- get_labels()[source]
Returns the set of unique labels in the dataset.
- Returns:
set – The set of unique labels.
- load_data_source(data_source, **kwargs)[source]
Loads data from a file or list of files.
- Parameters:
data_source (str or list) – Path to the data file or a list of paths.
**kwargs – Additional keyword arguments, e.g., max_examples.
- Returns:
list – A list of examples.
- prepare_input(instance, **kwargs)[source]
Prepares a single data instance for the model. Must be implemented by subclasses.
- Parameters:
instance (dict) – A single data instance (e.g., a dictionary).
**kwargs – Additional keyword arguments for tokenization.
- Returns:
dict – A dictionary of tokenized inputs.
- print_label_distribution()[source]
Print the distribution of labels for 0-dimensional (scalar) labels. This is useful for classification tasks where each sample has a single label.
- sample(n=1)[source]
Returns a random sample of n items from the dataset.
- Parameters:
n (int) – The number of samples to return.
- Returns:
list – A list of data samples.
- to(device)[source]
Moves all tensor data in the dataset to the specified device.
- Parameters:
device (str or torch.device) – The target device.
- Returns:
OmniDataset – The dataset itself.
- class omnigenbench.src.abc.abstract_dataset.OmniGenomeDict(*args, **kwargs)[source]
Bases:
dictThis class extends the standard Python dictionary to provide a convenient method for moving all tensor values to a specific device (CPU/GPU).
- to(device)[source]
Moves all tensor values in the dictionary to the specified device.
- Parameters:
device (str or torch.device) – The target device (e.g., ‘cuda:0’ or ‘cpu’).
- Returns:
OmniGenomeDict – The dictionary itself, with tensors moved to the new device.
Example
>>> data = OmniGenomeDict({'input_ids': torch.tensor([1, 2, 3])}) >>> data.to('cuda:0') # Moves tensors to GPU
- omnigenbench.src.abc.abstract_dataset.covert_input_to_tensor(data)[source]
This function traverses through nested dictionaries and lists, converting numerical values to PyTorch tensors while preserving the structure.
- Parameters:
data (list or dict) – A list or dictionary containing data samples.
- Returns:
list or dict – The data structure with numerical values converted to tensors.
Example
>>> data = [{'input_ids': [1, 2, 3], 'labels': [0]}] >>> tensor_data = covert_input_to_tensor(data) >>> print(type(tensor_data[0]['input_ids'])) # <class 'torch.Tensor'>
OmniTokenizer¶
- class omnigenbench.src.abc.abstract_tokenizer.OmniTokenizer(base_tokenizer=None, max_length=512, **kwargs)[source]
Bases:
objectThis class provides a unified interface for tokenizers in the OmniGenome framework. It wraps underlying tokenizers (typically from Hugging Face) and provides additional functionality for genomic sequence processing. It also supports custom tokenizer wrappers for specialized genomic tasks.
- Variables:
base_tokenizer – The underlying tokenizer instance (e.g., from Hugging Face).
max_length (int) – The default maximum sequence length.
metadata (dict) – Metadata about the tokenizer including version info.
u2t (bool) – Whether to convert ‘U’ to ‘T’.
t2u (bool) – Whether to convert ‘T’ to ‘U’.
add_whitespace (bool) – Whether to add whitespace between characters.
- decode(sequence, **kwargs)[source]
Converts a list of token IDs back into a sequence. Must be implemented by subclasses.
- Parameters:
sequence (list) – A list of token IDs.
**kwargs – Additional arguments.
- Returns:
str – The decoded sequence.
- Raises:
NotImplementedError – If the method is not implemented by the subclass.
Example
>>> # In a nucleotide tokenizer >>> sequence = tokenizer.decode([1, 2, 3, 4]) >>> print(sequence) # "ATCG"
- encode(sequence, **kwargs)[source]
Converts a sequence into a list of token IDs. Must be implemented by subclasses.
- Parameters:
sequence (str) – The input sequence.
**kwargs – Additional arguments.
- Returns:
list – A list of token IDs.
- Raises:
NotImplementedError – If the method is not implemented by the subclass.
Example
>>> # In a nucleotide tokenizer >>> token_ids = tokenizer.encode("ATCGATCG") >>> print(token_ids) # [1, 2, 3, 4, 1, 2, 3, 4]
- static from_pretrained(model_name_or_path, **kwargs)[source]
Loads a tokenizer from a pre-trained model path.
- Parameters:
model_name_or_path (str) – The name or path of the pre-trained model.
**kwargs – Additional arguments for the tokenizer.
- Returns:
OmniTokenizer – An instance of a tokenizer.
Example
>>> # Load from a pre-trained model >>> tokenizer = OmniTokenizer.from_pretrained("model_name") >>> # Load with custom parameters >>> tokenizer = OmniTokenizer.from_pretrained("model_name", trust_remote_code=True)
- save_pretrained(save_directory)[source]
Saves the base tokenizer to a directory.
- Parameters:
save_directory (str) – The directory to save the tokenizer to.
Example
>>> tokenizer.save_pretrained("./saved_tokenizer")
- tokenize(sequence, **kwargs)[source]
Converts a sequence into a list of tokens. Must be implemented by subclasses.
- Parameters:
sequence (str) – The input sequence.
**kwargs – Additional arguments.
- Returns:
list – A list of tokens.
- Raises:
NotImplementedError – If the method is not implemented by the subclass.
Example
>>> # In a nucleotide tokenizer >>> tokens = tokenizer.tokenize("ATCGATCG") >>> print(tokens) # ['A', 'T', 'C', 'G', 'A', 'T', 'C', 'G']
OmniMetrics¶
- class omnigenbench.src.abc.abstract_metric.OmniMetric(metric_func=None, ignore_y=None, *args, **kwargs)[source]
Bases:
objectThis class provides a unified interface for evaluation metrics in the OmniGenome framework. It integrates with scikit-learn’s metric functions and provides additional functionality for handling genomic data evaluation. The class automatically exposes all scikit-learn metrics as attributes, making them easily accessible for evaluation tasks.
- Variables:
metric_func (callable) – A callable metric function from sklearn.metrics.
ignore_y (any) – A value in the ground truth labels to be ignored during metric computation.
metadata (dict) – Metadata about the metric including version info.
- compute(y_true, y_pred) dict[source]
Computes the metric. This method must be implemented by subclasses.
- Parameters:
y_true – Ground truth labels.
y_pred – Predicted labels.
- Returns:
dict – A dictionary with the metric name as key and its value.
- Raises:
NotImplementedError – If the method is not implemented by the subclass.
Example
>>> # In a classification metric >>> result = metric.compute(y_true, y_pred) >>> print(result) # {'accuracy': 0.85}
- static flatten(y_true, y_pred)[source]
Flattens the ground truth and prediction arrays. It handles various input formats and converts them to 1D numpy arrays.
- Parameters:
y_true – Ground truth labels in any format that can be converted to numpy array.
y_pred – Predicted labels in any format that can be converted to numpy array.
- Returns:
tuple – A tuple of flattened y_true and y_pred as numpy arrays.
Example
>>> y_true = [[1, 2], [3, 4]] >>> y_pred = [[1, 2], [3, 4]] >>> flat_true, flat_pred = OmniMetric.flatten(y_true, y_pred) >>> print(flat_true.shape) # (4,)
OmniLoRA¶
This module provides Low-Rank Adaptation (LoRA) implementation for efficient fine-tuning of large genomic language models. LoRA reduces the number of trainable parameters by adding low-rank adaptation layers to existing model weights.
- class omnigenbench.src.lora.lora_model.OmniLoraModel(model, **kwargs)[source]
Bases:
ModuleThis class provides a wrapper around LoRA-adapted models, enabling efficient fine-tuning of large genomic language models while maintaining compatibility with the OmniGenome framework.
- Variables:
lora_model – The underlying LoRA-adapted model
config – Model configuration
device – Device the model is running on
dtype – Data type of the model parameters
- config()[source]
Get the configuration from the base model.
- Returns:
The configuration from the base model
- forward(*args, **kwargs)[source]
Forward pass through the LoRA model.
- Parameters:
*args – Positional arguments for the forward pass
**kwargs – Keyword arguments for the forward pass
- Returns:
The output from the LoRA model
- last_hidden_state_forward(**kwargs)[source]
Forward pass to get the last hidden state.
- Parameters:
**kwargs – Keyword arguments for the forward pass
- Returns:
Last hidden state from the base model
- model()[source]
Get the base model.
- Returns:
The base model
- model_info()[source]
Get information about the LoRA model.
- Returns:
Model information from the base model
- predict(*args, **kwargs)[source]
Generate predictions using the LoRA model.
- Parameters:
*args – Positional arguments for prediction
**kwargs – Keyword arguments for prediction
- Returns:
Model predictions
- save(*args, **kwargs)[source]
Save the LoRA model.
- Parameters:
*args – Positional arguments for saving
**kwargs – Keyword arguments for saving
- Returns:
Result of the save operation
- set_loss_fn(fn)[source]
Set the loss function for the LoRA model.
- Parameters:
fn – Loss function to set
- Returns:
Result of setting the loss function
- to(*args, **kwargs)[source]
Move the model to a specific device and data type.
- Parameters:
*args – Device specification (e.g., ‘cuda’, ‘cpu’)
**kwargs – Additional arguments including dtype
- Returns:
self – The model instance
- tokenizer()[source]
Get the tokenizer from the base model.
- Returns:
The tokenizer from the base model
- omnigenbench.src.lora.lora_model.auto_lora_model(model, **kwargs)[source]
This function automatically identifies suitable target modules and creates a LoRA-adapted version of the input model. It handles configuration setup and parameter freezing for efficient fine-tuning.
- Parameters:
model – The base model to adapt with LoRA
**kwargs – Additional LoRA configuration parameters
- Returns:
The LoRA-adapted model
- Raises:
AssertionError – If no target modules are found for LoRA injection
- omnigenbench.src.lora.lora_model.find_linear_target_modules(model, keyword_filter=None, use_full_path=True)[source]
This function searches through a model’s modules to identify linear layers that can be adapted using LoRA. It supports filtering by keyword patterns to target specific types of layers.
- Parameters:
model – The model to search for linear modules
keyword_filter (str, list, tuple, optional) – Keywords to filter modules by name
use_full_path (bool) – Whether to return full module paths or just names (default: True)
- Returns:
list – Sorted list of linear module names that can be targeted for LoRA
- Raises:
TypeError – If keyword_filter is not None, str, or a list/tuple of str