Source code for megnet.data.local_env

"""
Various NearNeighbors strategies to define local environments
of sites in structure/molecule. Most of them are directly
from pymatgen.analysis.local_env. The suitable NearNeighbors
should have get_nn_info method implemented and this method
needs to return a list of dict with each entry having following keys
['site', 'image', 'weight', 'site_index']

the weight will be used as the bond attributes in subsequent graph
construction

"""
from inspect import getfullargspec
from typing import Dict, List, Union

from pymatgen.core import Structure, Molecule
from pymatgen.analysis import local_env
from pymatgen.analysis.local_env import (
    NearNeighbors,
    VoronoiNN,
    JmolNN,
    MinimumDistanceNN,
    OpenBabelNN,
    CovalentBondNN,
    MinimumVIRENN,
    MinimumOKeeffeNN,
    BrunnerNN_reciprocal,
    BrunnerNN_real,
    BrunnerNN_relative,
    EconNN,
    CrystalNN,
    CutOffDictNN,
    Critic2NN,
)


[docs]class MinimumDistanceNNAll(NearNeighbors): """ Determine bonded sites by fixed cutoff """ def __init__(self, cutoff: float = 4.0): """ Args:. cutoff (float): cutoff radius in Angstrom to look for trial near-neighbor sites (default: 4.0). """ self.cutoff = cutoff
[docs] def get_nn_info(self, structure: Structure, n: int) -> List[Dict]: """ Get all near-neighbor sites as well as the associated image locations and weights of the site with index n using the closest neighbor distance-based method. Args: structure (Structure): input structure. n (integer): index of site for which to determine near neighbors. Returns: siw (list of tuples (Site, array, float)): tuples, each one of which represents a neighbor site, its image location, and its weight. """ site = structure[n] neighs_dists = structure.get_neighbors(site, self.cutoff) siw = [] for nn in neighs_dists: siw.append( { "site": nn, "image": self._get_image(structure, nn), "weight": nn.nn_distance, "site_index": self._get_original_site(structure, nn), } ) return siw
[docs]class AllAtomPairs(NearNeighbors): """ Get all combinations of atoms as bonds in a molecule """
[docs] def get_nn_info(self, molecule: Molecule, n: int) -> List[Dict]: """ Get near neighbor information Args: molecule (Molecule): pymatgen Molecule n (int): number of molecule Returns: List of neighbor dictionary """ site = molecule[n] siw = [] for i, s in enumerate(molecule): if i != n: siw.append({"site": s, "image": None, "weight": site.distance(s), "site_index": i}) return siw
[docs]def serialize(identifier: Union[str, NearNeighbors]): """ Serialize the local env objects to a dictionary Args: identifier: (NearNeighbors object/str/None) Returns: dictionary or None """ if isinstance(identifier, str): return identifier if isinstance(identifier, NearNeighbors): args = getfullargspec(identifier.__class__.__init__).args d = {"@module": identifier.__class__.__module__, "@class": identifier.__class__.__name__} for arg in args: if arg == "self": continue try: a = identifier.__getattribute__(arg) d[arg] = a except AttributeError: raise ValueError("Cannot find the argument") if hasattr(identifier, "kwargs"): d.update(**identifier.kwargs) return d if identifier is None: return None raise ValueError("Unknown identifier for local environment ", identifier)
[docs]def deserialize(config: Dict): """ Deserialize the config dict to object Args: config: (dict) nn_strategy config dict from seralize function Returns: object """ if config is None: return None if ("@module" not in config) or ("@class" not in config): raise ValueError("The config dict cannot be loaded") modname = config["@module"] classname = config["@class"] mod = __import__(modname, globals(), locals(), [classname]) cls_ = getattr(mod, classname) data = {k: v for k, v in config.items() if not k.startswith("@")} return cls_(**data)
NNDict = { i.__name__.lower(): i for i in [ NearNeighbors, VoronoiNN, JmolNN, MinimumDistanceNN, OpenBabelNN, CovalentBondNN, MinimumVIRENN, MinimumOKeeffeNN, BrunnerNN_reciprocal, BrunnerNN_real, BrunnerNN_relative, EconNN, CrystalNN, CutOffDictNN, Critic2NN, MinimumDistanceNNAll, AllAtomPairs, ] }
[docs]def get(identifier: Union[str, Dict, NearNeighbors]) -> NearNeighbors: """ Deserialize the NearNeighbors Args: identifier (str, dict or NearNeighbors): target for deserialize Returns: NearNeighbors instance """ # deserialize NearNeighbor from str if isinstance(identifier, str): if identifier.lower() in NNDict: return NNDict.get(identifier.lower()) # try pymatgen's local_env module nn = getattr(local_env, identifier, None) if nn is not None: return nn if isinstance(identifier, dict): return deserialize(identifier) if isinstance(identifier, NearNeighbors): return identifier raise ValueError(f"{identifier} not identified")