m3gnet.layers package

Module contents

Graph layers

class m3gnet.layers.AtomNetwork(*args, **kwargs)

Bases: m3gnet.layers._base.GraphUpdate

Atom network that takes a graph as input and then calculate new atom attributes

call(graph: List, **kwargs) List
Args:

graph (list): list repr of a MaterialGraph **kwargs:

Returns: a new MaterialGraph in list repr

update_atoms(graph: List) tensorflow.python.framework.ops.Tensor

Take a graph input and calculate the updated atom attributes Args:

graph (list): list repr of a MaterialGraph

Returns: tf.Tensor

class m3gnet.layers.AtomReduceState(*args, **kwargs)

Bases: keras.engine.base_layer.Layer

Reduce atom attributes to states via sum or mean

call(graph: List, **kwargs) tensorflow.python.framework.ops.Tensor

Main layer logic Args:

graph (list): input graph list representation **kwargs:

Returns:

get_config() dict

Get the configuration dict Returns:

class m3gnet.layers.AtomRef(*args, **kwargs)

Bases: m3gnet.layers._atom_ref.BaseAtomRef

Atom reference values. For example, if the average H energy is -20.0, and the average O energy is -10.0, the AtomRef predicts -20 * 2 + (-10.) = -50.0 for the atom reference energy for H2O

call(graph: List, **kwargs)
Args:

graph (list): a list repr of a graph **kwargs:

Returns:

fit(structs_or_graphs, properties)

Fit the elemental reference values for the properties Args:

structs_or_graphs (list): list of graphs or structures properties (np.ndarray): array of extensive properties

Returns:

get_config()

Get dict config for serialization Returns (dict):

inverse_transform(structs_or_graphs, properties)

Take the transformed values and get the original values Args:

structs_or_graphs (list): list of graphs or structures properties (np.ndarray): array of extensive properties

Returns: corrected property values

predict_properties(structs_or_graphs)
Args:
structs_or_graphs (list): calculate the atom summed property

values

Returns:

set_property_per_element(property_per_element)

Set the property per atom value Args:

property_per_element (np.ndarray): array of elemental properties,

the i-th row is the elemental value for atomic number i.

Returns:

transform(structs_or_graphs, properties)

Correct the extensive properties by subtracting the atom reference values Args:

structs_or_graphs (list): list of graphs or structures properties (np.ndarray): array of extensive properties

Returns: corrected property values

class m3gnet.layers.BaseAtomRef(*args, **kwargs)

Bases: keras.engine.base_layer.Layer

Base AtomRef that predicts 0 correction

call(graph: List, **kwargs)
Args:

graph (list): list repr of a graph **kwargs:

Returns: 0

class m3gnet.layers.BondNetwork(*args, **kwargs)

Bases: m3gnet.layers._base.GraphUpdate

Edge network that takes a graph as input and then calculate new bond attributes

call(graph: List, **kwargs) List

Update the bond and return a copy of the graph Args:

graph (list): graph list representation **kwargs:

Returns: graph list representation

update_bonds(graph: List) tensorflow.python.framework.ops.Tensor

Update bond info in the graph Args:

graph (list): list representation of a graph

Returns: tf.Tensor

class m3gnet.layers.ConcatAtoms(*args, **kwargs)

Bases: m3gnet.layers._bond.BondNetwork

\[eij^\prime = Update(vi⊕vj⊕eij⊕u)\]
get_config() dict

get config dict for serialization Returns: dict

update_bonds(graph: List, **kwargs) tensorflow.python.framework.ops.Tensor

Update bond information Args:

graph (list): list representation of the graph **kwargs:

Returns:

class m3gnet.layers.ConcatBondAtomState(*args, **kwargs)

Bases: m3gnet.layers._state.StateNetwork

u^prime = Update(bar e^prime⊕bar v^prime⊕u)

get_config() dict

Get config dict for serialization Returns: dict

update_states(graph: List) tensorflow.python.framework.ops.Tensor
Args:

graph (list): list repr of a MaterialGraph

Returns:

class m3gnet.layers.Embedding(*args, **kwargs)

Bases: keras.layers.embeddings.Embedding

Thin wrapper for embedding atomic numbers into feature vectors

call(inputs: List[tensorflow.python.framework.ops.Tensor]) tensorflow.python.framework.ops.Tensor

Implementation of the layer call

Args:

inputs (list): list representation of graph

Returns: tf.Tensor

class m3gnet.layers.GatedAtomUpdate(*args, **kwargs)

Bases: m3gnet.layers._atom.AtomNetwork

Take the neighbor atom attributes and bond attributes, update them to the center atom

get_config() dict

Get config dict for serialization Returns: dict

update_atoms(graph: List) tensorflow.python.framework.ops.Tensor

Update atom attributes Args:

graph (list): list repr of MaterialGraph

Returns: tf.Tensor

class m3gnet.layers.GatedMLP(*args, **kwargs)

Bases: keras.engine.base_layer.Layer

Gated MLP implementation. It implements the following

out = MLP(x) * MLP_\sigmoid(x)

where that latter changes the last layer activation function into sigmoid.

call(inputs: List[tensorflow.python.framework.ops.Tensor], **kwargs) tensorflow.python.framework.ops.Tensor

Implementation of the layer call

Args:

inputs (list): list representation of graph

Returns: tf.Tensor

get_config() Dict

Get layer configuration in dictionary format Returns: dict

class m3gnet.layers.GraphFeaturizer(*args, **kwargs)

Bases: m3gnet.layers._gn.GraphNetworkLayer

Graph featurizer that does several things to convert an initial graph with atomic number atom attributes and bond distance bond attributes to a graph with proper feature dimensions

get_config() dict

Get config dict for serialization Returns:

class m3gnet.layers.GraphFieldEmbedding(*args, **kwargs)

Bases: m3gnet.layers._base.GraphUpdateFunc

Embedding the categorical field of a graph to continuous space

get_config() dict

Get config dict for serialization Returns:

class m3gnet.layers.GraphNetworkLayer(*args, **kwargs)

Bases: m3gnet.layers._base.GraphUpdate

A graph network layer features bond/atom/state update in the sequence bond -> atom -> state. The input and output of each step are graphs/graph converted lists

call(graph: List, **kwargs) List
Args:

graph (List): a graph in list representation **kwargs:

Returns: tf.Tensor

get_config() dict

Get config dict for serialization Returns: dict

class m3gnet.layers.GraphUpdate(*args, **kwargs)

Bases: keras.engine.base_layer.Layer

A graph update takes a graph list representation as input and output a updated graph

call(graph: List, **kwargs) List
Args:

graph (list): list representation of a graph **kwargs:

Returns:

class m3gnet.layers.GraphUpdateFunc(*args, **kwargs)

Bases: m3gnet.layers._base.GraphUpdate

Update a graph with a function

call(graph: List, **kwargs) List

Call logic of the layer Args:

graph (list): list representation of the graph **kwargs:

Returns: graph list

get_config() dict

Get config dict for serialization Returns: config dict

class m3gnet.layers.MLP(*args, **kwargs)

Bases: keras.engine.base_layer.Layer

Multi-layer perceptron

call(inputs: List[tensorflow.python.framework.ops.Tensor], **kwargs) tensorflow.python.framework.ops.Tensor

Implementation of the layer call

Args:

inputs (list): list representation of graph

Returns: tf.Tensor

get_config() Dict

Get layer configuration in dictionary format Returns: dict

class m3gnet.layers.MultiFieldReadout(*args, **kwargs)

Bases: m3gnet.layers._readout.ReadOut

Read both bond and atom

call(graph: List, **kwargs) tensorflow.python.framework.ops.Tensor
Args:

graph (list): list repr of a MaterialGraph **kwargs:

Returns:

get_config() dict

Get config dict for serialization Returns: dict

class m3gnet.layers.PairDistance(*args, **kwargs)

Bases: keras.engine.base_layer.Layer

Compute pair distances from atom positions, bond indices, lattices and periodic offsets.

call(graph: List, **kwargs)

Calculate the pair distance from a MaterialGraph. Args:

graph (list): A list representation of a MaterialGraph object **kwargs:

Returns: tf.Tensor distance tensor

class m3gnet.layers.PairRadialBasisExpansion(*args, **kwargs)

Bases: m3gnet.layers._bond.BondNetwork

Expand the radial distance onto a basis

call(graph: List, **kwargs)

On top of the changing bond, also add bond weight for later use Args:

graph (list): list representation of a graph **kwargs:

Returns:

get_config() dict

Get config dict for serialization Returns: dict

update_bonds(graph: List) tensorflow.python.framework.ops.Tensor

Update the bond info using RBF Args:

graph (list): list representation of a graph

Returns: updated bond info

class m3gnet.layers.PairVector(*args, **kwargs)

Bases: keras.engine.base_layer.Layer

Compute pair atom distance vectors from graph

call(graph: List, **kwargs)

Calculate the pair vector distance from a MaterialGraph. Args:

graph (List): A MaterialGraph object **kwargs:

Returns: tf.Tensor distance vector tensor

class m3gnet.layers.Pipe(*args, **kwargs)

Bases: keras.engine.base_layer.Layer

Simple layer for consecutive layer calls, similar to Sequential

call(inputs: List[tensorflow.python.framework.ops.Tensor], **kwargs) tensorflow.python.framework.ops.Tensor

Run the inputs through the layers Args:

inputs (List): a graph in list representation **kwargs:

Returns: tf.Tensor

classmethod from_config(config: Dict) m3gnet.layers._core.Pipe

Construct Pipe object from a config dict Args:

config (dict): configuration dictionary

Returns: Pipe object

get_config() Dict

Get layer configuration in dictionary format Returns: dict

class m3gnet.layers.RadialBasisFunctions(*args, **kwargs)

Bases: keras.engine.base_layer.Layer

Radial distribution function basis

call(r: tensorflow.python.framework.ops.Tensor, **kwargs)
Args:

r (tf.Tensor): 1D radial distance tensor **kwargs:

Returns: radial basis functions

get_config() dict

Get config of the class for serialization Returns: dict

class m3gnet.layers.ReadOut(*args, **kwargs)

Bases: keras.engine.base_layer.Layer

Readout reduces a graph into a tensor

call(graph: List, **kwargs) tensorflow.python.framework.ops.Tensor
Args:

graph (list): list repr of a MaterialGraph **kwargs:

Returns: tf.Tensor, tensor readout

class m3gnet.layers.ReduceReadOut(*args, **kwargs)

Bases: m3gnet.layers._readout.ReadOut

Reduce atom or bond attributes into lower dimensional tensors as readout. This could be summing up the atoms or bonds, or taking the mean, etc.

call(graph: List, **kwargs) tensorflow.python.framework.ops.Tensor
Args:

graph (list): list repr of a MaterialGraph **kwargs:

Returns:

get_config() dict

Get config dict for serialization Returns: dict

class m3gnet.layers.Set2Set(*args, **kwargs)

Bases: m3gnet.layers._readout.ReadOut

The Set2Set readout function

call(graph: List, **kwargs) tensorflow.python.framework.ops.Tensor
Args:

graph (list): list repr of a MaterialGraph **kwargs:

Returns:

get_config() dict

Get config dict for serialization Returns: dict

class m3gnet.layers.SphericalBesselWithHarmonics(*args, **kwargs)

Bases: keras.engine.base_layer.Layer

Spherical bessel function as radial function and spherical harmonics as the angular function

call(graph: List, **kwargs) tensorflow.python.framework.ops.Tensor
Args:

graph (list): the list representation of a graph **kwargs:

Returns: combined radial and spherical harmonic expansion of the

distance and angle

get_config() dict

Get the config dict for serialization Returns: config dict

class m3gnet.layers.StateNetwork(*args, **kwargs)

Bases: m3gnet.layers._base.GraphUpdate

Edge network that takes a graph as input and then calculate new bond attributes

call(graph: List, **kwargs) List
Args:

graph (list): list repr of a MaterialGraph **kwargs:

Returns: new graph in list repr

update_states(graph: List) tensorflow.python.framework.ops.Tensor

Calculate the new state attributes Args:

graph (list): list repr of a MaterialGraph

Returns: tf.Tensor

class m3gnet.layers.ThreeDInteraction(*args, **kwargs)

Bases: keras.engine.base_layer.Layer

Include 3D interactions to the bond update

call(graph: List, three_basis: tensorflow.python.framework.ops.Tensor, three_cutoff: float, **kwargs) List
Args:

graph (list): graph list representation three_basis (tf.Tensor): three body basis expansion three_cutoff (float): cutoff radius **kwargs:

Returns:

get_config() dict

Get config dict for serialization Returns: dict

class m3gnet.layers.WeightedReadout(*args, **kwargs)

Bases: m3gnet.layers._readout.ReadOut

Perform a weighted average of the readout field. Weights are learnable from this layer

call(graph: List, **kwargs) tensorflow.python.framework.ops.Tensor
Args:

graph (list): list repr of a MaterialGraph **kwargs:

Returns:

get_config()

Get config dict for serialization Returns: dict

m3gnet.layers.cosine(r: tensorflow.python.framework.ops.Tensor, cutoff: float) tensorflow.python.framework.ops.Tensor

Cosine cutoff function Args:

r (tf.Tensor): radius distance tensor cutoff (float): cutoff distance

Returns: cosine cutoff functions

m3gnet.layers.polynomial(r: tensorflow.python.framework.ops.Tensor, cutoff: float) tensorflow.python.framework.ops.Tensor

Polynomial cutoff function Args:

r (tf.Tensor): radius distance tensor cutoff (float): cutoff distance

Returns: polynomial cutoff functions