gec_metrics.metrics.utils module
- class gec_metrics.metrics.utils.AutoModelForSequenceClassificationMeanPool(model: AutoModelForSequenceClassification)[source]
Bases:
ModuleAn extended version of BERTForSequenceClassification to use mean pooling. It is currently intended for use in IMPARA.
- property device
- forward(input_ids: Tensor | None = None, attention_mask: Tensor | None = None, token_type_ids: Tensor | None = None, position_ids: Tensor | None = None, inputs_embeds: Tensor | None = None, labels: Tensor | None = None, output_attentions: bool | None = None, return_dict: bool | None = None) Tuple[Tensor] | SequenceClassifierOutput[source]
- labels (torch.LongTensor of shape (batch_size,), optional):
Labels for computing the sequence classification/regression loss. Indices should be in [0, …, config.num_labels - 1]. If config.num_labels == 1 a regression loss is computed (Mean-Square loss), If config.num_labels > 1 a classification loss is computed (Cross-Entropy).