megnet.models package¶
Submodules¶
Module contents¶
Models package, this package contains various graph-based models
- class GraphModel(model: keras.engine.training.Model, graph_converter: megnet.data.graph.StructureGraph, target_scaler: megnet.utils.preprocessing.Scaler = <megnet.utils.preprocessing.DummyScaler object>, metadata: Optional[Dict] = None, **kwargs)[source]¶
Bases:
object
Composition of keras model and converter class for transfering structure object to input tensors. We add methods to train the model from (structures, targets) pairs
- Parameters
model – (keras model)
graph_converter – (object) a object that turns a structure to a graph, check megnet.data.crystal
target_scaler – (object) a scaler object for converting targets, check megnet.utils.preprocessing
metadata – (dict) An optional dict of metadata associated with the model. Recommended to incorporate some basic information such as units, MAE performance, etc.
- check_dimension(graph: Dict) bool [source]¶
Check the model dimension against the graph converter dimension :param graph: structure graph
Returns:
- classmethod from_file(filename: str) megnet.models.base.GraphModel [source]¶
- Class method to load model from
filename for keras model filename.json for additional converters
- Parameters
filename – (str) model file name
- Returns
GraphModel
- get_all_graphs_targets(structures: List[pymatgen.core.structure.Structure], targets: List[float], scrub_failed_structures: bool = False) tuple [source]¶
Compute the graphs from structures and spit out (graphs, targets) with options to automatically remove structures with failed graph computations
- Parameters
structures – (list) pymatgen structure list
targets – (list) target property list
scrub_failed_structures – (bool) whether to scrub those failed structures
- Returns
graphs, targets
- predict_graph(graph: Dict) numpy.ndarray [source]¶
Predict property from graph
- Parameters
graph – a graph dictionary, see megnet.data.graph
- Returns
predicted target value
- predict_graphs(graphs: List[Dict]) numpy.ndarray [source]¶
Predict properties from graphs
- Parameters
graphs – a list graph dictionary, see megnet.data.graph
- Returns
predicted target values
- predict_structure(structure: pymatgen.core.structure.Structure) numpy.ndarray [source]¶
Predict property from structure
- Parameters
structure – pymatgen structure or molecule
- Returns
predicted target value
- predict_structures(structures: List[pymatgen.core.structure.Structure]) numpy.ndarray [source]¶
Predict properties of structure list
- Parameters
structures – list of pymatgen Structure/Molecule
- Returns
predicted target values
- save_model(filename: str) None [source]¶
Save the model to a keras model hdf5 and a json config for additional converters
- Parameters
filename – (str) output file name
- Returns
None
- train(train_structures: List[pymatgen.core.structure.Structure], train_targets: List[float], validation_structures: Optional[List[pymatgen.core.structure.Structure]] = None, validation_targets: Optional[List[float]] = None, sample_weights: Optional[List[float]] = None, epochs: int = 1000, batch_size: int = 128, verbose: int = 1, callbacks: Optional[List[keras.callbacks.Callback]] = None, scrub_failed_structures: bool = False, prev_model: Optional[str] = None, save_checkpoint: bool = True, automatic_correction: bool = False, lr_scaling_factor: float = 0.5, patience: int = 500, dirname: str = 'callback', **kwargs) megnet.models.base.GraphModel [source]¶
- Parameters
train_structures – (list) list of pymatgen structures
train_targets – (list) list of target values
validation_structures – (list) list of pymatgen structures as validation
validation_targets – (list) list of validation targets
sample_weights – (list) list of sample weights for training data
epochs – (int) number of epochs
batch_size – (int) training batch size
verbose – (int) keras fit verbose, 0 no progress bar, 1 only at the epoch end and 2 every batch
callbacks – (list) megnet or keras callback functions for training
scrub_failed_structures – (bool) whether to scrub structures with failed graph computation
prev_model – (str) file name for previously saved model
save_checkpoint – (bool) whether to save checkpoint
automatic_correction – (bool) correct nan errors
lr_scaling_factor – (float, less than 1) scale the learning rate down when nan loss encountered
patience – (int) patience for early stopping
dirname – (str) the directory in which to save checkpoints, if save_checkpoint=True
**kwargs –
- train_from_graphs(train_graphs: List[Dict], train_targets: List[float], validation_graphs: Optional[List[Dict]] = None, validation_targets: Optional[List[float]] = None, sample_weights: Optional[List[float]] = None, epochs: int = 1000, batch_size: int = 128, verbose: int = 1, callbacks: Optional[List[keras.callbacks.Callback]] = None, prev_model: Optional[str] = None, lr_scaling_factor: float = 0.5, patience: int = 500, save_checkpoint: bool = True, automatic_correction: bool = False, dirname: str = 'callback', **kwargs) megnet.models.base.GraphModel [source]¶
- Parameters
train_graphs – (list) list of graph dictionaries
train_targets – (list) list of target values
validation_graphs – (list) list of graphs as validation
validation_targets – (list) list of validation targets
sample_weights – (list) list of sample weights
epochs – (int) number of epochs
batch_size – (int) training batch size
verbose – (int) keras fit verbose, 0 no progress bar, 1 only at the epoch end and 2 every batch
callbacks – (list) megnet or keras callback functions for training
prev_model – (str) file name for previously saved model
lr_scaling_factor – (float, less than 1) scale the learning rate down when nan loss encountered
patience – (int) patience for early stopping
save_checkpoint – (bool) whether to save checkpoint
automatic_correction – (bool) correct nan errors
dirname – (str) the directory in which to save checkpoints, if save_checkpoint=True
**kwargs –
- class MEGNetModel(nfeat_edge: Optional[int] = None, nfeat_global: Optional[int] = None, nfeat_node: Optional[int] = None, nblocks: int = 3, lr: float = 0.001, n1: int = 64, n2: int = 32, n3: int = 16, nvocal: int = 95, embedding_dim: int = 16, nbvocal: Optional[int] = None, bond_embedding_dim: Optional[int] = None, ngvocal: Optional[int] = None, global_embedding_dim: Optional[int] = None, npass: int = 3, ntarget: int = 1, act: Callable = <function softplus2>, is_classification: bool = False, loss: str = 'mse', metrics: Optional[List[str]] = None, l2_coef: Optional[float] = None, dropout: Optional[float] = None, graph_converter: Optional[megnet.data.graph.StructureGraph] = None, target_scaler: megnet.utils.preprocessing.Scaler = <megnet.utils.preprocessing.DummyScaler object>, optimizer_kwargs: Dict = {'clipnorm': 3}, dropout_on_predict: bool = False, sample_weight_mode: Optional[str] = None, **kwargs)[source]¶
Bases:
megnet.models.base.GraphModel
Construct a graph network model with or without explicit atom features if n_feature is specified then a general graph model is assumed, otherwise a crystal graph model with z number as atom feature is assumed.
- Parameters
nfeat_edge – (int) number of bond features
nfeat_global – (int) number of state features
nfeat_node – (int) number of atom features
nblocks – (int) number of MEGNetLayer blocks
lr – (float) learning rate
n1 – (int) number of hidden units in layer 1 in MEGNetLayer
n2 – (int) number of hidden units in layer 2 in MEGNetLayer
n3 – (int) number of hidden units in layer 3 in MEGNetLayer
nvocal – (int) number of total element
embedding_dim – (int) number of embedding dimension
nbvocal – (int) number of bond types if bond attributes are types
bond_embedding_dim – (int) number of bond embedding dimension
ngvocal – (int) number of global types if global attributes are types
global_embedding_dim – (int) number of global embedding dimension
npass – (int) number of recurrent steps in Set2Set layer
ntarget – (int) number of output targets
act – (object) activation function
l2_coef – (float or None) l2 regularization parameter
is_classification – (bool) whether it is a classification task
loss – (object or str) loss function
metrics – (list or dict) List or dictionary of Keras metrics to be evaluated by the model during training and testing
dropout – (float) dropout rate
graph_converter – (object) object that exposes a “convert” method for structure to graph conversion
target_scaler – (object) object that exposes a “transform” and “inverse_transform” methods for transforming the target values
optimizer_kwargs (dict) – extra keywords for optimizer, for example clipnorm and clipvalue
sample_weight_mode (str) – sample weight mode for compilation
kwargs (dict) – in the case where bond inputs are pure distances (not the expanded distances nor integers for embedding, i.e., nfeat_edge=None and bond_embedding_dim=None), kwargs can take additional inputs for expand the distance using Gaussian basis. centers (np.ndarray): array for defining the Gaussian expansion centers width (float): width for the Gaussian basis
- classmethod from_mvl_models(name: str) megnet.models.base.GraphModel [source]¶
load model using mvl model names :param name: model name string. Please check
megnet.utils.models.AVAILABEL_MODELS for available models
Returns: GraphModel instance
- classmethod from_url(url: str) megnet.models.base.GraphModel [source]¶
Download and load a model from a URL. E.g. https://github.com/materialsvirtuallab/megnet/blob/master/mvl_models/mp-2019.4.1/formation_energy.hdf5
- Parameters
url – (str) url link of the model
- Returns
GraphModel