Source code for megnet.utils.general

"""
Operation utilities on lists and arrays
"""
from collections.abc import Iterable
from typing import Union, List, Sequence, Optional

import numpy as np


[docs]def to_list(x: Union[Iterable, np.ndarray]) -> List: """ If x is not a list, convert it to list """ if isinstance(x, Iterable): return list(x) if isinstance(x, np.ndarray): return x.tolist() # noqa return [x]
[docs]def expand_1st(x: np.ndarray) -> np.ndarray: """ Adding an extra first dimension Args: x: (np.array) Returns: (np.array) """ return np.expand_dims(x, axis=0)
[docs]def fast_label_binarize(value: List, labels: List) -> List[int]: """Faster version of label binarize `label_binarize` from scikit-learn is slow when run 1 label at a time. `label_binarize` also is efficient for large numbers of classes, which is not common in `megnet` Args: value: Value to encode labels (list): Possible class values Returns: ([int]): List of integers """ if len(labels) == 2: return [int(value == labels[0])] output = [0] * len(labels) if value in labels: output[labels.index(value)] = 1 return output
[docs]def check_shape(array: Optional[np.ndarray], shape: Sequence) -> bool: """ Check if array complies with shape. Shape is a sequence of integer that may end with None. If None is at the end of shape, then any shapes in array after that dimension will match with shape. Example: array with shape [10, 20, 30, 40] matches with [10, 20, None], but does not match with shape [10, 20, 30, 20] Args: array (np.ndarray or None): array to be checked shape (Sequence): integer array shape, it may ends with None Returns: bool """ if array is None: return True if all(i is None for i in shape): return True array_shape = array.shape valid_dims = [i for i in shape if i is not None] n_for_check = len(valid_dims) return all(i == j for i, j in zip(array_shape[:n_for_check], valid_dims))
[docs]def reshape(array: np.ndarray, shape: Sequence) -> np.ndarray: """ Take an array and reshape it according to shape. Here shape may contain None field at the end. if array shape is [3, 4] and shape is [3, 4, None], then array is shaped to [3, 4, 1]. If the two shapes do not match then report an error Args: array (np.ndarray): array to be reshaped shape (Sequence): shape dimensions Returns: np.ndarray, reshaped array """ if not check_shape(array, shape): raise ValueError("array cannot be reshaped due to mismatch") if array.ndim >= len(shape): return array shape_r = [i if i is not None else 1 for i in shape] missing_dim = range(len(array.shape), len(shape_r)) array_r = np.expand_dims(array, axis=list(missing_dim)) tiles = [i // j for i, j in zip(shape_r, array_r.shape)] return np.tile(array_r, tiles)