Source code for megnet.utils.layer

"""
Tensorflow layer utilities
"""
import numpy as np  # noqa
import tensorflow as tf
from megnet.config import DataType


def _repeat(x: tf.Tensor, n: tf.Tensor, axis: int = 1) -> tf.Tensor:
    """
    Given an tensor x (N*M*K), repeat the middle axis (axis=1)
    according to repetition indicator n (M, )
    for example, if M = 3, axis=1, and n = Tensor([3, 1, 2]),
    and the final tensor would have the shape (N*6*3) with the
    first one in M repeated 3 times,
    second 1 time and third 2 times.

     Args:
        x: (3d Tensor) tensor to be augmented
        n: (1d Tensor) number of repetition for each row
        axis: (int) axis for repetition

    Returns:
        (3d Tensor) tensor after repetition
    """
    # get maximum repeat length in x
    assert len(n.shape) == 1
    maxlen = tf.reduce_max(input_tensor=n)
    x_shape = tf.shape(input=x)
    x_dim = len(x.shape)
    # create a range with the length of x
    shape = [1] * (x_dim + 1)
    shape[axis + 1] = maxlen
    # tile it to the maximum repeat length, it should be of shape
    # [xlen, maxlen] now
    x_tiled = tf.tile(tf.expand_dims(x, axis + 1), tf.stack(shape))

    new_shape = tf.unstack(x_shape)
    new_shape[axis] = -1
    new_shape[-1] = x.shape[-1]
    x_tiled = tf.reshape(x_tiled, new_shape)
    # create a sequence mask using x
    # this will create a boolean matrix of shape [xlen, maxlen]
    # where result[i,j] is true if j < x[i].
    mask = tf.sequence_mask(n, maxlen)
    mask = tf.reshape(mask, (-1,))
    # mask the elements based on the sequence mask
    return tf.boolean_mask(tensor=x_tiled, mask=mask, axis=axis)


[docs]def repeat_with_index(x: tf.Tensor, index: tf.Tensor, axis: int = 1): """ Given an tensor x (N*M*K), repeat the middle axis (axis=1) according to the index tensor index (G, ) for example, if axis=1 and index = Tensor([0, 0, 0, 1, 2, 2]) then M = 3 (3 unique values), and the final tensor would have the shape (N*6*3) with the first one in M repeated 3 times, second 1 time and third 2 times. Args: x: (3d Tensor) tensor to be augmented index: (1d Tensor) repetition tensor axis: (int) axis for repetition Returns: (3d Tensor) tensor after repetition """ index = tf.reshape(index, (-1,)) _, _, n = tf.unique_with_counts(index) return _repeat(x, n, axis)
[docs]def gather(tensor: tf.Tensor, indices: tf.Tensor) -> tf.Tensor: """ Alternative implementations to tf.gather, without the index warnings Args: tensor: (Tensor) tensor to be gathered indices: (Tensor) indices tensor """ ta = tf.TensorArray(dtype=DataType.tf_float, size=0, dynamic_size=True) ta = ta.unstack(tensor) results = ta.gather(indices) return results