Source code for gec_metrics.metrics.impara

from transformers import (
    AutoModel,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    PreTrainedTokenizer
)
from .base import MetricBase, MetricBaseForReferenceFree
from .utils import AutoModelForSequenceClassificationMeanPool
import torch
import torch.nn as nn
import math
from dataclasses import dataclass

[docs] class SimilarityEstimator(MetricBaseForReferenceFree):
[docs] @dataclass class Config(MetricBaseForReferenceFree.Config): '''Similarity Estimator configuration. Args: model (str): Model name to compute similarity. batch_size (int): Batch size during inference. max_length (int): Maximum length in tokenization. The input is truncated if longer than it. no_cuda (bool): If True, it will work on CPU. ''' model: str = 'google-bert/bert-base-cased' batch_size: int = 32 max_length: int = 128 no_cuda: bool = False
def __init__(self, config: Config=None): super().__init__(config) self.model = AutoModel.from_pretrained(self.config.model).eval() self.tokenizer = AutoTokenizer.from_pretrained(self.config.model) if not self.config.no_cuda and torch.cuda.is_available(): self.model.cuda()
[docs] def score_sentence( self, sources: list[str], hypotheses: list[str] ) -> list[float]: '''Compute similarity scores. Args: sources (list[str]): Source sentence. The shape is (num_sentences, ) hypotheses (list[str]): Corrected sentences. The shape is (num_sentences, ) Returns: list[float]: The similarity scores. ''' bsz = self.config.batch_size assert len(sources) == len(hypotheses) similarities = None tokenizer_args = { 'max_length': self.config.max_length, 'padding': "max_length", 'truncation': True, 'return_tensors': 'pt' } for i in range(0, len(sources), bsz): src_encode = self.tokenizer(sources[i: i+bsz], **tokenizer_args) hyp_encode = self.tokenizer(hypotheses[i: i+bsz], **tokenizer_args) src_encode = {k: v.to(self.model.device) for k, v in src_encode.items()} hyp_encode = {k: v.to(self.model.device) for k, v in hyp_encode.items()} out = self.forward( src_encode['input_ids'], src_encode['attention_mask'], hyp_encode['input_ids'], hyp_encode['attention_mask'], ) if similarities is None: similarities = out.view(-1) else: similarities = torch.cat([similarities, out.view(-1)], dim=-1) assert len(similarities) == len(sources) return similarities
[docs] @torch.no_grad() def forward( self, src_input_ids: torch.Tensor, src_attention_mask: torch.Tensor, pred_input_ids: torch.Tensor, pred_attention_mask: torch.Tensor, ) -> torch.Tensor: '''Compute the cosine similarity given source and corrected sentences. Args: src_input_ids (torch.Tensor): Tokenized source sentences. The shape is (num_batch, sequence_length) src_attention_mask (torch.Tensor): The attention mask to handle padding. The shape is (num_batch, sequence_length) pred_input_ids (torch.Tensor): Tokenized corrected sentences. The shape is (num_batch, sequence_length) pred_attention_mask (torch.Tensor): The attention mask to handle padding. The shape is (num_batch, sequence_length) Returns: torch.Tensor: The cosine similarity. The shape is (num_batch, ) ''' src_state = self.model( src_input_ids, src_attention_mask ).last_hidden_state pred_state = self.model( pred_input_ids, pred_attention_mask ).last_hidden_state src_pooler = self.mean_pooling(src_state, src_attention_mask) trg_pooler = self.mean_pooling(pred_state, pred_attention_mask) cosine_sim = nn.CosineSimilarity() similarity = cosine_sim(src_pooler, trg_pooler) return similarity
[docs] def mean_pooling( self, states: torch.Tensor, mask: torch.Tensor ) -> torch.Tensor: '''Compute mean pooling. Only the representaion with mask==1 are used. Args: states (torch.Tensor): The token-level representation. The shape is (num_batch, sequence_length, hidden_size) mask: torch.Tensor: The mask indicates padding or not. The shape is (num_batch, sequence_length) Returns: torch.Tensor: The mean pooled representation. The shape is (num_batch, hidden_size) ''' states[mask == 0] = 0 # batch x seq_len x hidden sum_logits = torch.sum(states, dim=1) # batch x hidden length = torch.sum(mask, dim=-1) # batch x pooled_logits = torch.div(sum_logits.transpose(1, 0), length).transpose(1, 0) # batch x hidden return pooled_logits
@property def device(self): return self.model.device
[docs] class IMPARA(MetricBaseForReferenceFree):
[docs] @dataclass class Config(MetricBase.Config): '''IMPARA configuration. Args: model_qe (str): Quality estimation model. model_se (str): Similarity estimation model. pooling (str): Pooling method. 'cls' or 'mean'. max_length (int): Maximum length of inputs. threshold (float): Threshold for the similarity score. no_cuda (bool): If True, work on CPU. batch_size (int): Batch size for the inference. ''' model_qe: str = 'gotutiyan/IMPARA-QE' model_se: str = 'google-bert/bert-base-cased' pooling: str = 'cls' max_length: int = 128 threshold: float = 0.9 no_cuda: bool = False batch_size: int = 32
def __init__(self, config: Config=None): super().__init__(config) assert self.config.pooling in ['cls', 'mean'], "The config.pooling should be in ['cls', 'mean']." self.similarity_estimator = SimilarityEstimator(SimilarityEstimator.Config( model=self.config.model_se, batch_size=self.config.batch_size, max_length=self.config.max_length, no_cuda=self.config.no_cuda )) self.model_qe = AutoModelForSequenceClassification.from_pretrained(self.config.model_qe).eval() self.tokenizer_qe = AutoTokenizer.from_pretrained(self.config.model_qe) if not self.config.no_cuda and torch.cuda.is_available(): self.model_qe.cuda() if self.config.pooling == 'mean': # Wrap the model to use a mean pooling instead of CLS representation. self.model_qe = AutoModelForSequenceClassificationMeanPool( self.model_qe )
[docs] def score_sentence_se( self, sources: list[str], hypotheses: list[str] ) -> list[float]: '''Compute similarity scores. Args: sources (list[str]): Source sentence. The shape is (num_sentences, ) hypotheses (list[str]): Corrected sentences. The shape is (num_sentences, ) Returns: list[float]: The similarity scores. ''' return self.similarity_estimator.score_sentence( sources, hypotheses )
[docs] def score_sentence_qe( self, hypotheses: list[str] ) -> list[float]: '''Compute quality scores. Args: hypotheses (list[str]): Corrected sentences. The shape is (num_sentences, ) Returns: list[float]: The quality scores. ''' batch_size = self.config.batch_size scores = None for i in range(0, len(hypotheses), batch_size): tokenizer_args = { 'max_length': self.config.max_length, 'padding': "max_length", 'truncation': True, 'return_tensors': 'pt' } batch = hypotheses[i: i+batch_size] hyp_encode_qe = self.tokenizer_qe( batch, **tokenizer_args ) hyp_encode_qe = {k: v.to(self.model_qe.device) for k, v in hyp_encode_qe.items()} qe_scores = self.model_qe( hyp_encode_qe['input_ids'], hyp_encode_qe['attention_mask'] ).logits.view(-1) qe_scores = torch.sigmoid(qe_scores) if scores is None: scores = qe_scores else: scores = torch.cat([scores, qe_scores], dim=-1) assert len(scores) == len(hypotheses) return scores
[docs] @torch.no_grad() def score_sentence( self, sources: list[str], hypotheses: list[str] ) -> list[float]: '''Calculate sentence-level scores. Args: sources (list[str]): Source sentence. The shape is (num_sentences, ) hypotheses (list[str]): Corrected sentences. The shape is (num_sentences, ) Returns: list[float]: The sentence-level scores. ''' scores = [] assert len(sources) == len(hypotheses) se_scores = self.score_sentence_se( sources, hypotheses ) qe_scores = self.score_sentence_qe( hypotheses ) qe_scores[se_scores <= self.config.threshold] = 0 scores += qe_scores.tolist() assert len(scores) == len(sources) return scores