Source code for megnet.callbacks

"""
callbacks functions used in training process
"""
import logging
import os
import re
import warnings
from collections import deque
from glob import glob
from typing import Dict

import numpy as np
import tensorflow.keras.backend as kb
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.utils import Sequence

from megnet.utils.metrics import mae, accuracy
from megnet.utils.preprocessing import DummyScaler, Scaler

logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


[docs]class ModelCheckpointMAE(Callback): """ Save the best MAE model with target scaler """ def __init__( self, filepath: str = "./callback/val_mae_{epoch:05d}_{val_mae:.6f}.hdf5", monitor: str = "val_mae", verbose: int = 0, save_best_only: bool = True, save_weights_only: bool = False, val_gen: Sequence = None, steps_per_val: int = None, target_scaler: Scaler = None, period: int = 1, mode: str = "auto", ): """ Args: filepath (string): path to save the model file with format. For example `weights.{epoch:02d}-{val_mae:.6f}.hdf5` will save the corresponding epoch and val_mae in the filename monitor (string): quantity to monitor, default to "val_mae" verbose (int): 0 for no training log, 1 for only epoch-level log and 2 for batch-level log save_best_only (bool): whether to save only the best model save_weights_only (bool): whether to save the weights only excluding model structure val_gen (generator): validation generator steps_per_val (int): steps per epoch for validation generator target_scaler (object): exposing inverse_transform method to scale the output period (int): number of epoch interval for this callback mode: (string) choose from "min", "max" or "auto" """ super().__init__() if val_gen is None: raise ValueError("No validation data is provided!") self.verbose = verbose if self.verbose > 0: logging.basicConfig(level=logging.INFO) self.filepath = filepath self.save_best_only = save_best_only self.save_weights_only = save_weights_only self.period = period self.epochs_since_last_save = 0 self.val_gen = val_gen self.steps_per_val = steps_per_val or len(val_gen) self.target_scaler = target_scaler or DummyScaler() if monitor == "val_mae": self.metric = mae self.monitor = "val_mae" elif monitor == "val_acc": self.metric = accuracy self.filepath = self.filepath.replace("val_mae", "val_acc") self.monitor = "val_acc" if mode == "min": self.monitor_op = np.less self.best = np.Inf elif mode == "max": self.monitor_op = np.greater self.best = -np.Inf else: if "acc" in self.monitor or self.monitor.startswith("fmeasure"): self.monitor_op = np.greater self.best = -np.Inf else: self.monitor_op = np.less self.best = np.Inf
[docs] def on_epoch_end(self, epoch: int, logs: Dict = None) -> None: """ Codes called by the callback at the end of epoch Args: epoch (int): epoch id logs (dict): logs of training Returns: None """ self.epochs_since_last_save += 1 if self.epochs_since_last_save >= self.period: self.epochs_since_last_save = 0 val_pred = [] val_y = [] for i in range(self.steps_per_val): val_data = self.val_gen[i] # type: ignore nb_atom = _count(np.array(val_data[0][-2])) stop_training = self.model.stop_training # save stop_trainings state pred_ = self.model.predict(val_data[0]) self.model.stop_training = stop_training val_pred.append(self.target_scaler.inverse_transform(pred_[0, :, :], nb_atom[:, None])) val_y.append(self.target_scaler.inverse_transform(val_data[1][0, :, :], nb_atom[:, None])) current = self.metric(np.concatenate(val_y, axis=0), np.concatenate(val_pred, axis=0)) filepath = self.filepath.format(**{"epoch": epoch + 1, self.monitor: current}) if self.save_best_only: if current is None: warnings.warn(f"Can save best model only with {self.monitor} available, skipping.", RuntimeWarning) else: if self.monitor_op(current, self.best): logger.info( f"\nEpoch {epoch+1:05d}: {self.monitor} improved from {self.best:.5f} to {current:.5f}," f" saving model to {filepath}" ) self.best = current if self.save_weights_only: self.model.save_weights(filepath, overwrite=True) else: self.model.save(filepath, overwrite=True) else: if self.verbose > 0: logger.info(f"\nEpoch {epoch+1:05d}: {self.monitor} did not improve from {self.best:.5f}") else: logger.info(f"\nEpoch {epoch+1:05d}: saving model to {filepath}") if self.save_weights_only: self.model.save_weights(filepath, overwrite=True) else: self.model.save(filepath, overwrite=True)
[docs]class ManualStop(Callback): """ Stop the training manually by putting a "STOP" file in the directory """
[docs] def on_batch_end(self, epoch: int, logs: Dict = None) -> None: """ Codes called at the end of a batch Args: epoch (int): epoch id logs (Dict): log dict Returns: None """ if os.path.isfile("STOP"): self.model.stop_training = True
[docs]class ReduceLRUponNan(Callback): """ This callback function solves a problem that when doing regression, an nan loss may occur, or the loss suddenly shoot up. If such things happen, the model will reduce the learning rate and load the last best model during the training process. It has an extra function that patience for early stopping. This will move to indepedent callback in the future. """ def __init__( self, filepath: str = "./callback/val_mae_{epoch:05d}_{val_mae:.6f}.hdf5", factor: float = 0.5, verbose: bool = True, patience: int = 500, monitor: str = "val_mae", mode: str = "auto", has_sample_weights: bool = False, ): """ Args: filepath (str): filepath for saved model checkpoint, should be consistent with checkpoint callback factor (float): a value < 1 for scaling the learning rate verbose (bool): whether to show the loading event patience (int): number of steps that the val mae does not change. It is a criteria for early stopping monitor (str): target metric to monitor mode (str): min, max or auto has_sample_weights (bool): whether the data has sample weights """ self.filepath = filepath self.verbose = verbose self.factor = factor self.losses: deque = deque([], maxlen=10) self.patience = patience self.monitor = monitor super().__init__() if mode == "min": self.monitor_op = np.argmin elif mode == "max": self.monitor_op = np.argmax else: if "acc" in self.monitor: self.monitor_op = np.argmax else: self.monitor_op = np.argmin # get variable name variable_name_pattern = r"{(.+?)}" self.variable_names = re.findall(variable_name_pattern, filepath) self.variable_names = [i.split(":")[0] for i in self.variable_names] self.has_sample_weights = has_sample_weights if self.monitor not in self.variable_names: raise ValueError("The monitored metric should be in the name pattern")
[docs] def on_epoch_end(self, epoch: int, logs: Dict = None): """ Check the loss value at the end of an epoch Args: epoch (int): epoch id logs (dict): log history Returns: None """ logs = logs or {} loss = logs.get("loss") last_saved_epoch, last_metric, last_file = self._get_checkpoints() if last_saved_epoch is not None: if last_saved_epoch + self.patience <= epoch: self.model.stop_training = True logger.info(f"{self.monitor} does not improve after {self.patience}, stopping the fitting...") if loss is not None: self.losses.append(loss) if np.isnan(loss) or np.isinf(loss): if self.verbose: logger.info("Nan loss found!") self._reduce_lr_and_load(last_file) if self.verbose: logger.info(f"Now lr is {float(kb.eval(self.model.optimizer.lr))}.") else: if len(self.losses) > 1: if self.losses[-1] > (self.losses[-2] * 100): self._reduce_lr_and_load(last_file) if self.verbose: logger.info( f"Loss shot up from {self.losses[-2]:.3f} to {self.losses[-1]:.3f}! Reducing lr " ) logger.info(f"Now lr is {float(kb.eval(self.model.optimizer.lr))}.")
def _reduce_lr_and_load(self, last_file): old_value = float(kb.eval(self.model.optimizer.lr)) self.model.reset_states() self.model.optimizer.lr = old_value * self.factor if last_file is not None: self.model.load_weights(last_file) if self.verbose: logger.info(f"Load weights {last_file}") else: logger.info("No weights were loaded") opt_dict = self.model.optimizer.get_config() sample_weight_model = "temporal" if self.has_sample_weights else None self.model.compile( self.model.optimizer.__class__(**opt_dict), self.model.loss, sample_weight_mode=sample_weight_model ) def _get_checkpoints(self): file_pattern = re.sub(r"{(.+?)}", r"([0-9\.]+)", self.filepath) glob_pattern = re.sub(r"{(.+?)}", r"*", self.filepath) all_check_points = glob(glob_pattern) if len(all_check_points) > 0: metric_index = self.variable_names.index(self.monitor) epoch_index = self.variable_names.index("epoch") metric_values = [] epochs = [] for i in all_check_points: metrics = re.findall(file_pattern, i)[0] metric_values.append(float(metrics[metric_index])) epochs.append(int(metrics[epoch_index])) ind = self.monitor_op(metric_values) return epochs[ind], metric_values[ind], all_check_points[ind] return None, None, None
def _count(a: np.ndarray) -> np.ndarray: """ count number of appearance for each element in a Args: a: (np.array) Returns: (np.array) number of appearance of each element in a """ a = a.ravel() a = np.r_[a[0], a, np.Inf] z = np.where(np.abs(np.diff(a)) > 0)[0] z = np.r_[0, z] return np.diff(z)