Source code for megnet.layers.readout.linear
"""
Linear readout layer includes stats calculated on the atom dimension
"""
from tensorflow.keras.layers import Layer
import tensorflow as tf
MAPPING = {
"mean": tf.math.segment_mean,
"sum": tf.math.segment_sum,
"max": tf.math.segment_max,
"min": tf.math.segment_min,
"prod": tf.math.segment_prod,
}
[docs]class LinearWithIndex(Layer):
"""
Sum or average the node/edge attributes to get a structure-level vector
"""
def __init__(self, mode="mean", **kwargs):
"""
Args:
mode: (str) 'mean', 'sum', 'max', 'mean' or 'prod'
**kwargs:
"""
super().__init__(**kwargs)
self.mode = mode
self.reduce_method = MAPPING.get(mode, None)
if self.reduce_method is None:
raise ValueError("mode not supported")
[docs] def build(self, input_shape):
"""
Build tensors
Args:
input_shape (sequence of tuple): input shapes
"""
self.built = True
[docs] def call(self, inputs, mask=None):
"""
Main logic
Args:
inputs (tuple of tensor): input tensors
mask (tensor): mask tensor
Returns: output tensor
"""
prop, index = inputs
index = tf.reshape(index, (-1,))
prop = tf.transpose(a=prop, perm=[1, 0, 2])
out = self.reduce_method(prop, index)
out = tf.transpose(a=out, perm=[1, 0, 2])
return out
[docs] def compute_output_shape(self, input_shape):
"""
Compute output shapes from input shapes
Args:
input_shape (sequence of tuple): input shapes
Returns: sequence of tuples output shapes
"""
prop_shape = input_shape[0]
return prop_shape[0], None, prop_shape[-1]
[docs] def get_config(self):
"""
Part of keras layer interface, where the signature is converted into a dict
Returns:
configurational dictionary
"""
config = {"mode": self.mode}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))