Source code for megnet.utils.models
"""
Model utilities, mainly for model loading and download
"""
import logging
import os
from glob import glob
from zipfile import ZipFile
from megnet.models import MEGNetModel, GraphModel
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
CWD = os.path.dirname(os.path.abspath(__file__))
TEMP_PATH = os.path.join(CWD, "./mvl_models.zip")
LOCAL_MODEL_PATH = os.path.join(CWD, "./mvl_models")
MODEL_PATH = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../../mvl_models")
MODEL_MAPPING = {
"Eform_MP_2019": "mp-2019.4.1/formation_energy.hdf5",
"Eform_MP_2018": "mp-2018.6.1/formation_energy.hdf5",
"Efermi_MP_2019": "mp-2019.4.1/efermi.hdf5",
"Bandgap_classifier_MP_2018": "mp-2018.6.1/band_classification.hdf5",
"Bandgap_MP_2018": "mp-2018.6.1/band_gap_regression.hdf5",
"logK_MP_2018": "mp-2018.6.1/log10K.hdf5",
"logG_MP_2018": "mp-2018.6.1/log10G.hdf5",
"logK_MP_2019": "mp-2019.4.1/log10K.hdf5",
"logG_MP_2019": "mp-2019.4.1/log10G.hdf5",
}
qm9_models = glob(os.path.join(MODEL_PATH, "qm9-2018.6.1/*.hdf5"))
MODEL_MAPPING.update(
{f"QM9_{i}_2018": f"qm9-2018.6.1/{i}.hdf5" for i in [j.split("/")[-1].split(".")[0] for j in qm9_models]}
)
AVAILABLE_MODELS = list(MODEL_MAPPING.keys())
[docs]def load_model(model_name: str) -> GraphModel:
"""
load the model by user friendly name as in megnet.utils.models.AVAILABEL_MODELS
Args:
model_name: str model name string
Returns: GraphModel
"""
if model_name in AVAILABLE_MODELS:
mvl_path = os.path.join(MODEL_PATH, MODEL_MAPPING[model_name])
if os.path.isfile(mvl_path):
return MEGNetModel.from_file(mvl_path)
logger.info("Package-level mvl_models not included, trying " "temperary mvl_models downloads..")
local_mvl_path = os.path.join(LOCAL_MODEL_PATH, MODEL_MAPPING[model_name])
if os.path.isfile(local_mvl_path):
logger.info("Model found in local mvl_models path")
return MEGNetModel.from_file(local_mvl_path)
_download_models()
return load_model(model_name)
raise ValueError(f"model name {model_name} not in available model list {AVAILABLE_MODELS}")
def _download_models(url: str = "https://ndownloader.figshare.com/files/22291785", file_path: str = TEMP_PATH):
"""
Download machine learning model files
Args:
url: (str) url link for the models
"""
logger.info(f"Fetching {os.path.basename(file_path)} from {url} to {file_path}")
import urllib.request
urllib.request.urlretrieve(url, file_path)
logger.info("Start extracting models...")
with ZipFile(file_path, "r") as zip_obj:
zip_obj.extractall(os.path.dirname(file_path))