m3gnet.trainers package

Module contents

M3GNet trainers

class m3gnet.trainers.PotentialTrainer(potential: m3gnet.models._base.Potential, optimizer: keras.optimizer_v2.optimizer_v2.OptimizerV2)

Bases: object

Trainer for M3GNet potential

train(graphs_or_structures: List, energies: List, forces: List, stresses: Optional[List] = None, validation_graphs_or_structures: Optional[List] = None, val_energies: Optional[List] = None, val_forces: Optional[List] = None, val_stresses: Optional[List] = None, loss: keras.losses.Loss = <function mean_squared_error>, force_loss_ratio: float = 1, stress_loss_ratio: float = 0.1, batch_size: int = 32, epochs: int = 1000, callbacks: Optional[List] = None, save_checkpoint: bool = True, early_stop_patience: int = 200, verbose: int = 1, fit_per_element_offset: bool = False)
Args:

graphs_or_structures (list): a list of MaterialGraph or structures energies (list): list of train energies forces (list): list of train forces stresses (list): list of train stresses validation_graphs_or_structures (list): optional list of validation

graphs or structures

val_energies (list): list of val energies val_forces (list): list of val forces val_stresses (list): list of val stresses loss (tf.keras.losses.Loss): loss object force_loss_ratio (float): the ratio of forces in loss stress_loss_ratio (float): the ratio of stresses in loss train_metrics (list): list of train metrics val_metrics (list): list of validation metrics val_monitor (str): field to monitor during validation, e.g.,

“val_mae”, “val_acc” or “val_auc”

batch_size (int): batch size of combining graphs epochs (int): epochs for training the data callbacks (list): list of callback functions save_checkpoint (bool): whether to save model check point early_stop_patience (int): patience for early stopping verbose (bool): whether to show model training progress fit_per_element_offset (bool): whether to train an element-wise

offset, e.g., elemental energies etc. If trained, such energy will be summed to the neural network predictions.

Returns: None

class m3gnet.trainers.Trainer(model: m3gnet.models._m3gnet.M3GNet, optimizer: keras.optimizer_v2.optimizer_v2.OptimizerV2)

Bases: object

Trainer for material properties

restart_from_directory(dirname: str)

Continue previous model training from a directory Args:

dirname (str): directory name

Returns:

train(graphs_or_structures: List, targets: List, validation_graphs_or_structures: Optional[List] = None, validation_targets: Optional[List] = None, loss: keras.losses.Loss = <function mean_squared_error>, train_metrics: Optional[List] = None, val_metrics: Optional[List] = None, val_monitor: str = 'val_mae', batch_size: int = 128, epochs: int = 1000, callbacks: Optional[List] = None, save_checkpoint: bool = True, early_stop_patience: int = 200, verbose: int = 1, clip_norm: Optional[float] = 10.0, fit_per_element_offset: bool = False)
Args:

graphs_or_structures (list): a list of MaterialGraph or structures targets (list): list of properties in float validation_graphs_or_structures (list): optional list of validation

graphs or structures

validation_targets (list): optional list of properties loss (tf.keras.losses.Loss): loss object train_metrics (list): list of train metrics val_metrics (list): list of validation metrics val_monitor (str): field to monitor during validation, e.g.,

“val_mae”, “val_acc” or “val_auc”

batch_size (int): batch size of combining graphs epochs (int): epochs for training the data callbacks (list): list of callback functions save_checkpoint (bool): whether to save model check point early_stop_patience (int): patience for early stopping verbose (bool): whether to show model training progress clip_norm (float): gradient norm clip fit_per_element_offset (bool): whether to train an element-wise

offset, e.g., elemental energies etc. If trained, such energy will be summed to the neural network predictions.

Returns: None