Source code for gec_metrics.analysis.attributor.shapley

from .base import AttributorBase
import errant
from typing import Union, Optional
import math
from gecommon import apply_edits
    
[docs] class AttributorShapley(AttributorBase): def __init__(self, config): super().__init__(config)
[docs] def generate( self, src: str, edits: Union[list[errant.edit.Edit], list[list[errant.edit.Edit]]] ) -> list[dict]: '''Generate edited sentence by applying all patterns of edits. Args: src (str): source sentence. edits (list[errant.edit.Edit]): Edit to be applied to the source. Returns: list[Dict]: Each element has two keys: "sentence": An edited sentence. "indices": Indices of edits that were applied to the source sentence. ''' edited = [] num_edits = len(edits) for i in range(2 ** (num_edits)): # Get edit ids by binary number. # E.g. 5 is 101, which means first and third edits are used. indices = tuple(j for j in range(num_edits) if (i >> j) & 1) to_be_applied = [edits[j] for j in indices] # flatten if type(edits) is list[list[errant.edit.Edit]] if isinstance(edits[0], list): to_be_applied = [ee for e in to_be_applied for ee in e] sent = apply_edits(src, to_be_applied) edited.append({ 'sentence': sent, 'indices': indices }) return edited
[docs] def post_process( self, scores: list[float], sent_level_score: Optional[float] = None, indices: Optional[list[tuple]] = None ) -> list[float]: '''Caluclate Shapley values. Args: scores (list[float]): \delta M() scores. sent_level_score (Optional[float]): Used when normalization. indices (Optional[list[Tuple]]): Which edits were applied to the source. Returns: list[float]: Post pocessed scores. ''' def shapley_weight(n, s): return (math.perm(s) * math.perm(n-s-1)) / math.perm(n) assert len(scores) == len(indices) # In Shapley-based attribution, 2^(num_edits) sentences and their scores are used. # So we can know the number of edits by using log2() num_edits = int(math.log2(len(scores))) attributed_scores = [0] * num_edits # Create hash to access score by indices idx2score = { idx: score for idx, score in zip(indices, scores) } for i in range(num_edits): # We will calculate i-th edit's attributed score. # \boldsymbol{e} \setminus e_i in Eq.2 in the paper. indices_wo_i = tuple(j for j in range(num_edits) if i != j) # Loop for each \boldsymbol{e}' in Eq.2. for j in range(2 ** len(indices_wo_i)): # if k-th bit of j is 1, k-th edit is used. # E.g. j == 5 (0b101) means that the first and third edits are used. subset = list(indices_wo_i[k] for k in range(j) if (j >> k) & 1) weight = shapley_weight(num_edits, len(subset)) # \boldsymbol{e}' \cap e_i in Eq.2. subset_cap_i = sorted(subset + [i]) attributed_scores[i] += weight * ( idx2score[tuple(subset_cap_i)] - idx2score[tuple(subset)] ) return attributed_scores