Downstream Models

Classification Models

class omnigenbench.src.model.classification.model.OmniModelForMultiLabelSequenceClassification(config_or_model, tokenizer, *args, **kwargs)[source]

Bases: OmniModelForSequenceClassification

This model is designed for multi-label classification tasks where a single sequence can be assigned multiple labels simultaneously. It extends the sequence classification model with multi-label capabilities. It uses sigmoid activation instead of softmax to allow multiple labels per sequence and uses binary cross-entropy loss for training.

Variables:
  • softmax (torch.nn.Sigmoid) – Sigmoid layer for multi-label probability computation.

  • loss_fn (torch.nn.BCELoss) – Binary cross-entropy loss function for training.

inference(sequence_or_inputs, **kwargs)[source]

Performs multi-label inference with human-readable output. It converts logits to binary labels and provides confidence scores.

Parameters:
  • sequence_or_inputs – A sequence (str), list of sequences, or tokenized inputs (dict/tuple).

  • **kwargs – Additional arguments for tokenization and inference.

Returns:

dict

A dictionary containing:
  • predictions: Human-readable binary labels for each sequence

  • logits: Raw logits from the model

  • confidence: Confidence scores for predictions

  • last_hidden_state: Final hidden states

Example

>>> # Inference on a single sequence
>>> results = model.inference("ATCGATCG")
>>> print(results['predictions'])  # tensor([1, 0, 1, 0])
loss_function(logits, labels)[source]

Calculates the binary cross-entropy loss for multi-label classification.

Parameters:
  • logits (torch.Tensor) – Predicted logits from the model.

  • labels (torch.Tensor) – Ground truth multi-label targets.

Returns:

torch.Tensor – The computed loss value.

Example

>>> loss = model.loss_function(logits, labels)
predict(sequence_or_inputs, **kwargs)[source]

This method takes raw sequences or tokenized inputs and returns multi-label predictions. It applies a threshold to determine which labels are active for each sequence.

Parameters:
  • sequence_or_inputs – A sequence (str), list of sequences, or tokenized inputs (dict/tuple).

  • **kwargs – Additional arguments for tokenization and inference.

Returns:

dict

A dictionary containing:
  • predictions: Multi-label predictions for each sequence

  • logits: Raw logits from the model

  • last_hidden_state: Final hidden states

Example

>>> # Predict on a single sequence
>>> outputs = model.predict("ATCGATCG")
>>> print(outputs['predictions'])  # tensor([1, 0, 1, 0])
class omnigenbench.src.model.classification.model.OmniModelForMultiLabelSequenceClassificationWith2DStructure(config_or_model, tokenizer, *args, **kwargs)[source]

Bases: OmniModelForSequenceClassificationWith2DStructure

inference(sequence_or_inputs, **kwargs)[source]

This method provides processed, human-readable sequence-level predictions. It converts logits to class labels and provides confidence scores.

Parameters:
  • sequence_or_inputs – A sequence (str), list of sequences, or tokenized inputs (dict/tuple).

  • **kwargs – Additional arguments for tokenization and inference.

Returns:

dict

A dictionary containing:
  • predictions: Human-readable class labels for each sequence

  • logits: Raw logits from the model

  • confidence: Confidence scores for predictions

  • last_hidden_state: Final hidden states

Example

>>> # Inference on a single sequence
>>> results = model.inference("ATCGATCG")
>>> print(results['predictions'])  # "positive"
>>> print(results['confidence'])   # 0.95
loss_function(logits, labels)[source]

This method computes the cross-entropy loss between the predicted logits and the ground truth labels.

Parameters:
  • logits (torch.Tensor) – Predicted logits from the model.

  • labels (torch.Tensor) – Ground truth labels.

Returns:

torch.Tensor – The computed loss value.

Example

>>> loss = model.loss_function(logits, labels)
predict(sequence_or_inputs, **kwargs)[source]

This method takes raw sequences or tokenized inputs and returns sequence-level predictions. It processes the inputs through the model and returns the predicted class for each sequence.

Parameters:
  • sequence_or_inputs – A sequence (str), list of sequences, or tokenized inputs (dict/tuple).

  • **kwargs – Additional arguments for tokenization and inference.

Returns:

dict

A dictionary containing:
  • predictions: Predicted class indices for each sequence

  • logits: Raw logits from the model

  • last_hidden_state: Final hidden states

Example

>>> # Predict on a single sequence
>>> outputs = model.predict("ATCGATCG")
>>> print(outputs['predictions'])  # tensor([0])
>>> # Predict on multiple sequences
>>> outputs = model.predict(["ATCGATCG", "GCTAGCTA"])
class omnigenbench.src.model.classification.model.OmniModelForSequenceClassification(config_or_model, tokenizer, *args, **kwargs)[source]

Bases: OmniModel

This model is designed for sequence-level classification tasks where the entire input sequence is classified into one of several categories. It extends the base OmniModel with sequence-level classification capabilities.

Variables:
  • pooler (OmniPooling) – Pooling layer for sequence-level representation.

  • softmax (torch.nn.Softmax) – Softmax layer for probability computation.

  • classifier (torch.nn.Linear) – Linear classification head.

  • loss_fn (torch.nn.CrossEntropyLoss) – Loss function for training.

forward(**inputs)[source]

This method performs the forward pass through the model, computing sequence-level logits and applying softmax to produce probability distributions over the label classes.

Parameters:

**inputs – Input tensors including ‘input_ids’, ‘attention_mask’, and optionally ‘labels’.

Returns:

dict

A dictionary containing:
  • logits: Sequence-level classification logits

  • last_hidden_state: Final hidden states from the base model

  • labels: Ground truth labels (if provided)

Example

>>> outputs = model(
...     input_ids=torch.tensor([[1, 2, 3, 4]]),
...     attention_mask=torch.tensor([[1, 1, 1, 1]]),
...     labels=torch.tensor([0])
... )
inference(sequence_or_inputs, **kwargs)[source]

This method provides processed, human-readable sequence-level predictions. It converts logits to class labels and provides confidence scores.

Parameters:
  • sequence_or_inputs – A sequence (str), list of sequences, or tokenized inputs (dict/tuple).

  • **kwargs – Additional arguments for tokenization and inference.

Returns:

dict

A dictionary containing:
  • predictions: Human-readable class labels for each sequence

  • logits: Raw logits from the model

  • confidence: Confidence scores for predictions

  • last_hidden_state: Final hidden states

Example

>>> # Inference on a single sequence
>>> results = model.inference("ATCGATCG")
>>> print(results['predictions'])  # "positive"
>>> print(results['confidence'])   # 0.95
loss_function(logits, labels)[source]

This method computes the cross-entropy loss between the predicted logits and the ground truth labels.

Parameters:
  • logits (torch.Tensor) – Predicted logits from the model.

  • labels (torch.Tensor) – Ground truth labels.

Returns:

torch.Tensor – The computed loss value.

Example

>>> loss = model.loss_function(logits, labels)
predict(sequence_or_inputs, **kwargs)[source]

This method takes raw sequences or tokenized inputs and returns sequence-level predictions. It processes the inputs through the model and returns the predicted class for each sequence.

Parameters:
  • sequence_or_inputs – A sequence (str), list of sequences, or tokenized inputs (dict/tuple).

  • **kwargs – Additional arguments for tokenization and inference.

Returns:

dict

A dictionary containing:
  • predictions: Predicted class indices for each sequence

  • logits: Raw logits from the model

  • last_hidden_state: Final hidden states

Example

>>> # Predict on a single sequence
>>> outputs = model.predict("ATCGATCG")
>>> print(outputs['predictions'])  # tensor([0])
>>> # Predict on multiple sequences
>>> outputs = model.predict(["ATCGATCG", "GCTAGCTA"])
class omnigenbench.src.model.classification.model.OmniModelForSequenceClassificationWith2DStructure(config_or_model, tokenizer, *args, **kwargs)[source]

Bases: OmniModelForSequenceClassification

forward(**inputs)[source]

This method performs the forward pass through the model, computing sequence-level logits and applying softmax to produce probability distributions over the label classes.

Parameters:

**inputs – Input tensors including ‘input_ids’, ‘attention_mask’, and optionally ‘labels’.

Returns:

dict

A dictionary containing:
  • logits: Sequence-level classification logits

  • last_hidden_state: Final hidden states from the base model

  • labels: Ground truth labels (if provided)

Example

>>> outputs = model(
...     input_ids=torch.tensor([[1, 2, 3, 4]]),
...     attention_mask=torch.tensor([[1, 1, 1, 1]]),
...     labels=torch.tensor([0])
... )
class omnigenbench.src.model.classification.model.OmniModelForTokenClassification(config_or_model, tokenizer, *args, **kwargs)[source]

Bases: OmniModel

This model is designed for token-level classification tasks such as sequence labeling, where each token in the input sequence needs to be classified into different categories. It extends the base OmniModel with token-level classification capabilities.

Variables:
  • softmax (torch.nn.Softmax) – Softmax layer for probability computation.

  • classifier (torch.nn.Linear) – Linear classification head.

  • loss_fn (torch.nn.CrossEntropyLoss) – Loss function for training.

forward(**inputs)[source]

Forward pass for token classification.

This method performs the forward pass through the model, computing logits for each token in the input sequence and applying softmax to produce probability distributions.

Parameters:

**inputs – Input tensors including ‘input_ids’, ‘attention_mask’, and optionally ‘labels’.

Returns:

dict

A dictionary containing:
  • logits: Token-level classification logits

  • last_hidden_state: Final hidden states from the base model

  • labels: Ground truth labels (if provided)

Example

>>> outputs = model(
...     input_ids=torch.tensor([[1, 2, 3, 4]]),
...     attention_mask=torch.tensor([[1, 1, 1, 1]]),
...     labels=torch.tensor([[0, 1, 0, 1]])
... )
inference(sequence_or_inputs, **kwargs)[source]

Performs token-level inference with human-readable output.

This method provides processed, human-readable token-level predictions. It converts logits to class labels and handles special tokens appropriately.

Parameters:
  • sequence_or_inputs – A sequence (str), list of sequences, or tokenized inputs (dict/tuple).

  • **kwargs – Additional arguments for tokenization and inference.

Returns:

dict

A dictionary containing:
  • predictions: Human-readable class labels for each token

  • logits: Raw logits from the model

  • confidence: Confidence scores for predictions

  • last_hidden_state: Final hidden states

Example

>>> # Inference on a single sequence
>>> results = model.inference("ATCGATCG")
>>> print(results['predictions'])  # ['A', 'T', 'C', 'G', ...]
loss_function(logits, labels)[source]

Calculates the cross-entropy loss for token classification.

This method computes the cross-entropy loss between the predicted logits and the ground truth labels, ignoring padding tokens.

Parameters:
  • logits (torch.Tensor) – Predicted logits from the model.

  • labels (torch.Tensor) – Ground truth labels.

Returns:

torch.Tensor – The computed loss value.

Example

>>> loss = model.loss_function(logits, labels)
predict(sequence_or_inputs, **kwargs)[source]

Performs token-level prediction on raw inputs.

This method takes raw sequences or tokenized inputs and returns token-level predictions. It processes the inputs through the model and returns the predicted class for each token.

Parameters:
  • sequence_or_inputs – A sequence (str), list of sequences, or tokenized inputs (dict/tuple).

  • **kwargs – Additional arguments for tokenization and inference.

Returns:

dict

A dictionary containing:
  • predictions: Predicted class indices for each token

  • logits: Raw logits from the model

  • last_hidden_state: Final hidden states

Example

>>> # Predict on a single sequence
>>> outputs = model.predict("ATCGATCG")
>>> print(outputs['predictions'].shape)  # (seq_len,)
>>> # Predict on multiple sequences
>>> outputs = model.predict(["ATCGATCG", "GCTAGCTA"])
class omnigenbench.src.model.classification.model.OmniModelForTokenClassificationWith2DStructure(config_or_model, tokenizer, *args, **kwargs)[source]

Bases: OmniModelForTokenClassification

forward(**inputs)[source]

Forward pass for token classification.

This method performs the forward pass through the model, computing logits for each token in the input sequence and applying softmax to produce probability distributions.

Parameters:

**inputs – Input tensors including ‘input_ids’, ‘attention_mask’, and optionally ‘labels’.

Returns:

dict

A dictionary containing:
  • logits: Token-level classification logits

  • last_hidden_state: Final hidden states from the base model

  • labels: Ground truth labels (if provided)

Example

>>> outputs = model(
...     input_ids=torch.tensor([[1, 2, 3, 4]]),
...     attention_mask=torch.tensor([[1, 1, 1, 1]]),
...     labels=torch.tensor([[0, 1, 0, 1]])
... )

Regression Models

Regression models for OmniGenome framework.

This module provides various regression model implementations for genomic sequence analysis, including token-level regression, sequence-level regression, structural imputation, and matrix regression/classification tasks.

class omnigenbench.src.model.regression.model.OmniModelForMatrixClassification(config_or_model, tokenizer, *args, **kwargs)[source]

Bases: OmniModel

This model performs classification on matrix representations of genomic sequences, useful for tasks like structure classification, contact map classification, or other matrix-based genomic analysis tasks.

Variables:
  • resnet – ResNet backbone for processing matrix inputs

  • classifier – Linear layer for classification output

  • loss_fn – Cross-entropy loss function

forward(**inputs)[source]

Forward pass for matrix classification.

Parameters:

**inputs – Input tensors including matrix representations and labels

Returns:

dict – Dictionary containing logits, last_hidden_state, and labels

inference(sequence_or_inputs, **kwargs)[source]

Perform inference for matrix classification.

Parameters:
  • sequence_or_inputs – Input sequences or pre-processed inputs

  • **kwargs – Additional keyword arguments

Returns:

dict – Dictionary containing predictions, logits, and last_hidden_state

loss_function(logits, labels)[source]

Compute the loss for matrix classification.

Parameters:
  • logits (torch.Tensor) – Model predictions

  • labels (torch.Tensor) – Ground truth labels

Returns:

torch.Tensor – Computed loss value

predict(sequence_or_inputs, **kwargs)[source]

Generate predictions for matrix classification.

Parameters:
  • sequence_or_inputs – Input sequences or pre-processed inputs

  • **kwargs – Additional keyword arguments

Returns:

dict – Dictionary containing predictions, logits, and last_hidden_state

class omnigenbench.src.model.regression.model.OmniModelForMatrixRegression(config_or_model, tokenizer, *args, **kwargs)[source]

Bases: OmniModel

This model performs regression on matrix representations of genomic sequences, useful for tasks like contact map prediction, structure prediction, or other matrix-based genomic analysis tasks.

Variables:
  • resnet – ResNet backbone for processing matrix inputs

  • classifier – Linear layer for regression output

  • loss_fn – Mean squared error loss function

forward(**inputs)[source]

Forward pass for matrix regression.

Parameters:

**inputs – Input tensors including matrix representations and labels

Returns:

dict – Dictionary containing logits, last_hidden_state, and labels

inference(sequence_or_inputs, **kwargs)[source]

Perform inference for matrix regression.

Parameters:
  • sequence_or_inputs – Input sequences or pre-processed inputs

  • **kwargs – Additional keyword arguments

Returns:

dict – Dictionary containing predictions, logits, and last_hidden_state

loss_function(logits, labels)[source]

Compute the loss for matrix regression.

Parameters:
  • logits (torch.Tensor) – Model predictions

  • labels (torch.Tensor) – Ground truth labels

Returns:

torch.Tensor – Computed loss value

predict(sequence_or_inputs, **kwargs)[source]

Generate predictions for matrix regression.

Parameters:
  • sequence_or_inputs – Input sequences or pre-processed inputs

  • **kwargs – Additional keyword arguments

Returns:

dict – Dictionary containing predictions, logits, and last_hidden_state

class omnigenbench.src.model.regression.model.OmniModelForSequenceRegression(config_or_model, tokenizer, *args, **kwargs)[source]

Bases: OmniModel

This model performs regression at the sequence level, predicting a single continuous value for the entire input sequence. It’s useful for tasks like predicting overall expression levels, binding affinities, or other sequence-level properties.

Variables:
  • pooler – OmniPooling layer for sequence-level representation

  • classifier – Linear layer for regression output

  • loss_fn – Mean squared error loss function

forward(**inputs)[source]

Forward pass for sequence-level regression.

Parameters:

**inputs – Input tensors including input_ids, attention_mask, and labels

Returns:

dict – Dictionary containing logits, last_hidden_state, and labels

inference(sequence_or_inputs, **kwargs)[source]

Perform inference for sequence-level regression.

Parameters:
  • sequence_or_inputs – Input sequences or pre-processed inputs

  • **kwargs – Additional keyword arguments

Returns:

dict – Dictionary containing predictions, logits, and last_hidden_state

loss_function(logits, labels)[source]

Compute the loss for sequence-level regression.

Parameters:
  • logits (torch.Tensor) – Model predictions

  • labels (torch.Tensor) – Ground truth labels

Returns:

torch.Tensor – Computed loss value

predict(sequence_or_inputs, **kwargs)[source]

Generate predictions for sequence-level regression.

Parameters:
  • sequence_or_inputs – Input sequences or pre-processed inputs

  • **kwargs – Additional keyword arguments

Returns:

dict – Dictionary containing predictions, logits, and last_hidden_state

class omnigenbench.src.model.regression.model.OmniModelForSequenceRegressionWith2DStructure(config_or_model, tokenizer, *args, **kwargs)[source]

Bases: OmniModelForSequenceRegression

This model extends the basic sequence regression model to incorporate 2D structural information, useful for RNA structure prediction and other structural genomics tasks.

forward(**inputs)[source]

Forward pass for 2D structure-aware sequence regression.

Parameters:

**inputs – Input tensors including input_ids, attention_mask, labels, and structural info

Returns:

dict – Dictionary containing logits, last_hidden_state, and labels

class omnigenbench.src.model.regression.model.OmniModelForStructuralImputation(config_or_model, tokenizer, *args, **kwargs)[source]

Bases: OmniModelForSequenceRegression

This model is specialized for imputing missing structural information in genomic sequences. It extends the sequence regression model with additional embedding capabilities for structural features.

Variables:
  • embedding – Embedding layer for structural features

  • loss_fn – Mean squared error loss function

forward(**inputs)[source]

Forward pass for structural imputation.

Parameters:

**inputs – Input tensors including input_ids, attention_mask, and labels

Returns:

dict – Dictionary containing logits, last_hidden_state, and labels

class omnigenbench.src.model.regression.model.OmniModelForTokenRegression(config_or_model, tokenizer, *args, **kwargs)[source]

Bases: OmniModel

Token-level regression model for genomic sequences.

This model performs regression at the token level, predicting continuous values for each token in the input sequence. It’s useful for tasks like predicting binding affinities, expression levels, or other continuous properties at each position in a genomic sequence.

Variables:
  • classifier – Linear layer for regression output

  • loss_fn – Mean squared error loss function

forward(**inputs)[source]

Forward pass for token-level regression.

Parameters:

**inputs – Input tensors including input_ids, attention_mask, and labels

Returns:

dict – Dictionary containing logits, last_hidden_state, and labels

inference(sequence_or_inputs, **kwargs)[source]

Perform inference for token-level regression, excluding special tokens.

Parameters:
  • sequence_or_inputs – Input sequences or pre-processed inputs

  • **kwargs – Additional keyword arguments

Returns:

dict – Dictionary containing predictions, logits, and last_hidden_state

loss_function(logits, labels)[source]

Compute the loss for token-level regression.

Parameters:
  • logits (torch.Tensor) – Model predictions

  • labels (torch.Tensor) – Ground truth labels

Returns:

torch.Tensor – Computed loss value

predict(sequence_or_inputs, **kwargs)[source]

Generate predictions for token-level regression.

Parameters:
  • sequence_or_inputs – Input sequences or pre-processed inputs

  • **kwargs – Additional keyword arguments

Returns:

dict – Dictionary containing predictions, logits, and last_hidden_state

class omnigenbench.src.model.regression.model.OmniModelForTokenRegressionWith2DStructure(config_or_model, tokenizer, *args, **kwargs)[source]

Bases: OmniModelForTokenRegression

This model extends the basic token regression model to incorporate 2D structural information, useful for RNA structure prediction and other structural genomics tasks.

forward(**inputs)[source]

Forward pass for 2D structure-aware token regression.

Parameters:

**inputs – Input tensors including input_ids, attention_mask, labels, and structural info

Returns:

dict – Dictionary containing logits, last_hidden_state, and labels

ResNet implementation for genomic sequence analysis.

This module provides a ResNet architecture adapted for processing genomic sequences and their structural representations. It includes basic blocks, bottleneck blocks, and a complete ResNet implementation optimized for genomic data.

class omnigenbench.src.model.regression.resnet.BasicBlock(inplanes: int, planes: int, stride: int = 1, downsample=None, groups: int = 1, dilation: int = 1, norm_layer: Callable[[...], Module] | None = None)[source]

Bases: Module

This block implements a basic residual connection with two convolutions and is optimized for processing genomic sequence data with layer normalization.

Variables:
  • expansion (int) – Expansion factor for the block (default: 1)

  • conv1 – First 3x3 convolution layer

  • bn1 – First layer normalization

  • conv2 – Second 5x5 convolution layer

  • bn2 – Second layer normalization

  • relu – ReLU activation function

  • drop – Dropout layer

  • downsample – Downsampling layer for residual connection

  • stride – Stride for the convolutions

expansion: int = 1
forward(x: Tensor) Tensor[source]

Forward pass through the BasicBlock.

Parameters:

x (Tensor) – Input tensor [batch_size, channels, height, width]

Returns:

Tensor – Output tensor with same shape as input

class omnigenbench.src.model.regression.resnet.Bottleneck(inplanes: int, planes: int, stride: int = 1, downsample: Module | None = None, groups: int = 1, base_width: int = 64, dilation: int = 1, norm_layer: Callable[[...], Module] | None = None)[source]

Bases: Module

This block implements a bottleneck residual connection with three convolutions (1x1, 3x3, 1x1) and is designed for deeper networks. It’s adapted from the original ResNet V1.5 implementation.

Variables:
  • expansion (int) – Expansion factor for the block (default: 4)

  • conv1 – First 1x1 convolution layer

  • bn1 – First batch normalization

  • conv2 – Second 3x3 convolution layer

  • bn2 – Second batch normalization

  • conv3 – Third 1x1 convolution layer

  • bn3 – Third batch normalization

  • relu – ReLU activation function

  • downsample – Downsampling layer for residual connection

  • stride – Stride for the convolutions

expansion: int = 4
forward(x: Tensor) Tensor[source]

Forward pass through the Bottleneck block.

Parameters:

x (Tensor) – Input tensor [batch_size, channels, height, width]

Returns:

Tensor – Output tensor with same shape as input

class omnigenbench.src.model.regression.resnet.ResNet(channels, block: Type[BasicBlock | Bottleneck], layers: List[int], zero_init_residual: bool = False, groups: int = 1, width_per_group: int = 1, replace_stride_with_dilation=None, norm_layer=None)[source]

Bases: Module

This ResNet implementation is specifically designed for processing genomic sequences and their structural representations. It uses layer normalization instead of batch normalization and is optimized for genomic data characteristics.

Variables:
  • _norm_layer – Normalization layer type

  • inplanes – Number of input channels for the first layer

  • dilation – Dilation factor for convolutions

  • groups – Number of groups for grouped convolutions

  • base_width – Base width for bottleneck blocks

  • conv1 – Initial convolution layer

  • bn1 – Initial normalization layer

  • relu – ReLU activation function

  • layer1 – First layer of ResNet blocks

  • fc1 – Final fully connected layer

forward(x: Tensor) Tensor[source]

Forward pass through the ResNet.

Parameters:

x (Tensor) – Input tensor [batch_size, channels, height, width]

Returns:

Tensor – Output tensor after processing through ResNet

omnigenbench.src.model.regression.resnet.conv1x1(in_planes, out_planes, stride=1)[source]

1x1 convolution.

Parameters:
  • in_planes (int) – Number of input channels

  • out_planes (int) – Number of output channels

  • stride (int) – Stride for the convolution (default: 1)

Returns:

nn.Conv2d – 1x1 convolution layer

omnigenbench.src.model.regression.resnet.conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1)[source]

3x3 convolution with padding.

Parameters:
  • in_planes (int) – Number of input channels

  • out_planes (int) – Number of output channels

  • stride (int) – Stride for the convolution (default: 1)

  • groups (int) – Number of groups for grouped convolution (default: 1)

  • dilation (int) – Dilation factor for the convolution (default: 1)

Returns:

nn.Conv2d – 3x3 convolution layer

omnigenbench.src.model.regression.resnet.conv5x5(in_planes, out_planes, stride=1, groups=1, dilation=1)[source]

5x5 convolution with padding.

Parameters:
  • in_planes (int) – Number of input channels

  • out_planes (int) – Number of output channels

  • stride (int) – Stride for the convolution (default: 1)

  • groups (int) – Number of groups for grouped convolution (default: 1)

  • dilation (int) – Dilation factor for the convolution (default: 1)

Returns:

nn.Conv2d – 5x5 convolution layer

omnigenbench.src.model.regression.resnet.resnet_b16(channels=128, bbn=16)[source]

This function creates a ResNet model with 16 basic blocks, optimized for processing genomic sequences and their structural representations.

Parameters:
  • channels (int) – Number of input channels (default: 128)

  • bbn (int) – Number of basic blocks (default: 16)

Returns:

ResNet – Configured ResNet model

Embedding Models

class omnigenbench.src.model.embedding.model.OmniModelForEmbedding(model_name_or_path, *args, **kwargs)[source]

Bases: Module

This class provides a unified interface for loading pre-trained models and generating embeddings from genomic sequences. It supports various aggregation methods and batch processing for efficient embedding generation.

Variables:
  • tokenizer – The tokenizer for processing input sequences

  • model – The pre-trained model for generating embeddings

  • _device – The device (CPU/GPU) where the model is loaded

Example

>>> from omnigenbench import OmniModelForEmbedding
>>> model = OmniModelForEmbedding("anonymous8/OmniGenome-186M")
>>> sequences = ["ATCGGCTA", "GGCTAGCTA"]
>>> embeddings = model.batch_encode(sequences)
>>> print(f"Embeddings shape: {embeddings.shape}")
torch.Size([2, 768])
batch_encode(sequences, batch_size=8, max_length=512, agg='head', require_grad: bool = False, return_on_cpu: bool = True, use_autocast: bool = False, amp_dtype=None)[source]

批量编码序列为 pooled 向量。

Batch encode sequences into aggregated (pooled) embeddings.

参数 / Args:

sequences (List[str]): 输入序列 / input DNA (or RNA) sequences. batch_size (int): 批大小 / processing batch size. max_length (int): tokenizer 截断/填充长度 / truncate/pad length. agg (str): 聚合方式 head|mean|tail / aggregation method. require_grad (bool): 是否需要梯度; True 时允许反向传播 / keep graph for finetuning. return_on_cpu (bool): 若 True 输出放到 CPU, 否则保持在模型设备 / move result to CPU for memory relief. use_autocast (bool): 使用混合精度 / enable autocast (CUDA only). amp_dtype (torch.dtype|None): autocast 精度类型 / dtype for autocast.

返回 / Returns:

torch.Tensor 形状 (N, H) / shape (num_sequences, hidden_size)

兼容性 / Compatibility:

旧调用无需修改; 新参数有默认值。

batch_encode_tokens(sequences, batch_size=8, max_length=512, use_autocast=False, amp_dtype=None, require_grad: bool = False, return_on_cpu: bool = True)[source]

Encode sequences to token-level embeddings (last_hidden_state).

Parameters:
  • sequences (List[str]) – Input DNA/RNA sequences for token-level encoding

  • batch_size (int, default=8) – Number of sequences to process per batch

  • max_length (int, default=512) – Maximum sequence length for tokenization

  • use_autocast (bool, default=False) – Enable mixed precision training (CUDA only)

  • amp_dtype (torch.dtype, optional) – Data type for automatic mixed precision

  • require_grad (bool, default=False) – Preserve gradient computation graph for fine-tuning

  • return_on_cpu (bool, default=True) – Transfer outputs to CPU memory

Returns:

torch.Tensor – Token embeddings with shape (num_sequences, max_length, hidden_size)

Note

When require_grad=True, gradients flow through the transformer model for end-to-end training. Set return_on_cpu=False to keep tensors on GPU device for downstream processing.

compute_similarity(embedding1, embedding2, dim=0)[source]

Compute cosine similarity between two embeddings.

Parameters:
  • embedding1 (torch.Tensor) – The first embedding

  • embedding2 (torch.Tensor) – The second embedding

  • dim (int, optional) – Dimension along which to compute cosine similarity. Defaults to 0

Returns:

float – Cosine similarity score between -1 and 1

Example

>>> emb1 = model.encode("ATCGGCTA")
>>> emb2 = model.encode("GGCTAGCTA")
>>> similarity = model.compute_similarity(emb1, emb2)
>>> print(f"Cosine similarity: {similarity:.4f}")
0.8234
property device

Get the current device for the underlying model.

Returns:

torch.device – The device where the model currently resides.

Note

This queries the model parameters directly so it stays correct when external frameworks (e.g., Accelerate/DDP) move the module across devices after initialization.

encode(sequence, max_length=512, agg='head', keep_dim=False, require_grad: bool = False, return_on_cpu: bool = True, use_autocast: bool = False, amp_dtype=None)[source]

编码单个序列 / Encode a single sequence.

参数 / Args:

sequence (str): 输入序列 / input sequence. max_length (int): 截断/填充长度 / tokenizer max length. agg (str): head|mean|tail 聚合策略 / aggregation strategy. keep_dim (bool): 是否保留 batch 维 / keep batch dimension. require_grad (bool): 是否保留梯度 / keep graph for finetune. return_on_cpu (bool): 输出是否转 CPU / move result to CPU. use_autocast (bool): 是否使用 autocast / enable autocast. amp_dtype (torch.dtype|None): autocast dtype.

返回 / Returns:

torch.Tensor shape (H,) 或 (1,H) / pooled embedding.

encode_tokens(sequence, max_length=512, use_autocast=False, amp_dtype=None, require_grad: bool = False, return_on_cpu: bool = True)[source]

Encode a single sequence to token-level embeddings.

Parameters:
  • sequence (str) – Input DNA/RNA sequence for token-level encoding

  • max_length (int, default=512) – Maximum sequence length for tokenization

  • use_autocast (bool, default=False) – Enable mixed precision training (CUDA only)

  • amp_dtype (torch.dtype, optional) – Data type for automatic mixed precision

  • require_grad (bool, default=False) – Preserve gradient computation graph for fine-tuning

  • return_on_cpu (bool, default=True) – Transfer output to CPU memory

Returns:

torch.Tensor – Token embeddings with shape (max_length, hidden_size)

Example

>>> model = OmniModelForEmbedding("yangheng/OmniGenome-52M")
>>> sequence = "ATCGATCGATCG"
>>> token_embeddings = model.encode_tokens(sequence, max_length=200)
>>> print(f"Token embeddings shape: {token_embeddings.shape}")
torch.Size([200, 768])
load_embeddings(embedding_path)[source]

Load embeddings from a file.

Parameters:

embedding_path (str) – Path to the saved embeddings

Returns:

torch.Tensor – The loaded embeddings

Example

>>> embeddings = model.load_embeddings("embeddings.pt")
>>> print(f"Loaded embeddings shape: {embeddings.shape}")
torch.Size([100, 768])
save_embeddings(embeddings, output_path)[source]

Save the generated embeddings to a file.

Parameters:
  • embeddings (torch.Tensor) – The embeddings to save

  • output_path (str) – Path to save the embeddings

Example

>>> embeddings = model.batch_encode(sequences)
>>> model.save_embeddings(embeddings, "embeddings.pt")
>>> print("Embeddings saved successfully")

MLM Models

Masked Language Model (MLM) for genomic sequences.

This module provides a masked language model implementation specifically designed for genomic sequences. It supports masked language modeling tasks where tokens are randomly masked and the model learns to predict the original tokens.

class omnigenbench.src.model.mlm.model.OmniModelForMLM(config_or_model, tokenizer, *args, **kwargs)[source]

Bases: OmniModel

Masked Language Model for genomic sequences.

This model implements masked language modeling for genomic sequences, where tokens are randomly masked and the model learns to predict the original tokens. It’s useful for pre-training genomic language models and understanding sequence patterns and dependencies.

Variables:

loss_fn – Cross-entropy loss function for masked language modeling

forward(**inputs)[source]

Forward pass for masked language modeling.

Parameters:

**inputs – Input tensors including input_ids, attention_mask, and labels

Returns:

dict – Dictionary containing loss, logits, and last_hidden_state

inference(sequence_or_inputs, **kwargs)[source]

Perform inference for masked language modeling, decoding predictions to sequences.

Parameters:
  • sequence_or_inputs – Input sequences or pre-processed inputs

  • **kwargs – Additional keyword arguments

Returns:

dict – Dictionary containing decoded predictions, logits, and last_hidden_state

loss_function(logits, labels)[source]

Compute the loss for masked language modeling.

Parameters:
  • logits (torch.Tensor) – Model predictions [batch_size, seq_len, vocab_size]

  • labels (torch.Tensor) – Ground truth labels [batch_size, seq_len]

Returns:

torch.Tensor – Computed cross-entropy loss value

predict(sequence_or_inputs, **kwargs)[source]

Generate predictions for masked language modeling.

Parameters:
  • sequence_or_inputs – Input sequences or pre-processed inputs

  • **kwargs – Additional keyword arguments

Returns:

dict – Dictionary containing predictions, logits, and last_hidden_state

RNA Design Models

RNA design model using masked language modeling and evolutionary algorithms.

This module provides an RNA design model that combines masked language modeling with evolutionary algorithms to design RNA sequences that fold into specific target structures. It uses a multi-objective optimization approach to balance structure similarity and thermodynamic stability.

class omnigenbench.src.model.rna_design.model.OmniModelForRNADesign(model='yangheng/OmniGenome-186M', device=None, parallel=False, *args, **kwargs)[source]

Bases: Module

RNA design model using masked language modeling and evolutionary algorithms.

This model combines a pre-trained masked language model with evolutionary algorithms to design RNA sequences that fold into specific target structures. It uses a multi-objective optimization approach to balance structure similarity and thermodynamic stability.

Variables:
  • device – Device to run the model on (CPU or GPU)

  • parallel – Whether to use parallel processing for structure prediction

  • tokenizer – Tokenizer for processing RNA sequences

  • model – Pre-trained masked language model

design(structure, mutation_ratio=0.5, num_population=100, num_generation=100)[source]

Design RNA sequences for a target structure using evolutionary algorithms.

Parameters:
  • structure (str) – Target RNA structure in dot-bracket notation

  • mutation_ratio (float) – Ratio of tokens to mutate (default: 0.5)

  • num_population (int) – Population size (default: 100)

  • num_generation (int) – Number of generations (default: 100)

Returns:

list – List of designed RNA sequences with their fitness scores

Sequence-to-Sequence Models

Sequence-to-sequence model for genomic sequences.

This module provides a sequence-to-sequence model implementation for genomic sequences. It’s designed for tasks where the input and output are both sequences, such as sequence translation, structure prediction, or sequence transformation tasks.

class omnigenbench.src.model.seq2seq.model.OmniModelForSeq2Seq(config_or_model, tokenizer, *args, **kwargs)[source]

Bases: OmniModel

This model implements a sequence-to-sequence architecture for genomic sequences, where the input is one sequence and the output is another sequence. It’s useful for tasks like sequence translation, structure prediction, or sequence transformation. The model can be extended to implement specific seq2seq tasks by overriding the forward, predict, and inference methods.

Augmentation Models

Data augmentation model for genomic sequences.

This module provides a data augmentation model that uses masked language modeling to generate augmented versions of genomic sequences. It’s useful for expanding training datasets and improving model robustness.

class omnigenbench.src.model.augmentation.model.OmniModelForAugmentation(model_name_or_path=None, noise_ratio=0.15, max_length=1026, instance_num=1, *args, **kwargs)[source]

Bases: Module

Data augmentation model for genomic sequences using masked language modeling. This model uses a pre-trained masked language model to generate augmented versions of genomic sequences by randomly masking tokens and predicting replacements. It’s useful for expanding training datasets and improving model generalization.

Variables:
  • tokenizer – Tokenizer for processing genomic sequences

  • model – Pre-trained masked language model

  • device – Device to run the model on (CPU or GPU)

  • noise_ratio – Proportion of tokens to mask for augmentation

  • max_length – Maximum sequence length for tokenization

  • k – Number of augmented instances to generate per sequence

apply_noise_to_sequence(seq)[source]

Apply noise to a single sequence by randomly masking tokens.

Parameters:

seq (str) – Input genomic sequence

Returns:

str – Sequence with randomly masked tokens

augment(seq, k=None)[source]

Generate multiple augmented instances for a single sequence.

Parameters:
  • seq (str) – Input genomic sequence

  • k (int, optional) – Number of augmented instances to generate (default: None, uses self.k)

Returns:

list – List of augmented sequences

augment_from_file(input_file, output_file)[source]

Main function to handle the augmentation process from a file input to a file output.

This method loads sequences from an input file, augments them using the MLM model, and saves the augmented sequences to an output file.

Parameters:
  • input_file (str) – Path to the input file containing sequences

  • output_file (str) – Path to the output file where augmented sequences will be saved

augment_sequence(seq)[source]

Perform augmentation on a single sequence by predicting masked tokens.

Parameters:

seq (str) – Input genomic sequence with masked tokens

Returns:

str – Augmented sequence with predicted tokens replacing masked tokens

augment_sequences(sequences)[source]

Augment a list of sequences by applying noise and performing MLM-based predictions.

Parameters:

sequences (list) – List of genomic sequences to augment

Returns:

list – List of all augmented sequences

load_sequences_from_file(input_file)[source]

Load sequences from a JSON file.

Parameters:

input_file (str) – Path to the input JSON file containing sequences

Returns:

list – List of sequences loaded from the file

save_augmented_sequences(augmented_sequences, output_file)[source]

Save augmented sequences to a JSON file.

Parameters:
  • augmented_sequences (list) – List of augmented sequences to save

  • output_file (str) – Path to the output JSON file

Model Utilities

This module provides utility classes and functions for handling model inputs, pooling operations, and attention mechanisms used across different OmniGenome model types.

class omnigenbench.src.model.module_utils.InteractingAttention(embed_size, num_heads=24)[source]

Bases: Module

An interacting attention mechanism for sequence modeling.

This class implements a multi-head attention mechanism with residual connections and layer normalization. It’s designed for processing sequences where different parts of the sequence need to interact with each other.

Variables:
  • attention – Multi-head attention layer

  • layer_norm – Layer normalization for residual connections

  • fc_out – Output projection layer

forward(query, keys, values)[source]

Forward pass through the interacting attention mechanism.

Parameters:
  • query (torch.Tensor) – Query tensor [batch_size, query_len, embed_size]

  • keys (torch.Tensor) – Key tensor [batch_size, key_len, embed_size]

  • values (torch.Tensor) – Value tensor [batch_size, value_len, embed_size]

Returns:

torch.Tensor – Output tensor with same shape as query

class omnigenbench.src.model.module_utils.OmniPooling(config, *args, **kwargs)[source]

Bases: Module

A flexible pooling layer for OmniGenome models that handles different input formats.

This class provides a unified interface for pooling operations across different model architectures, supporting both causal language models and encoder-based models. It can handle various input formats including tuples, dictionaries, BatchEncoding objects, and tensors.

Variables:
  • config – Model configuration object containing architecture and tokenizer settings

  • pooler – BertPooler instance for non-causal models, None for causal models

forward(inputs, last_hidden_state)[source]

Perform pooling operation on the last hidden state.

This method handles different input formats and applies appropriate pooling: - For causal language models: Uses the last non-padded token - For encoder models: Uses the BertPooler

Parameters:
  • inputs – Input data in various formats (tuple, dict, BatchEncoding, or tensor)

  • last_hidden_state (torch.Tensor) – Hidden states from the model [batch_size, seq_len, hidden_size]

Returns:

torch.Tensor – Pooled representation [batch_size, hidden_size]

Raises:

ValueError – If input format is not supported or cannot be parsed