Source code for omnigenbench.utility.model_hub.model_hub

# -*- coding: utf-8 -*-
# file: model_hub.py
# time: 18:13 12/04/2024
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# huggingface: https://huggingface.co/yangheng
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
# Copyright (C) 2019-2024. All Rights Reserved.
import json
import os

import autocuda
import torch
from transformers import AutoConfig, AutoModel

from ..hub_utils import query_models_info, download_model
from ...src.misc.utils import env_meta_info, fprint
from ...src.abc.abstract_tokenizer import OmniTokenizer


[docs] class ModelHub: """ This class provides a unified interface for loading pre-trained models from the OmniGenome hub or local paths. It handles model downloading, tokenizer loading, and device placement automatically. It supports various model types and can automatically download models from the hub if they're not available locally. Attributes: metadata (dict): Environment metadata information Example: >>> from omnigenbench import ModelHub >>> hub = ModelHub() >>> # Load a model from the hub >>> model, tokenizer = ModelHub.load_model_and_tokenizer("model_name") >>> # Check available models >>> models = hub.available_models() >>> print(list(models.keys())) """ def __init__(self, *args, **kwargs): """ Initialize the ModelHub instance. Args: *args: Additional positional arguments **kwargs: Additional keyword arguments """ super(ModelHub, self).__init__(*args, **kwargs) self.metadata = env_meta_info()
[docs] @staticmethod def load_model_and_tokenizer( model_name_or_path, local_only=False, device=None, dtype=torch.float16, **kwargs, ): """ This method loads both the model and tokenizer, places them on the specified device, and returns them as a tuple. It handles automatic device selection if none is specified. Args: model_name_or_path (str): Name or path of the model to load local_only (bool, optional): Whether to use only local cache. Defaults to False device (str, optional): Device to load the model on. If None, uses auto-detection dtype (torch.dtype, optional): Data type for the model. Defaults to torch.float16 **kwargs: Additional keyword arguments passed to the model loading functions Returns: tuple: A tuple containing (model, tokenizer) Example: >>> model, tokenizer = ModelHub.load_model_and_tokenizer("yangheng/OmniGenome-186M") >>> print(f"Model loaded on device: {next(model.parameters()).device}") """ model = ModelHub.load(model_name_or_path, local_only=local_only, **kwargs) fprint(f"The model and tokenizer has been loaded from {model_name_or_path}.") model.to(dtype) if device is None: device = autocuda.auto_cuda() fprint( f"No device is specified, the model will be loaded to the default device: {device}" ) model.to(device) else: model.to(device) return model, model.tokenizer
[docs] @staticmethod def load( model_name_or_path, local_only=False, device=None, dtype=torch.float16, **kwargs, ): """ This method handles model loading from various sources including local paths and the OmniGenome hub. It automatically downloads models if they're not available locally. Args: model_name_or_path (str): Name or path of the model to load local_only (bool, optional): Whether to use only local cache. Defaults to False device (str, optional): Device to load the model on. If None, uses auto-detection dtype (torch.dtype, optional): Data type for the model. Defaults to torch.float16 **kwargs: Additional keyword arguments passed to the model loading functions Returns: torch.nn.Module: The loaded model Raises: ValueError: If model_name_or_path is not a string Example: >>> model = ModelHub.load("yangheng/OmniGenome-186M") >>> print(f"Model type: {type(model)}") """ if isinstance(model_name_or_path, str) and os.path.exists(model_name_or_path): path = model_name_or_path elif isinstance(model_name_or_path, str) and not os.path.exists( model_name_or_path ): path = download_model(model_name_or_path, local_only=local_only, **kwargs) else: raise ValueError("model_name_or_path must be a string.") import importlib config = AutoConfig.from_pretrained(path, trust_remote_code=True, **kwargs) with open(f"{path}/metadata.json", "r", encoding="utf8") as f: metadata = json.load(f) tokenizer = OmniTokenizer.from_pretrained(path, **kwargs) config.metadata = metadata base_model = AutoModel.from_config(config, trust_remote_code=True, **kwargs) model_lib = importlib.import_module(metadata["library_name"].lower()).model model_cls = getattr(model_lib, metadata["model_cls"]) model = model_cls( base_model, tokenizer, label2id=config.label2id, num_labels=config.num_labels, **kwargs, ) with open(f"{path}/pytorch_model.bin", "rb") as f: model.load_state_dict( torch.load(f, map_location=kwargs.get("device", "cpu")), strict=False ) model.to(dtype) if device is None: device = autocuda.auto_cuda() fprint( f"No device is specified, the model will be loaded to the default device: {device}" ) model.to(device) else: model.to(device) return model
[docs] def available_models( self, model_name_or_path=None, local_only=False, repo="", **kwargs ): """ This method queries the OmniGenome hub to retrieve information about available models. It can filter models by name and supports both local and remote queries. Args: model_name_or_path (str, optional): Filter models by name. Defaults to None local_only (bool, optional): Whether to use only local cache. Defaults to False repo (str, optional): Repository URL to query. Defaults to "" **kwargs: Additional keyword arguments Returns: dict: Dictionary containing information about available models Example: >>> # Load all available models >>> hub = ModelHub() >>> models = hub.available_models() >>> print(f"Available models: {len(models)}") >>> # Filter models by name >>> dna_models = hub.available_models("DNA") >>> print(f"DNA models: {list(dna_models.keys())}") """ models_info = query_models_info( model_name_or_path, local_only=local_only, repo=repo, **kwargs ) return models_info
[docs] def push(self, model, **kwargs): """ Push a model to the hub. This method is not yet implemented and will raise a NotImplementedError. Args: model: The model to push to the hub **kwargs: Additional keyword arguments Raises: NotImplementedError: This method has not been implemented yet """ raise NotImplementedError("This method has not implemented yet.")