Source code for gec_metrics.metrics.pt_errant

from .errant import ERRANT
import errant
from .bertscore import BertScore
from dataclasses import dataclass
from .base import MetricBaseForSourceFree

[docs] class PTERRANT(ERRANT):
[docs] @dataclass class Config(ERRANT.Config): '''Configuration of PTERRANT - weight_model_name (str): Model to compute edit-level weights. Currently only "bertscore" is available. - weight_model_config (MetricBaseForSourceFree.Config): The config instance of the weight model. If not specified, it uses the default one. Also, you can use the same configurations as ERRANT. ''' weight_model_name: str = 'bertscore' weight_model_config: MetricBaseForSourceFree.Config = None
def __init__(self, config: Config = None): super().__init__(config) name2class = { 'bertscore': BertScore } if self.config.weight_model_name not in name2class: raise KeyError(f'The model name {self.config.weight_model_name} is invalid. It should be in {list(name2class.keys())}.') weight_model_cls = name2class[self.config.weight_model_name] weight_model_config = self.config.weight_model_config if weight_model_config is None: # Use default config weight_model_config = weight_model_cls.Config() self.weight_model = weight_model_cls(weight_model_config)
[docs] def calc_edit_weights( self, src: str, ref: str, edits: list[errant.edit.Edit] ) -> list[float]: '''Calculate a weight for each edit. Args: src (str): Source sentence. src (str): Reference sentence. edits (list[errant.edit.Edit]): Edits. Returns: list[float]: The weight of each edit. ''' if edits == []: return [] tuple_edits = [(e.o_start, e.o_end, e.c_str) for e in edits] # Remove duplications edits = [edits[i] for i, e in enumerate(tuple_edits) if e not in tuple_edits[:i]] num_edits = len(edits) sents = [self.apply_edits(src, [e]) for e in edits] scores1 = self.weight_model.score_sentence([src], [[ref]]) * num_edits scores2 = self.weight_model.score_sentence(sents, [[ref] * num_edits]) weights = [abs(s2 - s1) for s1, s2 in zip(scores1, scores2)] return { (e.o_start, e.o_end, e.c_str): w for e, w in zip(edits, weights) }
[docs] def score_base( self, sources: list[str], hypotheses: list[str], references: list[list[str]] ) -> list[list[dict[str, "Score"]]]: '''Calculate scores while retaining sentence and reference boundaries. The results can be aggregated according to the purpose, e.g., at sentence-level or corpus-level. Args: sources (list[str]): Source sentence. hypothesis (list[str]): Corrected sentences. references (list[list[str]]): Reference sentences. The shape is (the number of references, the number of sentences). Returns: list[list[dict[str, "Score"]]]: The verbose scores. - The list shape is (num_sents, num_refs) - The dict contains error type-wise scores. ''' num_sents = len(sources) num_refs = len(references) scores = [] # shape will be: (num_sents, num_refs, ) for sent_id in range(num_sents): hyp_edits = self.edit_extraction( sources[sent_id], hypotheses[sent_id] ) ref_edits_list = [self.edit_extraction( sources[sent_id], references[ref_id][sent_id] ) for ref_id in range(num_refs)] sent_scores = [] # shape will be: (num_refs, ) h_edits = [(e.o_start, e.o_end, e.c_str) for e in hyp_edits] h_types = [e.type for e in hyp_edits] for ref_id, ref_edits in enumerate(ref_edits_list): weights = self.calc_edit_weights( sources[sent_id], references[ref_id][sent_id], hyp_edits + ref_edits ) r_edits = [(e.o_start, e.o_end, e.c_str) for e in ref_edits] r_types = [e.type for e in ref_edits] this_score = dict() for h_edit, h_type in zip(h_edits, h_types): this_score[h_type] = this_score.get( h_type, self.Score(beta=self.config.beta) ) if h_edit in r_edits: this_score[h_type].tp += weights[h_edit] else: this_score[h_type].fp += weights[h_edit] for r_edit, r_type in zip(r_edits, r_types): if r_edit not in h_edits: this_score[r_type] = this_score.get( r_type, self.Score(beta=self.config.beta) ) this_score[r_type].fn += weights[r_edit] sent_scores.append(this_score) scores.append(sent_scores) return scores