Source code for gec_metrics.analysis.attributor.sub

from .base import AttributorBase
import errant
from typing import Union, Optional
from gecommon import apply_edits

[docs] class AttributorSub(AttributorBase): def __init__(self, config): super().__init__(config)
[docs] def generate( self, src: str, edits: Union[list[errant.edit.Edit], list[list[errant.edit.Edit]]] ) -> list[dict]: '''Generate edited sentence by removing each edit from the reference. Args: src (str): source sentence. edits (list[errant.edit.Edit]): Edit to be applied to the source. Returns: list[Dict]: Each element has two keys: "sentence": An edited sentence. "indices": Indices of edits that were removed from the reference sentence. ''' edited = [] for i, e in enumerate(edits): # Get edits without i-th edit to_be_applied = [e for j, e in enumerate(edits) if j != i] # flatten if type(edits) is list[list[errant.edit.Edit]] if isinstance(edits[0], list): to_be_applied = [ee for e in to_be_applied for ee in e] sent = apply_edits(src, to_be_applied) edited.append({ 'sentence': sent, 'indices': (i, ) }) return edited
[docs] def post_process( self, scores: list[float], sent_level_score: Optional[float] = None, indices: Optional[list[tuple]] = None ) -> list[float]: '''Normalize each score by the sum of the scores. Args: scores (list[float]): \delta M() scores. sent_level_score (Optional[float]): Used when normalization. indices (Optional[list[Tuple]]): Which edits were applied to the source. Returns: list[float]: Post pocessed scores. ''' scores = [sent_level_score - s for s in scores] sum_scores = sum(scores) if sum_scores == 0: weight = 0 else: weight = sent_level_score / sum_scores return [weight * s for s in scores]