megnet.utils.layer module

Tensorflow layer utilities

gather(tensor: tensorflow.python.framework.ops.Tensor, indices: tensorflow.python.framework.ops.Tensor) tensorflow.python.framework.ops.Tensor[source]

Alternative implementations to tf.gather, without the index warnings

Parameters
  • tensor – (Tensor) tensor to be gathered

  • indices – (Tensor) indices tensor

repeat_with_index(x: tensorflow.python.framework.ops.Tensor, index: tensorflow.python.framework.ops.Tensor, axis: int = 1)[source]

Given an tensor x (N*M*K), repeat the middle axis (axis=1) according to the index tensor index (G, ) for example, if axis=1 and index = Tensor([0, 0, 0, 1, 2, 2]) then M = 3 (3 unique values), and the final tensor would have the shape (N*6*3) with the first one in M repeated 3 times, second 1 time and third 2 times.

Args:

x: (3d Tensor) tensor to be augmented index: (1d Tensor) repetition tensor axis: (int) axis for repetition

Returns

(3d Tensor) tensor after repetition