m3gnet.layers package¶
Module contents¶
Graph layers
- class m3gnet.layers.AtomNetwork(*args, **kwargs)¶
Bases:
m3gnet.layers._base.GraphUpdate
Atom network that takes a graph as input and then calculate new atom attributes
- call(graph: List, **kwargs) List ¶
- Args:
graph (list): list repr of a MaterialGraph **kwargs:
Returns: a new MaterialGraph in list repr
- update_atoms(graph: List) tensorflow.python.framework.ops.Tensor ¶
Take a graph input and calculate the updated atom attributes Args:
graph (list): list repr of a MaterialGraph
Returns: tf.Tensor
- class m3gnet.layers.AtomReduceState(*args, **kwargs)¶
Bases:
keras.engine.base_layer.Layer
Reduce atom attributes to states via sum or mean
- call(graph: List, **kwargs) tensorflow.python.framework.ops.Tensor ¶
Main layer logic Args:
graph (list): input graph list representation **kwargs:
Returns:
- get_config() dict ¶
Get the configuration dict Returns:
- class m3gnet.layers.AtomRef(*args, **kwargs)¶
Bases:
m3gnet.layers._atom_ref.BaseAtomRef
Atom reference values. For example, if the average H energy is -20.0, and the average O energy is -10.0, the AtomRef predicts -20 * 2 + (-10.) = -50.0 for the atom reference energy for H2O
- fit(structs_or_graphs, properties)¶
Fit the elemental reference values for the properties Args:
structs_or_graphs (list): list of graphs or structures properties (np.ndarray): array of extensive properties
Returns:
- get_config()¶
Get dict config for serialization Returns (dict):
- inverse_transform(structs_or_graphs, properties)¶
Take the transformed values and get the original values Args:
structs_or_graphs (list): list of graphs or structures properties (np.ndarray): array of extensive properties
Returns: corrected property values
- predict_properties(structs_or_graphs)¶
- Args:
- structs_or_graphs (list): calculate the atom summed property
values
Returns:
- set_property_per_element(property_per_element)¶
Set the property per atom value Args:
- property_per_element (np.ndarray): array of elemental properties,
the i-th row is the elemental value for atomic number i.
Returns:
- transform(structs_or_graphs, properties)¶
Correct the extensive properties by subtracting the atom reference values Args:
structs_or_graphs (list): list of graphs or structures properties (np.ndarray): array of extensive properties
Returns: corrected property values
- class m3gnet.layers.BaseAtomRef(*args, **kwargs)¶
Bases:
keras.engine.base_layer.Layer
Base AtomRef that predicts 0 correction
- class m3gnet.layers.BondNetwork(*args, **kwargs)¶
Bases:
m3gnet.layers._base.GraphUpdate
Edge network that takes a graph as input and then calculate new bond attributes
- call(graph: List, **kwargs) List ¶
Update the bond and return a copy of the graph Args:
graph (list): graph list representation **kwargs:
Returns: graph list representation
- update_bonds(graph: List) tensorflow.python.framework.ops.Tensor ¶
Update bond info in the graph Args:
graph (list): list representation of a graph
Returns: tf.Tensor
- class m3gnet.layers.ConcatAtoms(*args, **kwargs)¶
Bases:
m3gnet.layers._bond.BondNetwork
\[eij^\prime = Update(vi⊕vj⊕eij⊕u)\]- get_config() dict ¶
get config dict for serialization Returns: dict
- class m3gnet.layers.ConcatBondAtomState(*args, **kwargs)¶
Bases:
m3gnet.layers._state.StateNetwork
u^prime = Update(bar e^prime⊕bar v^prime⊕u)
- get_config() dict ¶
Get config dict for serialization Returns: dict
- update_states(graph: List) tensorflow.python.framework.ops.Tensor ¶
- Args:
graph (list): list repr of a MaterialGraph
Returns:
- class m3gnet.layers.Embedding(*args, **kwargs)¶
Bases:
keras.layers.embeddings.Embedding
Thin wrapper for embedding atomic numbers into feature vectors
- call(inputs: List[tensorflow.python.framework.ops.Tensor]) tensorflow.python.framework.ops.Tensor ¶
Implementation of the layer call
- Args:
inputs (list): list representation of graph
Returns: tf.Tensor
- class m3gnet.layers.GatedAtomUpdate(*args, **kwargs)¶
Bases:
m3gnet.layers._atom.AtomNetwork
Take the neighbor atom attributes and bond attributes, update them to the center atom
- get_config() dict ¶
Get config dict for serialization Returns: dict
- update_atoms(graph: List) tensorflow.python.framework.ops.Tensor ¶
Update atom attributes Args:
graph (list): list repr of MaterialGraph
Returns: tf.Tensor
- class m3gnet.layers.GatedMLP(*args, **kwargs)¶
Bases:
keras.engine.base_layer.Layer
- Gated MLP implementation. It implements the following
out = MLP(x) * MLP_\sigmoid(x)
where that latter changes the last layer activation function into sigmoid.
- call(inputs: List[tensorflow.python.framework.ops.Tensor], **kwargs) tensorflow.python.framework.ops.Tensor ¶
Implementation of the layer call
- Args:
inputs (list): list representation of graph
Returns: tf.Tensor
- get_config() Dict ¶
Get layer configuration in dictionary format Returns: dict
- class m3gnet.layers.GraphFeaturizer(*args, **kwargs)¶
Bases:
m3gnet.layers._gn.GraphNetworkLayer
Graph featurizer that does several things to convert an initial graph with atomic number atom attributes and bond distance bond attributes to a graph with proper feature dimensions
- get_config() dict ¶
Get config dict for serialization Returns:
- class m3gnet.layers.GraphFieldEmbedding(*args, **kwargs)¶
Bases:
m3gnet.layers._base.GraphUpdateFunc
Embedding the categorical field of a graph to continuous space
- get_config() dict ¶
Get config dict for serialization Returns:
- class m3gnet.layers.GraphNetworkLayer(*args, **kwargs)¶
Bases:
m3gnet.layers._base.GraphUpdate
A graph network layer features bond/atom/state update in the sequence bond -> atom -> state. The input and output of each step are graphs/graph converted lists
- call(graph: List, **kwargs) List ¶
- Args:
graph (List): a graph in list representation **kwargs:
Returns: tf.Tensor
- get_config() dict ¶
Get config dict for serialization Returns: dict
- class m3gnet.layers.GraphUpdate(*args, **kwargs)¶
Bases:
keras.engine.base_layer.Layer
A graph update takes a graph list representation as input and output a updated graph
- class m3gnet.layers.GraphUpdateFunc(*args, **kwargs)¶
Bases:
m3gnet.layers._base.GraphUpdate
Update a graph with a function
- call(graph: List, **kwargs) List ¶
Call logic of the layer Args:
graph (list): list representation of the graph **kwargs:
Returns: graph list
- get_config() dict ¶
Get config dict for serialization Returns: config dict
- class m3gnet.layers.MLP(*args, **kwargs)¶
Bases:
keras.engine.base_layer.Layer
Multi-layer perceptron
- call(inputs: List[tensorflow.python.framework.ops.Tensor], **kwargs) tensorflow.python.framework.ops.Tensor ¶
Implementation of the layer call
- Args:
inputs (list): list representation of graph
Returns: tf.Tensor
- get_config() Dict ¶
Get layer configuration in dictionary format Returns: dict
- class m3gnet.layers.MultiFieldReadout(*args, **kwargs)¶
Bases:
m3gnet.layers._readout.ReadOut
Read both bond and atom
- call(graph: List, **kwargs) tensorflow.python.framework.ops.Tensor ¶
- Args:
graph (list): list repr of a MaterialGraph **kwargs:
Returns:
- get_config() dict ¶
Get config dict for serialization Returns: dict
- class m3gnet.layers.PairDistance(*args, **kwargs)¶
Bases:
keras.engine.base_layer.Layer
Compute pair distances from atom positions, bond indices, lattices and periodic offsets.
- class m3gnet.layers.PairRadialBasisExpansion(*args, **kwargs)¶
Bases:
m3gnet.layers._bond.BondNetwork
Expand the radial distance onto a basis
- call(graph: List, **kwargs)¶
On top of the changing bond, also add bond weight for later use Args:
graph (list): list representation of a graph **kwargs:
Returns:
- get_config() dict ¶
Get config dict for serialization Returns: dict
- update_bonds(graph: List) tensorflow.python.framework.ops.Tensor ¶
Update the bond info using RBF Args:
graph (list): list representation of a graph
Returns: updated bond info
- class m3gnet.layers.PairVector(*args, **kwargs)¶
Bases:
keras.engine.base_layer.Layer
Compute pair atom distance vectors from graph
- class m3gnet.layers.Pipe(*args, **kwargs)¶
Bases:
keras.engine.base_layer.Layer
Simple layer for consecutive layer calls, similar to Sequential
- call(inputs: List[tensorflow.python.framework.ops.Tensor], **kwargs) tensorflow.python.framework.ops.Tensor ¶
Run the inputs through the layers Args:
inputs (List): a graph in list representation **kwargs:
Returns: tf.Tensor
- classmethod from_config(config: Dict) m3gnet.layers._core.Pipe ¶
Construct Pipe object from a config dict Args:
config (dict): configuration dictionary
Returns: Pipe object
- get_config() Dict ¶
Get layer configuration in dictionary format Returns: dict
- class m3gnet.layers.RadialBasisFunctions(*args, **kwargs)¶
Bases:
keras.engine.base_layer.Layer
Radial distribution function basis
- call(r: tensorflow.python.framework.ops.Tensor, **kwargs)¶
- Args:
r (tf.Tensor): 1D radial distance tensor **kwargs:
Returns: radial basis functions
- get_config() dict ¶
Get config of the class for serialization Returns: dict
- class m3gnet.layers.ReadOut(*args, **kwargs)¶
Bases:
keras.engine.base_layer.Layer
Readout reduces a graph into a tensor
- class m3gnet.layers.ReduceReadOut(*args, **kwargs)¶
Bases:
m3gnet.layers._readout.ReadOut
Reduce atom or bond attributes into lower dimensional tensors as readout. This could be summing up the atoms or bonds, or taking the mean, etc.
- call(graph: List, **kwargs) tensorflow.python.framework.ops.Tensor ¶
- Args:
graph (list): list repr of a MaterialGraph **kwargs:
Returns:
- get_config() dict ¶
Get config dict for serialization Returns: dict
- class m3gnet.layers.Set2Set(*args, **kwargs)¶
Bases:
m3gnet.layers._readout.ReadOut
The Set2Set readout function
- call(graph: List, **kwargs) tensorflow.python.framework.ops.Tensor ¶
- Args:
graph (list): list repr of a MaterialGraph **kwargs:
Returns:
- get_config() dict ¶
Get config dict for serialization Returns: dict
- class m3gnet.layers.SphericalBesselWithHarmonics(*args, **kwargs)¶
Bases:
keras.engine.base_layer.Layer
Spherical bessel function as radial function and spherical harmonics as the angular function
- call(graph: List, **kwargs) tensorflow.python.framework.ops.Tensor ¶
- Args:
graph (list): the list representation of a graph **kwargs:
- Returns: combined radial and spherical harmonic expansion of the
distance and angle
- get_config() dict ¶
Get the config dict for serialization Returns: config dict
- class m3gnet.layers.StateNetwork(*args, **kwargs)¶
Bases:
m3gnet.layers._base.GraphUpdate
Edge network that takes a graph as input and then calculate new bond attributes
- call(graph: List, **kwargs) List ¶
- Args:
graph (list): list repr of a MaterialGraph **kwargs:
Returns: new graph in list repr
- update_states(graph: List) tensorflow.python.framework.ops.Tensor ¶
Calculate the new state attributes Args:
graph (list): list repr of a MaterialGraph
Returns: tf.Tensor
- class m3gnet.layers.ThreeDInteraction(*args, **kwargs)¶
Bases:
keras.engine.base_layer.Layer
Include 3D interactions to the bond update
- call(graph: List, three_basis: tensorflow.python.framework.ops.Tensor, three_cutoff: float, **kwargs) List ¶
- Args:
graph (list): graph list representation three_basis (tf.Tensor): three body basis expansion three_cutoff (float): cutoff radius **kwargs:
Returns:
- get_config() dict ¶
Get config dict for serialization Returns: dict
- class m3gnet.layers.WeightedReadout(*args, **kwargs)¶
Bases:
m3gnet.layers._readout.ReadOut
Perform a weighted average of the readout field. Weights are learnable from this layer
- call(graph: List, **kwargs) tensorflow.python.framework.ops.Tensor ¶
- Args:
graph (list): list repr of a MaterialGraph **kwargs:
Returns:
- get_config()¶
Get config dict for serialization Returns: dict
- m3gnet.layers.cosine(r: tensorflow.python.framework.ops.Tensor, cutoff: float) tensorflow.python.framework.ops.Tensor ¶
Cosine cutoff function Args:
r (tf.Tensor): radius distance tensor cutoff (float): cutoff distance
Returns: cosine cutoff functions
- m3gnet.layers.polynomial(r: tensorflow.python.framework.ops.Tensor, cutoff: float) tensorflow.python.framework.ops.Tensor ¶
Polynomial cutoff function Args:
r (tf.Tensor): radius distance tensor cutoff (float): cutoff distance
Returns: polynomial cutoff functions