Source code for megnet.utils.descriptor

"""
This module implements atom/bond/structure-wise descriptor calculated from
pretrained megnet model
"""

import os
from typing import Union, Dict

import numpy as np
from tensorflow.keras.models import Model

from megnet.models import MEGNetModel, GraphModel
from megnet.utils.typing import StructureOrMolecule

DEFAULT_MODEL = os.path.join(os.path.dirname(__file__), "../../mvl_models/mp-2019.4.1/formation_energy.hdf5")


[docs]class MEGNetDescriptor: """ MEGNet descriptors. This class takes a trained model and then compute the intermediate outputs as structure features """ def __init__(self, model_name: Union[str, GraphModel, MEGNetModel] = DEFAULT_MODEL, use_cache: bool = True): """ Args: model_name (str or MEGNetModel): trained model. If it is str, then only models in mvl_models are used. use_cache (bool): whether to use cache for structure graph calculations """ if isinstance(model_name, str): model = MEGNetModel.from_file(model_name) elif isinstance(model_name, GraphModel): model = model_name else: raise ValueError("model_name only support str or GraphModel object") layers = model.layers important_prefix = ["meg", "set", "concatenate"] all_names = [i.name for i in layers if any(i.name.startswith(j) for j in important_prefix)] if any(i.startswith("megnet") for i in all_names): self.version = "v2" else: self.version = "v1" valid_outputs = [i.output for i in layers if any(i.name.startswith(j) for j in important_prefix)] outputs = [] valid_names = [] for i, j in zip(all_names, valid_outputs): if isinstance(j, list): for k, l in enumerate(j): valid_names.append(i + f"_{k}") outputs.append(l) else: valid_names.append(i) outputs.append(j) full_model = Model(inputs=model.inputs, outputs=outputs) model.model = full_model self.model = model self.valid_names = valid_names self._cache: Dict[str, float] = {} self.use_cache = use_cache def _predict_structure(self, structure: StructureOrMolecule) -> np.ndarray: graph = self.model.graph_converter.convert(structure) inp = self.model.graph_converter.graph_to_input(graph) return self.model.predict(inp) def _predict_feature(self, structure: StructureOrMolecule) -> np.ndarray: if not self.use_cache: return self._predict_structure(structure) s = str(structure) if s in self._cache: return self._cache[s] result = self._predict_structure(structure) self._cache[s] = result return result def _get_features(self, structure: StructureOrMolecule, prefix: str, level: int, index: int = None) -> np.ndarray: name = prefix if level is not None: name = f"{prefix}_{level}" if index is not None: name += f"_{index}" if name not in self.valid_names: raise ValueError(f"{name} not in original megnet model") ind = self.valid_names.index(name) out_all = self._predict_feature(structure) return out_all[ind][0] def _get_updated_prefix_level(self, prefix: str, level: int): mapping = { "meg_net_layer": ["megnet", level - 1], "set2_set": ["set2set_atom" if level == 1 else "set2set_bond", None], "concatenate": ["concatenate", None], } if self.version == "v2": return mapping[prefix][0], mapping[prefix][1] # type: ignore return prefix, level
[docs] def get_atom_features(self, structure: StructureOrMolecule, level: int = 3) -> np.ndarray: """ Get megnet atom features from structure Args: structure: pymatgen structure or molecule level: int, indicating the block number of megnet, starting from 1 Returns: nxm atomic feature matrix """ prefix, level = self._get_updated_prefix_level("meg_net_layer", level) return self._get_features(structure, prefix=prefix, level=level, index=0)
[docs] def get_bond_features(self, structure: StructureOrMolecule, level: int = 3) -> np.ndarray: """ Get bond features at megnet block level Args: structure: pymatgen structure level: int Returns: n_bond x m bond feature matrix """ prefix, level = self._get_updated_prefix_level("meg_net_layer", level) return self._get_features(structure, prefix=prefix, level=level, index=1)
[docs] def get_global_features(self, structure: StructureOrMolecule, level: int = 2) -> np.ndarray: """ Get state features at megnet block level Args: structure: pymatgen structure or molecule level: int Returns: 1 x m_g global feature vector """ prefix, level = self._get_updated_prefix_level("meg_net_layer", level) return self._get_features(structure, prefix=prefix, level=level, index=2)
[docs] def get_set2set(self, structure: StructureOrMolecule, ftype: str = "atom") -> np.ndarray: """ Get set2set output as features Args: structure (StructureOrMolecule): pymatgen structure or molecule ftype (str): atom or bond Returns: feature matrix, each row is a vector for an atom or bond """ mapping = {"atom": 1, "bond": 2} prefix, level = self._get_updated_prefix_level("set2_set", level=mapping[ftype]) return self._get_features(structure, prefix=prefix, level=level)
[docs] def get_structure_features(self, structure: StructureOrMolecule) -> np.ndarray: """ Get structure level feature vector Args: structure (StructureOrMolecule): pymatgen structure or molecule Returns: one feature vector for the structure """ prefix, level = self._get_updated_prefix_level("concatenate", level=1) return self._get_features(structure, prefix=prefix, level=level)