# coding: utf-8 # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """BERT models.""" __all__ = ['BERTClassifier', 'BERTRegression'] from mxnet.gluon import HybridBlock from mxnet.gluon import nn class BERTRegression(HybridBlock): """Model for sentence (pair) regression task with BERT. The model feeds token ids and token type ids into BERT to get the pooled BERT sequence representation, then apply a Dense layer for regression. Parameters ---------- bert: BERTModel Bidirectional encoder with transformer. dropout : float or None, default 0.0. Dropout probability for the bert output. prefix : str or None See document of `mx.gluon.Block`. params : ParameterDict or None See document of `mx.gluon.Block`. """ def __init__(self, bert, dropout=0.0, prefix=None, params=None): super(BERTRegression, self).__init__(prefix=prefix, params=params) self.bert = bert with self.name_scope(): self.regression = nn.HybridSequential(prefix=prefix) if dropout: self.regression.add(nn.Dropout(rate=dropout)) self.regression.add(nn.Dense(1)) def hybrid_forward(self, F, inputs, token_types, valid_length=None): # pylint: disable=arguments-differ """Generate the unnormalized score for the given the input sequences. Parameters ---------- inputs : NDArray, shape (batch_size, seq_length) Input words for the sequences. token_types : NDArray, shape (batch_size, seq_length) Token types for the sequences, used to indicate whether the word belongs to the first sentence or the second one. valid_length : NDArray or None, shape (batch_size) Valid length of the sequence. This is used to mask the padded tokens. Returns ------- outputs : NDArray Shape (batch_size, num_classes) """ _, pooler_out = self.bert(inputs, token_types, valid_length) return self.regression(pooler_out) class BERTClassifier(HybridBlock): """Model for sentence (pair) classification task with BERT. The model feeds token ids and token type ids into BERT to get the pooled BERT sequence representation, then apply a Dense layer for classification. Parameters ---------- bert: BERTModel Bidirectional encoder with transformer. num_classes : int, default is 2 The number of target classes. dropout : float or None, default 0.0. Dropout probability for the bert output. prefix : str or None See document of `mx.gluon.Block`. params : ParameterDict or None See document of `mx.gluon.Block`. """ def __init__(self, bert, num_classes=2, dropout=0.0, prefix=None, params=None): super(BERTClassifier, self).__init__(prefix=prefix, params=params) self.bert = bert with self.name_scope(): self.classifier = nn.HybridSequential(prefix=prefix) if dropout: self.classifier.add(nn.Dropout(rate=dropout)) self.classifier.add(nn.Dense(units=num_classes)) def hybrid_forward(self, F, inputs, token_types, valid_length=None): # pylint: disable=arguments-differ """Generate the unnormalized score for the given the input sequences. Parameters ---------- inputs : NDArray, shape (batch_size, seq_length) Input words for the sequences. token_types : NDArray, shape (batch_size, seq_length) Token types for the sequences, used to indicate whether the word belongs to the first sentence or the second one. valid_length : NDArray or None, shape (batch_size) Valid length of the sequence. This is used to mask the padded tokens. Returns ------- outputs : NDArray Shape (batch_size, num_classes) """ _, pooler_out = self.bert(inputs, token_types, valid_length) return self.classifier(pooler_out)