gec_metrics.metrics.utils module

class gec_metrics.metrics.utils.AutoModelForSequenceClassificationMeanPool(model: AutoModelForSequenceClassification)[source]

Bases: Module

An extended version of BERTForSequenceClassification to use mean pooling. It is currently intended for use in IMPARA.

property device
forward(input_ids: Tensor | None = None, attention_mask: Tensor | None = None, token_type_ids: Tensor | None = None, position_ids: Tensor | None = None, inputs_embeds: Tensor | None = None, labels: Tensor | None = None, output_attentions: bool | None = None, return_dict: bool | None = None) Tuple[Tensor] | SequenceClassifierOutput[source]
labels (torch.LongTensor of shape (batch_size,), optional):

Labels for computing the sequence classification/regression loss. Indices should be in [0, …, config.num_labels - 1]. If config.num_labels == 1 a regression loss is computed (Mean-Square loss), If config.num_labels > 1 a classification loss is computed (Cross-Entropy).

classmethod from_pretrained(name_or_path, **kwards)[source]
mean_pooling(hidden_state, attention_mask)[source]
save_pretrained(save_path, **kwards)[source]