Source code for omnigenbench.src.explainability.shared_methods.tsne_explainer
## Author: Shasha Zhou <sz484@exeter.ac.uk># Description:## Copyright (C) 2020-2025. All Rights Reserved.## -*- coding: utf-8 -*-# file: tsne_explainer.py# time: 2025-06-16 21:33# author: Shasha Zhou <sz484@exeter.ac.uk># Copyright (C) 2020-2025. All Rights Reserved.fromsklearn.manifoldimportTSNEfromtypingimportList,Union,Optionalfrom...abc.abstract_explainerimportAbstractExplainerfrom...misc.utilsimportfprint
[docs]classTSNEExplainer(AbstractExplainer):"""Visualizes high-dimensional sequence embeddings in 2D using t-SNE. This explainer generates high-dimensional embeddings from a set of input sequences using a given model. It then applies the t-SNE (t-Distributed Stochastic Neighbor Embedding) algorithm to project these embeddings into a two-dimensional space. This is useful for visualizing the structure of the learned embedding space and observing how sequences with different labels cluster. Attributes: model: The model used to generate sequence embeddings. tsne (sklearn.manifold.TSNE): The t-SNE transformer instance. """def__init__(self,model,**kwargs):"""Initializes the TSNEExplainer. Args: model: A model object capable of generating embeddings, which should have a `batch_encode` method (e.g., `OmniModelForEmbedding`). **kwargs: Additional keyword arguments to be passed directly to the `sklearn.manifold.TSNE` constructor. This allows for customization of parameters like `perplexity`, `learning_rate`, `n_iter`, etc. """super().__init__(model)self.tsne=TSNE(n_components=2,**kwargs)
[docs]defexplain(self,sequences:List[str],labels:List[Union[int,str]],embedding_file:Optional[str]=None,**kwargs,):"""Generates 2D embeddings for a set of sequences using t-SNE. This method first obtains high-dimensional embeddings for the input sequences, either by generating them with the model or by loading them from a file. It then applies the fitted t-SNE algorithm to project these embeddings into a two-dimensional representation suitable for plotting. """fprint("Starting t-SNE explanation")ifembedding_fileisnotNone:fprint(f"Loading embeddings from {embedding_file}")model_embeddings=self.model.load_embeddings(embedding_file)else:fprint("Encoding sequences")model_embeddings=self.model.batch_encode(sequences,**kwargs)fprint("Fitting t-SNE")tsne_embeddings=self.tsne.fit_transform(model_embeddings)fprint("t-SNE explanation completed")returntsne_embeddings