Source code for omnigenbench.src.abc.abstract_model

# -*- coding: utf-8 -*-
# file: omnigenbench_model.py
# time: 18:36 06/04/2024
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# huggingface: https://huggingface.co/yangheng
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
# Copyright (C) 2019-2024. All Rights Reserved.
import json
import os
import shutil
import time
import warnings
import inspect
from importlib import import_module

import dill
import findfile
import torch
from transformers import AutoModel, AutoConfig, AutoTokenizer, BatchEncoding

from ..misc.utils import fprint, env_meta_info

warnings.filterwarnings("once")


[docs] def count_parameters(model): """ This function iterates through all parameters of a PyTorch model and counts only those that require gradients (i.e., trainable parameters). Args: model (torch.nn.Module): A PyTorch model. Returns: int: The total number of trainable parameters. Example: >>> model = OmniModelForSequenceClassification(config, tokenizer) >>> num_params = count_parameters(model) >>> print(f"Model has {num_params} trainable parameters") """ return sum(p.numel() for p in model.parameters() if p.requires_grad)
[docs] class OmniModel(torch.nn.Module): """ This class provides a unified interface for all genomic models in the OmniGenome framework. It handles model initialization, forward passes, loss computation, prediction, inference, and model persistence. """ def __init__(self, config_or_model, tokenizer, *args, **kwargs): """ Initializes the model. This method handles different types of model initialization: - From a pre-trained model path (string) - From a PyTorch model instance - From a configuration object Args: config_or_model: A model configuration, a pre-trained model path (str), or a `torch.nn.Module` instance. tokenizer: The tokenizer associated with the model. *args: Additional positional arguments. **kwargs: Additional keyword arguments. - label2id (dict): Mapping from class labels to IDs. - num_labels (int): The number of labels. - trust_remote_code (bool): Whether to trust remote code when loading from Hugging Face Hub. Defaults to True. - ignore_mismatched_sizes (bool): Whether to ignore size mismatches when loading pre-trained weights. Defaults to False. - dropout (float): Dropout rate. Defaults to 0.0. Raises: ValueError: If config_or_model is not a valid type or if required configuration is missing. RuntimeError: If the hidden size cannot be determined from the config. Example: >>> # Initialize from a pre-trained model >>> model = OmniModelForSequenceClassification("model_path", tokenizer) >>> # Initialize from a configuration >>> config = AutoConfig.from_pretrained("model_path") >>> model = OmniModelForSequenceClassification(config, tokenizer) """ self.loss_fn = None label2id = kwargs.pop("label2id", None) trust_remote_code = kwargs.pop("trust_remote_code", True) num_labels = kwargs.pop("num_labels", None) ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) if label2id is not None and num_labels is None: num_labels = len(label2id) elif num_labels is not None and label2id is None: label2id = {str(i): i for i in range(num_labels)} elif label2id is None and num_labels is None: raise ValueError( "Either label2id or num_labels must be provided to initialize the model." ) else: if len(label2id) != num_labels: raise ValueError( "The length of label2id does not match num_labels. " f"Expected {num_labels}, but got {len(label2id)}." ) # do not change the order of the following lines super().__init__(*args, **kwargs) if isinstance(config_or_model, str): config = AutoConfig.from_pretrained( config_or_model, num_labels=num_labels, label2id=label2id, trust_remote_code=trust_remote_code, ) # Load the model from either `architectures` or `auto_map` if hasattr(config, "auto_map") and config.auto_map: architectures = list(set(config.auto_map.keys()) - set(["AutoConfig"])) if architectures: model_cls_name = ( "AutoModel" if "AutoModel" in architectures else architectures[-1] ) if "multimolecule" in config_or_model.__repr__().lower(): model_cls = getattr(import_module(f"multimolecule"), model_cls_name) else: model_cls = getattr(import_module(f"transformers"), model_cls_name) model = model_cls.from_pretrained( config_or_model, config=config, trust_remote_code=trust_remote_code, ignore_mismatched_sizes=ignore_mismatched_sizes, ).base_model else: raise ValueError( f"The model cannot be instantiated from {config_or_model}. " f"Please check the model configuration contains the architectures or auto_map." ) elif hasattr(config, "architectures") and config.architectures: model_cls_name = ( AutoModel if "AutoModel" in config.architectures else config.architectures[-1] ) if "multimolecule" in config_or_model.__repr__().lower(): model_cls = getattr(import_module(f"multimolecule"), model_cls_name) else: model_cls = getattr(import_module(f"transformers"), model_cls_name) model = model_cls.from_pretrained( config_or_model, config=config, trust_remote_code=trust_remote_code, ignore_mismatched_sizes=ignore_mismatched_sizes, ).base_model else: raise ValueError( "Neither `architectures` nor `auto_map` is defined in the config." ) self.model = model self.model.config = config del model_cls elif isinstance(config_or_model, torch.nn.Module): self.model = config_or_model self.model.config.num_labels = ( num_labels if len(label2id) == num_labels else len(label2id) ) self.model.config.label2id = label2id elif isinstance(config_or_model, AutoConfig): config = config_or_model config.num_labels = ( num_labels if len(label2id) == num_labels else len(label2id) ) config.label2id = label2id self.model = AutoModel.from_config(config) self.model.config = config else: raise ValueError( "The config_or_model should be either a string, a torch.nn.Module or a AutoConfig object." ) # Update the config self.config = self.model.config if isinstance(label2id, dict): self.config.label2id = label2id self.config.id2label = {v: k for k, v in label2id.items()} if ( not hasattr(self.config, "num_labels") or len(self.config.id2label) != self.config.num_labels ): fprint( "Warning: The number of labels in the config is not equal to the number of labels in the label2id dictionary. " ) fprint( "Please check the label2id dictionary and the num_labels parameter in the config." ) self.config.num_labels = len(self.config.id2label) assert ( len(self.config.label2id) == num_labels ), f"Expected {num_labels} labels, but got {len(self.config.label2id)} in label2id dictionary." # The metadata of the model self.metadata = env_meta_info() self.metadata["model_cls"] = self.__class__.__name__ # The config of the model if hasattr(self.config, "n_embd") and self.config.n_embd: self.config.hidden_size = self.config.n_embd elif hasattr(self.config, "d_model") and self.config.d_model: self.config.hidden_size = self.config.d_model elif hasattr(self.config, "hidden_size") and self.config.hidden_size: self.config.hidden_size = self.config.hidden_size else: raise RuntimeError( "The hidden size of the model is not found in the config." ) # The tokenizer of the model self.tokenizer = tokenizer self.metadata["tokenizer_cls"] = self.tokenizer.__class__.__name__ if hasattr(self.tokenizer, "base_tokenizer"): self.pad_token_id = self.tokenizer.base_tokenizer.pad_token_id else: self.pad_token_id = self.tokenizer.pad_token_id self.dropout = torch.nn.Dropout(kwargs.get("dropout", 0.0)) self.activation = torch.nn.Tanh()
[docs] def last_hidden_state_forward(self, **inputs): """ Performs a forward pass to get the last hidden state from the base model. It also handles compatibility with different model architectures by mapping input parameters appropriately. Args: **inputs: The inputs to the model, compatible with the base model's forward method. Typically includes 'input_ids', 'attention_mask', and other model-specific parameters. Returns: torch.Tensor: The last hidden state tensor. Example: >>> inputs = { ... 'input_ids': torch.tensor([[1, 2, 3, 4]]), ... 'attention_mask': torch.tensor([[1, 1, 1, 1]]) ... } >>> hidden_states = model.last_hidden_state_forward(**inputs) """ model = self.model input_mapping = {} inputs["output_hidden_states"] = True if "strippedhyena" in model.__class__.__name__.lower(): inputs["x"] = inputs["input_ids"] # For compatibility with Evo models if isinstance(inputs, BatchEncoding) or isinstance(inputs, dict): # Determine the input parameter names of the model's forward method forward_params = inspect.signature(model.forward).parameters # Map the inputs to the forward method parameters for param in forward_params: if param in inputs: input_mapping[param] = inputs[param] # 对于未在模型签名中声明的关键参数,可以给出警告或日志 ignored_keys = set(inputs.keys()) - set(input_mapping.keys()) if ignored_keys: warnings.warn(f"Warning: Ignored keys in inputs: {ignored_keys}") inputs = input_mapping elif isinstance(inputs, tuple): input_ids = inputs[0] attention_mask = inputs[1] if len(inputs) > 1 else None inputs = {"input_ids": input_ids, "attention_mask": attention_mask} elif isinstance(inputs, torch.Tensor): shape = inputs.shape try: if len(shape) == 3: if shape[1] == 2: input_ids = inputs[:, 0] attention_mask = inputs[:, 1] else: input_ids = inputs[0] attention_mask = inputs[1] if len(inputs) > 1 else None elif len(shape) == 2: input_ids = inputs attention_mask = None else: raise ValueError( f"Failed to get the input_ids and attention_mask from the inputs, got shape {shape}." ) except: raise ValueError( f"Failed to get the input_ids and attention_mask from the inputs, got shape {shape}." ) inputs = {"input_ids": input_ids, "attention_mask": attention_mask} else: raise ValueError( f"The inputs should be a tuple, BatchEncoding or a dictionary-like object, got {type(inputs)}." ) # 执行模型 outputs = model(**inputs) if not hasattr(outputs, "last_hidden_state"): warnings.warn( f"last_hidden_state not found in the outputs from the {model.__class__.__name__} model." ) if hasattr(outputs, "last_hidden_state"): last_hidden_state = outputs.last_hidden_state elif isinstance(outputs, dict) and "last_hidden_state" in outputs: last_hidden_state = outputs["last_hidden_state"] elif hasattr(outputs, "hidden_states"): last_hidden_state = outputs.hidden_states[-1] elif isinstance(outputs, (list, tuple, torch.Tensor)): if len(outputs) <= 2: # For Evo models that return a tuple of (last_hidden_state, logits) last_hidden_state = outputs[0] elif len(outputs) >= 3: last_hidden_state = outputs[-1] else: raise ValueError( f"Cannot find the last hidden state in the outputs from the {model.__class__.__name__} model, " f"please check the model architecture." ) return last_hidden_state
[docs] def loss_function(self, logits, labels): """ Calculates the loss. This method should be implemented by concrete model classes to define how the loss is calculated for their specific task (classification, regression, etc.). Args: logits (torch.Tensor): The model's output logits. labels (torch.Tensor): The ground truth labels. Returns: torch.Tensor: The calculated loss. Raises: NotImplementedError: If the method is not implemented by the subclass. Example: >>> # In a classification model >>> loss = model.loss_function(logits, labels) """ raise NotImplementedError( "The loss_function() function should be implemented for your model." )
[docs] def set_loss_fn(self, loss_function): """ Sets a custom loss function for the model. The loss function should be compatible with the model's output format. Args: loss_function (callable): A callable loss function that takes logits and labels as arguments. Example: >>> import torch.nn as nn >>> model.set_loss_fn(nn.CrossEntropyLoss()) """ self.loss_fn = loss_function try: self.loss_fn.weight.to(self.model.device) except AttributeError: # If the loss function does not have a weight attribute, we assume it's not weighted pass
[docs] def predict(self, sequence_or_inputs, **kwargs): """ This method takes raw sequences or tokenized inputs and returns the raw model outputs (logits, hidden states, etc.) without post-processing. It's useful for getting the model's direct predictions for further processing. Args: sequence_or_inputs: A sequence (str), list of sequences, or tokenized inputs (dict/tuple). **kwargs: Additional arguments for tokenization and inference. Returns: dict: A dictionary containing the raw model outputs, typically including `logits`, `last_hidden_state`, and other model-specific outputs. Example: >>> # Predict on a single sequence >>> outputs = model.predict("ATCGATCG") >>> # Predict on multiple sequences >>> outputs = model.predict(["ATCGATCG", "GCTAGCTA"]) """ # Please implement the predict() function for your model raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs) return raw_outputs
[docs] def inference(self, sequence_or_inputs, **kwargs): """ This method takes raw sequences or tokenized inputs and returns processed predictions that are ready for human consumption. It typically includes post-processing steps like converting logits to class labels or probabilities. Args: sequence_or_inputs: A sequence (str), list of sequences, or tokenized inputs (dict/tuple). **kwargs: Additional arguments for tokenization and inference. Returns: dict: A dictionary containing the processed predictions, typically including 'predictions', 'confidence', and other human-readable outputs. Example: >>> # Inference on a single sequence >>> results = model.inference("ATCGATCG") >>> print(results['predictions']) # Class labels >>> # Inference on multiple sequences >>> results = model.inference(["ATCGATCG", "GCTAGCTA"]) """ # Please implement the predict() function for your model raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs) return raw_outputs
def __call__(self, **inputs): """ The main forward pass of the model, suitable for training loops. This method is the primary interface for model forward passes during training. It handles both tokenized inputs and raw sequences, calculates loss if labels are provided, and returns a comprehensive output dictionary. Args: **inputs: A dictionary of tokenized inputs, potentially including labels. Can also handle raw sequences that will be tokenized automatically. Returns: dict: A dictionary containing logits, last_hidden_state, labels, and loss (if labels were provided). Example: >>> # Training forward pass >>> outputs = model( ... input_ids=torch.tensor([[1, 2, 3, 4]]), ... attention_mask=torch.tensor([[1, 1, 1, 1]]), ... labels=torch.tensor([0]) ... ) >>> loss = outputs['loss'] """ # For transformer trainer integration, we need to pop the "inputs" to be a tokenized inputs object. # For native trainer, the inputs are already tokenized inputs object labels = inputs.pop("labels", None) inputs = inputs.pop("inputs", inputs) inputs["labels"] = labels if isinstance(inputs, dict): labels = inputs.get("labels", None) label = inputs.get("label", None) labels = labels if labels is not None else label # if labels is None: # warnings.warn( # "No labels are provided in the inputs, the model will not calculate the loss." # ) elif isinstance(inputs, tuple): labels = inputs[1] inputs = inputs[0] elif labels is not None: labels = labels outputs = self.forward(**inputs) if labels is not None: outputs["loss"] = self._calculate_loss(outputs, labels) else: outputs["loss"] = None return outputs def _calculate_loss(self, outputs, labels): """ Internal method to calculate loss if not already present in outputs. :param outputs: The dictionary of model outputs. :param labels: The ground truth labels. :return: The calculated loss. """ loss = outputs.get("loss", None) if loss is not None: return loss logits = outputs["logits"] if logits is not None or labels is not None: loss = self.loss_function(logits, labels) return loss else: raise RuntimeError( "The output of the forward() function should be a dictionary-like objective" " and have either 'loss', or 'logits' and 'labels' attribute." )
[docs] def save(self, path, overwrite=False, dtype=torch.float16, **kwargs): """ Saves the model, tokenizer, and metadata to a directory. :param path: The directory to save the model to. :param overwrite: Whether to overwrite the directory if it exists. :param dtype: The data type to save the model weights in. :param kwargs: Additional arguments. """ self.eval() if os.path.exists(path) and not overwrite: fprint( f"The path {path} already exists, please set overwrite=True to overwrite it. " f"Rename the path to {path}_{time.strftime('%Y%m%d_%H%M%S')} to save it with a timestamp." ) path = f"{path}_{time.strftime('%Y%m%d_%H%M%S')}" if not os.path.exists(path): os.makedirs(path) for file in findfile.find_files( self.config.name_or_path, or_key=["bin", "json", "txt", "py"], exclude_key=["pytorch_model.bin", "model.safetensors"], return_relative_path=False, ): shutil.copyfile(file, f"{path}/{os.path.basename(file)}") _device = self.model.device _dtype = self.model.dtype self.model.to(dtype).to("cpu") self.tokenizer.save_pretrained(path) # Save metadata including information about the loss function metadata = self.metadata.copy() if self.loss_fn is not None: metadata["loss_fn_class"] = self.loss_fn.__class__.__name__ metadata["loss_fn_module"] = self.loss_fn.__class__.__module__ with open(f"{path}/metadata.json", "w", encoding="utf8") as f: json.dump(metadata, f) with open(f"{path}/tokenizer.bin", "wb") as f: dill.dump(self.tokenizer, f) # Try to save the underlying base model (e.g., HuggingFace models) # Some lightweight baselines may not implement `save_pretrained` on the base model. try: self.model.save_pretrained( f"{path}", safe_serialization=False ) # do not remove this line, used to save customized model scripts except AttributeError: # Fallback: if the OmniModel subclass provides its own `save_pretrained`, use it if hasattr(self, "save_pretrained"): try: self.save_pretrained(path, overwrite=True) except Exception: # As a last resort, continue and rely on the full state_dict save below pass # Otherwise, continue; the complete state dict will be saved below # Save complete state dict including all components with open(f"{path}/pytorch_model.bin", "wb") as f: torch.save(self.state_dict(), f) self.model.to(_dtype).to(_device) fprint(f"The model is saved to {path}.")
[docs] def load(self, path, **kwargs): """ Loads the model, tokenizer, and metadata from a directory. Args: path: The directory to load the model from. **kwargs: Additional arguments. Returns: The loaded model instance. """ with open(f"{path}/metadata.json", "r", encoding="utf8") as f: metadata = json.load(f) if metadata["model_cls"] != self.__class__.__name__: # Check the model class raise ValueError( f"The model class in the loaded model is {metadata['model_cls']}, " f"but the current model class is {self.__class__.__name__}." ) config = AutoConfig.from_pretrained(path, trust_remote_code=True, **kwargs) for key, value in config.__dict__.items(): if key not in self.config.__dict__ or self.config.__dict__[key] != value: fprint( f"Warning: The value of the key {key} in the loaded model is {value}, " f"but the current value is {self.config.__dict__.get(key, None)}." ) # Attempt to restore any saved loss function if "loss_fn_class" in metadata and "loss_fn_module" in metadata: try: loss_module = import_module(metadata["loss_fn_module"]) loss_class = getattr(loss_module, metadata["loss_fn_class"]) # Initialize loss function if possible (parameters will be loaded with state dict) self.loss_fn = loss_class() fprint( f"Restored loss function: {metadata['loss_fn_class']} from {metadata['loss_fn_module']}" ) except (ImportError, AttributeError) as e: warnings.warn(f"Could not restore loss function: {e}") with open(f"{path}/pytorch_model.bin", "rb") as f: loaded_state_dict = torch.load(f, map_location=kwargs.get("device", "cpu")) # Check if keys match between current and loaded state dict current_keys = set(self.state_dict().keys()) loaded_keys = set(loaded_state_dict.keys()) missing_keys = current_keys - loaded_keys unexpected_keys = loaded_keys - current_keys if missing_keys: warnings.warn(f"Missing keys in loaded weights: {missing_keys}") if unexpected_keys: warnings.warn(f"Unexpected keys in loaded weights: {unexpected_keys}") self.load_state_dict(loaded_state_dict, strict=False) # Load the tokenizer if os.path.exists(f"{path}/tokenizer.bin"): with open(f"{path}/tokenizer.bin", "rb") as f: self.tokenizer = dill.load(f) return self
def _forward_from_raw_input(self, sequence_or_inputs, **kwargs): """ Tokenizes raw input and performs a forward pass in no_grad mode. :param sequence_or_inputs: A sequence, list of sequences, or tokenized inputs. :param kwargs: Additional arguments for tokenization. :return: A dictionary containing the raw model outputs and the tokenized inputs. """ if not isinstance(sequence_or_inputs, BatchEncoding) and not isinstance( sequence_or_inputs, dict ): inputs = self.tokenizer( sequence_or_inputs, padding=kwargs.pop("padding", True), max_length=kwargs.pop("max_length", 1024), truncation=kwargs.pop("truncation", True), return_tensors=kwargs.pop("return_tensors", "pt"), **kwargs, ) else: inputs = sequence_or_inputs inputs = inputs.to(self.model.device) with torch.no_grad(): raw_outputs = self(**inputs) raw_outputs["inputs"] = inputs return raw_outputs
[docs] @staticmethod def from_pretrained(model_name_or_path, tokenizer, *args, **kwargs): """ Loads a pre-trained model and tokenizer. :param model_name_or_path: The name or path of the pre-trained model. :param tokenizer: The tokenizer to use. :param args: Additional positional arguments. :param kwargs: Additional keyword arguments. :return: An instance of `OmniModel`. """ config = kwargs.pop("config", None) if config is None: config = AutoConfig.from_pretrained(model_name_or_path, **kwargs) base_model = AutoModel.from_pretrained(model_name_or_path, **kwargs) if tokenizer is None: tokenizer = AutoTokenizer.from_pretrained(base_model, **kwargs) return OmniModel(config, base_model, tokenizer, *args, **kwargs)
[docs] def model_info(self): """ Prints and returns detailed information about the model. :return: A string containing the model information. """ info = f"Model Name: {self.__class__.__name__}\n" info += f"Model Metadata: {self.metadata}\n" info += f"Base Model Name: {self.config.name_or_path}\n" info += f"Model Type: {self.config.model_type}\n" info += f"Model Architecture: {self.config.architectures}\n" info += f"Model Parameters: {count_parameters(self.model) / 1e6} M\n" info += f"Model Config: {self.config}\n" fprint(info) return info