Source code for megnet.data.crystal

"""
Crystal graph related
"""
from copy import deepcopy
from pathlib import Path
from typing import Union, List, Dict

import numpy as np
from monty.serialization import loadfn
from pymatgen.core import Element
from pymatgen.core import Structure
from pymatgen.analysis.local_env import NearNeighbors

from megnet.data.graph import Converter
from megnet.data.graph import StructureGraph, StructureGraphFixedRadius

MODULE_DIR = Path(__file__).parent.absolute()


[docs]def get_elemental_embeddings() -> Dict: """ Provides the pre-trained elemental embeddings using formation energies, which can be used to speed up the training of other models. The embeddings are also extremely useful elemental descriptors that encode chemical similarity that may be used in other ways. See "Graph Networks as a Universal Machine Learning Framework for Molecules and Crystals", https://arxiv.org/abs/1812.05055 :return: Dict of elemental embeddings as {symbol: length 16 string} """ return loadfn(MODULE_DIR / "resources" / "elemental_embedding_1MEGNet_layer.json")
[docs]class CrystalGraph(StructureGraphFixedRadius): """ Convert a crystal into a graph with z as atomic feature and distance as bond feature one can optionally include state features """ def __init__( self, nn_strategy: Union[str, NearNeighbors] = "MinimumDistanceNNAll", atom_converter: Converter = None, bond_converter: Converter = None, cutoff: float = 5.0, ): """ Convert the structure into crystal graph Args: nn_strategy (str): NearNeighbor strategy atom_converter (Converter): atom features converter bond_converter (Converter): bond features converter cutoff (float): cutoff radius """ self.cutoff = cutoff super().__init__( nn_strategy=nn_strategy, atom_converter=atom_converter, bond_converter=bond_converter, cutoff=self.cutoff )
[docs]class CrystalGraphWithBondTypes(StructureGraph): """ Overwrite the bond attributes with bond types, defined simply by the metallicity of the atoms forming the bond. Three types of scenario is considered, nonmetal-nonmetal (type 0), metal-nonmetal (type 1), and metal-metal (type 2) """ def __init__( self, nn_strategy: Union[str, NearNeighbors] = "VoronoiNN", atom_converter: Converter = None, bond_converter: Converter = None, ): """ Args: nn_strategy (str): NearNeighbor strategy atom_converter (Converter): atom features converter bond_converter (Converter): bond features converter """ super().__init__(nn_strategy=nn_strategy, atom_converter=atom_converter, bond_converter=bond_converter)
[docs] def convert(self, structure: Structure, state_attributes: List = None) -> Dict: """ Convert structure into graph Args: structure (Structure): pymatgen Structure state_attributes (list): state attributes Returns: graph dictionary """ graph = super().convert(structure, state_attributes=state_attributes) return self._get_bond_type(graph)
@staticmethod def _get_bond_type(graph) -> Dict: new_graph = deepcopy(graph) elements = [Element.from_Z(i) for i in graph["atom"]] for k, (i, j) in enumerate(zip(graph["index1"], graph["index2"])): new_graph["bond"][k] = elements[i].is_metal + elements[j].is_metal return new_graph
class _AtomEmbeddingMap(Converter): """ Fixed Atom embedding map, used with CrystalGraphDisordered """ def __init__(self, embedding_dict: dict = None): """ Args: embedding_dict (dict): element to element vector dictionary """ if embedding_dict is None: embedding_dict = get_elemental_embeddings() self.embedding_dict = embedding_dict def convert(self, atoms: list) -> np.ndarray: """ Convert atom {symbol: fraction} list to numeric features """ features = [] for atom in atoms: emb = 0 for k, v in atom.items(): emb += np.array(self.embedding_dict[k]) * v features.append(emb) return np.array(features).reshape((len(atoms), -1))
[docs]class CrystalGraphDisordered(StructureGraphFixedRadius): """ Enable disordered site predictions """ def __init__( self, nn_strategy: Union[str, NearNeighbors] = "MinimumDistanceNNAll", atom_converter: Converter = _AtomEmbeddingMap(), bond_converter: Converter = None, cutoff: float = 5.0, ): """ Convert the structure into crystal graph Args: nn_strategy (str): NearNeighbor strategy atom_converter (Converter): atom features converter bond_converter (Converter): bond features converter cutoff (float): cutoff radius """ self.cutoff = cutoff super().__init__( nn_strategy=nn_strategy, atom_converter=atom_converter, bond_converter=bond_converter, cutoff=self.cutoff )
[docs] @staticmethod def get_atom_features(structure) -> List[dict]: """ For a structure return the list of dictionary for the site occupancy for example, Fe0.5Ni0.5 site will be returned as {"Fe": 0.5, "Ni": 0.5} Args: structure (Structure): pymatgen Structure with potential site disorder Returns: a list of site fraction description """ return [i.species.as_dict() for i in structure.sites]