# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: MIT-0 """ DGL implementation of GVP and GVP-GNN (without the autoregressive functionality) modified from source: https://github.com/drorlab/gvp-pytorch/blob/main/gvp/__init__.py """ import functools import torch from torch import nn import torch.nn.functional as F def tuple_sum(*args): """ Sums any number of tuples (s, V) elementwise. """ return tuple(map(sum, zip(*args))) def tuple_cat(*args, dim=-1): """ Concatenates any number of tuples (s, V) elementwise. :param dim: dimension along which to concatenate when viewed as the `dim` index for the scalar-channel tensors. This means that `dim=-1` will be applied as `dim=-2` for the vector-channel tensors. """ dim %= len(args[0][0].shape) s_args, v_args = list(zip(*args)) return torch.cat(s_args, dim=dim), torch.cat(v_args, dim=dim) def tuple_index(x, idx): """ Indexes into a tuple (s, V) along the first dimension. :param idx: any object which can be used to index into a `torch.Tensor` """ return x[0][idx], x[1][idx] def randn(n, dims, device="cpu"): """ Returns random tuples (s, V) drawn elementwise from a normal distribution. :param n: number of data points :param dims: tuple of dimensions (n_scalar, n_vector) :return: (s, V) with s.shape = (n, n_scalar) and V.shape = (n, n_vector, 3) """ return torch.randn(n, dims[0], device=device), torch.randn( n, dims[1], 3, device=device ) def _norm_no_nan(x, axis=-1, keepdims=False, eps=1e-8, sqrt=True): """ L2 norm of tensor clamped above a minimum value `eps`. :param sqrt: if `False`, returns the square of the L2 norm """ out = torch.clamp(torch.sum(torch.square(x), axis, keepdims), min=eps) return torch.sqrt(out) if sqrt else out class GVP(nn.Module): """ Geometric Vector Perceptron. See manuscript and README.md for more details. :param in_dims: tuple (n_scalar, n_vector) :param out_dims: tuple (n_scalar, n_vector) :param h_dim: intermediate number of vector channels, optional :param activations: tuple of functions (scalar_act, vector_act) :param vector_gate: whether to use vector gating. (vector_act will be used as sigma^+ in vector gating if `True`) """ def __init__( self, in_dims, out_dims, h_dim=None, activations=(F.relu, torch.sigmoid), vector_gate=False, ): super(GVP, self).__init__() self.si, self.vi = in_dims self.so, self.vo = out_dims self.vector_gate = vector_gate if self.vi: self.h_dim = h_dim or max(self.vi, self.vo) self.wh = nn.Linear(self.vi, self.h_dim, bias=False) self.ws = nn.Linear(self.h_dim + self.si, self.so) if self.vo: self.wv = nn.Linear(self.h_dim, self.vo, bias=False) if self.vector_gate: self.wsv = nn.Linear(self.so, self.vo) else: self.ws = nn.Linear(self.si, self.so) self.scalar_act, self.vector_act = activations self.dummy_param = nn.Parameter(torch.empty(0)) def forward(self, x): """ :param x: tuple (s, V) of `torch.Tensor`, or (if vectors_in is 0), a single `torch.Tensor` :return: tuple (s, V) of `torch.Tensor`, or (if vectors_out is 0), a single `torch.Tensor` """ if self.vi: s, v = x v = torch.transpose(v, -1, -2) vh = self.wh(v) vn = _norm_no_nan(vh, axis=-2) s = self.ws(torch.cat([s, vn], -1)) if self.vo: v = self.wv(vh) v = torch.transpose(v, -1, -2) if self.vector_gate: if self.vector_act: gate = self.wsv(self.vector_act(s)) else: gate = self.wsv(s) v = v * torch.sigmoid(gate).unsqueeze(-1) elif self.vector_act: v = v * self.vector_act( _norm_no_nan(v, axis=-1, keepdims=True) ) else: s = self.ws(x) if self.vo: v = torch.zeros( s.shape[0], self.vo, 3, device=self.dummy_param.device ) if self.scalar_act: s = self.scalar_act(s) return (s, v) if self.vo else s class _VDropout(nn.Module): """ Vector channel dropout where the elements of each vector channel are dropped together. """ def __init__(self, drop_rate): super(_VDropout, self).__init__() self.drop_rate = drop_rate self.dummy_param = nn.Parameter(torch.empty(0)) def forward(self, x): """ :param x: `torch.Tensor` corresponding to vector channels """ device = self.dummy_param.device if not self.training: return x mask = torch.bernoulli( (1 - self.drop_rate) * torch.ones(x.shape[:-1], device=device) ).unsqueeze(-1) x = mask * x / (1 - self.drop_rate) return x class Dropout(nn.Module): """ Combined dropout for tuples (s, V). Takes tuples (s, V) as input and as output. """ def __init__(self, drop_rate): super(Dropout, self).__init__() self.sdropout = nn.Dropout(drop_rate) self.vdropout = _VDropout(drop_rate) def forward(self, x): """ :param x: tuple (s, V) of `torch.Tensor`, or single `torch.Tensor` (will be assumed to be scalar channels) """ if type(x) is torch.Tensor: return self.sdropout(x) s, v = x return self.sdropout(s), self.vdropout(v) class LayerNorm(nn.Module): """ Combined LayerNorm for tuples (s, V). Takes tuples (s, V) as input and as output. """ def __init__(self, dims): super(LayerNorm, self).__init__() self.s, self.v = dims self.scalar_norm = nn.LayerNorm(self.s) def forward(self, x): """ :param x: tuple (s, V) of `torch.Tensor`, or single `torch.Tensor` (will be assumed to be scalar channels) """ if not self.v: return self.scalar_norm(x) s, v = x vn = _norm_no_nan(v, axis=-1, keepdims=True, sqrt=False) vn = torch.sqrt(torch.mean(vn, dim=-2, keepdim=True)) return self.scalar_norm(s), v / vn class GVPConv(nn.Module): """ Graph convolution / message passing with Geometric Vector Perceptrons. Takes in a graph with node and edge embeddings, and returns new node embeddings. This does NOT do residual updates and pointwise feedforward layers ---see `GVPConvLayer`. :param in_dims: input node embedding dimensions (n_scalar, n_vector) :param out_dims: output node embedding dimensions (n_scalar, n_vector) :param edge_dims: input edge embedding dimensions (n_scalar, n_vector) :param n_layers: number of GVPs in the message function :param module_list: preconstructed message function, overrides n_layers :param activations: tuple of functions (scalar_act, vector_act) to use in GVPs :param vector_gate: whether to use vector gating. (vector_act will be used as sigma^+ in vector gating if `True`) """ def __init__( self, in_dims, out_dims, edge_dims, n_layers=3, module_list=None, activations=(F.relu, torch.sigmoid), vector_gate=False, ): super(GVPConv, self).__init__() self.si, self.vi = in_dims self.so, self.vo = out_dims self.se, self.ve = edge_dims GVP_ = functools.partial( GVP, activations=activations, vector_gate=vector_gate ) module_list = module_list or [] if not module_list: if n_layers == 1: module_list.append( GVP_( (2 * self.si + self.se, 2 * self.vi + self.ve), (self.so, self.vo), activations=(None, None), ) ) else: module_list.append( GVP_( (2 * self.si + self.se, 2 * self.vi + self.ve), out_dims, ) ) for i in range(n_layers - 2): module_list.append(GVP_(out_dims, out_dims)) module_list.append( GVP_(out_dims, out_dims, activations=(None, None)) ) self.message_func = nn.Sequential(*module_list) def forward(self, g): g.update_all( message_func=self.message_udf, reduce_func=self.reduce_udf ) return g.ndata["node_s_agg"], g.ndata["node_v_agg"] def message(self, s_i, v_i, s_j, v_j, edge_attr): message = tuple_cat((s_j, v_j), edge_attr, (s_i, v_i)) message = self.message_func(message) return message def message_udf(self, edges): """ message function for GVP-GNN :param edges: EdgeBatch :return dict[str, tensor]: s_m: scalar message; s_v: vector message """ s_i, v_i = edges.src["node_s"], edges.src["node_v"] s_j, v_j = edges.dst["node_s"], edges.dst["node_v"] edge_attr = edges.data["edge_s"], edges.data["edge_v"] s_m, v_m = self.message(s_i, v_i, s_j, v_j, edge_attr) return {"s_m": s_m, "v_m": v_m} def reduce_udf(self, nodes): """ reduce function for GVP-GNN :param nodes: NodeBatch """ s_m, v_m = nodes.mailbox["s_m"], nodes.mailbox["v_m"] return { "node_s_agg": torch.mean(s_m, dim=1), "node_v_agg": torch.mean(v_m, dim=1), } class GVPConvLayer(nn.Module): """ Full graph convolution / message passing layer with Geometric Vector Perceptrons. Residually updates node embeddings with aggregated incoming messages, applies a pointwise feedforward network to node embeddings, and returns updated node embeddings. To only compute the aggregated messages, see `GVPConv`. :param node_dims: node embedding dimensions (n_scalar, n_vector) :param edge_dims: input edge embedding dimensions (n_scalar, n_vector) :param n_message: number of GVPs to use in message function :param n_feedforward: number of GVPs to use in feedforward function :param drop_rate: drop probability in all dropout layers :param activations: tuple of functions (scalar_act, vector_act) to use in GVPs :param vector_gate: whether to use vector gating. (vector_act will be used as sigma^+ in vector gating if `True`) """ def __init__( self, node_dims, edge_dims, n_message=3, n_feedforward=2, drop_rate=0.1, activations=(F.relu, torch.sigmoid), vector_gate=False, ): super(GVPConvLayer, self).__init__() self.conv = GVPConv( node_dims, node_dims, edge_dims, n_message, activations=activations, vector_gate=vector_gate, ) GVP_ = functools.partial( GVP, activations=activations, vector_gate=vector_gate ) self.norm = nn.ModuleList([LayerNorm(node_dims) for _ in range(2)]) self.dropout = nn.ModuleList([Dropout(drop_rate) for _ in range(2)]) ff_func = [] if n_feedforward == 1: ff_func.append( GVP_(node_dims, node_dims, activations=(None, None)) ) else: hid_dims = 4 * node_dims[0], 2 * node_dims[1] ff_func.append(GVP_(node_dims, hid_dims)) for i in range(n_feedforward - 2): ff_func.append(GVP_(hid_dims, hid_dims)) ff_func.append(GVP_(hid_dims, node_dims, activations=(None, None))) self.ff_func = nn.Sequential(*ff_func) def forward(self, g): """ :param g: dgl.graph """ dh = self.conv(g) x = g.ndata["node_s"], g.ndata["node_v"] x = self.norm[0](tuple_sum(x, self.dropout[0](dh))) dh = self.ff_func(x) x = self.norm[1](tuple_sum(x, self.dropout[1](dh))) return x