Source code for megnet.layers.readout.set2set

"""
Set2Set implementation
"""
import tensorflow as tf
import tensorflow.keras.backend as kb
from tensorflow.keras import activations, initializers, regularizers, constraints
from tensorflow.keras.layers import Layer


[docs]class Set2Set(Layer): """ For a set of vectors, the set2set neural network maps it to a single vector. The order invariance is acheived by a attention mechanism. See Vinyals, Oriol, Samy Bengio, and Manjunath Kudlur. "Order matters: Sequence to sequence for sets." arXiv preprint arXiv:1511.06391 (2015). """ def __init__( self, T=3, n_hidden=512, activation=None, activation_lstm="tanh", recurrent_activation="hard_sigmoid", kernel_initializer="glorot_uniform", recurrent_initializer="orthogonal", bias_initializer="zeros", use_bias=True, unit_forget_bias=True, kernel_regularizer=None, recurrent_regularizer=None, bias_regularizer=None, kernel_constraint=None, recurrent_constraint=None, bias_constraint=None, **kwargs, ): """ Args: T: (int) recurrent step n_hidden: (int) number of hidden units activation: (str or object) activation function activation_lstm: (str or object) activation function for lstm recurrent_activation: (str or object) activation function for recurrent step kernel_initializer: (str or object) initializer for kernel weights recurrent_initializer: (str or object) initializer for recurrent weights bias_initializer: (str or object) initializer for biases use_bias: (bool) whether to use biases unit_forget_bias: (bool) whether to use basis in forget gate kernel_regularizer: (str or object) regularizer for kernel weights recurrent_regularizer: (str or object) regularizer for recurrent weights bias_regularizer: (str or object) regularizer for biases kernel_constraint: (str or object) constraint for kernel weights recurrent_constraint: (str or object) constraint for recurrent weights bias_constraint:(str or object) constraint for biases kwargs: other inputs for keras Layer class """ super().__init__(**kwargs) self.activation = activations.get(activation) self.use_bias = use_bias self.kernel_initializer = initializers.get(kernel_initializer) self.bias_initializer = initializers.get(bias_initializer) self.kernel_regularizer = regularizers.get(kernel_regularizer) self.bias_regularizer = regularizers.get(bias_regularizer) self.kernel_regularizer = regularizers.get(kernel_regularizer) self.kernel_constraint = constraints.get(kernel_constraint) self.bias_constraint = constraints.get(bias_constraint) self.activation_lstm = activations.get(activation_lstm) self.recurrent_activation = activations.get(recurrent_activation) self.recurrent_initializer = initializers.get(recurrent_initializer) self.unit_forget_bias = unit_forget_bias self.recurrent_regularizer = regularizers.get(recurrent_regularizer) self.recurrent_constraint = constraints.get(recurrent_constraint) self.T = T self.n_hidden = n_hidden
[docs] def build(self, input_shape): """ Build tensors Args: input_shape (sequence of tuple): input shapes """ feature_shape, index_shape = input_shape self.m_weight = self.add_weight( shape=(feature_shape[-1], self.n_hidden), initializer=self.kernel_initializer, name="x_to_m_weight", regularizer=self.kernel_regularizer, constraint=self.kernel_constraint, ) if self.use_bias: self.m_bias = self.add_weight( shape=(self.n_hidden,), initializer=self.bias_initializer, name="x_to_m_bias", regularizer=self.bias_regularizer, constraint=self.bias_constraint, ) else: self.m_bias = None self.recurrent_kernel = self.add_weight( shape=(2 * self.n_hidden, 4 * self.n_hidden), name="recurrent_kernel", initializer=self.recurrent_initializer, regularizer=self.recurrent_regularizer, constraint=self.recurrent_constraint, ) if self.use_bias: if self.unit_forget_bias: def bias_initializer(_, *args, **kwargs): return kb.concatenate( [ self.bias_initializer((self.n_hidden,), *args, **kwargs), initializers.Ones()((self.n_hidden,), *args, **kwargs), self.bias_initializer((self.n_hidden * 2,), *args, **kwargs), ] ) else: bias_initializer = self.bias_initializer self.recurrent_bias = self.add_weight( shape=(self.n_hidden * 4,), name="recurrent_bias", initializer=bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint, ) else: self.recurrent_bias = None self.built = True
[docs] def compute_output_shape(self, input_shape): """ Compute output shapes from input shapes Args: input_shape (sequence of tuple): input shapes Returns: sequence of tuples output shapes """ feature_shape, index_shape = input_shape return feature_shape[0], None, 2 * self.n_hidden
[docs] def call(self, inputs, mask=None): """ Main logic Args: inputs (tuple of tensor): input tensors mask (tensor): mask tensor Returns: output tensor """ features, feature_graph_index = inputs feature_graph_index = feature_graph_index[0] counts = tf.math.bincount(feature_graph_index) n_count = tf.shape(counts)[0] n_feature = tf.shape(features)[0] # _, _, count = tf.unique_with_counts(feature_graph_index) m = kb.dot(features, self.m_weight) if self.use_bias: m += self.m_bias self.h = tf.zeros(tf.stack([n_feature, n_count, self.n_hidden])) self.c = tf.zeros(tf.stack([n_feature, n_count, self.n_hidden])) q_star = tf.zeros(tf.stack([n_feature, n_count, 2 * self.n_hidden])) for i in range(self.T): self.h, c = self._lstm(q_star, self.c) e_i_t = tf.reduce_sum(input_tensor=m * tf.repeat(self.h, repeats=counts, axis=1), axis=-1) maxes = tf.math.segment_max(e_i_t[0], feature_graph_index) maxes = tf.repeat(maxes, repeats=counts) e_i_t -= tf.expand_dims(maxes, axis=0) # e_i_t -= tf.expand_dims(tf.gather(maxes, feature_graph_index, # axis=0), axis=0) exp = tf.exp(e_i_t) seg_sum = tf.transpose( a=tf.math.segment_sum(tf.transpose(a=exp, perm=[1, 0]), feature_graph_index), perm=[1, 0] ) seg_sum = tf.expand_dims(seg_sum, axis=-1) interm = tf.repeat(seg_sum, repeats=counts, axis=1) a_i_t = exp / interm[..., 0] r_t = tf.transpose( a=tf.math.segment_sum( tf.transpose(a=tf.multiply(m, a_i_t[:, :, None]), perm=[1, 0, 2]), feature_graph_index ), perm=[1, 0, 2], ) q_star = kb.concatenate([self.h, r_t], axis=-1) return q_star
def _lstm(self, h, c): # lstm implementation here z = kb.dot(h, self.recurrent_kernel) if self.use_bias: z += self.recurrent_bias z0 = z[:, :, : self.n_hidden] z1 = z[:, :, self.n_hidden : 2 * self.n_hidden] z2 = z[:, :, 2 * self.n_hidden : 3 * self.n_hidden] z3 = z[:, :, 3 * self.n_hidden :] i = self.recurrent_activation(z0) f = self.recurrent_activation(z1) # print(z.shape, f.shape, c.shape, z2.shape) c = f * c + i * self.activation_lstm(z2) o = self.recurrent_activation(z3) h = o * self.activation_lstm(c) return h, c
[docs] def get_config(self): """ Part of keras layer interface, where the signature is converted into a dict Returns: configurational dictionary """ config = { "T": self.T, "n_hidden": self.n_hidden, "activation": activations.serialize(self.activation), "activation_lstm": activations.serialize(self.activation_lstm), "recurrent_activation": activations.serialize(self.recurrent_activation), "kernel_initializer": initializers.serialize(self.kernel_initializer), "recurrent_initializer": initializers.serialize(self.recurrent_initializer), "bias_initializer": initializers.serialize(self.bias_initializer), "use_bias": self.use_bias, "unit_forget_bias": self.unit_forget_bias, "kernel_regularizer": regularizers.serialize(self.kernel_regularizer), "recurrent_regularizer": regularizers.serialize(self.recurrent_regularizer), "bias_regularizer": regularizers.serialize(self.bias_regularizer), "kernel_constraint": constraints.serialize(self.kernel_constraint), "recurrent_constraint": constraints.serialize(self.recurrent_constraint), "bias_constraint": constraints.serialize(self.bias_constraint), } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items()))