Source code for omnigenbench.src.metric.classification_metric
# -*- coding: utf-8 -*-# file: classification_metric.py# time: 12:57 09/04/2024# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)# github: https://github.com/yangheng95# huggingface: https://huggingface.co/yangheng# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en# Copyright (C) 2019-2024. All Rights Reserved.importtypesimportwarningsimportnumpyasnpimportsklearn.metricsasmetricsfrom..abc.abstract_metricimportOmniMetric
[docs]classClassificationMetric(OmniMetric):""" This class provides a comprehensive interface for classification metrics in the OmniGenome framework. It integrates with scikit-learn's classification metrics and provides additional functionality for handling genomic classification tasks. The class automatically exposes all scikit-learn classification metrics as callable attributes, making them easily accessible for evaluation. It also handles special cases like Hugging Face's EvalPrediction objects and provides proper handling of ignored labels. Attributes: metric_func (callable): A callable metric function from sklearn.metrics. ignore_y (any): A value in the ground truth labels to be ignored during metric computation. Defaults to -100. kwargs (dict): Additional keyword arguments for metric computation. """def__init__(self,metric_func=None,ignore_y=-100,*args,**kwargs):""" Initializes the classification metric. Args: metric_func (callable, optional): A callable metric function from sklearn.metrics. If None, subclasses should implement their own compute method. ignore_y (any, optional): A value in the ground truth labels to be ignored during metric computation. Defaults to -100. *args: Additional positional arguments. **kwargs: Additional keyword arguments. Example: >>> # Initialize with a specific metric function >>> metric = ClassificationMetric(metrics.accuracy_score) >>> # Initialize with ignore value >>> metric = ClassificationMetric(ignore_y=-100) """super().__init__(metric_func,ignore_y,*args,**kwargs)self.kwargs=kwargs# def __getattr__(self, name):def__getattribute__(self,name):""" Custom attribute getter that provides dynamic access to scikit-learn metrics. This method provides transparent access to all scikit-learn classification metrics. When a metric function is accessed, it returns a callable wrapper that handles the metric computation with proper preprocessing. Args: name (str): The attribute name to get. Returns: callable: A wrapper function for the requested metric, or the original attribute if it's not a metric function. Example: >>> metric = ClassificationMetric() >>> # Access any scikit-learn metric >>> accuracy_fn = metric.accuracy_score >>> result = accuracy_fn(y_true, y_pred) """# Get the metric functionmetric_func=getattr(metrics,name,None)ifmetric_funcandisinstance(metric_func,types.FunctionType):setattr(self,"compute",metric_func)# If the metric function exists, return a wrapper functiondefwrapper(y_true=None,y_pred=None,*args,**kwargs):""" Compute the metric, based on the true and predicted values. This wrapper function handles various input formats including Hugging Face's EvalPrediction objects and provides proper preprocessing for metric computation. Args: y_true: The true values (ground truth labels). y_pred: The predicted values (model predictions). ignore_y: The value to ignore in the predictions and true values in corresponding positions. *args: Additional positional arguments for the metric function. **kwargs: Additional keyword arguments for the metric function. Returns: dict: A dictionary with the metric name as key and its value. Example: >>> # Standard usage >>> result = accuracy_fn(y_true, y_pred) >>> print(result) # {'accuracy_score': 0.85} >>> # With Hugging Face EvalPrediction >>> result = accuracy_fn(eval_prediction) >>> print(result) # {'accuracy_score': 0.85} """# This is an ugly method to handle the case when the predictions are in the form of a tuple# for huggingface trainersify_true.__class__.__name__=="EvalPrediction":eval_prediction=y_trueifhasattr(eval_prediction,"label_ids"):y_true=eval_prediction.label_idsifhasattr(eval_prediction,"labels"):y_true=eval_prediction.labelspredictions=eval_prediction.predictionsforiinrange(len(predictions)):ifpredictions[i].shape==y_true.shapeandnotnp.all(predictions[i]==y_true):y_score=predictions[i]breaky_true,y_pred=ClassificationMetric.flatten(y_true,y_pred)y_true_mask_idx=np.where(y_true!=self.ignore_y)ifself.ignore_yisnotNone:y_true=y_true[y_true_mask_idx]try:y_pred=y_pred[y_true_mask_idx]exceptExceptionase:warnings.warn(str(e))kwargs.update(self.kwargs)return{name:self.compute(y_true,y_pred,*args,**kwargs)}returnwrapperelse:returnsuper().__getattribute__(name)
[docs]defcompute(self,y_true,y_pred,*args,**kwargs):""" Compute the metric, based on the true and predicted values. This method computes the classification metric using the provided metric function. It handles preprocessing and applies any additional keyword arguments. Args: y_true: The true values (ground truth labels). y_pred: The predicted values (model predictions). *args: Additional positional arguments for the metric function. **kwargs: Additional keyword arguments for the metric function. Returns: dict: A dictionary with the metric name as key and its value. Raises: NotImplementedError: If no metric function is provided and the method is not implemented by the subclass. Example: >>> metric = ClassificationMetric(metrics.accuracy_score) >>> result = metric.compute(y_true, y_pred) >>> print(result) # {'accuracy_score': 0.85} """ifself.metric_funcisnotNone:kwargs.update(self.kwargs)returnself.metric_func(y_true,y_pred,*args,**kwargs)else:raiseNotImplementedError("Method compute() is not implemented in the child class.")