import datetime
import inspect
import json
import os
import shutil
from importlib import import_module
from pathlib import Path
from datetime import timedelta

import numpy as np
import pandas as pd

from .log import logger
from .introspection import infer_args_from_method


def check_directory_exists_and_if_not_mkdir(directory):
    """ Checks if the given directory exists and creates it if it does not exist

    Parameters
    ==========
    directory: str
        Name of the directory

    """
    Path(directory).mkdir(parents=True, exist_ok=True)


class BilbyJsonEncoder(json.JSONEncoder):
    def default(self, obj):
        from ..prior import MultivariateGaussianDist, Prior, PriorDict
        from ...gw.prior import HealPixMapPriorDist
        from ...bilby_mcmc.proposals import ProposalCycle

        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, PriorDict):
            return {"__prior_dict__": True, "content": obj._get_json_dict()}
        if isinstance(obj, (MultivariateGaussianDist, HealPixMapPriorDist, Prior)):
            return {
                "__prior__": True,
                "__module__": obj.__module__,
                "__name__": obj.__class__.__name__,
                "kwargs": dict(obj.get_instantiation_dict()),
            }
        if isinstance(obj, ProposalCycle):
            return str(obj)
        try:
            from astropy import cosmology as cosmo, units

            if isinstance(obj, cosmo.FLRW):
                return encode_astropy_cosmology(obj)
            if isinstance(obj, units.Quantity):
                return encode_astropy_quantity(obj)
            if isinstance(obj, units.PrefixUnit):
                return str(obj)
        except ImportError:
            logger.debug("Cannot import astropy, cannot write cosmological priors")
        if isinstance(obj, np.ndarray):
            return {"__array__": True, "content": obj.tolist()}
        if isinstance(obj, complex):
            return {"__complex__": True, "real": obj.real, "imag": obj.imag}
        if isinstance(obj, pd.DataFrame):
            return {"__dataframe__": True, "content": obj.to_dict(orient="list")}
        if isinstance(obj, pd.Series):
            return {"__series__": True, "content": obj.to_dict()}
        if inspect.isfunction(obj):
            return {
                "__function__": True,
                "__module__": obj.__module__,
                "__name__": obj.__name__,
            }
        if inspect.isclass(obj):
            return {
                "__class__": True,
                "__module__": obj.__module__,
                "__name__": obj.__name__,
            }
        if isinstance(obj, (timedelta)):
            return {
                "__timedelta__": True,
                "__total_seconds__": obj.total_seconds()
            }
            return obj.isoformat()
        return json.JSONEncoder.default(self, obj)


def encode_astropy_cosmology(obj):
    cls_name = obj.__class__.__name__
    dct = {key: getattr(obj, key) for key in infer_args_from_method(obj.__init__)}
    dct["__cosmology__"] = True
    dct["__name__"] = cls_name
    return dct


def encode_astropy_quantity(dct):
    dct = dict(__astropy_quantity__=True, value=dct.value, unit=str(dct.unit))
    if isinstance(dct["value"], np.ndarray):
        dct["value"] = list(dct["value"])
    return dct


def decode_astropy_cosmology(dct):
    try:
        from astropy import cosmology as cosmo

        cosmo_cls = getattr(cosmo, dct["__name__"])
        del dct["__cosmology__"], dct["__name__"]
        return cosmo_cls(**dct)
    except ImportError:
        logger.debug(
            "Cannot import astropy, cosmological priors may not be " "properly loaded."
        )
        return dct


def decode_astropy_quantity(dct):
    try:
        from astropy import units

        if dct["value"] is None:
            return None
        else:
            del dct["__astropy_quantity__"]
            return units.Quantity(**dct)
    except ImportError:
        logger.debug(
            "Cannot import astropy, cosmological priors may not be " "properly loaded."
        )
        return dct


def load_json(filename, gzip):
    if gzip or os.path.splitext(filename)[1].lstrip(".") == "gz":
        import gzip

        with gzip.GzipFile(filename, "r") as file:
            json_str = file.read().decode("utf-8")
        dictionary = json.loads(json_str, object_hook=decode_bilby_json)
    else:
        with open(filename, "r") as file:
            dictionary = json.load(file, object_hook=decode_bilby_json)
    return dictionary


def decode_bilby_json(dct):
    if dct.get("__prior_dict__", False):
        cls = getattr(import_module(dct["__module__"]), dct["__name__"])
        obj = cls._get_from_json_dict(dct)
        return obj
    if dct.get("__prior__", False):
        try:
            cls = getattr(import_module(dct["__module__"]), dct["__name__"])
        except AttributeError:
            logger.debug(
                "Unknown prior class for parameter {}, defaulting to base Prior object".format(
                    dct["kwargs"]["name"]
                )
            )
            from ..prior import Prior

            cls = Prior
        obj = cls(**dct["kwargs"])
        return obj
    if dct.get("__cosmology__", False):
        return decode_astropy_cosmology(dct)
    if dct.get("__astropy_quantity__", False):
        return decode_astropy_quantity(dct)
    if dct.get("__array__", False):
        return np.asarray(dct["content"])
    if dct.get("__complex__", False):
        return complex(dct["real"], dct["imag"])
    if dct.get("__dataframe__", False):
        return pd.DataFrame(dct["content"])
    if dct.get("__series__", False):
        return pd.Series(dct["content"])
    if dct.get("__function__", False) or dct.get("__class__", False):
        default = ".".join([dct["__module__"], dct["__name__"]])
        return getattr(import_module(dct["__module__"]), dct["__name__"], default)
    if dct.get("__timedelta__", False):
        return timedelta(seconds=dct["__total_seconds__"])
    return dct


def recursively_decode_bilby_json(dct):
    """
    Recursively call `bilby_decode_json`

    Parameters
    ----------
    dct: dict
        The dictionary to decode

    Returns
    -------
    dct: dict
        The original dictionary with all the elements decode if possible
    """
    dct = decode_bilby_json(dct)
    if isinstance(dct, dict):
        for key in dct:
            if isinstance(dct[key], dict):
                dct[key] = recursively_decode_bilby_json(dct[key])
    return dct


def decode_from_hdf5(item):
    """
    Decode an item from HDF5 format to python type.

    This currently just converts __none__ to None and some arrays to lists

    .. versionadded:: 1.0.0

    Parameters
    ----------
    item: object
        Item to be decoded

    Returns
    -------
    output: object
        Converted input item
    """
    if isinstance(item, str) and item == "__none__":
        output = None
    elif isinstance(item, bytes) and item == b"__none__":
        output = None
    elif isinstance(item, (bytes, bytearray)):
        output = item.decode()
    elif isinstance(item, np.ndarray):
        if item.size == 0:
            output = item
        elif "|S" in str(item.dtype) or isinstance(item[0], bytes):
            output = [it.decode() for it in item]
        else:
            output = item
    elif isinstance(item, np.bool_):
        output = bool(item)
    else:
        output = item
    return output


def encode_for_hdf5(key, item):
    """
    Encode an item to a HDF5 saveable format.

    .. versionadded:: 1.1.0

    Parameters
    ----------
    item: object
        Object to be encoded, specific options are provided for Bilby types

    Returns
    -------
    output: object
        Input item converted into HDF5 saveable format
    """
    from ..prior.dict import PriorDict

    if isinstance(item, np.int_):
        item = int(item)
    elif isinstance(item, np.float_):
        item = float(item)
    elif isinstance(item, np.complex_):
        item = complex(item)
    if isinstance(item, np.ndarray):
        # Numpy's wide unicode strings are not supported by hdf5
        if item.dtype.kind == 'U':
            logger.debug(f'converting dtype {item.dtype} for hdf5')
            item = np.array(item, dtype='S')
    if isinstance(item, (np.ndarray, int, float, complex, str, bytes)):
        output = item
    elif item is None:
        output = "__none__"
    elif isinstance(item, list):
        item_array = np.array(item)
        if len(item) == 0:
            output = item
        elif np.issubdtype(item_array.dtype, np.number):
            output = np.array(item)
        elif issubclass(item_array.dtype.type, str) or item[0] is None:
            output = list()
            for value in item:
                if isinstance(value, str):
                    output.append(value.encode("utf-8"))
                elif isinstance(value, bytes):
                    output.append(value)
                elif value is None:
                    output.append(b"__none__")
                else:
                    output.append(str(value).encode("utf-8"))
        else:
            raise ValueError(f'Cannot save {key}: {type(item)} type')
    elif isinstance(item, PriorDict):
        output = json.dumps(item._get_json_dict())
    elif isinstance(item, pd.DataFrame):
        output = item.to_dict(orient="list")
    elif inspect.isfunction(item) or inspect.isclass(item):
        output = dict(
            __module__=item.__module__, __name__=item.__name__, __class__=True
        )
    elif isinstance(item, dict):
        output = item.copy()
    elif isinstance(item, tuple):
        output = {str(ii): elem for ii, elem in enumerate(item)}
    elif isinstance(item, datetime.timedelta):
        output = item.total_seconds()
    else:
        raise ValueError(f'Cannot save {key}: {type(item)} type')
    return output


def recursively_load_dict_contents_from_group(h5file, path):
    """
    Recursively load a HDF5 file into a dictionary

    .. versionadded:: 1.1.0

    Parameters
    ----------
    h5file: h5py.File
        Open h5py file object
    path: str
        Path within the HDF5 file

    Returns
    -------
    output: dict
        The contents of the HDF5 file unpacked into the dictionary.
    """
    import h5py

    output = dict()
    for key, item in h5file[path].items():
        if isinstance(item, h5py.Dataset):
            output[key] = decode_from_hdf5(item[()])
        elif isinstance(item, h5py.Group):
            output[key] = recursively_load_dict_contents_from_group(
                h5file, path + key + "/"
            )
    return output


def recursively_save_dict_contents_to_group(h5file, path, dic):
    """
    Recursively save a dictionary to a HDF5 group

    .. versionadded:: 1.1.0

    Parameters
    ----------
    h5file: h5py.File
        Open HDF5 file
    path: str
        Path inside the HDF5 file
    dic: dict
        The dictionary containing the data
    """
    for key, item in dic.items():
        item = encode_for_hdf5(key, item)
        if isinstance(item, dict):
            recursively_save_dict_contents_to_group(h5file, path + key + "/", item)
        else:
            h5file[path + key] = item


def safe_file_dump(data, filename, module):
    """ Safely dump data to a .pickle file

    Parameters
    ==========
    data:
        data to dump
    filename: str
        The file to dump to
    module: pickle, dill, str
        The python module to use. If a string, the module will be imported
    """
    if isinstance(module, str):
        module = import_module(module)
    temp_filename = filename + ".temp"
    with open(temp_filename, "wb") as file:
        module.dump(data, file)
    shutil.move(temp_filename, filename)


def move_old_file(filename, overwrite=False):
    """ Moves or removes an old file.

    Parameters
    ==========
    filename: str
        Name of the file to be move
    overwrite: bool, optional
        Whether or not to remove the file or to change the name
        to filename + '.old'
    """
    if os.path.isfile(filename):
        if overwrite:
            logger.debug("Removing existing file {}".format(filename))
            os.remove(filename)
        else:
            logger.debug(
                "Renaming existing file {} to {}.old".format(filename, filename)
            )
            shutil.move(filename, filename + ".old")
    logger.debug("Saving result to {}".format(filename))


def safe_save_figure(fig, filename, **kwargs):
    check_directory_exists_and_if_not_mkdir(os.path.dirname(filename))
    from matplotlib import rcParams

    try:
        fig.savefig(fname=filename, **kwargs)
    except RuntimeError:
        logger.debug("Failed to save plot with tex labels turning off tex.")
        rcParams["text.usetex"] = False
        fig.savefig(fname=filename, **kwargs)