megnet.utils.layer module¶
Tensorflow layer utilities
- gather(tensor: tensorflow.python.framework.ops.Tensor, indices: tensorflow.python.framework.ops.Tensor) tensorflow.python.framework.ops.Tensor [source]¶
Alternative implementations to tf.gather, without the index warnings
- Parameters
tensor – (Tensor) tensor to be gathered
indices – (Tensor) indices tensor
- repeat_with_index(x: tensorflow.python.framework.ops.Tensor, index: tensorflow.python.framework.ops.Tensor, axis: int = 1)[source]¶
Given an tensor x (N*M*K), repeat the middle axis (axis=1) according to the index tensor index (G, ) for example, if axis=1 and index = Tensor([0, 0, 0, 1, 2, 2]) then M = 3 (3 unique values), and the final tensor would have the shape (N*6*3) with the first one in M repeated 3 times, second 1 time and third 2 times.
- Args:
x: (3d Tensor) tensor to be augmented index: (1d Tensor) repetition tensor axis: (int) axis for repetition
- Returns
(3d Tensor) tensor after repetition