Source code for omnigenbench.src.model.augmentation.model
# -*- coding: utf-8 -*-# file: model.py# time: 18:37 22/09/2024# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)# github: https://github.com/yangheng95# huggingface: https://huggingface.co/yangheng# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en# Copyright (C) 2019-2024. All Rights Reserved."""Data augmentation model for genomic sequences.This module provides a data augmentation model that uses masked language modelingto generate augmented versions of genomic sequences. It's useful for expandingtraining datasets and improving model robustness."""importtorchimportrandomimportjsonimporttqdmfromtransformersimportAutoModelForMaskedLM,AutoTokenizerimportautocuda
[docs]classOmniModelForAugmentation(torch.nn.Module):""" Data augmentation model for genomic sequences using masked language modeling. This model uses a pre-trained masked language model to generate augmented versions of genomic sequences by randomly masking tokens and predicting replacements. It's useful for expanding training datasets and improving model generalization. Attributes: tokenizer: Tokenizer for processing genomic sequences model: Pre-trained masked language model device: Device to run the model on (CPU or GPU) noise_ratio: Proportion of tokens to mask for augmentation max_length: Maximum sequence length for tokenization k: Number of augmented instances to generate per sequence """def__init__(self,model_name_or_path=None,noise_ratio=0.15,max_length=1026,instance_num=1,*args,**kwargs):""" Initialize the augmentation model. Args: model_name_or_path (str): Path or model name for loading the pre-trained model noise_ratio (float): The proportion of tokens to mask in each sequence for augmentation (default: 0.15) max_length (int): The maximum sequence length for tokenization (default: 1026) instance_num (int): Number of augmented instances to generate per sequence (default: 1) *args: Additional positional arguments **kwargs: Additional keyword arguments """super().__init__()try:self.tokenizer=AutoTokenizer.from_pretrained(model_name_or_path)exceptExceptionase:if"RnaTokenizer"instr(e):frommultimoleculeimportRnaTokenizerself.tokenizer=RnaTokenizer.from_pretrained(model_name_or_path)self.model=AutoModelForMaskedLM.from_pretrained(model_name_or_path,trust_remote_code=True)self.device=autocuda.auto_cuda()self.model.to(self.device)# Hyperparameters for augmentationself.noise_ratio=noise_ratioself.max_length=max_lengthself.k=instance_num
[docs]defload_sequences_from_file(self,input_file):""" Load sequences from a JSON file. Args: input_file (str): Path to the input JSON file containing sequences Returns: list: List of sequences loaded from the file """sequences=[]withopen(input_file,"r")asf:forlineinf.readlines():sequences.append(json.loads(line)["seq"])returnsequences
[docs]defapply_noise_to_sequence(self,seq):""" Apply noise to a single sequence by randomly masking tokens. Args: seq (str): Input genomic sequence Returns: str: Sequence with randomly masked tokens """seq_list=self.tokenizer.tokenize(seq)for_inrange(int(len(seq_list)*self.noise_ratio)):random_idx=random.randint(0,len(seq_list)-1)seq_list[random_idx]=self.tokenizer.mask_tokenreturn"".join(seq_list)
[docs]defaugment_sequence(self,seq):""" Perform augmentation on a single sequence by predicting masked tokens. Args: seq (str): Input genomic sequence with masked tokens Returns: str: Augmented sequence with predicted tokens replacing masked tokens """tokenized_inputs=self.tokenizer(seq,padding="do_not_pad",truncation=True,max_length=self.max_length,return_tensors="pt",)withtorch.no_grad():predictions=self.model(**tokenized_inputs.to(self.device))["logits"]predicted_tokens=predictions.argmax(dim=-1).cpu()# Replace masked tokens with predicted tokensinput_ids=tokenized_inputs["input_ids"][0].cpu()input_ids[input_ids==self.tokenizer.mask_token_id]=predicted_tokens[0][input_ids==self.tokenizer.mask_token_id]augmented_sequence=self.tokenizer.decode(input_ids,skip_special_tokens=True)returnaugmented_sequence
[docs]defaugment(self,seq,k=None):""" Generate multiple augmented instances for a single sequence. Args: seq (str): Input genomic sequence k (int, optional): Number of augmented instances to generate (default: None, uses self.k) Returns: list: List of augmented sequences """augmented_sequences=[]for_inrange(self.kifkisNoneelsek):noised_seq=self.apply_noise_to_sequence(seq)augmented_seq=self.augment_sequence(noised_seq)augmented_sequences.append(augmented_seq)returnaugmented_sequences
[docs]defaugment_sequences(self,sequences):""" Augment a list of sequences by applying noise and performing MLM-based predictions. Args: sequences (list): List of genomic sequences to augment Returns: list: List of all augmented sequences """all_augmented_sequences=[]forseqintqdm.tqdm(sequences,desc="Augmenting Sequences"):augmented_instances=self.augment(seq)all_augmented_sequences.extend(augmented_instances)returnall_augmented_sequences
[docs]defsave_augmented_sequences(self,augmented_sequences,output_file):""" Save augmented sequences to a JSON file. Args: augmented_sequences (list): List of augmented sequences to save output_file (str): Path to the output JSON file """withopen(output_file,"w")asf:forseqinaugmented_sequences:f.write(json.dumps({"aug_seq":seq})+"\n")
[docs]defaugment_from_file(self,input_file,output_file):""" Main function to handle the augmentation process from a file input to a file output. This method loads sequences from an input file, augments them using the MLM model, and saves the augmented sequences to an output file. Args: input_file (str): Path to the input file containing sequences output_file (str): Path to the output file where augmented sequences will be saved """sequences=self.load_sequences_from_file(input_file)augmented_sequences=self.augment_sequences(sequences)self.save_augmented_sequences(augmented_sequences,output_file)
# Example usageif__name__=="__main__":model=OmniModelForAugmentation(model_name_or_path="anonymous8/OmniGenome-186M",noise_ratio=0.2,# Example noise ratiomax_length=1026,# Maximum token lengthinstance_num=3,# Number of augmented instances per sequence)aug=model.augment_sequence("ATCTTGCATTGAAG")input_file="toy_datasets/test.json"output_file="toy_datasets/augmented_sequences.json"model.augment_from_file(input_file,output_file)