from farm.modeling.prediction_head import TextClassificationHead
import torch


class ExtendedTextClassificationHead(TextClassificationHead):
    def logits_to_probs(self, logits, return_class_probs, **kwargs):
        softmax = torch.nn.Softmax(dim=1)
        probs = softmax(logits)
        if return_class_probs:
            probs = probs.cpu().numpy()
        else:
            pred_ids = logits.argmax(1)
            probs = torch.max(probs, dim=1)[0]
            probs = probs.cpu().numpy()
            probs = [val if pred_ids[i] == 1 else 1 - val for i, val in enumerate(probs)]
        return probs