import torch.nn as nn import torch class FeedForward(nn.Module): def __init__(self, dim_input, dim_hidden, dim_output, num_layers, activation='relu', dropout_rate=0, layer_norm=False, residual_connection=False): """This model wraps a (residual) neural network in a easy-access way. """ super().__init__() assert num_layers >= 0 # 0 = Linear if num_layers > 0: assert dim_hidden > 0 if residual_connection: assert dim_hidden == dim_input self.residual_connection = residual_connection self.stack = nn.ModuleList() for layer_idx in range(num_layers): layer = [] if layer_norm: layer.append(nn.LayerNorm(dim_input if layer_idx == 0 else dim_hidden)) layer.append(nn.Linear(dim_input if layer_idx == 0 else dim_hidden, dim_hidden)) layer.append({'tanh': nn.Tanh(), 'relu': nn.ReLU()}[activation]) if dropout_rate > 0: layer.append(nn.Dropout(dropout_rate)) self.stack.append(nn.Sequential(*layer)) self.out = nn.Linear(dim_input if num_layers < 1 else dim_hidden, dim_output) def forward(self, x): for layer in self.stack: x = x + layer(x) if self.residual_connection else layer(x) return self.out(x) class ConditionalDistributionZ(nn.Module): def __init__(self, number_clusters, dim_input, num_layers, dim_hidden): """This model maps the input (BERT representation of a sentence) to the probability vector for query routing or document assignment to the clusters. Parameters ---------- number_clusters : int The number of clusters to which we are going to assign the documents or route the queries. dim_input : int The dimension of the input (the representation from BERT base model is 768.) num_layers : int How many layers are used in this model. dim_hidden : int How many neurons are in the hidden layer of this model. """ super(ConditionalDistributionZ, self).__init__() self.number_clusters = number_clusters self.softmax = nn.Softmax(dim=-1) number_clusters = number_clusters self.ff = FeedForward(dim_input, dim_hidden, number_clusters, num_layers) def forward(self, inputs): logits = self.ff(inputs).view(inputs.size(0), 1, self.number_clusters) probability_vector = self.softmax(logits) return probability_vector class MarginalDistributionZ(nn.Module): def __init__(self, number_clusters): """This is the function approximating the distribution of cluster sizes for document assignment. We use `softmax` to make sure the output is a distribution. Parameters ---------- number_clusters : int The number of clusters to which we are going to assign the documents. """ super(MarginalDistributionZ, self).__init__() self.softmax = nn.Softmax(dim=-1) self.theta = nn.Embedding(1, number_clusters) def forward(self): """Return the approximated distribution. """ logits = self.theta.weight probability_vector = self.softmax(logits) return probability_vector def get_init_function(init_value): """This is for the initialization of the Prior and Posterior model. We can change the scale of the initialization by the argument `init_value`. """ def init_function(m): if init_value > 0.: if hasattr(m, 'weight'): m.weight.data.uniform_(-init_value, init_value) if hasattr(m, 'bias'): m.bias.data.fill_(0.) return init_function def cross_entropy_p_q(p, q): """This is the function calculating the cross entropy between two distributions. If there are multiple distributions in `p` or in `q`, we calculate the cross entropy correspondingly and take the average. """ if len(q.size()) == 2: q = q.repeat(p.size(0), 1, 1) return (- (p * torch.log(q)).sum(dim=(1, 2))).mean()