Source code for omnigenbench.src.model.embedding.model

# -*- coding: utf-8 -*-
# file: model.py
# time: 18:37 22/09/2024
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# huggingface: https://huggingface.co/yangheng
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
# Copyright (C) 2019-2024. All Rights Reserved.

import torch
from transformers import AutoTokenizer, AutoModel

from ....src.misc.utils import fprint


[docs] class OmniModelForEmbedding(torch.nn.Module): """ This class provides a unified interface for loading pre-trained models and generating embeddings from genomic sequences. It supports various aggregation methods and batch processing for efficient embedding generation. Attributes: tokenizer: The tokenizer for processing input sequences model: The pre-trained model for generating embeddings _device: The device (CPU/GPU) where the model is loaded Example: >>> from omnigenbench import OmniModelForEmbedding >>> model = OmniModelForEmbedding("anonymous8/OmniGenome-186M") >>> sequences = ["ATCGGCTA", "GGCTAGCTA"] >>> embeddings = model.batch_encode(sequences) >>> print(f"Embeddings shape: {embeddings.shape}") torch.Size([2, 768]) """ def __init__(self, model_name_or_path, *args, **kwargs): """ Initialize the embedding model. Args: model_name_or_path (str): Name or path of the pre-trained model to load *args: Additional positional arguments passed to AutoModel.from_pretrained **kwargs: Additional keyword arguments passed to AutoModel.from_pretrained """ super().__init__() self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) self.model = AutoModel.from_pretrained(model_name_or_path, *args, **kwargs) self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(self._device) self.model.eval() # Set model to evaluation mode
[docs] def batch_encode( self, sequences, batch_size=8, max_length=512, agg="head", require_grad: bool = False, return_on_cpu: bool = True, use_autocast: bool = False, amp_dtype=None, ): """批量编码序列为 pooled 向量。 Batch encode sequences into aggregated (pooled) embeddings. 参数 / Args: sequences (List[str]): 输入序列 / input DNA (or RNA) sequences. batch_size (int): 批大小 / processing batch size. max_length (int): tokenizer 截断/填充长度 / truncate/pad length. agg (str): 聚合方式 head|mean|tail / aggregation method. require_grad (bool): 是否需要梯度; True 时允许反向传播 / keep graph for finetuning. return_on_cpu (bool): 若 True 输出放到 CPU, 否则保持在模型设备 / move result to CPU for memory relief. use_autocast (bool): 使用混合精度 / enable autocast (CUDA only). amp_dtype (torch.dtype|None): autocast 精度类型 / dtype for autocast. 返回 / Returns: torch.Tensor 形状 (N, H) / shape (num_sequences, hidden_size) 兼容性 / Compatibility: 旧调用无需修改; 新参数有默认值。 """ embeds = [] device = self.device is_cuda = isinstance(device, torch.device) and device.type == "cuda" for i in range(0, len(sequences), batch_size): batch_sequences = sequences[i : i + batch_size] inputs = self.tokenizer( batch_sequences, return_tensors="pt", padding=True, truncation=True, max_length=max_length, ) inputs = {k: v.to(device) for k, v in inputs.items()} ctx = ( ( torch.autocast(device_type="cuda", dtype=amp_dtype) if (use_autocast and is_cuda) else torch.enable_grad() ) if require_grad else torch.no_grad() ) with ctx: outputs = self.model(**inputs).last_hidden_state # (B,L,H) hidden = outputs if not return_on_cpu else outputs.cpu() if agg == "head": pooled = hidden[:, 0, :] elif agg == "mean": mask = ( inputs["attention_mask"] if not return_on_cpu else inputs["attention_mask"].cpu() ) pooled = (hidden * mask.unsqueeze(-1)).sum(1) / mask.sum( 1, keepdim=True ) elif agg == "tail": mask = inputs["attention_mask"] lengths = mask.sum(1) - 1 pooled_list = [] for bi, l in enumerate(lengths): pooled_list.append(hidden[bi, int(l.item()), :]) pooled = torch.stack(pooled_list, 0) else: raise ValueError(f"Unsupported agg: {agg}") embeds.append(pooled) out = torch.cat(embeds, 0) return out
[docs] def batch_encode_tokens( self, sequences, batch_size=8, max_length=512, use_autocast=False, amp_dtype=None, require_grad: bool = False, return_on_cpu: bool = True, ): """ Encode sequences to token-level embeddings (last_hidden_state). Args: sequences (List[str]): Input DNA/RNA sequences for token-level encoding batch_size (int, default=8): Number of sequences to process per batch max_length (int, default=512): Maximum sequence length for tokenization use_autocast (bool, default=False): Enable mixed precision training (CUDA only) amp_dtype (torch.dtype, optional): Data type for automatic mixed precision require_grad (bool, default=False): Preserve gradient computation graph for fine-tuning return_on_cpu (bool, default=True): Transfer outputs to CPU memory Returns: torch.Tensor: Token embeddings with shape (num_sequences, max_length, hidden_size) Note: When require_grad=True, gradients flow through the transformer model for end-to-end training. Set return_on_cpu=False to keep tensors on GPU device for downstream processing. """ outputs = [] for i in range(0, len(sequences), batch_size): batch_sequences = sequences[i : i + batch_size] inputs = self.tokenizer( batch_sequences, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length, ) inputs = {key: value.to(self.device) for key, value in inputs.items()} ctx = ( ( torch.autocast(device_type="cuda", dtype=amp_dtype) if ( use_autocast and isinstance(self.device, torch.device) and self.device.type == "cuda" ) else torch.enable_grad() ) if require_grad else torch.no_grad() ) with ctx: last_hidden = self.model(**inputs).last_hidden_state # (B, L, H) if return_on_cpu: last_hidden = last_hidden.cpu() outputs.append(last_hidden) out = torch.cat(outputs, dim=0) return out
[docs] def encode_tokens( self, sequence, max_length=512, use_autocast=False, amp_dtype=None, require_grad: bool = False, return_on_cpu: bool = True, ): """ Encode a single sequence to token-level embeddings. Args: sequence (str): Input DNA/RNA sequence for token-level encoding max_length (int, default=512): Maximum sequence length for tokenization use_autocast (bool, default=False): Enable mixed precision training (CUDA only) amp_dtype (torch.dtype, optional): Data type for automatic mixed precision require_grad (bool, default=False): Preserve gradient computation graph for fine-tuning return_on_cpu (bool, default=True): Transfer output to CPU memory Returns: torch.Tensor: Token embeddings with shape (max_length, hidden_size) Example: >>> model = OmniModelForEmbedding("yangheng/OmniGenome-52M") >>> sequence = "ATCGATCGATCG" >>> token_embeddings = model.encode_tokens(sequence, max_length=200) >>> print(f"Token embeddings shape: {token_embeddings.shape}") torch.Size([200, 768]) """ device = self.device inputs = self.tokenizer( sequence, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length, ) inputs = {k: v.to(device) for k, v in inputs.items()} is_cuda = isinstance(device, torch.device) and device.type == "cuda" ctx = ( ( torch.autocast(device_type="cuda", dtype=amp_dtype) if (use_autocast and is_cuda) else torch.enable_grad() ) if require_grad else torch.no_grad() ) with ctx: hidden = self.model(**inputs).last_hidden_state # (1, L, H) if return_on_cpu: hidden = hidden.cpu() # Remove batch dimension for single sequence return hidden.squeeze(0) # (L, H)
[docs] def encode( self, sequence, max_length=512, agg="head", keep_dim=False, require_grad: bool = False, return_on_cpu: bool = True, use_autocast: bool = False, amp_dtype=None, ): """编码单个序列 / Encode a single sequence. 参数 / Args: sequence (str): 输入序列 / input sequence. max_length (int): 截断/填充长度 / tokenizer max length. agg (str): head|mean|tail 聚合策略 / aggregation strategy. keep_dim (bool): 是否保留 batch 维 / keep batch dimension. require_grad (bool): 是否保留梯度 / keep graph for finetune. return_on_cpu (bool): 输出是否转 CPU / move result to CPU. use_autocast (bool): 是否使用 autocast / enable autocast. amp_dtype (torch.dtype|None): autocast dtype. 返回 / Returns: torch.Tensor shape (H,) 或 (1,H) / pooled embedding. """ device = self.device inputs = self.tokenizer( sequence, return_tensors="pt", padding=True, truncation=True, max_length=max_length, ) inputs = {k: v.to(device) for k, v in inputs.items()} is_cuda = isinstance(device, torch.device) and device.type == "cuda" ctx = ( ( torch.autocast(device_type="cuda", dtype=amp_dtype) if (use_autocast and is_cuda) else torch.enable_grad() ) if require_grad else torch.no_grad() ) with ctx: hidden = self.model(**inputs).last_hidden_state # (1,L,H) hidden = hidden if not return_on_cpu else hidden.cpu() if agg == "head": emb = hidden[:, 0, :] elif agg == "mean": mask = ( inputs["attention_mask"] if not return_on_cpu else inputs["attention_mask"].cpu() ) emb = (hidden * mask.unsqueeze(-1)).sum(1) / mask.sum(1, keepdim=True) elif agg == "tail": mask = inputs["attention_mask"] l = int(mask.sum(1).item()) - 1 emb = hidden[:, l, :] else: raise ValueError(f"Unsupported agg: {agg}") if not keep_dim: emb = emb.squeeze(0) return emb
[docs] def save_embeddings(self, embeddings, output_path): """ Save the generated embeddings to a file. Args: embeddings (torch.Tensor): The embeddings to save output_path (str): Path to save the embeddings Example: >>> embeddings = model.batch_encode(sequences) >>> model.save_embeddings(embeddings, "embeddings.pt") >>> print("Embeddings saved successfully") """ torch.save(embeddings, output_path) fprint(f"Embeddings saved to {output_path}")
[docs] def load_embeddings(self, embedding_path): """ Load embeddings from a file. Args: embedding_path (str): Path to the saved embeddings Returns: torch.Tensor: The loaded embeddings Example: >>> embeddings = model.load_embeddings("embeddings.pt") >>> print(f"Loaded embeddings shape: {embeddings.shape}") torch.Size([100, 768]) """ embeddings = torch.load(embedding_path) fprint(f"Loaded embeddings from {embedding_path}") return embeddings
[docs] def compute_similarity(self, embedding1, embedding2, dim=0): """ Compute cosine similarity between two embeddings. Args: embedding1 (torch.Tensor): The first embedding embedding2 (torch.Tensor): The second embedding dim (int, optional): Dimension along which to compute cosine similarity. Defaults to 0 Returns: float: Cosine similarity score between -1 and 1 Example: >>> emb1 = model.encode("ATCGGCTA") >>> emb2 = model.encode("GGCTAGCTA") >>> similarity = model.compute_similarity(emb1, emb2) >>> print(f"Cosine similarity: {similarity:.4f}") 0.8234 """ similarity = torch.nn.functional.cosine_similarity( embedding1, embedding2, dim=dim ) return similarity
@property def device(self): """ Get the current device for the underlying model. Returns: torch.device: The device where the model currently resides. Note: This queries the model parameters directly so it stays correct when external frameworks (e.g., Accelerate/DDP) move the module across devices after initialization. """ try: return next(self.model.parameters()).device except StopIteration: return self._device
# Example usage if __name__ == "__main__": model_name = "anonymous8/OmniGenome-186M" embedding_model = OmniModelForEmbedding(model_name) # Encode multiple sequences sequences = ["ATCGGCTA", "GGCTAGCTA"] embedding = embedding_model.encode(sequences[0]) fprint(f"Single embedding shape: {embedding.shape}") embeddings = embedding_model.batch_encode(sequences) fprint(f"Embeddings for sequences: {embeddings}") # Save and load embeddings embedding_model.save_embeddings(embeddings, "embeddings.pt") loaded_embeddings = embedding_model.load_embeddings("embeddings.pt") # Compute similarity between two embeddings similarity = embedding_model.compute_similarity( loaded_embeddings[0], loaded_embeddings[1] ) fprint(f"Cosine similarity: {similarity}")