Source code for gec_metrics.analysis.attributor.shapley_sampling

from .base import AttributorBase
import errant
from typing import Union, Optional
import itertools
import random

[docs] class AttributorShapleySampling(AttributorBase): def __init__(self, config): super().__init__(config)
[docs] def generate( self, src: str, edits: list[errant.edit.Edit] ) -> list[dict]: '''Generate edited sentence by applying sampled patterns of edits. Args: src (str): source sentence. edits (list[errant.edit.Edit]): Edit to be applied to the source. Returns: list[Dict]: Each element has two keys: "sentence": An edited sentence. "indices": Indices of edits that were applied to the source sentence. ''' edited = [] indices = list(range(len(edits))) naive_num_samples = 1 need_sampling = False for i in range(1, len(edits) + 1): naive_num_samples *= i if naive_num_samples > self.config.num_samples: need_sampling = True break if not need_sampling: orders = list(itertools.permutations(indices)) else: orders = [] used = dict() for _ in range(self.config.num_samples): while True: random.shuffle(indices) key = ' '.join(list(map(str, indices))) if key not in used: used[key] = 1 orders.append(indices) break assert len(orders) == self.config.num_samples for order in orders: current_edits = [] for i, idx in enumerate(order): current_edits.append(edits[idx]) corrected = self.apply_edits( src, current_edits ) edited.append({ 'sentence': corrected, 'indices': tuple(order[:i+1]) }) return edited
[docs] def post_process( self, scores: list[float], sent_level_score: Optional[float] = None, indices: Optional[list[tuple]] = None ) -> list[float]: '''Calculate Shapley sampling values. Args: scores (list[float]): \delta M() scores. sent_level_score (Optional[float]): Used when normalization. indices (Optional[list[Tuple]]): Which edits were applied to the source. Returns: list[float]: Post pocessed scores. ''' num_edits = max(len(i) for i in indices) num_orders = len(scores) // num_edits assert len(scores) == num_edits * num_orders attributed_scores = [0] * num_edits for sample_id in range(num_orders): batch_scores = scores[sample_id * num_edits: (1+sample_id) * num_edits] batch_indices = indices[sample_id * num_edits: (1+sample_id) * num_edits] for i, edit_id in enumerate(batch_indices[-1]): if i == 0: attributed_scores[edit_id] += batch_scores[i] else: attributed_scores[edit_id] += \ batch_scores[i] - batch_scores[i - 1] attributed_scores = [s / num_orders for s in attributed_scores] return attributed_scores