Downstream Models¶
Classification Models¶
- class omnigenbench.src.model.classification.model.OmniModelForMultiLabelSequenceClassification(config_or_model, tokenizer, *args, **kwargs)[source]
Bases:
OmniModelForSequenceClassificationThis model is designed for multi-label classification tasks where a single sequence can be assigned multiple labels simultaneously. It extends the sequence classification model with multi-label capabilities. It uses sigmoid activation instead of softmax to allow multiple labels per sequence and uses binary cross-entropy loss for training.
- Variables:
softmax (torch.nn.Sigmoid) – Sigmoid layer for multi-label probability computation.
loss_fn (torch.nn.BCELoss) – Binary cross-entropy loss function for training.
- inference(sequence_or_inputs, **kwargs)[source]
Performs multi-label inference with human-readable output. It converts logits to binary labels and provides confidence scores.
- Parameters:
sequence_or_inputs – A sequence (str), list of sequences, or tokenized inputs (dict/tuple).
**kwargs – Additional arguments for tokenization and inference.
- Returns:
dict –
- A dictionary containing:
predictions: Human-readable binary labels for each sequence
logits: Raw logits from the model
confidence: Confidence scores for predictions
last_hidden_state: Final hidden states
Example
>>> # Inference on a single sequence >>> results = model.inference("ATCGATCG") >>> print(results['predictions']) # tensor([1, 0, 1, 0])
- loss_function(logits, labels)[source]
Calculates the binary cross-entropy loss for multi-label classification.
- Parameters:
logits (torch.Tensor) – Predicted logits from the model.
labels (torch.Tensor) – Ground truth multi-label targets.
- Returns:
torch.Tensor – The computed loss value.
Example
>>> loss = model.loss_function(logits, labels)
- predict(sequence_or_inputs, **kwargs)[source]
This method takes raw sequences or tokenized inputs and returns multi-label predictions. It applies a threshold to determine which labels are active for each sequence.
- Parameters:
sequence_or_inputs – A sequence (str), list of sequences, or tokenized inputs (dict/tuple).
**kwargs – Additional arguments for tokenization and inference.
- Returns:
dict –
- A dictionary containing:
predictions: Multi-label predictions for each sequence
logits: Raw logits from the model
last_hidden_state: Final hidden states
Example
>>> # Predict on a single sequence >>> outputs = model.predict("ATCGATCG") >>> print(outputs['predictions']) # tensor([1, 0, 1, 0])
- class omnigenbench.src.model.classification.model.OmniModelForMultiLabelSequenceClassificationWith2DStructure(config_or_model, tokenizer, *args, **kwargs)[source]
Bases:
OmniModelForSequenceClassificationWith2DStructure- inference(sequence_or_inputs, **kwargs)[source]
This method provides processed, human-readable sequence-level predictions. It converts logits to class labels and provides confidence scores.
- Parameters:
sequence_or_inputs – A sequence (str), list of sequences, or tokenized inputs (dict/tuple).
**kwargs – Additional arguments for tokenization and inference.
- Returns:
dict –
- A dictionary containing:
predictions: Human-readable class labels for each sequence
logits: Raw logits from the model
confidence: Confidence scores for predictions
last_hidden_state: Final hidden states
Example
>>> # Inference on a single sequence >>> results = model.inference("ATCGATCG") >>> print(results['predictions']) # "positive" >>> print(results['confidence']) # 0.95
- loss_function(logits, labels)[source]
This method computes the cross-entropy loss between the predicted logits and the ground truth labels.
- Parameters:
logits (torch.Tensor) – Predicted logits from the model.
labels (torch.Tensor) – Ground truth labels.
- Returns:
torch.Tensor – The computed loss value.
Example
>>> loss = model.loss_function(logits, labels)
- predict(sequence_or_inputs, **kwargs)[source]
This method takes raw sequences or tokenized inputs and returns sequence-level predictions. It processes the inputs through the model and returns the predicted class for each sequence.
- Parameters:
sequence_or_inputs – A sequence (str), list of sequences, or tokenized inputs (dict/tuple).
**kwargs – Additional arguments for tokenization and inference.
- Returns:
dict –
- A dictionary containing:
predictions: Predicted class indices for each sequence
logits: Raw logits from the model
last_hidden_state: Final hidden states
Example
>>> # Predict on a single sequence >>> outputs = model.predict("ATCGATCG") >>> print(outputs['predictions']) # tensor([0])
>>> # Predict on multiple sequences >>> outputs = model.predict(["ATCGATCG", "GCTAGCTA"])
- class omnigenbench.src.model.classification.model.OmniModelForSequenceClassification(config_or_model, tokenizer, *args, **kwargs)[source]
Bases:
OmniModelThis model is designed for sequence-level classification tasks where the entire input sequence is classified into one of several categories. It extends the base OmniModel with sequence-level classification capabilities.
- Variables:
pooler (OmniPooling) – Pooling layer for sequence-level representation.
softmax (torch.nn.Softmax) – Softmax layer for probability computation.
classifier (torch.nn.Linear) – Linear classification head.
loss_fn (torch.nn.CrossEntropyLoss) – Loss function for training.
- forward(**inputs)[source]
This method performs the forward pass through the model, computing sequence-level logits and applying softmax to produce probability distributions over the label classes.
- Parameters:
**inputs – Input tensors including ‘input_ids’, ‘attention_mask’, and optionally ‘labels’.
- Returns:
dict –
- A dictionary containing:
logits: Sequence-level classification logits
last_hidden_state: Final hidden states from the base model
labels: Ground truth labels (if provided)
Example
>>> outputs = model( ... input_ids=torch.tensor([[1, 2, 3, 4]]), ... attention_mask=torch.tensor([[1, 1, 1, 1]]), ... labels=torch.tensor([0]) ... )
- inference(sequence_or_inputs, **kwargs)[source]
This method provides processed, human-readable sequence-level predictions. It converts logits to class labels and provides confidence scores.
- Parameters:
sequence_or_inputs – A sequence (str), list of sequences, or tokenized inputs (dict/tuple).
**kwargs – Additional arguments for tokenization and inference.
- Returns:
dict –
- A dictionary containing:
predictions: Human-readable class labels for each sequence
logits: Raw logits from the model
confidence: Confidence scores for predictions
last_hidden_state: Final hidden states
Example
>>> # Inference on a single sequence >>> results = model.inference("ATCGATCG") >>> print(results['predictions']) # "positive" >>> print(results['confidence']) # 0.95
- loss_function(logits, labels)[source]
This method computes the cross-entropy loss between the predicted logits and the ground truth labels.
- Parameters:
logits (torch.Tensor) – Predicted logits from the model.
labels (torch.Tensor) – Ground truth labels.
- Returns:
torch.Tensor – The computed loss value.
Example
>>> loss = model.loss_function(logits, labels)
- predict(sequence_or_inputs, **kwargs)[source]
This method takes raw sequences or tokenized inputs and returns sequence-level predictions. It processes the inputs through the model and returns the predicted class for each sequence.
- Parameters:
sequence_or_inputs – A sequence (str), list of sequences, or tokenized inputs (dict/tuple).
**kwargs – Additional arguments for tokenization and inference.
- Returns:
dict –
- A dictionary containing:
predictions: Predicted class indices for each sequence
logits: Raw logits from the model
last_hidden_state: Final hidden states
Example
>>> # Predict on a single sequence >>> outputs = model.predict("ATCGATCG") >>> print(outputs['predictions']) # tensor([0])
>>> # Predict on multiple sequences >>> outputs = model.predict(["ATCGATCG", "GCTAGCTA"])
- class omnigenbench.src.model.classification.model.OmniModelForSequenceClassificationWith2DStructure(config_or_model, tokenizer, *args, **kwargs)[source]
Bases:
OmniModelForSequenceClassification- forward(**inputs)[source]
This method performs the forward pass through the model, computing sequence-level logits and applying softmax to produce probability distributions over the label classes.
- Parameters:
**inputs – Input tensors including ‘input_ids’, ‘attention_mask’, and optionally ‘labels’.
- Returns:
dict –
- A dictionary containing:
logits: Sequence-level classification logits
last_hidden_state: Final hidden states from the base model
labels: Ground truth labels (if provided)
Example
>>> outputs = model( ... input_ids=torch.tensor([[1, 2, 3, 4]]), ... attention_mask=torch.tensor([[1, 1, 1, 1]]), ... labels=torch.tensor([0]) ... )
- class omnigenbench.src.model.classification.model.OmniModelForTokenClassification(config_or_model, tokenizer, *args, **kwargs)[source]
Bases:
OmniModelThis model is designed for token-level classification tasks such as sequence labeling, where each token in the input sequence needs to be classified into different categories. It extends the base OmniModel with token-level classification capabilities.
- Variables:
softmax (torch.nn.Softmax) – Softmax layer for probability computation.
classifier (torch.nn.Linear) – Linear classification head.
loss_fn (torch.nn.CrossEntropyLoss) – Loss function for training.
- forward(**inputs)[source]
Forward pass for token classification.
This method performs the forward pass through the model, computing logits for each token in the input sequence and applying softmax to produce probability distributions.
- Parameters:
**inputs – Input tensors including ‘input_ids’, ‘attention_mask’, and optionally ‘labels’.
- Returns:
dict –
- A dictionary containing:
logits: Token-level classification logits
last_hidden_state: Final hidden states from the base model
labels: Ground truth labels (if provided)
Example
>>> outputs = model( ... input_ids=torch.tensor([[1, 2, 3, 4]]), ... attention_mask=torch.tensor([[1, 1, 1, 1]]), ... labels=torch.tensor([[0, 1, 0, 1]]) ... )
- inference(sequence_or_inputs, **kwargs)[source]
Performs token-level inference with human-readable output.
This method provides processed, human-readable token-level predictions. It converts logits to class labels and handles special tokens appropriately.
- Parameters:
sequence_or_inputs – A sequence (str), list of sequences, or tokenized inputs (dict/tuple).
**kwargs – Additional arguments for tokenization and inference.
- Returns:
dict –
- A dictionary containing:
predictions: Human-readable class labels for each token
logits: Raw logits from the model
confidence: Confidence scores for predictions
last_hidden_state: Final hidden states
Example
>>> # Inference on a single sequence >>> results = model.inference("ATCGATCG") >>> print(results['predictions']) # ['A', 'T', 'C', 'G', ...]
- loss_function(logits, labels)[source]
Calculates the cross-entropy loss for token classification.
This method computes the cross-entropy loss between the predicted logits and the ground truth labels, ignoring padding tokens.
- Parameters:
logits (torch.Tensor) – Predicted logits from the model.
labels (torch.Tensor) – Ground truth labels.
- Returns:
torch.Tensor – The computed loss value.
Example
>>> loss = model.loss_function(logits, labels)
- predict(sequence_or_inputs, **kwargs)[source]
Performs token-level prediction on raw inputs.
This method takes raw sequences or tokenized inputs and returns token-level predictions. It processes the inputs through the model and returns the predicted class for each token.
- Parameters:
sequence_or_inputs – A sequence (str), list of sequences, or tokenized inputs (dict/tuple).
**kwargs – Additional arguments for tokenization and inference.
- Returns:
dict –
- A dictionary containing:
predictions: Predicted class indices for each token
logits: Raw logits from the model
last_hidden_state: Final hidden states
Example
>>> # Predict on a single sequence >>> outputs = model.predict("ATCGATCG") >>> print(outputs['predictions'].shape) # (seq_len,)
>>> # Predict on multiple sequences >>> outputs = model.predict(["ATCGATCG", "GCTAGCTA"])
- class omnigenbench.src.model.classification.model.OmniModelForTokenClassificationWith2DStructure(config_or_model, tokenizer, *args, **kwargs)[source]
Bases:
OmniModelForTokenClassification- forward(**inputs)[source]
Forward pass for token classification.
This method performs the forward pass through the model, computing logits for each token in the input sequence and applying softmax to produce probability distributions.
- Parameters:
**inputs – Input tensors including ‘input_ids’, ‘attention_mask’, and optionally ‘labels’.
- Returns:
dict –
- A dictionary containing:
logits: Token-level classification logits
last_hidden_state: Final hidden states from the base model
labels: Ground truth labels (if provided)
Example
>>> outputs = model( ... input_ids=torch.tensor([[1, 2, 3, 4]]), ... attention_mask=torch.tensor([[1, 1, 1, 1]]), ... labels=torch.tensor([[0, 1, 0, 1]]) ... )
Regression Models¶
Regression models for OmniGenome framework.
This module provides various regression model implementations for genomic sequence analysis, including token-level regression, sequence-level regression, structural imputation, and matrix regression/classification tasks.
- class omnigenbench.src.model.regression.model.OmniModelForMatrixClassification(config_or_model, tokenizer, *args, **kwargs)[source]
Bases:
OmniModelThis model performs classification on matrix representations of genomic sequences, useful for tasks like structure classification, contact map classification, or other matrix-based genomic analysis tasks.
- Variables:
resnet – ResNet backbone for processing matrix inputs
classifier – Linear layer for classification output
loss_fn – Cross-entropy loss function
- forward(**inputs)[source]
Forward pass for matrix classification.
- Parameters:
**inputs – Input tensors including matrix representations and labels
- Returns:
dict – Dictionary containing logits, last_hidden_state, and labels
- inference(sequence_or_inputs, **kwargs)[source]
Perform inference for matrix classification.
- Parameters:
sequence_or_inputs – Input sequences or pre-processed inputs
**kwargs – Additional keyword arguments
- Returns:
dict – Dictionary containing predictions, logits, and last_hidden_state
- loss_function(logits, labels)[source]
Compute the loss for matrix classification.
- Parameters:
logits (torch.Tensor) – Model predictions
labels (torch.Tensor) – Ground truth labels
- Returns:
torch.Tensor – Computed loss value
- predict(sequence_or_inputs, **kwargs)[source]
Generate predictions for matrix classification.
- Parameters:
sequence_or_inputs – Input sequences or pre-processed inputs
**kwargs – Additional keyword arguments
- Returns:
dict – Dictionary containing predictions, logits, and last_hidden_state
- class omnigenbench.src.model.regression.model.OmniModelForMatrixRegression(config_or_model, tokenizer, *args, **kwargs)[source]
Bases:
OmniModelThis model performs regression on matrix representations of genomic sequences, useful for tasks like contact map prediction, structure prediction, or other matrix-based genomic analysis tasks.
- Variables:
resnet – ResNet backbone for processing matrix inputs
classifier – Linear layer for regression output
loss_fn – Mean squared error loss function
- forward(**inputs)[source]
Forward pass for matrix regression.
- Parameters:
**inputs – Input tensors including matrix representations and labels
- Returns:
dict – Dictionary containing logits, last_hidden_state, and labels
- inference(sequence_or_inputs, **kwargs)[source]
Perform inference for matrix regression.
- Parameters:
sequence_or_inputs – Input sequences or pre-processed inputs
**kwargs – Additional keyword arguments
- Returns:
dict – Dictionary containing predictions, logits, and last_hidden_state
- loss_function(logits, labels)[source]
Compute the loss for matrix regression.
- Parameters:
logits (torch.Tensor) – Model predictions
labels (torch.Tensor) – Ground truth labels
- Returns:
torch.Tensor – Computed loss value
- predict(sequence_or_inputs, **kwargs)[source]
Generate predictions for matrix regression.
- Parameters:
sequence_or_inputs – Input sequences or pre-processed inputs
**kwargs – Additional keyword arguments
- Returns:
dict – Dictionary containing predictions, logits, and last_hidden_state
- class omnigenbench.src.model.regression.model.OmniModelForSequenceRegression(config_or_model, tokenizer, *args, **kwargs)[source]
Bases:
OmniModelThis model performs regression at the sequence level, predicting a single continuous value for the entire input sequence. It’s useful for tasks like predicting overall expression levels, binding affinities, or other sequence-level properties.
- Variables:
pooler – OmniPooling layer for sequence-level representation
classifier – Linear layer for regression output
loss_fn – Mean squared error loss function
- forward(**inputs)[source]
Forward pass for sequence-level regression.
- Parameters:
**inputs – Input tensors including input_ids, attention_mask, and labels
- Returns:
dict – Dictionary containing logits, last_hidden_state, and labels
- inference(sequence_or_inputs, **kwargs)[source]
Perform inference for sequence-level regression.
- Parameters:
sequence_or_inputs – Input sequences or pre-processed inputs
**kwargs – Additional keyword arguments
- Returns:
dict – Dictionary containing predictions, logits, and last_hidden_state
- loss_function(logits, labels)[source]
Compute the loss for sequence-level regression.
- Parameters:
logits (torch.Tensor) – Model predictions
labels (torch.Tensor) – Ground truth labels
- Returns:
torch.Tensor – Computed loss value
- predict(sequence_or_inputs, **kwargs)[source]
Generate predictions for sequence-level regression.
- Parameters:
sequence_or_inputs – Input sequences or pre-processed inputs
**kwargs – Additional keyword arguments
- Returns:
dict – Dictionary containing predictions, logits, and last_hidden_state
- class omnigenbench.src.model.regression.model.OmniModelForSequenceRegressionWith2DStructure(config_or_model, tokenizer, *args, **kwargs)[source]
Bases:
OmniModelForSequenceRegressionThis model extends the basic sequence regression model to incorporate 2D structural information, useful for RNA structure prediction and other structural genomics tasks.
- forward(**inputs)[source]
Forward pass for 2D structure-aware sequence regression.
- Parameters:
**inputs – Input tensors including input_ids, attention_mask, labels, and structural info
- Returns:
dict – Dictionary containing logits, last_hidden_state, and labels
- class omnigenbench.src.model.regression.model.OmniModelForStructuralImputation(config_or_model, tokenizer, *args, **kwargs)[source]
Bases:
OmniModelForSequenceRegressionThis model is specialized for imputing missing structural information in genomic sequences. It extends the sequence regression model with additional embedding capabilities for structural features.
- Variables:
embedding – Embedding layer for structural features
loss_fn – Mean squared error loss function
- forward(**inputs)[source]
Forward pass for structural imputation.
- Parameters:
**inputs – Input tensors including input_ids, attention_mask, and labels
- Returns:
dict – Dictionary containing logits, last_hidden_state, and labels
- class omnigenbench.src.model.regression.model.OmniModelForTokenRegression(config_or_model, tokenizer, *args, **kwargs)[source]
Bases:
OmniModelToken-level regression model for genomic sequences.
This model performs regression at the token level, predicting continuous values for each token in the input sequence. It’s useful for tasks like predicting binding affinities, expression levels, or other continuous properties at each position in a genomic sequence.
- Variables:
classifier – Linear layer for regression output
loss_fn – Mean squared error loss function
- forward(**inputs)[source]
Forward pass for token-level regression.
- Parameters:
**inputs – Input tensors including input_ids, attention_mask, and labels
- Returns:
dict – Dictionary containing logits, last_hidden_state, and labels
- inference(sequence_or_inputs, **kwargs)[source]
Perform inference for token-level regression, excluding special tokens.
- Parameters:
sequence_or_inputs – Input sequences or pre-processed inputs
**kwargs – Additional keyword arguments
- Returns:
dict – Dictionary containing predictions, logits, and last_hidden_state
- loss_function(logits, labels)[source]
Compute the loss for token-level regression.
- Parameters:
logits (torch.Tensor) – Model predictions
labels (torch.Tensor) – Ground truth labels
- Returns:
torch.Tensor – Computed loss value
- predict(sequence_or_inputs, **kwargs)[source]
Generate predictions for token-level regression.
- Parameters:
sequence_or_inputs – Input sequences or pre-processed inputs
**kwargs – Additional keyword arguments
- Returns:
dict – Dictionary containing predictions, logits, and last_hidden_state
- class omnigenbench.src.model.regression.model.OmniModelForTokenRegressionWith2DStructure(config_or_model, tokenizer, *args, **kwargs)[source]
Bases:
OmniModelForTokenRegressionThis model extends the basic token regression model to incorporate 2D structural information, useful for RNA structure prediction and other structural genomics tasks.
- forward(**inputs)[source]
Forward pass for 2D structure-aware token regression.
- Parameters:
**inputs – Input tensors including input_ids, attention_mask, labels, and structural info
- Returns:
dict – Dictionary containing logits, last_hidden_state, and labels
ResNet implementation for genomic sequence analysis.
This module provides a ResNet architecture adapted for processing genomic sequences and their structural representations. It includes basic blocks, bottleneck blocks, and a complete ResNet implementation optimized for genomic data.
- class omnigenbench.src.model.regression.resnet.BasicBlock(inplanes: int, planes: int, stride: int = 1, downsample=None, groups: int = 1, dilation: int = 1, norm_layer: Callable[[...], Module] | None = None)[source]
Bases:
ModuleThis block implements a basic residual connection with two convolutions and is optimized for processing genomic sequence data with layer normalization.
- Variables:
expansion (int) – Expansion factor for the block (default: 1)
conv1 – First 3x3 convolution layer
bn1 – First layer normalization
conv2 – Second 5x5 convolution layer
bn2 – Second layer normalization
relu – ReLU activation function
drop – Dropout layer
downsample – Downsampling layer for residual connection
stride – Stride for the convolutions
- expansion: int = 1
- forward(x: Tensor) Tensor[source]
Forward pass through the BasicBlock.
- Parameters:
x (Tensor) – Input tensor [batch_size, channels, height, width]
- Returns:
Tensor – Output tensor with same shape as input
- class omnigenbench.src.model.regression.resnet.Bottleneck(inplanes: int, planes: int, stride: int = 1, downsample: Module | None = None, groups: int = 1, base_width: int = 64, dilation: int = 1, norm_layer: Callable[[...], Module] | None = None)[source]
Bases:
ModuleThis block implements a bottleneck residual connection with three convolutions (1x1, 3x3, 1x1) and is designed for deeper networks. It’s adapted from the original ResNet V1.5 implementation.
- Variables:
expansion (int) – Expansion factor for the block (default: 4)
conv1 – First 1x1 convolution layer
bn1 – First batch normalization
conv2 – Second 3x3 convolution layer
bn2 – Second batch normalization
conv3 – Third 1x1 convolution layer
bn3 – Third batch normalization
relu – ReLU activation function
downsample – Downsampling layer for residual connection
stride – Stride for the convolutions
- expansion: int = 4
- forward(x: Tensor) Tensor[source]
Forward pass through the Bottleneck block.
- Parameters:
x (Tensor) – Input tensor [batch_size, channels, height, width]
- Returns:
Tensor – Output tensor with same shape as input
- class omnigenbench.src.model.regression.resnet.ResNet(channels, block: Type[BasicBlock | Bottleneck], layers: List[int], zero_init_residual: bool = False, groups: int = 1, width_per_group: int = 1, replace_stride_with_dilation=None, norm_layer=None)[source]
Bases:
ModuleThis ResNet implementation is specifically designed for processing genomic sequences and their structural representations. It uses layer normalization instead of batch normalization and is optimized for genomic data characteristics.
- Variables:
_norm_layer – Normalization layer type
inplanes – Number of input channels for the first layer
dilation – Dilation factor for convolutions
groups – Number of groups for grouped convolutions
base_width – Base width for bottleneck blocks
conv1 – Initial convolution layer
bn1 – Initial normalization layer
relu – ReLU activation function
layer1 – First layer of ResNet blocks
fc1 – Final fully connected layer
- forward(x: Tensor) Tensor[source]
Forward pass through the ResNet.
- Parameters:
x (Tensor) – Input tensor [batch_size, channels, height, width]
- Returns:
Tensor – Output tensor after processing through ResNet
- omnigenbench.src.model.regression.resnet.conv1x1(in_planes, out_planes, stride=1)[source]
1x1 convolution.
- Parameters:
in_planes (int) – Number of input channels
out_planes (int) – Number of output channels
stride (int) – Stride for the convolution (default: 1)
- Returns:
nn.Conv2d – 1x1 convolution layer
- omnigenbench.src.model.regression.resnet.conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1)[source]
3x3 convolution with padding.
- Parameters:
in_planes (int) – Number of input channels
out_planes (int) – Number of output channels
stride (int) – Stride for the convolution (default: 1)
groups (int) – Number of groups for grouped convolution (default: 1)
dilation (int) – Dilation factor for the convolution (default: 1)
- Returns:
nn.Conv2d – 3x3 convolution layer
- omnigenbench.src.model.regression.resnet.conv5x5(in_planes, out_planes, stride=1, groups=1, dilation=1)[source]
5x5 convolution with padding.
- Parameters:
in_planes (int) – Number of input channels
out_planes (int) – Number of output channels
stride (int) – Stride for the convolution (default: 1)
groups (int) – Number of groups for grouped convolution (default: 1)
dilation (int) – Dilation factor for the convolution (default: 1)
- Returns:
nn.Conv2d – 5x5 convolution layer
- omnigenbench.src.model.regression.resnet.resnet_b16(channels=128, bbn=16)[source]
This function creates a ResNet model with 16 basic blocks, optimized for processing genomic sequences and their structural representations.
- Parameters:
channels (int) – Number of input channels (default: 128)
bbn (int) – Number of basic blocks (default: 16)
- Returns:
ResNet – Configured ResNet model
Embedding Models¶
- class omnigenbench.src.model.embedding.model.OmniModelForEmbedding(model_name_or_path, *args, **kwargs)[source]
Bases:
ModuleThis class provides a unified interface for loading pre-trained models and generating embeddings from genomic sequences. It supports various aggregation methods and batch processing for efficient embedding generation.
- Variables:
tokenizer – The tokenizer for processing input sequences
model – The pre-trained model for generating embeddings
_device – The device (CPU/GPU) where the model is loaded
Example
>>> from omnigenbench import OmniModelForEmbedding >>> model = OmniModelForEmbedding("anonymous8/OmniGenome-186M") >>> sequences = ["ATCGGCTA", "GGCTAGCTA"] >>> embeddings = model.batch_encode(sequences) >>> print(f"Embeddings shape: {embeddings.shape}") torch.Size([2, 768])
- batch_encode(sequences, batch_size=8, max_length=512, agg='head', require_grad: bool = False, return_on_cpu: bool = True, use_autocast: bool = False, amp_dtype=None)[source]
批量编码序列为 pooled 向量。
Batch encode sequences into aggregated (pooled) embeddings.
- 参数 / Args:
sequences (List[str]): 输入序列 / input DNA (or RNA) sequences. batch_size (int): 批大小 / processing batch size. max_length (int): tokenizer 截断/填充长度 / truncate/pad length. agg (str): 聚合方式 head|mean|tail / aggregation method. require_grad (bool): 是否需要梯度; True 时允许反向传播 / keep graph for finetuning. return_on_cpu (bool): 若 True 输出放到 CPU, 否则保持在模型设备 / move result to CPU for memory relief. use_autocast (bool): 使用混合精度 / enable autocast (CUDA only). amp_dtype (torch.dtype|None): autocast 精度类型 / dtype for autocast.
- 返回 / Returns:
torch.Tensor 形状 (N, H) / shape (num_sequences, hidden_size)
- 兼容性 / Compatibility:
旧调用无需修改; 新参数有默认值。
- batch_encode_tokens(sequences, batch_size=8, max_length=512, use_autocast=False, amp_dtype=None, require_grad: bool = False, return_on_cpu: bool = True)[source]
Encode sequences to token-level embeddings (last_hidden_state).
- Parameters:
sequences (List[str]) – Input DNA/RNA sequences for token-level encoding
batch_size (int, default=8) – Number of sequences to process per batch
max_length (int, default=512) – Maximum sequence length for tokenization
use_autocast (bool, default=False) – Enable mixed precision training (CUDA only)
amp_dtype (torch.dtype, optional) – Data type for automatic mixed precision
require_grad (bool, default=False) – Preserve gradient computation graph for fine-tuning
return_on_cpu (bool, default=True) – Transfer outputs to CPU memory
- Returns:
torch.Tensor – Token embeddings with shape (num_sequences, max_length, hidden_size)
Note
When require_grad=True, gradients flow through the transformer model for end-to-end training. Set return_on_cpu=False to keep tensors on GPU device for downstream processing.
- compute_similarity(embedding1, embedding2, dim=0)[source]
Compute cosine similarity between two embeddings.
- Parameters:
embedding1 (torch.Tensor) – The first embedding
embedding2 (torch.Tensor) – The second embedding
dim (int, optional) – Dimension along which to compute cosine similarity. Defaults to 0
- Returns:
float – Cosine similarity score between -1 and 1
Example
>>> emb1 = model.encode("ATCGGCTA") >>> emb2 = model.encode("GGCTAGCTA") >>> similarity = model.compute_similarity(emb1, emb2) >>> print(f"Cosine similarity: {similarity:.4f}") 0.8234
- property device
Get the current device for the underlying model.
- Returns:
torch.device – The device where the model currently resides.
Note
This queries the model parameters directly so it stays correct when external frameworks (e.g., Accelerate/DDP) move the module across devices after initialization.
- encode(sequence, max_length=512, agg='head', keep_dim=False, require_grad: bool = False, return_on_cpu: bool = True, use_autocast: bool = False, amp_dtype=None)[source]
编码单个序列 / Encode a single sequence.
- 参数 / Args:
sequence (str): 输入序列 / input sequence. max_length (int): 截断/填充长度 / tokenizer max length. agg (str): head|mean|tail 聚合策略 / aggregation strategy. keep_dim (bool): 是否保留 batch 维 / keep batch dimension. require_grad (bool): 是否保留梯度 / keep graph for finetune. return_on_cpu (bool): 输出是否转 CPU / move result to CPU. use_autocast (bool): 是否使用 autocast / enable autocast. amp_dtype (torch.dtype|None): autocast dtype.
- 返回 / Returns:
torch.Tensor shape (H,) 或 (1,H) / pooled embedding.
- encode_tokens(sequence, max_length=512, use_autocast=False, amp_dtype=None, require_grad: bool = False, return_on_cpu: bool = True)[source]
Encode a single sequence to token-level embeddings.
- Parameters:
sequence (str) – Input DNA/RNA sequence for token-level encoding
max_length (int, default=512) – Maximum sequence length for tokenization
use_autocast (bool, default=False) – Enable mixed precision training (CUDA only)
amp_dtype (torch.dtype, optional) – Data type for automatic mixed precision
require_grad (bool, default=False) – Preserve gradient computation graph for fine-tuning
return_on_cpu (bool, default=True) – Transfer output to CPU memory
- Returns:
torch.Tensor – Token embeddings with shape (max_length, hidden_size)
Example
>>> model = OmniModelForEmbedding("yangheng/OmniGenome-52M") >>> sequence = "ATCGATCGATCG" >>> token_embeddings = model.encode_tokens(sequence, max_length=200) >>> print(f"Token embeddings shape: {token_embeddings.shape}") torch.Size([200, 768])
- load_embeddings(embedding_path)[source]
Load embeddings from a file.
- Parameters:
embedding_path (str) – Path to the saved embeddings
- Returns:
torch.Tensor – The loaded embeddings
Example
>>> embeddings = model.load_embeddings("embeddings.pt") >>> print(f"Loaded embeddings shape: {embeddings.shape}") torch.Size([100, 768])
- save_embeddings(embeddings, output_path)[source]
Save the generated embeddings to a file.
- Parameters:
embeddings (torch.Tensor) – The embeddings to save
output_path (str) – Path to save the embeddings
Example
>>> embeddings = model.batch_encode(sequences) >>> model.save_embeddings(embeddings, "embeddings.pt") >>> print("Embeddings saved successfully")
MLM Models¶
Masked Language Model (MLM) for genomic sequences.
This module provides a masked language model implementation specifically designed for genomic sequences. It supports masked language modeling tasks where tokens are randomly masked and the model learns to predict the original tokens.
- class omnigenbench.src.model.mlm.model.OmniModelForMLM(config_or_model, tokenizer, *args, **kwargs)[source]
Bases:
OmniModelMasked Language Model for genomic sequences.
This model implements masked language modeling for genomic sequences, where tokens are randomly masked and the model learns to predict the original tokens. It’s useful for pre-training genomic language models and understanding sequence patterns and dependencies.
- Variables:
loss_fn – Cross-entropy loss function for masked language modeling
- forward(**inputs)[source]
Forward pass for masked language modeling.
- Parameters:
**inputs – Input tensors including input_ids, attention_mask, and labels
- Returns:
dict – Dictionary containing loss, logits, and last_hidden_state
- inference(sequence_or_inputs, **kwargs)[source]
Perform inference for masked language modeling, decoding predictions to sequences.
- Parameters:
sequence_or_inputs – Input sequences or pre-processed inputs
**kwargs – Additional keyword arguments
- Returns:
dict – Dictionary containing decoded predictions, logits, and last_hidden_state
- loss_function(logits, labels)[source]
Compute the loss for masked language modeling.
- Parameters:
logits (torch.Tensor) – Model predictions [batch_size, seq_len, vocab_size]
labels (torch.Tensor) – Ground truth labels [batch_size, seq_len]
- Returns:
torch.Tensor – Computed cross-entropy loss value
- predict(sequence_or_inputs, **kwargs)[source]
Generate predictions for masked language modeling.
- Parameters:
sequence_or_inputs – Input sequences or pre-processed inputs
**kwargs – Additional keyword arguments
- Returns:
dict – Dictionary containing predictions, logits, and last_hidden_state
RNA Design Models¶
RNA design model using masked language modeling and evolutionary algorithms.
This module provides an RNA design model that combines masked language modeling with evolutionary algorithms to design RNA sequences that fold into specific target structures. It uses a multi-objective optimization approach to balance structure similarity and thermodynamic stability.
- class omnigenbench.src.model.rna_design.model.OmniModelForRNADesign(model='yangheng/OmniGenome-186M', device=None, parallel=False, *args, **kwargs)[source]
Bases:
ModuleRNA design model using masked language modeling and evolutionary algorithms.
This model combines a pre-trained masked language model with evolutionary algorithms to design RNA sequences that fold into specific target structures. It uses a multi-objective optimization approach to balance structure similarity and thermodynamic stability.
- Variables:
device – Device to run the model on (CPU or GPU)
parallel – Whether to use parallel processing for structure prediction
tokenizer – Tokenizer for processing RNA sequences
model – Pre-trained masked language model
- design(structure, mutation_ratio=0.5, num_population=100, num_generation=100)[source]
Design RNA sequences for a target structure using evolutionary algorithms.
- Parameters:
structure (str) – Target RNA structure in dot-bracket notation
mutation_ratio (float) – Ratio of tokens to mutate (default: 0.5)
num_population (int) – Population size (default: 100)
num_generation (int) – Number of generations (default: 100)
- Returns:
list – List of designed RNA sequences with their fitness scores
Sequence-to-Sequence Models¶
Sequence-to-sequence model for genomic sequences.
This module provides a sequence-to-sequence model implementation for genomic sequences. It’s designed for tasks where the input and output are both sequences, such as sequence translation, structure prediction, or sequence transformation tasks.
- class omnigenbench.src.model.seq2seq.model.OmniModelForSeq2Seq(config_or_model, tokenizer, *args, **kwargs)[source]
Bases:
OmniModelThis model implements a sequence-to-sequence architecture for genomic sequences, where the input is one sequence and the output is another sequence. It’s useful for tasks like sequence translation, structure prediction, or sequence transformation. The model can be extended to implement specific seq2seq tasks by overriding the forward, predict, and inference methods.
Augmentation Models¶
Data augmentation model for genomic sequences.
This module provides a data augmentation model that uses masked language modeling to generate augmented versions of genomic sequences. It’s useful for expanding training datasets and improving model robustness.
- class omnigenbench.src.model.augmentation.model.OmniModelForAugmentation(model_name_or_path=None, noise_ratio=0.15, max_length=1026, instance_num=1, *args, **kwargs)[source]
Bases:
ModuleData augmentation model for genomic sequences using masked language modeling. This model uses a pre-trained masked language model to generate augmented versions of genomic sequences by randomly masking tokens and predicting replacements. It’s useful for expanding training datasets and improving model generalization.
- Variables:
tokenizer – Tokenizer for processing genomic sequences
model – Pre-trained masked language model
device – Device to run the model on (CPU or GPU)
noise_ratio – Proportion of tokens to mask for augmentation
max_length – Maximum sequence length for tokenization
k – Number of augmented instances to generate per sequence
- apply_noise_to_sequence(seq)[source]
Apply noise to a single sequence by randomly masking tokens.
- Parameters:
seq (str) – Input genomic sequence
- Returns:
str – Sequence with randomly masked tokens
- augment(seq, k=None)[source]
Generate multiple augmented instances for a single sequence.
- Parameters:
seq (str) – Input genomic sequence
k (int, optional) – Number of augmented instances to generate (default: None, uses self.k)
- Returns:
list – List of augmented sequences
- augment_from_file(input_file, output_file)[source]
Main function to handle the augmentation process from a file input to a file output.
This method loads sequences from an input file, augments them using the MLM model, and saves the augmented sequences to an output file.
- Parameters:
input_file (str) – Path to the input file containing sequences
output_file (str) – Path to the output file where augmented sequences will be saved
- augment_sequence(seq)[source]
Perform augmentation on a single sequence by predicting masked tokens.
- Parameters:
seq (str) – Input genomic sequence with masked tokens
- Returns:
str – Augmented sequence with predicted tokens replacing masked tokens
- augment_sequences(sequences)[source]
Augment a list of sequences by applying noise and performing MLM-based predictions.
- Parameters:
sequences (list) – List of genomic sequences to augment
- Returns:
list – List of all augmented sequences
- load_sequences_from_file(input_file)[source]
Load sequences from a JSON file.
- Parameters:
input_file (str) – Path to the input JSON file containing sequences
- Returns:
list – List of sequences loaded from the file
- save_augmented_sequences(augmented_sequences, output_file)[source]
Save augmented sequences to a JSON file.
- Parameters:
augmented_sequences (list) – List of augmented sequences to save
output_file (str) – Path to the output JSON file
Model Utilities¶
This module provides utility classes and functions for handling model inputs, pooling operations, and attention mechanisms used across different OmniGenome model types.
- class omnigenbench.src.model.module_utils.InteractingAttention(embed_size, num_heads=24)[source]
Bases:
ModuleAn interacting attention mechanism for sequence modeling.
This class implements a multi-head attention mechanism with residual connections and layer normalization. It’s designed for processing sequences where different parts of the sequence need to interact with each other.
- Variables:
attention – Multi-head attention layer
layer_norm – Layer normalization for residual connections
fc_out – Output projection layer
- forward(query, keys, values)[source]
Forward pass through the interacting attention mechanism.
- Parameters:
query (torch.Tensor) – Query tensor [batch_size, query_len, embed_size]
keys (torch.Tensor) – Key tensor [batch_size, key_len, embed_size]
values (torch.Tensor) – Value tensor [batch_size, value_len, embed_size]
- Returns:
torch.Tensor – Output tensor with same shape as query
- class omnigenbench.src.model.module_utils.OmniPooling(config, *args, **kwargs)[source]
Bases:
ModuleA flexible pooling layer for OmniGenome models that handles different input formats.
This class provides a unified interface for pooling operations across different model architectures, supporting both causal language models and encoder-based models. It can handle various input formats including tuples, dictionaries, BatchEncoding objects, and tensors.
- Variables:
config – Model configuration object containing architecture and tokenizer settings
pooler – BertPooler instance for non-causal models, None for causal models
- forward(inputs, last_hidden_state)[source]
Perform pooling operation on the last hidden state.
This method handles different input formats and applies appropriate pooling: - For causal language models: Uses the last non-padded token - For encoder models: Uses the BertPooler
- Parameters:
inputs – Input data in various formats (tuple, dict, BatchEncoding, or tensor)
last_hidden_state (torch.Tensor) – Hidden states from the model [batch_size, seq_len, hidden_size]
- Returns:
torch.Tensor – Pooled representation [batch_size, hidden_size]
- Raises:
ValueError – If input format is not supported or cannot be parsed