Source code for omnigenbench.utility.hub_utils

# -*- coding: utf-8 -*-
# file: hub_utils.py
# time: 16:54 13/04/2024
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# huggingface: https://huggingface.co/yangheng
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
# Copyright (C) 2019-2024. All Rights Reserved.

import json
import os
from typing import Union, Dict, Any

import findfile
import requests
import tqdm
from packaging.version import Version
from termcolor import colored

from .. import __version__ as current_version
from ..src.misc.utils import fprint, default_omnigenome_repo


[docs] def unzip_checkpoint(checkpoint_path): """ This function extracts a zipped checkpoint file to a directory, making it ready for use by the model loading functions. Args: checkpoint_path (str): The path to the checkpoint file. Returns: str: The path to the extracted checkpoint directory. Example: >>> extracted_path = unzip_checkpoint("model.zip") >>> print(extracted_path) # "model" """ if not checkpoint_path.endswith(".zip"): fprint("Checkpoint path does not end with .zip, returning the original path.") return checkpoint_path import zipfile fprint("Unzipping checkpoint from {}...".format(checkpoint_path)) with zipfile.ZipFile(checkpoint_path, "r") as zip_ref: zip_ref.extractall(checkpoint_path.strip(".zip")) return checkpoint_path.strip(".zip")
[docs] def query_models_info( keyword: Union[list, str], repo: str = None, local_only: bool = False, **kwargs ) -> Dict[str, Any]: """ This function retrieves model information from the OmniGenome hub, either from a remote repository or from a local cache. It supports filtering by keywords to find specific models. Args: keyword (Union[list, str]): A keyword or list of keywords to filter models. repo (str, optional): The repository URL to query. If None, uses the default hub. local_only (bool): Whether to use only local cache. Defaults to False. **kwargs: Additional keyword arguments. Returns: Dict[str, Any]: A dictionary containing model information filtered by the keyword. Example: >>> # Query all models >>> models = query_models_info("") >>> print(len(models)) # Number of available models >>> # Query specific models >>> models = query_models_info("DNA") >>> print(models.keys()) # Models containing "DNA" """ if local_only: with open("./models_info.json", "r", encoding="utf8") as f: models_info = json.load(f) else: repo = repo if repo else "https://huggingface.co/spaces/anonymous8/gfm_hub/" try: response = requests.get(repo + "models_info.json") models_info = response.json() with open("./models_info.json", "w", encoding="utf8") as f: json.dump(models_info, f) except Exception as e: fprint( "Fail to download models info from huggingface space, the error is: {}".format( e ) ) with open("./models_info.json", "r", encoding="utf8") as f: models_info = json.load(f) if isinstance(keyword, str): filtered_models_info = {} for key in models_info: if keyword in key: filtered_models_info[key] = models_info[key] return filtered_models_info else: return models_info
[docs] def query_pipelines_info( keyword: Union[list, str], repo: str = None, local_only: bool = False, **kwargs ) -> Dict[str, Any]: """ This function retrieves pipeline information from the OmniGenome hub, either from a remote repository or from a local cache. It supports filtering by keywords to find specific pipelines. Args: keyword (Union[list, str]): A keyword or list of keywords to filter pipelines. repo (str, optional): The repository URL to query. If None, uses the default hub. local_only (bool): Whether to use only local cache. Defaults to False. **kwargs: Additional keyword arguments. Returns: Dict[str, Any]: A dictionary containing pipeline information filtered by the keyword. Example: >>> # Query all pipelines >>> pipelines = query_pipelines_info("") >>> print(len(pipelines)) # Number of available pipelines >>> # Query specific pipelines >>> pipelines = query_pipelines_info("classification") >>> print(pipelines.keys()) # Pipelines containing "classification" """ if local_only: with open("./pipelines_info.json", "r", encoding="utf8") as f: pipelines_info = json.load(f) else: repo = (repo if repo else default_omnigenome_repo) + "resolve/main/" try: response = requests.get(repo + "pipelines_info.json") pipelines_info = response.json() with open("./pipelines_info.json", "w", encoding="utf8") as f: json.dump(pipelines_info, f) except Exception as e: fprint( "Fail to download pipelines info from huggingface space, the error is: {}".format( e ) ) with open("./pipelines_info.json", "r", encoding="utf8") as f: pipelines_info = json.load(f) if isinstance(keyword, str): filtered_pipelines_info = {} for key in pipelines_info: if keyword in key: filtered_pipelines_info[key] = pipelines_info[key] return filtered_pipelines_info else: return pipelines_info
[docs] def query_benchmarks_info( keyword: Union[list, str], repo: str = None, local_only: bool = False, **kwargs ) -> Dict[str, Any]: """ This function retrieves benchmark information from the OmniGenome hub, either from a remote repository or from a local cache. It supports filtering by keywords to find specific benchmarks. Args: keyword (Union[list, str]): A keyword or list of keywords to filter benchmarks. repo (str, optional): The repository URL to query. If None, uses the default hub. local_only (bool): Whether to use only local cache. Defaults to False. **kwargs: Additional keyword arguments. Returns: Dict[str, Any]: A dictionary containing benchmark information filtered by the keyword. Example: >>> # Query all benchmarks >>> benchmarks = query_benchmarks_info("") >>> print(len(benchmarks)) # Number of available benchmarks >>> # Query specific benchmarks >>> benchmarks = query_benchmarks_info("RGB") >>> print(benchmarks.keys()) # Benchmarks containing "RGB" """ if local_only: with open("./benchmarks_info.json", "r", encoding="utf8") as f: benchmarks_info = json.load(f) else: repo = (repo if repo else default_omnigenome_repo) + "resolve/main/" try: response = requests.get(repo + "benchmarks_info.json") benchmarks_info = response.json() with open("./benchmarks_info.json", "w", encoding="utf8") as f: json.dump(benchmarks_info, f) except Exception as e: fprint( "Fail to download datasets info from huggingface space, the error is: {}".format( e ) ) with open("./benchmarks_info.json", "r", encoding="utf8") as f: benchmarks_info = json.load(f) if isinstance(keyword, str): filtered_benchmarks_info = {} for key in benchmarks_info: if keyword in key: filtered_benchmarks_info[key] = benchmarks_info[key] return filtered_benchmarks_info else: return benchmarks_info
[docs] def download_model( model_name_or_path: str, local_only: bool = False, repo: str = None, cache_dir=None ) -> str: """ Downloads a model from a given URL. It supports both remote and local-only modes. Args: model_name_or_path (str): The name or path of the model to download. local_only (bool): A flag indicating whether to download the model from the local cache. Defaults to False. repo (str, optional): The URL of the repository to download the model from. cache_dir (str, optional): The directory to cache the downloaded model. If None, uses "__OMNIGENOME_DATA__/models/". Returns: str: A string representing the path to the downloaded model. Raises: ConnectionError: If the model download fails. ValueError: If the model is not found in the repository. Example: >>> # Download a model >>> model_path = download_model("DNABERT-2") >>> print(model_path) # Path to the downloaded model >>> # Download with custom cache directory >>> model_path = download_model("DNABERT-2", cache_dir="./models") """ cache_dir = (cache_dir if cache_dir else "__OMNIGENOME_DATA__") + "/models/" if not os.path.exists(cache_dir): os.makedirs(cache_dir) ckpt_config = findfile.find_files(cache_dir, ["config.json"]) if ckpt_config: return os.path.dirname(ckpt_config[0]) if local_only: try: with open("./models_info.json", "r", encoding="utf8") as f: models_info = json.load(f) except FileNotFoundError: fprint( "Local models_info.json not found. Please run the script without local_only=True to download it." ) raise FileNotFoundError("models_info.json not found in local cache.") else: repo = (repo if repo else default_omnigenome_repo) + "resolve/main/" try: response = requests.get(repo + "models_info.json") models_info = response.json() with open("./models_info.json", "w", encoding="utf8") as f: json.dump(models_info, f) except Exception as e: fprint( "Fail to download models info from huggingface space, the error is: {}".format( e ) ) # Fallback to local cache if remote download fails fprint("Using local cache for models_info.json. Ensure it is up-to-date.") try: with open("./models_info.json", "r", encoding="utf8") as f: models_info = json.load(f) except FileNotFoundError: fprint( "Local models_info.json not found. Please run the script without local_only=True to download it." ) raise FileNotFoundError("models_info.json not found in local cache.") if model_name_or_path in models_info: model_info = models_info[model_name_or_path] try: model_url = f'{repo}/models/{model_info["filename"]}' response = requests.get(model_url, stream=True) cache_path = os.path.join(cache_dir, f"{model_info['filename']}") with open(cache_path, "wb") as f: for chunk in tqdm.tqdm( response.iter_content(chunk_size=1024 * 1024), unit="MB", total=int(response.headers["content-length"]) // 1024 // 1024, desc="Downloading model", ): f.write(chunk) except Exception as e: raise ConnectionError("Fail to download model: {}".format(e)) return unzip_checkpoint(cache_path) else: raise ValueError("Model not found in the repository.")
[docs] def download_pipeline( pipeline_name_or_path: str, local_only: bool = False, repo: str = None, cache_dir=None, ) -> str: """ Downloads a pipeline from a given URL. It supports both remote and local-only modes. Args: pipeline_name_or_path (str): The name or path of the pipeline to download. local_only (bool): A flag indicating whether to download the pipeline from the local cache. Defaults to False. repo (str, optional): The URL of the repository to download the pipeline from. cache_dir (str, optional): The directory to cache the downloaded pipeline. If None, uses "__OMNIGENOME_DATA__/pipelines/". Returns: str: A string representing the path to the downloaded pipeline. Raises: ConnectionError: If the pipeline download fails. ValueError: If the pipeline is not found in the repository. Example: >>> # Download a pipeline >>> pipeline_path = download_pipeline("classification_pipeline") >>> print(pipeline_path) # Path to the downloaded pipeline """ cache_dir = (cache_dir if cache_dir else "__OMNIGENOME_DATA__") + "/pipelines/" if not os.path.exists(cache_dir): os.makedirs(cache_dir) ckpt_config = findfile.find_files(cache_dir, ["config.json"]) if ckpt_config: return os.path.dirname(ckpt_config[0]) if local_only: with open("./pipelines_info.json", "r", encoding="utf8") as f: pipelines_info = json.load(f) else: repo = (repo if repo else default_omnigenome_repo) + "resolve/main/" try: response = requests.get(repo + "pipelines_info.json") pipelines_info = response.json() with open("./pipelines_info.json", "w", encoding="utf8") as f: json.dump(pipelines_info, f) except Exception as e: fprint( "Fail to download pipelines info from huggingface space, the error is: {}".format( e ) ) with open("./pipelines_info.json", "r", encoding="utf8") as f: pipelines_info = json.load(f) if pipeline_name_or_path in pipelines_info: pipeline_info = pipelines_info[pipeline_name_or_path] try: pipeline_url = f'{repo}/pipelines/{pipeline_info["filename"]}' response = requests.get(pipeline_url, stream=True) cache_path = os.path.join(cache_dir, f"{pipeline_info['filename']}") with open(cache_path, "wb") as f: for chunk in tqdm.tqdm( response.iter_content(chunk_size=1024 * 1024), unit="MB", total=int(response.headers["content-length"]) // 1024 // 1024, desc="Downloading pipeline", ): f.write(chunk) except Exception as e: raise ConnectionError("Fail to download pipeline: {}".format(e)) return unzip_checkpoint(cache_path) else: raise ValueError("Pipeline not found in the repository.")
[docs] def download_benchmark( benchmark_name_or_path: str, local_only: bool = False, repo: str = None, cache_dir=None, ) -> str: """ Downloads a benchmark from a given URL. It supports both remote and local-only modes. Args: benchmark_name_or_path (str): The name or path of the benchmark to download. local_only (bool): A flag indicating whether to download the benchmark from the local cache. Defaults to False. repo (str, optional): The URL of the repository to download the benchmark from. cache_dir (str, optional): The directory to cache the downloaded benchmark. If None, uses "__OMNIGENOME_DATA__/benchmarks/". Returns: str: A string representing the path to the downloaded benchmark. Raises: ConnectionError: If the benchmark download fails. ValueError: If the benchmark is not found in the repository. Example: >>> # Download a benchmark >>> benchmark_path = download_benchmark("RGB") >>> print(benchmark_path) # Path to the downloaded benchmark >>> # Download with custom cache directory >>> benchmark_path = download_benchmark("RGB", cache_dir="./benchmarks") """ p = findfile.find_cwd_dir(benchmark_name_or_path) if p: fprint("Benchmark:", benchmark_name_or_path, "found in {}.".format(p)) return p else: fprint( "Benchmark:", benchmark_name_or_path, "cannot be found. Search from the online hub to download...", ) cache_dir = (cache_dir if cache_dir else "__OMNIGENOME_DATA__") + "/benchmarks/" if not os.path.exists(cache_dir): os.makedirs(cache_dir) bench_config = findfile.find_file( cache_dir, [benchmark_name_or_path, "metadata.py"] ) if bench_config: return os.path.dirname(bench_config) if local_only: with open("./benchmarks_info.json", "r", encoding="utf8") as f: benchmarks_info = json.load(f) else: repo = (repo if repo else default_omnigenome_repo) + "resolve/main/" try: response = requests.get(repo + "benchmarks_info.json") benchmarks_info = response.json() with open("./benchmarks_info.json", "w", encoding="utf8") as f: json.dump(benchmarks_info, f) except Exception as e: fprint( "Fail to download datasets info from huggingface space, the error is: {}".format( e ) ) with open("./benchmarks_info.json", "r", encoding="utf8") as f: benchmarks_info = json.load(f) if benchmark_name_or_path in benchmarks_info: benchmarks_info_item = benchmarks_info[benchmark_name_or_path] try: benchmark_url = f'{repo}benchmarks/{benchmarks_info_item["filename"]}' response = requests.get(benchmark_url, stream=True) cache_path = os.path.join(cache_dir, f"{benchmarks_info_item['filename']}") if not os.path.exists(cache_path): os.makedirs(os.path.dirname(cache_path), exist_ok=True) with open(cache_path, "wb") as f: for chunk in tqdm.tqdm( response.iter_content(chunk_size=1024 * 1024), unit="MB", total=int(response.headers["content-length"]) // 1024 // 1024, desc="Downloading benchmark", ): f.write(chunk) fprint( f"Benchmark {benchmark_name_or_path} downloaded successfully to: {cache_path}" ) return unzip_checkpoint(cache_path) except ConnectionError as e: raise ConnectionError("Fail to download benchmark: {}".format(e)) else: raise ValueError("Benchmark not found in the repository.")
[docs] def check_version(repo: str = None) -> None: """ Checks the version compatibility between local and remote OmniGenome. Args: repo (str, optional): The repository URL to check. If None, uses the default hub. Example: >>> check_version() # Check version compatibility """ repo = (repo if repo else default_omnigenome_repo) + "resolve/main/" try: response = requests.get(repo + "version.json") version_info = response.json() remote_version = version_info["version"] if Version(current_version) < Version(remote_version): fprint( colored( f"Warning: Your local OmniGenome version ({current_version}) " f"is older than the remote version ({remote_version}). " f"Please consider updating.", "yellow", ) ) elif Version(current_version) > Version(remote_version): fprint( colored( f"Warning: Your local OmniGenome version ({current_version}) " f"is newer than the remote version ({remote_version}). " f"This might cause compatibility issues.", "yellow", ) ) else: fprint( colored( f"OmniGenome version ({current_version}) is up to date.", "green", ) ) except Exception as e: fprint( colored( f"Failed to check version: {e}", "red", ) )