"""
Schnet implementation
"""
import tensorflow as tf
import tensorflow.keras.backend as kb
from megnet.activations import softplus2
from megnet.layers.graph.base import GraphNetworkLayer
[docs]class InteractionLayer(GraphNetworkLayer):
"""
The Continuous filter InteractionLayer in Schnet
Schütt et al. SchNet: A continuous-filter convolutional neural network for modeling quantum interactions
Methods:
call(inputs, mask=None): the logic of the layer, returns the final graph
compute_output_shape(input_shape): compute static output shapes, returns list of tuple shapes
build(input_shape): initialize the weights and biases for each function
phi_e(inputs): update function for bonds and returns updated bond attribute e_p
rho_e_v(e_p, inputs): aggregate updated bonds e_p to per atom attributes, b_e_p
phi_v(b_e_p, inputs): update the atom attributes by the results from previous step b_e_p and all the inputs
returns v_p.
rho_e_u(e_p, inputs): aggregate bonds to global attribute
rho_v_u(v_p, inputs): aggregate atom to global attributes
get_config(): part of keras interface for serialization
"""
def __init__(
self,
activation=softplus2,
use_bias=True,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs,
):
"""
Args:
activation (str): Default: None. The activation function used for each sub-neural network. Examples include
'relu', 'softmax', 'tanh', 'sigmoid' and etc.
use_bias (bool): Default: True. Whether to use the bias term in the neural network.
kernel_initializer (str): Default: 'glorot_uniform'. Initialization function for the layer kernel weights,
bias_initializer (str): Default: 'zeros'
activity_regularizer (str): Default: None. The regularization function for the output
kernel_constraint (str): Default: None. Keras constraint for kernel values
bias_constraint (str): Default: None .Keras constraint for bias values
"""
super().__init__(
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
kernel_constraint=kernel_constraint,
bias_constraint=bias_constraint,
**kwargs,
)
[docs] def build(self, input_shapes):
"""
Build the weights for the layer
Args:
input_shapes (sequence of tuple): the shapes of all input tensors
"""
vdim = input_shapes[0][2]
edim = input_shapes[1][2]
with kb.name_scope(self.name):
with kb.name_scope("phi_e"):
e_shapes = [[edim, vdim]] + [[vdim, vdim]] * 2
self.phi_e_weights = [
self.add_weight(
shape=i,
initializer=self.kernel_initializer,
name=f"weight_v_{j}",
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
)
for j, i in enumerate(e_shapes)
]
if self.use_bias:
self.phi_e_biases = [
self.add_weight(
shape=(i[-1],),
initializer=self.bias_initializer,
name=f"bias_v_{j}",
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
)
for j, i in enumerate(e_shapes)
]
else:
self.phi_e_biases = None
with kb.name_scope(self.name):
with kb.name_scope("phi_v"):
v_shapes = [[vdim, vdim]] + [[vdim, vdim]] * 2
self.phi_v_weights = [
self.add_weight(
shape=i,
initializer=self.kernel_initializer,
name=f"weight_v_{j}",
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
)
for j, i in enumerate(v_shapes)
]
if self.use_bias:
self.phi_v_biases = [
self.add_weight(
shape=(i[-1],),
initializer=self.bias_initializer,
name=f"bias_v_{j}",
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
)
for j, i in enumerate(v_shapes)
]
else:
self.phi_v_biases = None
self.built = True
[docs] def compute_output_shape(self, input_shape):
"""
Compute output shapes from input shapes
Args:
input_shape (sequence of tuple): input shapes
Returns: sequence of tuples output shapes
"""
return input_shape
[docs] def phi_e(self, inputs):
"""
Edge update function
Args:
inputs (tuple of tensor)
Returns:
output tensor
"""
nodes, edges, u, index1, index2, gnode, gbond = inputs
return edges
[docs] def rho_e_v(self, e_p, inputs):
"""
Reduce edge attributes to node attribute, eqn 5 in the paper
Args:
e_p: updated bond
inputs: the whole input list
Returns: summed tensor
"""
nodes, edges, u, index1, index2, gnode, gbond = inputs
atomwise1 = self._mlp(nodes, self.phi_v_weights[0], self.phi_v_biases[0])
cfconv1 = self.activation(self._mlp(edges, self.phi_e_weights[0], self.phi_e_biases[0]))
cfconv2 = self.activation(self._mlp(cfconv1, self.phi_e_weights[1], self.phi_e_biases[1]))
cfconv_out = self._mlp(cfconv2, self.phi_e_weights[2], self.phi_e_biases[2])
index1 = tf.reshape(index1, (-1,))
index2 = tf.reshape(index2, (-1,))
fr = tf.gather(atomwise1, index2, axis=1)
after_cfconv = atomwise1 + tf.transpose(
a=tf.math.segment_sum(tf.transpose(a=fr * cfconv_out, perm=[1, 0, 2]), index1), perm=[1, 0, 2]
)
atomwise2 = self.activation(self._mlp(after_cfconv, self.phi_v_weights[1], self.phi_v_biases[1]))
atomwise3 = self._mlp(atomwise2, self.phi_v_weights[2], self.phi_v_biases[2])
return atomwise3
[docs] def phi_v(self, b_ei_p, inputs):
"""
Node update function
Args:
b_ei_p (tensor): edge aggregated tensor
inputs (tuple of tensors): other graph inputs
Returns: updated node tensor
"""
nodes, edges, u, index1, index2, gnode, gbond = inputs
return nodes + b_ei_p
[docs] def rho_e_u(self, e_p, inputs):
"""
aggregate edge to state
Args:
e_p (tensor): edge tensor
inputs (tuple of tensors): other graph input tensors
Returns: edge aggregated tensor for states
"""
return 0
[docs] def rho_v_u(self, v_p, inputs):
"""
Args:
v_p (tf.Tensor): updated atom/node attributes
inputs (Sequence): list or tuple for the graph inputs
Returns:
atom/node to global/state aggregated tensor
"""
return 0
[docs] def phi_u(self, b_e_p, b_v_p, inputs):
"""
Args:
b_e_p (tf.Tensor): edge/bond to global aggregated tensor
b_v_p (tf.Tensor): node/atom to global aggregated tensor
inputs (Sequence): list or tuple for the graph inputs
Returns:
updated globa/state attributes
"""
return inputs[2]
@staticmethod
def _mlp(input_, weights, bias):
output = kb.dot(input_, weights) + bias
return output
[docs] def get_config(self):
"""
Part of keras layer interface, where the signature is converted into a dict
Returns:
configurational dictionary
"""
base_config = super().get_config()
return base_config