Source code for megnet.models.base

"""
Implements basic GraphModels.
"""

import os
from typing import Dict, List, Union
from warnings import warn

import numpy as np
from monty.serialization import dumpfn, loadfn
from tensorflow.keras.backend import int_shape
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.models import Model
from pymatgen.core import Structure

from megnet.callbacks import ModelCheckpointMAE, ManualStop, ReduceLRUponNan
from megnet.data.graph import GraphBatchDistanceConvert, GraphBatchGenerator, StructureGraph
from megnet.utils.preprocessing import DummyScaler, Scaler


[docs]class GraphModel: """ Composition of keras model and converter class for transfering structure object to input tensors. We add methods to train the model from (structures, targets) pairs """ def __init__( self, model: Model, graph_converter: StructureGraph, target_scaler: Scaler = DummyScaler(), metadata: Dict = None, **kwargs, ): """ Args: model: (keras model) graph_converter: (object) a object that turns a structure to a graph, check `megnet.data.crystal` target_scaler: (object) a scaler object for converting targets, check `megnet.utils.preprocessing` metadata: (dict) An optional dict of metadata associated with the model. Recommended to incorporate some basic information such as units, MAE performance, etc. """ self.model = model self.graph_converter = graph_converter self.target_scaler = target_scaler self.metadata = metadata or {} def __getattr__(self, p): return getattr(self.model, p)
[docs] def train( self, train_structures: List[Structure], train_targets: List[float], validation_structures: List[Structure] = None, validation_targets: List[float] = None, sample_weights: List[float] = None, epochs: int = 1000, batch_size: int = 128, verbose: int = 1, callbacks: List[Callback] = None, scrub_failed_structures: bool = False, prev_model: str = None, save_checkpoint: bool = True, automatic_correction: bool = False, lr_scaling_factor: float = 0.5, patience: int = 500, dirname: str = "callback", **kwargs, ) -> "GraphModel": """ Args: train_structures: (list) list of pymatgen structures train_targets: (list) list of target values validation_structures: (list) list of pymatgen structures as validation validation_targets: (list) list of validation targets sample_weights: (list) list of sample weights for training data epochs: (int) number of epochs batch_size: (int) training batch size verbose: (int) keras fit verbose, 0 no progress bar, 1 only at the epoch end and 2 every batch callbacks: (list) megnet or keras callback functions for training scrub_failed_structures: (bool) whether to scrub structures with failed graph computation prev_model: (str) file name for previously saved model save_checkpoint: (bool) whether to save checkpoint automatic_correction: (bool) correct nan errors lr_scaling_factor: (float, less than 1) scale the learning rate down when nan loss encountered patience: (int) patience for early stopping dirname: (str) the directory in which to save checkpoints, if `save_checkpoint=True` **kwargs: """ train_graphs, train_targets = self.get_all_graphs_targets( train_structures, train_targets, scrub_failed_structures=scrub_failed_structures ) if (validation_structures is not None) and (validation_targets is not None): val_graphs, validation_targets = self.get_all_graphs_targets( validation_structures, validation_targets, scrub_failed_structures=scrub_failed_structures ) else: val_graphs = None self.train_from_graphs( train_graphs, train_targets, validation_graphs=val_graphs, validation_targets=validation_targets, sample_weights=sample_weights, epochs=epochs, batch_size=batch_size, verbose=verbose, callbacks=callbacks, prev_model=prev_model, lr_scaling_factor=lr_scaling_factor, patience=patience, save_checkpoint=save_checkpoint, automatic_correction=automatic_correction, dirname=dirname, **kwargs, ) return self
[docs] def train_from_graphs( self, train_graphs: List[Dict], train_targets: List[float], validation_graphs: List[Dict] = None, validation_targets: List[float] = None, sample_weights: List[float] = None, epochs: int = 1000, batch_size: int = 128, verbose: int = 1, callbacks: List[Callback] = None, prev_model: str = None, lr_scaling_factor: float = 0.5, patience: int = 500, save_checkpoint: bool = True, automatic_correction: bool = False, dirname: str = "callback", **kwargs, ) -> "GraphModel": """ Args: train_graphs: (list) list of graph dictionaries train_targets: (list) list of target values validation_graphs: (list) list of graphs as validation validation_targets: (list) list of validation targets sample_weights: (list) list of sample weights epochs: (int) number of epochs batch_size: (int) training batch size verbose: (int) keras fit verbose, 0 no progress bar, 1 only at the epoch end and 2 every batch callbacks: (list) megnet or keras callback functions for training prev_model: (str) file name for previously saved model lr_scaling_factor: (float, less than 1) scale the learning rate down when nan loss encountered patience: (int) patience for early stopping save_checkpoint: (bool) whether to save checkpoint automatic_correction: (bool) correct nan errors dirname: (str) the directory in which to save checkpoints, if `save_checkpoint=True` **kwargs: """ # load from saved model if prev_model: self.load_weights(prev_model) is_classification = "entropy" in str(self.model.loss) monitor = "val_acc" if is_classification else "val_mae" mode = "max" if is_classification else "min" has_sample_weights = sample_weights is not None if not os.path.isdir(dirname): os.makedirs(dirname) if callbacks is None: # with this call back you can stop the model training by `touch STOP` callbacks = [ManualStop()] train_nb_atoms = [len(i["atom"]) for i in train_graphs] train_targets = [self.target_scaler.transform(i, j) for i, j in zip(train_targets, train_nb_atoms)] if (validation_graphs is not None) and (validation_targets is not None): filepath = os.path.join(dirname, f"{monitor}_{{epoch:05d}}_{{{monitor}:.6f}}.hdf5") val_nb_atoms = [len(i["atom"]) for i in validation_graphs] validation_targets = [self.target_scaler.transform(i, j) for i, j in zip(validation_targets, val_nb_atoms)] val_inputs = self.graph_converter.get_flat_data(validation_graphs, validation_targets) val_generator = self._create_generator(*val_inputs, batch_size=batch_size) steps_per_val = int(np.ceil(len(validation_graphs) / batch_size)) if save_checkpoint: callbacks.extend( [ ModelCheckpointMAE( filepath=filepath, monitor=monitor, mode=mode, save_best_only=True, save_weights_only=False, val_gen=val_generator, steps_per_val=steps_per_val, target_scaler=self.target_scaler, ) ] ) # avoid running validation twice in an epoch val_generator = None # type: ignore steps_per_val = None # type: ignore if automatic_correction: callbacks.extend( [ ReduceLRUponNan( filepath=filepath, monitor=monitor, mode=mode, factor=lr_scaling_factor, patience=patience, has_sample_weights=has_sample_weights, ) ] ) else: val_generator = None # type: ignore steps_per_val = None # type: ignore train_inputs = self.graph_converter.get_flat_data(train_graphs, train_targets) # check dimension match self.check_dimension(train_graphs[0]) train_generator = self._create_generator(*train_inputs, sample_weights=sample_weights, batch_size=batch_size) steps_per_train = int(np.ceil(len(train_graphs) / batch_size)) self.fit( train_generator, steps_per_epoch=steps_per_train, validation_data=val_generator, validation_steps=steps_per_val, epochs=epochs, verbose=verbose, callbacks=callbacks, **kwargs, ) return self
[docs] def check_dimension(self, graph: Dict) -> bool: """ Check the model dimension against the graph converter dimension Args: graph: structure graph Returns: """ test_inp = self.graph_converter.graph_to_input(graph) input_shapes = [i.shape for i in test_inp] model_input_shapes = [int_shape(i) for i in self.model.inputs] def _check_match(real_shape, tensor_shape): if len(real_shape) != len(tensor_shape): return False matched = True for i, j in zip(real_shape, tensor_shape): if j is None: continue if i == j: continue matched = False return matched for i, j, k in zip( ["atom features", "bond features", "state features"], input_shapes[:3], model_input_shapes[:3] ): matched = _check_match(j, k) if not matched: raise ValueError(f"The data dimension for {i} is {j} and does not match model required shape of {k}") return False
[docs] def get_all_graphs_targets( self, structures: List[Structure], targets: List[float], scrub_failed_structures: bool = False ) -> tuple: """ Compute the graphs from structures and spit out (graphs, targets) with options to automatically remove structures with failed graph computations Args: structures: (list) pymatgen structure list targets: (list) target property list scrub_failed_structures: (bool) whether to scrub those failed structures Returns: graphs, targets """ graphs_valid = [] targets_valid = [] for i, (s, t) in enumerate(zip(structures, targets)): try: graph = self.graph_converter.convert(s) graphs_valid.append(graph) targets_valid.append(t) except Exception as e: if scrub_failed_structures: warn(f"structure with index {i} failed the graph computations", UserWarning) continue raise RuntimeError(str(e)) return graphs_valid, targets_valid
[docs] def predict_structure(self, structure: Structure) -> np.ndarray: """ Predict property from structure Args: structure: pymatgen structure or molecule Returns: predicted target value """ graph = self.graph_converter.convert(structure) return self.predict_graph(graph)
[docs] def predict_structures(self, structures: List[Structure]) -> np.ndarray: """ Predict properties of structure list Args: structures: list of pymatgen Structure/Molecule Returns: predicted target values """ graphs = [self.graph_converter.convert(structure) for structure in structures] return self.predict_graphs(graphs)
[docs] def predict_graph(self, graph: Dict) -> np.ndarray: """ Predict property from graph Args: graph: a graph dictionary, see megnet.data.graph Returns: predicted target value """ inp = self.graph_converter.graph_to_input(graph) pred = self.predict(inp) # direct prediction, shape [1, 1, m] return self.target_scaler.inverse_transform(pred[0, 0], len(graph["atom"]))
[docs] def predict_graphs(self, graphs: List[Dict]) -> np.ndarray: """ Predict properties from graphs Args: graphs: a list graph dictionary, see megnet.data.graph Returns: predicted target values """ inputs = self.graph_converter.get_flat_data(graphs) n_atoms = [len(graph["atom"]) for graph in graphs] pred_gen = self._create_generator(*inputs, is_shuffle=False) predicted = [] for i in pred_gen: predicted.append(self.predict(i)) pred_targets = np.concatenate(predicted, axis=1)[0] return np.array([self.target_scaler.inverse_transform(i, j) for i, j in zip(pred_targets, n_atoms)])
def _create_generator(self, *args, **kwargs) -> Union[GraphBatchDistanceConvert, GraphBatchGenerator]: if hasattr(self.graph_converter, "bond_converter"): kwargs.update({"distance_converter": self.graph_converter.bond_converter}) return GraphBatchDistanceConvert(*args, **kwargs) return GraphBatchGenerator(*args, **kwargs)
[docs] def save_model(self, filename: str) -> None: """ Save the model to a keras model hdf5 and a json config for additional converters Args: filename: (str) output file name Returns: None """ self.model.save(filename) dumpfn( {"graph_converter": self.graph_converter, "target_scaler": self.target_scaler, "metadata": self.metadata}, filename + ".json", )
[docs] @classmethod def from_file(cls, filename: str) -> "GraphModel": """ Class method to load model from filename for keras model filename.json for additional converters Args: filename: (str) model file name Returns GraphModel """ configs = loadfn(filename + ".json") from tensorflow.keras.models import load_model from megnet.layers import _CUSTOM_OBJECTS model = load_model(filename, custom_objects=_CUSTOM_OBJECTS) configs.update({"model": model}) return GraphModel(**configs)