Source code for omnigenbench.src.lora.lora_model

# -*- coding: utf-8 -*-
# file: lora_model.py
# time: 12:36 11/06/2025
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# homepage: https://yangheng95.github.io
# github: https://github.com/yangheng95
# huggingface: https://huggingface.co/yangheng
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
# Copyright (C) 2019-2025. All Rights Reserved.
"""
This module provides Low-Rank Adaptation (LoRA) implementation for efficient fine-tuning of large
genomic language models. LoRA reduces the number of trainable parameters
by adding low-rank adaptation layers to existing model weights.
"""
from torch import nn
from ...src.misc.utils import fprint


[docs] def find_linear_target_modules(model, keyword_filter=None, use_full_path=True): """ This function searches through a model's modules to identify linear layers that can be adapted using LoRA. It supports filtering by keyword patterns to target specific types of layers. Args: model: The model to search for linear modules keyword_filter (str, list, tuple, optional): Keywords to filter modules by name use_full_path (bool): Whether to return full module paths or just names (default: True) Returns: list: Sorted list of linear module names that can be targeted for LoRA Raises: TypeError: If keyword_filter is not None, str, or a list/tuple of str """ import re from torch import nn if keyword_filter is not None: if isinstance(keyword_filter, str): keyword_filter = [keyword_filter] elif not isinstance(keyword_filter, (list, tuple)): raise TypeError("keyword_filter must be None, str, or a list/tuple of str") pattern = "|".join(map(re.escape, keyword_filter)) linear_modules = set() for name, module in model.named_modules(): if isinstance(module, nn.Linear): if keyword_filter is None or re.search(pattern, name, re.IGNORECASE): linear_modules.add(name if use_full_path else name.split(".")[-1]) return sorted(linear_modules)
[docs] def auto_lora_model(model, **kwargs): """ This function automatically identifies suitable target modules and creates a LoRA-adapted version of the input model. It handles configuration setup and parameter freezing for efficient fine-tuning. Args: model: The base model to adapt with LoRA **kwargs: Additional LoRA configuration parameters Returns: The LoRA-adapted model Raises: AssertionError: If no target modules are found for LoRA injection """ from peft import LoraConfig, get_peft_model from transformers import PretrainedConfig # A bad case for the EVO-1 model, which has a custom config class ###################### if hasattr(model, "config") and not isinstance(model.config, PretrainedConfig): delattr(model.config, "Loader") model.config = PretrainedConfig.from_dict(dict(model.config)) ####################### target_modules = kwargs.pop("target_modules", None) use_rslora = kwargs.pop("use_rslora", True) bias = kwargs.pop("bias", "none") r = kwargs.pop("r", 32) lora_alpha = kwargs.pop("lora_alpha", 256) lora_dropout = kwargs.pop("lora_dropout", 0.1) if target_modules is None: target_modules = find_linear_target_modules( model, keyword_filter=kwargs.get("keyword_filter", None) ) assert target_modules is not None, "No target modules found for LoRA injection." config = LoraConfig( target_modules=target_modules, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, bias=bias, use_rslora=use_rslora, **kwargs, ) for param in model.parameters(): param.requires_grad = False lora_model = get_peft_model(model, config) trainable_params, all_param = lora_model.get_nb_trainable_parameters() fprint( f"trainable params: {trainable_params:,d} || all params: {all_param:,d}" f" || trainable%: {100 * trainable_params / all_param:.4f}" ) return lora_model
[docs] class OmniLoraModel(nn.Module): """ This class provides a wrapper around LoRA-adapted models, enabling efficient fine-tuning of large genomic language models while maintaining compatibility with the OmniGenome framework. Attributes: lora_model: The underlying LoRA-adapted model config: Model configuration device: Device the model is running on dtype: Data type of the model parameters """ def __init__(self, model, **kwargs): """ Initialize the LoRA-adapted model. Args: model: The base model to adapt with LoRA **kwargs: LoRA configuration parameters Raises: ValueError: If no target modules are specified for LoRA injection """ super(OmniLoraModel, self).__init__() target_modules = kwargs.get("target_modules", None) if target_modules is None: raise ValueError( "No target modules found for LoRA injection. To perform LoRA adaptation fine-tuning, " "please specify the target modules using the 'target_modules' argument. " "The target modules depend on the model architecture, such as 'query', 'value', etc. " ) self.lora_model = auto_lora_model(model, **kwargs) fprint( "To reduce GPU memory occupation, " "you should avoid include non-trainable parameters into optimizers, " "e.g., optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), ...), " "AVOIDING: optimizer = torch.optim.AdamW(model.parameters(), ...)" ) self.config = model.config self.to("cpu") # Move the model to CPU initially fprint( "LoRA model initialized with the following configuration:\n", self.lora_model, )
[docs] def to(self, *args, **kwargs): """ Move the model to a specific device and data type. Args: *args: Device specification (e.g., 'cuda', 'cpu') **kwargs: Additional arguments including dtype Returns: self: The model instance """ self.lora_model.to(*args, **kwargs) try: # For evo-1 and similar models, we need to set the device and dtype for param in self.parameters(): self.device = param.device self.dtype = param.dtype break for module in self.lora_model.modules(): module.device = self.device if hasattr(module, "dtype"): module.dtype = self.dtype except Exception as e: pass # Ignore errors if parameters are not available return self
[docs] def forward(self, *args, **kwargs): """ Forward pass through the LoRA model. Args: *args: Positional arguments for the forward pass **kwargs: Keyword arguments for the forward pass Returns: The output from the LoRA model """ return self.lora_model(*args, **kwargs)
[docs] def predict(self, *args, **kwargs): """ Generate predictions using the LoRA model. Args: *args: Positional arguments for prediction **kwargs: Keyword arguments for prediction Returns: Model predictions """ return self.lora_model.base_model.predict(*args, **kwargs)
[docs] def save(self, *args, **kwargs): """ Save the LoRA model. Args: *args: Positional arguments for saving **kwargs: Keyword arguments for saving Returns: Result of the save operation """ return self.lora_model.base_model.save(*args, **kwargs)
[docs] def model_info(self): """ Get information about the LoRA model. Returns: Model information from the base model """ return self.lora_model.base_model.model_info()
[docs] def set_loss_fn(self, fn): """ Set the loss function for the LoRA model. Args: fn: Loss function to set Returns: Result of setting the loss function """ return self.lora_model.base_model.set_loss_fn(fn)
[docs] def last_hidden_state_forward(self, **kwargs): """ Forward pass to get the last hidden state. Args: **kwargs: Keyword arguments for the forward pass Returns: Last hidden state from the base model """ return self.lora_model.base_model.last_hidden_state_forward(**kwargs)
[docs] def tokenizer(self): """ Get the tokenizer from the base model. Returns: The tokenizer from the base model """ return self.lora_model.base_model.tokenizer
[docs] def config(self): """ Get the configuration from the base model. Returns: The configuration from the base model """ return self.lora_model.base_model.config
[docs] def model(self): """ Get the base model. Returns: The base model """ return self.lora_model.base_model.model