Source code for megnet.config
"""Data types"""
import numpy as np
import tensorflow as tf
DTYPES = {
"float32": {"numpy": np.float32, "tf": tf.float32},
"float16": {"numpy": np.float16, "tf": tf.float16},
"int32": {"numpy": np.int32, "tf": tf.int32},
"int16": {"numpy": np.int16, "tf": tf.int16},
}
[docs]class DataType:
"""
Data types for tensorflow. This enables users to choose
from 32-bit float and int, and 16-bit float and int
"""
np_float = np.float32
np_int = np.int32
tf_float = tf.float32
tf_int = tf.int32
[docs] @classmethod
def set_dtype(cls, data_type: str) -> None:
"""
Class method to set the data types
Args:
data_type (str): '16' or '32'
"""
if data_type.endswith("32"):
float_key = "float32"
int_key = "int32"
elif data_type.endswith("16"):
float_key = "float16"
int_key = "int16"
else:
raise ValueError("Data type not known, choose '16' or '32'")
cls.np_float = DTYPES[float_key]["numpy"]
cls.tf_float = DTYPES[float_key]["tf"]
cls.np_int = DTYPES[int_key]["numpy"]
cls.tf_int = DTYPES[int_key]["tf"]
[docs]def set_global_dtypes(data_type) -> None:
"""
Function to set the data types
Args:
data_type (str): '16' or '32'
Returns:
"""
DataType.set_dtype(data_type)