Source code for omnigenbench.src.metric.ranking_metric
# -*- coding: utf-8 -*-# file: ranking_metric.py# time: 13:27 09/04/2024# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)# github: https://github.com/yangheng95# huggingface: https://huggingface.co/yangheng# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en# Copyright (C) 2019-2024. All Rights Reserved.importtypesimportwarningsimportnumpyasnpimportsklearn.metricsasmetricsfrom..abc.abstract_metricimportOmniMetric
[docs]classRankingMetric(OmniMetric):""" A specialized metric class for ranking tasks and evaluation. This class provides access to ranking-specific metrics from scikit-learn and handles different input formats including HuggingFace trainer outputs. It dynamically wraps scikit-learn metrics and provides a unified interface for computing various ranking evaluation metrics. Attributes: metric_func: Custom metric function if provided ignore_y: Value to ignore in predictions and true values Example: >>> from omnigenbench import RankingMetric >>> metric = RankingMetric(ignore_y=-100) >>> y_true = [0, 1, 2, 0, 1] >>> y_pred = [0.1, 0.9, 0.8, 0.2, 0.7] >>> result = metric.roc_auc_score(y_true, y_pred) >>> print(result) {'roc_auc_score': 0.8} """def__init__(self,*args,**kwargs):""" Initialize the RankingMetric class. Args: *args: Additional positional arguments passed to parent class **kwargs: Additional keyword arguments passed to parent class """super().__init__(*args,**kwargs)def__getattr__(self,name):""" Dynamically create ranking metric computation methods. This method intercepts attribute access and creates wrapper functions for scikit-learn ranking metrics, handling different input formats and preprocessing the data appropriately. Args: name (str): Name of the ranking metric to access Returns: callable: Wrapper function for the requested ranking metric Raises: AttributeError: If the requested metric is not found """# Get the metric functionmetric_func=getattr(metrics,name,None)ifmetric_funcandisinstance(metric_func,types.FunctionType):# If the metric function exists, return a wrapper functiondefwrapper(y_true=None,y_score=None,*args,**kwargs):""" Compute the ranking metric, based on the true and predicted values. This wrapper handles different input formats including HuggingFace trainer outputs and performs necessary preprocessing for ranking tasks. Args: y_true: The true values or HuggingFace EvalPrediction object y_score: The predicted values (scores for ranking) ignore_y: The value to ignore in the predictions and true values in corresponding positions *args: Additional positional arguments for the metric **kwargs: Additional keyword arguments for the metric Returns: dict: Dictionary containing the metric name and computed value """# for huggingface trainersify_true.__class__.__name__=="EvalPrediction":eval_prediction=y_trueifhasattr(eval_prediction,"label_ids"):y_true=eval_prediction.label_idsifhasattr(eval_prediction,"labels"):y_true=eval_prediction.labelspredictions=eval_prediction.predictionsforiinrange(len(predictions)):ifpredictions[i].shape==y_true.shapeandnotnp.all(predictions[i]==y_true):y_score=predictions[i]breaky_true,y_score=RankingMetric.flatten(y_true,y_score)y_true_mask_idx=np.where(y_true!=self.ignore_y)ifself.ignore_yisnotNone:y_true=y_true[y_true_mask_idx]try:y_score=y_score[y_true_mask_idx]exceptExceptionase:warnings.warn(str(e))return{name:self.compute(y_true,y_score,*args,**kwargs)}returnwrapperraiseAttributeError(f"'CustomMetrics' object has no attribute '{name}'")
[docs]defcompute(self,y_true,y_score,*args,**kwargs):""" Compute the ranking metric, based on the true and predicted values. This method should be implemented by subclasses to provide specific ranking metric computation logic. Args: y_true: The true values y_score: The predicted values (scores for ranking) *args: Additional positional arguments for the metric **kwargs: Additional keyword arguments for the metric Returns: The computed ranking metric value Raises: NotImplementedError: If compute method is not implemented in the child class """raiseNotImplementedError("Method compute() is not implemented in the child class.")