Source code for modelsight.curves.compare

"""
This file deals with the implementation of functions that allow annotating plots 
with statistical tests results between pairs of estimators.
"""

from typing import Callable, Dict, Tuple, List
import matplotlib
from matplotlib import patches
import matplotlib.pyplot as plt

from modelsight.curves._delong import delong_roc_test
from modelsight._typing import CVModellingOutput

[docs]def annot_stat_vertical(text:str, x: float, y1: float, y2: float, ww: float = 0.02, col: str = 'black', fontsize: int = 13, voffset: float = 0, n_elems: int = None, ax=None, **kwargs): """ Draw a vertical whisker at position `x` that spans through `y1` to `y2` with annotation specified by `text`. Parameters ---------- text : str Annotation for whisker. x : float x-position the whisker is positioned at. y1 :float starting y position. y2 : float ending y position. ww : float, optional whisker width, by default 0.02 col : str, optional whisker color, by default 'black' fontsize : int, optional fontsize for the annotation, by default 13 voffset : float, optional vertical offset for the annotation, by default 0. Some font families and characters occupy different vertical spaces; this parameter allows compensating for such variations. n_elems : int, optional number of discrete elements in the y-axis, by default None. This value is precomputed by the caller (add_annotations) and passed to this function as input. ax : plt.Axes, optional a pyplot Axes to draw annotations on, by default None **kwargs rect_h_base: float, optional base height of rectangle patch for single-character annotations, by default 0.1 fontsize_nonsignif, optional fontsize for multi-character annotations (here called non significant annotations to reflect the fact that single-character annotations most often use some kind of symbol to denote statistical significance, e.g. *), by default `fontsize` (i.e., 13) """ ax = plt.gca() if ax is None else ax # we want the text to be centered on the whisker text_x_pos = x + ww text_y_pos = (y1+y2)/2 # draw whisker from y1 to y2 with width `ww` ax.plot([x, x + ww, x + ww, x], [y1, y1, y2, y2], lw=1, c=col) # this is the case of a whisker being annotated with a single character. # by default, symbols do not enforce a white background, hence when # superimposed on whiskers the readibility is limited. # here we enforce a white rectangle patch beneath the symbol to enhance # readibility of annotations. # the built-in bbox parameter of pyplot's .text() doesn't produce # acceptable results, hence we came up with a custom implementation for # single-character annotations. if len(text) == 1: # draw text at (text_x_pos, (text_y_pos - voffset) + 0.17) ax.text( text_x_pos, (text_y_pos - voffset) + 0.17, text, ha='center', va='center', color=col, size=fontsize, zorder=10 ) # Rectangle's props rect_h_base = kwargs.get("rect_h_base", 0.1) rect_w = 0.05 - (0.375 * 0.05) # on a scale from 0 to 1 rect_h = rect_h_base * n_elems # transform to scale from 0 to n_elems-1 rect_x_offset = -0.002 rect_y_offset = 0.01 # move rectangle to the bottom. (0,0) is top left in the inserted barplot # draw white rectangle and put it beneath the text # specifying a zorder inferior to that of the text rect = patches.Rectangle( ( text_x_pos - (rect_w/2) + rect_x_offset, text_y_pos - (rect_h/2) + rect_y_offset ), width = rect_w, height = rect_h, linewidth=1, edgecolor='w', facecolor='w', zorder=9 ) ax.add_patch(rect) else: # this is the case of multi-character annotations. # here, we leverage the built-in bbox of pyplot's text method # that allows drawing a bounding box beneath the annotation. fontsize_nonsignif = kwargs.pop("fontsize_nonsignif", fontsize) ax.text( text_x_pos, text_y_pos, text, ha='center', va='center', color=col, size=fontsize_nonsignif, zorder=10, bbox=dict( boxstyle='square,pad=0', facecolor="white", edgecolor="white" ) )
[docs]def annot_stat_horizontal(text: str, x1: float, x2: float, y: float, wh: float = 0.02, col: str = "black", fontsize: int = 13, voffset: float = 0, n_elems:int = None, ax: plt.Axes = None, **kwargs): """ Draw an horizontal whisker at position `y` that spans through `x1` to `x2` with annotation specified by `text`. Parameters ---------- text : str Annotation for whisker. x1 : float starting x position. x2 :float ending x position. y : float y-position the whisker is positioned at. wh : float, optional whisker height, by default 0.02 col : str, optional whisker color, by default 'black' fontsize : int, optional fontsize for the annotation, by default 13 voffset : float, optional vertical offset for the annotation, by default 0. Some font families and characters occupy different vertical spaces; this parameter allows compensating for such variations. n_elems : int, optional number of discrete elements in the y-axis, by default None. This value is precomputed by the caller (add_annotations) and passed to this function as input. ax : plt.Axes, optional a pyplot Axes to draw annotations on, by default None **kwargs fontsize_nonsignif, optional fontsize for multi-character annotations (here called non significant annotations to reflect the fact that single-character annotations most often use some kind of symbol to denote statistical significance, e.g. *), by default `fontsize` (i.e., 13) """ ax = plt.gca() if ax is None else ax # we want the text to be centered on the whisker text_y_pos = y + wh #+ 0.01 text_x_pos = (x1+x2)/2 # draw whisker from y1 to y2 with width `ww` ax.plot([x1, x1, x2, x2], [y, y + wh, y + wh, y], lw=1, c=col, clip_on=False) # this is the case of a whisker being annotated with a single character. # by default, symbols do not enforce a white background, hence when # superimposed on whiskers the readibility is limited. # here we enforce a white rectangle patch beneath the symbol to enhance # readibility of annotations. # the built-in bbox parameter of pyplot's .text() doesn't produce # acceptable results, hence we came up with a custom implementation for # single-character annotations. if len(text) == 1: # draw text at (text_x_pos, text_y_pos) # + 0.15 ax.text( text_x_pos, text_y_pos + voffset, text, ha='center', va='center', color=col, size=fontsize, zorder=10 ) # Rectangle's props rect_w = 0.09 # transform to scale from 0 to n_elems-1 rect_h = 0.05 - (0.375 * 0.05) # on a scale from 0 to 1 rect_x_offset = 0.005 rect_y_offset = -0.001 # move rectangle to the bottom. (0,0) is top left in the inserted barplot # draw white rectangle and put it beneath the text # specifying a zorder inferior to that of the text rect = patches.Rectangle( ( text_x_pos - (rect_w/2) + rect_x_offset, text_y_pos - (rect_h/2) + rect_y_offset ), width = rect_w, height = rect_h, linewidth=1, edgecolor='w', facecolor='w', zorder=9, clip_on=False ) ax.add_patch(rect) else: fontsize_nonsignif = kwargs.pop("fontsize_nonsignif", fontsize) ax.text( text_x_pos, text_y_pos, text, ha='center', va='center', color=col, size=fontsize_nonsignif, zorder=10, bbox=dict( boxstyle='square,pad=0', facecolor="white", edgecolor="white" ) )
[docs]def add_annotations(comparisons: Dict[str, Tuple[str, str, float]], alpha: float, bars: matplotlib.container.BarContainer, direction: str, order: List[Tuple[str, str]], symbol: str = "*", symbol_fontsize: int = 22, voffset: float = 0, ext_voffset: float = 0, ext_hoffset: float = 0, P_val_rounding: int = 2, ax: plt.Axes = None, **kwargs): """ Annotates the specified plot (`ax`) with the provided comparisons results either vertically or horizontally depending on the value of `direction`. Parameters ---------- comparisons : Dict[str, Tuple[str, str, float]] The results of models comparisons. alpha : float The significance level used for formatting the P value of comparisons. bars : matplotlib.container.BarContainer A list of matplotlib's bars that is used to access the bar's width or height when annotating horizontally and vertically, respectively. direction : str The direction for annotation. Possible values are "horizontal" and "vertical". order : List[Tuple[str, str]] The order in which the comparisons should be displayed. Each entry of this list is a tuple where elements are algorithm's names. symbol : str, optional The symbol used in place of the P value when statistical significance is achieved accoring to the specified alpha, by default "*". symbol_fontsize : int, optional Fontsize for the symbol used when statistical significance is achieved, by default 22 voffset : float, optional vertical offset for the annotation, by default 0., by default 0 ext_voffset : float, optional Additional vertical offset for vertical annotations. Ignored when direction = "horizontal", by default 0 ext_hoffset : float, optional Additional horizontal offset for horizontal annotations. Ignored when direction = "vertical", by default 0 P_val_rounding : int, optional Number of decimal places to round P values at, by default 2 ax : plt.Axes, optional The plot to be annotated, by default None Returns ------- ax : plt.Axes The annotated plot. Raises ------ ValueError When ax is None ValueError Whenever a comparison key doesn't exist. """ if not ax: raise ValueError("I need an Axes to draw comparisons on.") comparisons_list = [] if order: for fst_algo, snd_algo in order: cmp_key = f"{fst_algo}_{snd_algo}" cmp = comparisons.get(cmp_key, None) if not cmp: raise ValueError(f"The comparison {cmp_key} does not exist in the order list.") comparisons_list.append(cmp) else: comparisons_list = list(comparisons.values()) if direction == "horizontal": width = bars[0].get_width() entity_labels = ax.get_xticklabels() entity_idx = {label.get_text(): (i + 0.03) for i, label in enumerate(entity_labels)} whisker_y_offset = kwargs.pop("whisker_y_offset", 0) y_lim_upper = ax.get_ylim()[1] + 0.05 + whisker_y_offset v_offset = 0.07 for i, (fst_model, snd_model, P) in enumerate(comparisons_list): P_str = symbol if P <= alpha else f"{P:.{P_val_rounding}f}" annot_stat_horizontal(text=P_str, x1=entity_idx[fst_model] + width/2, x2=entity_idx[snd_model] + width/2, y=(y_lim_upper - 0.17) + (i * v_offset), # overall distance from top of bars and upper limit of y + inter-distance between whiskers wh=0.02, col="black", fontsize=symbol_fontsize, voffset = voffset, #-0.02 ext_offset = ext_hoffset, n_elems = len(entity_labels), ax=ax, **kwargs) elif direction == "vertical": height = bars[0].get_height() entity_labels = ax.get_yticklabels() entity_idx = {label.get_text(): (i + 0.03) for i, label in enumerate(entity_labels)} space_between_whiskers = kwargs.pop("space_between_whiskers", 0) x_lim_upper = ax.get_xlim()[1] + 0 h_offset = 0.07 + space_between_whiskers for i, (fst_model, snd_model, P) in enumerate(comparisons_list): P_str = symbol if P <= alpha else f"{P:.{P_val_rounding}f}" annot_stat_vertical(text=P_str, x=x_lim_upper + (i * h_offset), y1=entity_idx[fst_model], y2=entity_idx[snd_model], ww=0.02, col="black", fontsize=symbol_fontsize if P_str == "*" else 16, voffset=voffset, ext_offset = ext_voffset, n_elems = len(entity_labels), ax=ax, **kwargs) return ax
[docs]def roc_single_comparison(cv_preds: CVModellingOutput, fst_algo: str, snd_algo: str) -> Dict[str, Tuple[str, str, float]]: """Perform a single comparison of two areas under Receiver Operating Characteristic curves computed on the same set of data points by the DeLong test. Parameters ---------- cv_preds : CVModellingOutput The output of a cross-validation process encompassing mulitple (n>=2) models. fst_algo : str The name of the first algorithm for the comparison. Must be an existing key of `cv_preds`. snd_algo : str The name of the second algorithm for the comparison. Must be an existing key of `cv_preds`. Returns ------- comparison_result : Dict[str, Tuple[str, str, float]] The output of the comparison. This is a dictionary where the key is of the form "<fst_algo>_<snd_algo>" and the value is a tuple of three elements, the first two are the names of the algorithms being compared and the third element is the P value for the null hypothesis that the two AUC values are equal. """ ground_truths = cv_preds[fst_algo].gts_val_conc fst_algo_probas = cv_preds[fst_algo].probas_val_conc snd_algo_probas = cv_preds[snd_algo].probas_val_conc P = delong_roc_test(ground_truths, fst_algo_probas, snd_algo_probas) cmp_key = f"{fst_algo}_{snd_algo}" comparison_result = {cmp_key: (fst_algo, snd_algo, P)} return comparison_result
[docs]def roc_comparisons(cv_preds: CVModellingOutput, target_algo: str): """ Compares the AUC of the specified algorithm with the AUCs of all other algorithms. Parameters ---------- cv_preds : CVModellingOutput The output of a cross-validation process encompassing mulitple (n>=2) models. target_algo : str The name of the target algorithm's whose AUC will be compared with all other AUCs. Returns ------- comparisons : Dict[str, Tuple[str, str, float]] A dictionary containing the results of all comparisons. See output of `roc_single_comparison`. """ comparisons = dict() for algo_name in cv_preds.keys(): if algo_name != target_algo: cmp = roc_single_comparison(cv_preds, target_algo, algo_name) comparisons = dict(cmp, **comparisons) return comparisons