"""
Crystal graph related
"""
from copy import deepcopy
from pathlib import Path
from typing import Union, List, Dict
import numpy as np
from monty.serialization import loadfn
from pymatgen.core import Element
from pymatgen.core import Structure
from pymatgen.analysis.local_env import NearNeighbors
from megnet.data.graph import Converter
from megnet.data.graph import StructureGraph, StructureGraphFixedRadius
MODULE_DIR = Path(__file__).parent.absolute()
[docs]def get_elemental_embeddings() -> Dict:
"""
Provides the pre-trained elemental embeddings using formation energies,
which can be used to speed up the training of other models. The embeddings
are also extremely useful elemental descriptors that encode chemical
similarity that may be used in other ways. See
"Graph Networks as a Universal Machine Learning Framework for Molecules
and Crystals", https://arxiv.org/abs/1812.05055
:return: Dict of elemental embeddings as {symbol: length 16 string}
"""
return loadfn(MODULE_DIR / "resources" / "elemental_embedding_1MEGNet_layer.json")
[docs]class CrystalGraph(StructureGraphFixedRadius):
"""
Convert a crystal into a graph with z as atomic feature and distance as bond feature
one can optionally include state features
"""
def __init__(
self,
nn_strategy: Union[str, NearNeighbors] = "MinimumDistanceNNAll",
atom_converter: Converter = None,
bond_converter: Converter = None,
cutoff: float = 5.0,
):
"""
Convert the structure into crystal graph
Args:
nn_strategy (str): NearNeighbor strategy
atom_converter (Converter): atom features converter
bond_converter (Converter): bond features converter
cutoff (float): cutoff radius
"""
self.cutoff = cutoff
super().__init__(
nn_strategy=nn_strategy, atom_converter=atom_converter, bond_converter=bond_converter, cutoff=self.cutoff
)
[docs]class CrystalGraphWithBondTypes(StructureGraph):
"""
Overwrite the bond attributes with bond types, defined simply by
the metallicity of the atoms forming the bond. Three types of
scenario is considered, nonmetal-nonmetal (type 0), metal-nonmetal (type 1), and
metal-metal (type 2)
"""
def __init__(
self,
nn_strategy: Union[str, NearNeighbors] = "VoronoiNN",
atom_converter: Converter = None,
bond_converter: Converter = None,
):
"""
Args:
nn_strategy (str): NearNeighbor strategy
atom_converter (Converter): atom features converter
bond_converter (Converter): bond features converter
"""
super().__init__(nn_strategy=nn_strategy, atom_converter=atom_converter, bond_converter=bond_converter)
[docs] def convert(self, structure: Structure, state_attributes: List = None) -> Dict:
"""
Convert structure into graph
Args:
structure (Structure): pymatgen Structure
state_attributes (list): state attributes
Returns: graph dictionary
"""
graph = super().convert(structure, state_attributes=state_attributes)
return self._get_bond_type(graph)
@staticmethod
def _get_bond_type(graph) -> Dict:
new_graph = deepcopy(graph)
elements = [Element.from_Z(i) for i in graph["atom"]]
for k, (i, j) in enumerate(zip(graph["index1"], graph["index2"])):
new_graph["bond"][k] = elements[i].is_metal + elements[j].is_metal
return new_graph
class _AtomEmbeddingMap(Converter):
"""
Fixed Atom embedding map, used with CrystalGraphDisordered
"""
def __init__(self, embedding_dict: dict = None):
"""
Args:
embedding_dict (dict): element to element vector dictionary
"""
if embedding_dict is None:
embedding_dict = get_elemental_embeddings()
self.embedding_dict = embedding_dict
def convert(self, atoms: list) -> np.ndarray:
"""
Convert atom {symbol: fraction} list to numeric features
"""
features = []
for atom in atoms:
emb = 0
for k, v in atom.items():
emb += np.array(self.embedding_dict[k]) * v
features.append(emb)
return np.array(features).reshape((len(atoms), -1))
[docs]class CrystalGraphDisordered(StructureGraphFixedRadius):
"""
Enable disordered site predictions
"""
def __init__(
self,
nn_strategy: Union[str, NearNeighbors] = "MinimumDistanceNNAll",
atom_converter: Converter = _AtomEmbeddingMap(),
bond_converter: Converter = None,
cutoff: float = 5.0,
):
"""
Convert the structure into crystal graph
Args:
nn_strategy (str): NearNeighbor strategy
atom_converter (Converter): atom features converter
bond_converter (Converter): bond features converter
cutoff (float): cutoff radius
"""
self.cutoff = cutoff
super().__init__(
nn_strategy=nn_strategy, atom_converter=atom_converter, bond_converter=bond_converter, cutoff=self.cutoff
)
[docs] @staticmethod
def get_atom_features(structure) -> List[dict]:
"""
For a structure return the list of dictionary for the site occupancy
for example, Fe0.5Ni0.5 site will be returned as {"Fe": 0.5, "Ni": 0.5}
Args:
structure (Structure): pymatgen Structure with potential site disorder
Returns:
a list of site fraction description
"""
return [i.species.as_dict() for i in structure.sites]