Source code for gec_metrics.metrics.scribendi

from .base import MetricBase, MetricBaseForReferenceFree
from dataclasses import dataclass
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from fuzzywuzzy.fuzz import token_sort_ratio
import abc

[docs] class Scribendi(MetricBaseForReferenceFree):
[docs] @dataclass class Config(MetricBase.Config): '''Scribendi configuration. - model (str): Model id of a language model. - threshold (float): Threshold for the maximum values of the token sort ratio and the levenshtein distance ratio. - no_cuda (bool): If True, work on CPU. - batch_size (int): Batch size for the inference. ''' model: str = 'gpt2' threshold: float = 0.8 no_cuda: bool = False batch_size: int = 32
def __init__(self, config: Config = None): super().__init__(config) self.model = AutoModelForCausalLM.from_pretrained(self.config.model).eval() self.tokenizer = AutoTokenizer.from_pretrained(self.config.model) self.tokenizer.pad_token = self.tokenizer.eos_token if not self.config.no_cuda: self.model.cuda()
[docs] def score_corpus( self, sources: list[str], hypotheses: list[str] ) -> float: '''Calculate a corpus-level score. Args: sources (list[str]): Source sentence. The shape is (num_sentences, ) hypotheses (list[str]): Corrected sentences. The shape is (num_sentences, ) Returns: float: The corpus-level score. ''' sentence_scores = self.score_sentence( sources, hypotheses ) return sum(sentence_scores)
[docs] def score_sentence( self, sources: list[str], hypotheses: list[str] ) -> list[float]: '''Calculate sentence-level scores. Args: sources (list[str]): Source sentence. The shape is (num_sentences, ) hypotheses (list[str]): Corrected sentences. The shape is (num_sentences, ) Returns: list[float]: The sentence-level scores. ''' errorful_sources = [] errorful_hypotheses = [] num_sents = len(sources) scores = [-999] * num_sents original_indices = [] for sent_id, (s, h) in enumerate(zip(sources, hypotheses)): if s == h: scores[sent_id] = 0 else: errorful_sources.append(s) errorful_hypotheses.append(h) original_indices.append(sent_id) ppl_sources = self.ppl(errorful_sources) ppl_hypothesis = self.ppl(errorful_hypotheses) for i, (ppl_s, ppl_h) in enumerate(zip(ppl_sources, ppl_hypothesis)): if ppl_s <= ppl_h: scores[original_indices[i]] = -1 continue tsr = self.token_sort_ratio( errorful_sources[i], errorful_hypotheses[i] ) ldr = self.levenshtein_distance_ratio( errorful_sources[i], errorful_hypotheses[i] ) if max(tsr, ldr) >= self.config.threshold: scores[original_indices[i]] = 1 else: scores[original_indices[i]] = -1 assert -999 not in scores # All elements should filled in either -1, 0, 1. return scores
[docs] def ppl( self, sents: list[str] ) -> list[float]: '''Compute perplexity using a LM. Args: sents (list[str]): The sentences to be computed the perplexity. Returns: list[float]: The list of perplexity. ''' ppls = [] sents = [self.tokenizer.bos_token + sent for sent in sents] batch_size = self.config.batch_size for i in range(len(sents)//batch_size+1): batch = sents[i*batch_size:(i+1)*batch_size] if len(batch) == 0: continue inputs = self.tokenizer(batch, return_tensors='pt', padding=True) if not self.config.no_cuda: inputs = {k: v.cuda() for k, v in inputs.items()} with torch.no_grad(): outputs = self.model( inputs['input_ids'], attention_mask=inputs['attention_mask'], labels=inputs['input_ids'] ) shift_logits = outputs.logits[:, :-1, :].contiguous() shift_labels = inputs['input_ids'][:, 1:].contiguous() shift_mask = inputs['attention_mask'][:, 1:].contiguous() batch_size, seq_len = shift_labels.shape loss_fn = torch.nn.CrossEntropyLoss(reduction='none') loss = loss_fn( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) ).view(batch_size, seq_len) # The probability is normalized by the length. loss = (loss * shift_mask).sum(dim=1) / shift_mask.sum(dim=1) ppls += torch.exp(loss).tolist() return ppls
[docs] def token_sort_ratio(self, src: str, pred: str) -> float: ''' Args: src (str): The source sentence. pred (str): The corrected sentence. Returns: float: The token sort ratio. ''' return token_sort_ratio(src, pred) / 100
[docs] def levenshtein_distance_ratio(self, src: str, pred: str) -> float: '''The word-level levenshtein distance ratio. Args: src (str): The source sentence. pred (str): The corrected sentence. Returns: float: The levelshtein distance ratio. ''' len_src = len(src) len_pred = len(pred) dp = [[0] * (len_pred + 1) for _ in range(len_src + 1)] # dp = np.zeros((len_src+1, len_pred+1)) for i in range(1, len_src + 1): dp[i][0] = i for j in range(1, len_pred + 1): dp[0][j] = j for i in range(1, len_src + 1): for j in range(1, len_pred + 1): cost = 0 if src[i-1] != pred[j-1]: # Replacement cost is 2 cost = 2 dp[i][j] = min( dp[i-1][j-1] + cost, min(dp[i-1][j] + 1, dp[i][j-1] + 1) ) return 1 - dp[len_src][len_pred] / (len_src + len_pred)