# Copyright (c) Materials Virtual Lab.
# Distributed under the terms of the BSD License.
"""Van Hove analysis for correlations."""
from __future__ import annotations
import itertools
from collections import Counter
from typing import TYPE_CHECKING, Callable
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.stats import norm
from pymatgen.util.plotting import pretty_plot
from .rdf import RadialDistributionFunction
if TYPE_CHECKING:
    from pymatgen.analysis.diffusion.analyzer import DiffusionAnalyzer
    from pymatgen.core import Structure
__author__ = "Iek-Heng Chu"
__version__ = "1.0"
__date__ = "Aug 9, 2017"
[docs]
class VanHoveAnalysis:
    """
    Class for van Hove function analysis.
    In particular, self-part (Gs) and distinct-part (Gd) of the van Hove correlation function G(r,t) for given species
    and given structure are computed. If you use this class, please consider citing the following paper::
        Zhu, Z.; Chu, I.-H.; Deng, Z. and Ong, S. P. "Role of Na+ Interstitials
        and Dopants in Enhancing the Na+ Conductivity of the Cubic Na3PS4
        Superionic Conductor". Chem. Mater. (2015), 27, pp 8318-8325
    """
    def __init__(
        self,
        diffusion_analyzer: DiffusionAnalyzer,
        avg_nsteps: int = 50,
        ngrid: int = 101,
        rmax: float = 10.0,
        step_skip: int = 50,
        sigma: float = 0.1,
        cell_range: int = 1,
        species: tuple | list = ("Li", "Na"),
        reference_species: tuple | list | None = None,
        indices: list | None = None,
    ):
        """
        Initiation.
        Args:
            diffusion_analyzer (DiffusionAnalyzer): A
                pymatgen.analysis.diffusion.analyzer.DiffusionAnalyzer object
            avg_nsteps (int): Number of t0 used for statistical average
            ngrid (int): Number of radial grid points
            rmax (float): Maximum of radial grid (the minimum is always set zero)
            step_skip (int): # of time steps skipped during analysis. It defines
                the resolution of the reduced time grid
            sigma (float): Smearing of a Gaussian function
            cell_range (int): Range of translational vector elements associated
                with supercell. Default is 1, i.e. including the adjacent image
                cells along all three directions.
            species ([string]): a list of specie symbols of interest.
            reference_species ([string]): Set this option along with 'species'
                parameter to calculate the distinct-part of van Hove function.
                Note that the self-part of van Hove function is always computed
                only for those in "species" parameter.
            indices (list of int): If not None, only a subset of atomic indices
                will be selected for the analysis. If this is given, "species"
                parameter will be ignored.
        """
        # initial check
        if step_skip <= 0:
            raise ValueError("skip_step should be >=1!")
        n_ions, nsteps, ndim = diffusion_analyzer.disp.shape
        if nsteps <= avg_nsteps:
            raise ValueError("Number of timesteps is too small!")
        ntsteps = nsteps - avg_nsteps
        if ngrid - 1 <= 0:
            raise ValueError("Ntot should be greater than 1!")
        if sigma <= 0.0:
            raise ValueError("sigma should be > 0!")
        dr = rmax / (ngrid - 1)
        interval = np.linspace(0.0, rmax, ngrid)
        reduced_nt = int(ntsteps / float(step_skip)) + 1
        lattice = diffusion_analyzer.structure.lattice
        structure = diffusion_analyzer.structure
        if indices is None:
            indices = [j for j, site in enumerate(structure) if site.specie.symbol in species]
        ref_indices = indices
        if reference_species:
            ref_indices = [j for j, site in enumerate(structure) if site.specie.symbol in reference_species]
        rho = float(len(indices)) / lattice.volume
        # reduced time grid
        rtgrid = np.arange(0.0, reduced_nt)
        # van Hove functions
        gsrt = np.zeros((reduced_nt, ngrid), dtype=np.double)
        gdrt = np.zeros((reduced_nt, ngrid), dtype=np.double)
        tracking_ions = []
        ref_ions = []
        # auxiliary factor for 4*\pi*r^2
        aux_factor = 4.0 * np.pi * interval**2
        aux_factor[0] = np.pi * dr**2
        for _i, ss in enumerate(diffusion_analyzer.get_drift_corrected_structures()):
            all_fcoords = np.array(ss.frac_coords)
            tracking_ions.append(all_fcoords[indices, :])
            ref_ions.append(all_fcoords[ref_indices, :])
        tracking_ions = np.array(tracking_ions)  # type: ignore
        ref_ions = np.array(ref_ions)  # type: ignore
        gaussians = norm.pdf(interval[:, None], interval[None, :], sigma) / float(avg_nsteps) / float(len(ref_indices))
        # calculate self part of van Hove function
        image = np.array([0, 0, 0])
        for it in range(reduced_nt):
            dns = Counter()  # type: ignore
            it0 = min(it * step_skip, ntsteps)
            for it1 in range(avg_nsteps):
                dists = [
                    lattice.get_distance_and_image(tracking_ions[it1][u], tracking_ions[it0 + it1][u], jimage=image)[0]
                    for u in range(len(indices))
                ]
                r_indices = [int(dist / dr) for dist in filter(lambda e: e < rmax, dists)]
                dns.update(r_indices)  # type: ignore
            for indx, dn in dns.most_common(ngrid):
                gsrt[it, :] += gaussians[indx, :] * dn
        # calculate distinct part of van Hove function of species
        r = np.arange(-cell_range, cell_range + 1)
        arange = r[:, None] * np.array([1, 0, 0])[None, :]
        brange = r[:, None] * np.array([0, 1, 0])[None, :]
        crange = r[:, None] * np.array([0, 0, 1])[None, :]
        images = arange[:, None, None] + brange[None, :, None] + crange[None, None, :]
        images = images.reshape((len(r) ** 3, 3))
        # find the zero image vector
        zd = np.sum(images**2, axis=1)
        indx0 = np.argmin(zd)
        for it in range(reduced_nt):
            dns = Counter()
            it0 = min(it * step_skip, ntsteps)
            for it1 in range(avg_nsteps):
                dcf = (
                    tracking_ions[it0 + it1, :, None, None, :]  # type: ignore
                    + images[None, None, :, :]
                    - ref_ions[it1, None, :, None, :]  # type: ignore
                )
                dcc = lattice.get_cartesian_coords(dcf)
                d2 = np.sum(dcc**2, axis=3)
                dists = [
                    d2[u, v, j] ** 0.5
                    for u in range(len(indices))
                    for v in range(len(ref_indices))
                    for j in range(len(r) ** 3)
                    if u != v or j != indx0
                ]
                r_indices = [int(dist / dr) for dist in filter(lambda e: e < rmax, dists)]
                dns.update(r_indices)
            for indx, dn in dns.most_common(ngrid):
                gdrt[it, :] += gaussians[indx, :] * dn / aux_factor[indx] / rho
        self.obj = diffusion_analyzer
        self.avg_nsteps = avg_nsteps
        self.step_skip = step_skip
        self.rtgrid = rtgrid
        self.interval = interval
        self.gsrt = gsrt
        self.gdrt = gdrt
        # time interval (in ps) in gsrt and gdrt.
        self.timeskip = self.obj.time_step * self.obj.step_skip * step_skip / 1000.0
[docs]
    def get_3d_plot(self, figsize: tuple = (12, 8), mode: str = "distinct"):
        """
        Plot 3D self-part or distinct-part of van Hove function, which is
        specified by the input argument 'type'.
        Args:
            figsize (tuple): fig size in inches.
            mode (str): 'distinct' or 'both'.
        Returns:
            matplotlib.axes._subplots.Axes: axes object.
        """
        assert mode in ["distinct", "self"]
        if mode == "distinct":
            grt = self.gdrt.copy()
            vmax = 4.0
            cb_ticks = [0, 1, 2, 3, 4]
            cb_label = "$G_d$($t$,$r$)"
        else:
            grt = self.gsrt.copy()
            vmax = 1.0
            cb_ticks = [0, 1]
            cb_label = r"4$\pi r^2G_s$($t$,$r$)"
        y = np.arange(np.shape(grt)[1]) * self.interval[-1] / float(len(self.interval) - 1)
        x = np.arange(np.shape(grt)[0]) * self.timeskip
        X, Y = np.meshgrid(x, y, indexing="ij")
        ticksize = int(figsize[0] * 2.5)
        plt.figure(figsize=figsize, facecolor="w")
        plt.xticks(fontsize=ticksize)
        plt.yticks(fontsize=ticksize)
        ax = plt.gca()
        labelsize = int(figsize[0] * 3)
        plt.pcolor(X, Y, grt, cmap="jet", vmin=grt.min(), vmax=vmax)
        ax.set_xlabel("Time (ps)", size=labelsize)
        ax.set_ylabel(r"$r$ ($\AA$)", size=labelsize)
        ax.axis([x.min(), x.max(), y.min(), y.max()])
        cbar = plt.colorbar(ticks=cb_ticks)
        cbar.set_label(label=cb_label, size=labelsize)
        cbar.ax.tick_params(labelsize=ticksize)
        return ax 
[docs]
    def get_1d_plot(
        self,
        mode: str = "distinct",
        times: list | None = None,
        colors: list | None = None,
    ):
        """
        Plot the van Hove function at given r or t.
        Args:
            mode (str): Specify which part of van Hove function to be plotted.
            times (list of float): Time moments (in ps) in which the van Hove
                            function will be plotted.
            colors (list strings/tuples): Additional color settings. If not set,
                            seaborn.color_plaette("Set1", 10) will be used.
        Returns:
            matplotlib.axes._subplots.Axes: axes object.
        """
        if times is None:
            times = [0.0]
        if colors is None:
            import seaborn as sns
            colors = sns.color_palette("Set1", 10)
        assert mode in ["distinct", "self"]
        assert len(times) <= len(colors)
        if mode == "distinct":
            grt = self.gdrt.copy()
            ylabel = "$G_d$($t$,$r$)"
            ylim = [-0.005, 4.0]
        else:
            grt = self.gsrt.copy()
            ylabel = r"4$\pi r^2G_s$($t$,$r$)"
            ylim = [-0.005, 1.0]
        ax = pretty_plot(12, 8)
        for i, time in enumerate(times):
            index = int(np.round(time / self.timeskip))
            index = min(index, np.shape(grt)[0] - 1)
            new_time = index * self.timeskip
            label = str(new_time) + " ps"
            ax.plot(self.interval, grt[index], color=colors[i], label=label, linewidth=4.0)
        ax.set_xlabel(r"$r$ ($\AA$)")
        ax.set_ylabel(ylabel)
        ax.legend(loc="upper right", fontsize=36)
        ax.set_xlim(0.0, self.interval[-1] - 1.0)
        ax.set_ylim(ylim[0], ylim[1])
        return ax 
 
[docs]
class EvolutionAnalyzer:
    """Analyze the evolution of structures during AIMD simulations."""
    def __init__(self, structures: list, rmax: float = 10, step: int = 1, time_step: int = 2):
        """
        Initialization the EvolutionAnalyzer from MD simulations. From the
        structures obtained from MD simulations, we can analyze the structure
        evolution with time by some quantitative characterization such as RDF
        and atomic distribution.
        If you use this class, please consider citing the following paper:
        Tang, H.; Deng, Z.; Lin, Z.; Wang, Z.; Chu, I-H; Chen, C.; Zhu, Z.;
        Zheng, C.; Ong, S. P. "Probing Solid-Solid Interfacial Reactions in
        All-Solid-State Sodium-Ion Batteries with First-Principles
        Calculations", Chem. Mater. (2018), 30(1), pp 163-173.
        Args:
            structures ([Structure]): The list of structures obtained from MD
                simulations.
            rmax (float): Maximum of radial grid (the minimum is always zero).
            step (int): indicate the interval of input structures, which is used
                to calculated correct time step.
            time_step(int): the time step in fs, POTIM in INCAR.
        """
        self.pairs = self.get_pairs(structures[0])
        self.structure = structures[0]
        self.structures = structures
        self.rmax = rmax
        self.step = step
        self.time_step = time_step
[docs]
    @staticmethod
    def get_pairs(structure: Structure):
        """
        Get all element pairs in a structure.
        Args:
            structure (Structure): structure
        Returns:
            list of tuples
        """
        specie_list = [s.name for s in structure.types_of_specie]
        pairs = itertools.combinations_with_replacement(specie_list, 2)
        return list(pairs) 
[docs]
    @staticmethod
    def rdf(structure: Structure, pair: tuple, ngrid: int = 101, rmax: float = 10):
        """
        Process rdf from a given structure and pair.
        Args:
            structure (Structure): input structure.
            pair (str tuple): e.g. ("Na", "Na").
            ngrid (int): Number of radial grid points.
            rmax (float): Maximum of radial grid (the minimum is always zero).
        Returns:
            rdf (np.array)
        """
        r = RadialDistributionFunction.from_species(
            [structure],
            ngrid=ngrid,
            rmax=rmax,
            sigma=0.1,
            species=(pair[0]),
            reference_species=(pair[1]),
        )
        return r.rdf 
[docs]
    @staticmethod
    def atom_dist(
        structure: Structure,
        specie: str,
        ngrid: int = 101,
        window: float = 1,
        direction: str = "c",
    ):
        """
        Get atomic distribution for a given specie.
        Args:
            structure (Structure): input structure
            specie (str): species string for an element
            ngrid (int): Number of radial grid points.
            window (float): number of atoms will be counted within the range
                (i-window, i+window), unit is angstrom.
            direction (str): Choose from "a", "b" and "c". Default is "c".
        Returns:
            density (np.array): atomic concentration along one direction.
        """
        if direction in ["a", "b", "c"]:
            latt_len = getattr(structure.lattice, direction)
            ind = ["a", "b", "c"].index(direction)
            assert window <= latt_len, "Window range exceeds valid bounds!"
        else:
            raise ValueError("Choose from a, b and c!")
        atom_list = [site for site in structure.sites if site.species_string == specie]
        atom_total = structure.composition[specie]
        density = []
        for i in np.linspace(0, latt_len - window, ngrid):
            atoms = []
            for j in [-1, 0, 1]:
                temp = [s for s in atom_list if i - window < s.coords[ind] % latt_len + latt_len * j < i + window]
                atoms.extend(temp)
            density.append(len(atoms) / atom_total)
        return np.array(density) 
[docs]
    def get_df(self, func: Callable, save_csv: str | None = None, **kwargs):
        """
        Get the data frame for a given pair. This step would be very slow if
        there are hundreds or more structures to parse.
        Args:
            func (FunctionType): structure to spectrum function. choose from
                rdf (to get radial distribution function, pair required) or
                get_atomic_distribution (to get atomic distribution, specie
                required). Extra parameters can be parsed using kwargs.
                e.g. To get rdf dataframe:
                    df = EvolutionAnalyzer.get_df(
                        func=EvolutionAnalyzer.rdf, pair=("Na", "Na"))
                e.g. To get atomic distribution:
                    df = EvolutionAnalyzer.get_df(
                        func=EvolutionAnalyzer.atom_dist, specie="Na")
            save_csv (str): save pandas DataFrame to csv.
            **kwargs: Pass-through to func.
        Returns:
            pandas.DataFrame object: index is the radial distance in Angstrom,
                and column is the time step in ps.
        """
        prop_table = []
        ngrid = kwargs.get("ngrid", 101)
        if func == self.rdf:
            kwargs["rmax"] = self.rmax
        for structure in self.structures:
            prop_table.append(func(structure, **kwargs))
        index = np.arange(len(self.structures)) * self.time_step * self.step / 1000
        columns = np.linspace(0, self.rmax, ngrid)
        df = pd.DataFrame(prop_table, index=index, columns=columns)
        if save_csv is not None:
            df.to_csv(save_csv)
        return df 
[docs]
    @staticmethod
    def get_min_dist(df: pd.DataFrame, tol: float = 1e-10):
        """
        Get the shortest pair distance from the given DataFrame.
        Args:
            df (DataFrame): index is the radial distance in Angstrom, and
                column is the time step in ps.
            tol (float): any float number less than tol is considered as zero.
        Returns:
            The shortest pair distance throughout the table.
        """
        # TODO: Add unittest
        for i, col in enumerate(df.columns):
            min_dist = df.min(axis="index")[i]
            if min_dist > tol:
                return float(col)
        raise RuntimeError("Getting min dist failed.") 
[docs]
    @staticmethod
    def plot_evolution_from_data(
        df: pd.DataFrame,
        x_label: str | None = None,
        cb_label: str | None = None,
        cmap=plt.cm.plasma,  # pylint: disable=E1101
    ):
        """
        Plot the evolution with time for a given DataFrame. It can be RDF,
        atomic distribution or other characterization data we might
        implement in the future.
        Args:
            df (pandas.DataFrame): input DataFrame object, index is the radial
                distance in Angstrom, and column is the time step in ps.
            x_label (str): x label
            cb_label (str): color bar label
            cmap (color map): the color map used in heat map.
                cmocean.cm.thermal is recommended
        Returns:
            matplotlib.axes._subplots.AxesSubplot object
        """
        import seaborn as sns
        sns.set_style("white")
        plt.rcParams["axes.linewidth"] = 1.5
        plt.rcParams["font.family"] = "sans-serif"
        plt.rcParams["xtick.labelsize"] = 20
        plt.rcParams["ytick.labelsize"] = 20
        plt.rcParams["xtick.major.pad"] = 10
        fig, ax = plt.subplots(figsize=(12, 8), facecolor="w")
        ax = sns.heatmap(
            df,
            linewidths=0,
            cmap=cmap,
            annot=False,
            cbar=True,
            xticklabels=10,
            yticklabels=25,
            rasterized=True,
        )
        ax.set_ylim(ax.get_ylim()[::-1])
        ax.collections[0].colorbar.set_label(cb_label, fontsize=30)
        plt.xticks(rotation="horizontal")
        ax.set_xlabel(x_label, fontsize=30)
        ax.set_ylabel("Time (ps)", fontsize=30)
        plt.yticks(rotation="horizontal")
        return ax 
[docs]
    def plot_rdf_evolution(
        self,
        pair: tuple,
        cmap=plt.cm.plasma,  # pylint: disable=E1101
        df: pd.DataFrame = None,
    ):
        """
        Plot the RDF evolution with time for a given pair.
        Args:
            pair (str tuple): e.g. ("Na", "Na")
            cmap (color map): the color map used in heat map.
                cmocean.cm.thermal is recommended
            df (DataFrame): external data, index is the radial distance in
                Angstrom, and column is the time step in ps.
        Returns:
            matplotlib.axes._subplots.AxesSubplot object
        """
        if df is None:
            df = self.get_df(func=EvolutionAnalyzer.rdf, pair=pair)
        x_label, cb_label = f"$r$ ({pair[0]}-{pair[1]}) ($\\rm\\AA$)", "$g(r)$"
        return self.plot_evolution_from_data(df=df, x_label=x_label, cb_label=cb_label, cmap=cmap) 
[docs]
    def plot_atomic_evolution(
        self,
        specie: str,
        direction: str = "c",
        cmap=plt.cm.Blues,  # pylint: disable=E1101
        df: pd.DataFrame = None,
    ):
        """
        Plot the atomic distribution evolution with time for a given species.
        Args:
            specie (str): species string for an element.
            direction (str): Choose from "a", "b", "c". Default to "c".
            cmap (color map): the color map used in heat map.
            df (DataFrame): external data, index is the atomic distance in
                         Angstrom, and column is the time step in ps.
        Returns:
            matplotlib.axes._subplots.AxesSubplot object
        """
        if df is None:
            df = self.get_df(func=EvolutionAnalyzer.atom_dist, specie=specie, direction=direction)
        x_label, cb_label = (
            f"Atomic distribution along {direction}",
            "Probability",
        )
        return self.plot_evolution_from_data(df=df, x_label=x_label, cb_label=cb_label, cmap=cmap)