# coding=utf-8
#
# The MIT License (MIT)
#
# Copyright (c) 2014 Lorenz Hüdepohl, Emmanuel Stamou
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

from __future__ import absolute_import
from __future__ import print_function

from .arithmetic_decorators import add_operators
from .bin import Bin

def binned_operator(func):
    from itertools import cycle
    def wrapped_binned_operator(self, other):
        if hasattr(other, "bins"):
            if not self.compatible(other):
                raise Exception("Cannot operate on non-compatible BinnedObjects!")
            def other_value(sc):
                return other[sc]
        else:
            def other_value(sc):
                return other
        res = BinnedObject(ndim=self.ndim)
        try:
            otheriter = iter(other)
        except TypeError:
            otheriter = cycle((other,))

        for sbin, obin in zip(self, otheriter):
            res.add_bin(func(sbin, obin))
        return res

    return wrapped_binned_operator

@add_operators(binned_operator)
class BinnedObject(object):
    def __init__(self, ndim, initial_bins=None):
        if not isinstance(ndim, int) or not ndim > 0:
            raise Exception("Invalid number of dimensions: {0}".format(ndim))
        from sortedcontainers import SortedListWithKey

        self.ndim = ndim
        self.bins = SortedListWithKey(key = lambda bin : bin.center())

        def lowkey(ndim):
            def key(bin):
                return bin.low[ndim]
            return key

        def highkey(ndim):
            def key(bin):
                return bin.high[ndim]
            return key

        self.lows = tuple(SortedListWithKey(key = lowkey(i)) for i in range(ndim))
        self.highs = tuple(SortedListWithKey(key = highkey(i)) for i in range(ndim))

        if initial_bins is not None:
            for bin in initial_bins:
                self.add_bin(bin)

    def add_bin(self, bin):
        """
        Insert bin into self. Raises an exception if an existing bin overlaps
        with the new bin.

        Examples
        --------

        >>> bin1 = Bin([0.], [1.], 1.0)
        >>> bin2 = Bin([1.], [2.], 1.0)
        >>> a = BinnedObject(1)
        >>> a.add_bin(bin1)
        >>> a
        BinnedObject(1,
         (Bin((0.0,), (1.0,), 1.0),)
        )
        >>> a.add_bin(bin2)
        >>> a
        BinnedObject(1,
         (Bin((0.0,), (1.0,), 1.0),
          Bin((1.0,), (2.0,), 1.0),)
        )
        >>> a.add_bin(bin2)
        Traceback (most recent call last):
        ...
        Exception: Cannot add a bin [1.0 - 2.0] = 1.0, overlaps with other bin(s):
          [1.0 - 2.0] = 1.0
        """
        if len(self) > 0:
            bo = BinnedObject(bin.ndim, (bin,))
            ov_b, ov_s = bo.find_overlap(self)
            if ov_b[bin]:
                raise Exception("Cannot add a bin {0}, overlaps with other bin(s):\n  {1}"
                  "".format(bin, ", ".join(str(b) for b in ov_b[bin])))

        self.bins.add(bin)
        for container in self.lows + self.highs:
            container.add(bin)

    def add_bins(self, bins):
        if len(self) > 0:
            bo = BinnedObject(self.ndim, bins)
            ov_b, ov_s = bo.find_overlap(self)
            if any(ov_b.values()):
                raise Exception("Cannot add bins, they overlap with existing bins")
        for bin in bins:
            self.bins.add(bin)
            for container in self.lows + self.highs:
                container.add(bin)

    def remove_bin(self, bin):
        """
        Remove bin from self.
        """
        self.bins.remove(bin)
        for container in self.lows + self.highs:
            container.remove(bin)

    def iadd_bin(self, bin):
        """
        Insert bin into self. In case existing bins are overlapping with the new
        bin, those bins are split along overlapping edges and the values are updated
        with volume-weighted fractions of the original.

        Parameters
        ----------
        bin : Bin object with same dimension as self

        Returns
        -------
          self

        Examples
        --------

        >>> bin1 = Bin([0.], [1.], 1.0)
        >>> bin2 = Bin([1.], [2.], 1.0)
        >>> a = BinnedObject(1, (bin1, bin2))
        >>> a
        BinnedObject(1,
         (Bin((0.0,), (1.0,), 1.0),
          Bin((1.0,), (2.0,), 1.0),)
        )
        >>> new_bin = Bin([0.5], [1.5], 1.0)
        >>> a.iadd_bin(new_bin)
        BinnedObject(1,
         (Bin((0.0,), (0.5,), 0.5),
          Bin((0.5,), (1.0,), 1.0),
          Bin((1.0,), (1.5,), 1.0),
          Bin((1.5,), (2.0,), 0.5),)
        )
        """

        self, other, ov_self, ov_other = \
            conforming_rebin(self, BinnedObject(self.ndim, (bin,)), inplace=True, return_overlap=True)

        for new_bin in other:
            if len(ov_other[new_bin]) == 0:
                self.add_bin(new_bin)
            elif len(ov_other[new_bin]) == 1:
                self_bin = ov_other[new_bin].pop()
                self_bin.value += new_bin.value
            else:
                raise Exception("This shouldn't happen!")

        return self

    def find_bins(self, point):
        """
        Return all bins of self that have the given point strictly inside
        themselfes (i.e. not as edge or on a boundary surface)
        """

        class mock_bin(object):
            def __init__(this, point):
                this.low = point
                this.high = point
                this.ndim = self.ndim

        bin_point = mock_bin(point)

        def intersecting_bins(ndim):
            i0 = self.lows[ndim].bisect_left(bin_point)
            i1 = self.highs[ndim].bisect_right(bin_point)
            s0 = set(self.lows[ndim][:i0])
            s1 = set(self.highs[ndim][i1:])
            res = s0.intersection(s1)
            return res

        bins = intersecting_bins(0)
        for i in range(1, self.ndim):
            bins.intersection_update(intersecting_bins(i))

        if not bins:
            raise Exception("Point not contained in any bin!")

        return bins

    def find_bin(self, point):
        """
        Return the unique bin that encloses the given point
        """
        bins = self.find_bins(point)
        if len(bins) != 1:
            raise Exception("More than one bin contains {0}:\n{1}".format(point, bins))
        return bins.pop()

    def __repr__(self):
        return "BinnedObject({0},\n ({1},)\n)".format(self.ndim, ",\n  ".join(repr(b) for b in self))

    def __str__(self):
        return "BinnedObject(\n " + \
            "\n ".join(str(b) for b in self) + \
            "\n)"

    def __iter__(self):
        return iter(self.bins)

    def __len__(self):
        return len(self.bins)

    def plot(self, ax=None, *args, **kwargs):
        import matplotlib.pyplot as plt
        import numpy as np
        if ax is None:
            ax = plt.gca()

        if self.ndim == 1:
            lines = []
            for bin in self:
                if hasattr(bin.value, "error") and np.isfinite(bin.value.error):
                    lines.extend(ax.plot([bin.low, bin.high],
                                         [bin.value.value, bin.value.value], *args, **kwargs))
                    lines.extend(ax.plot([bin.center(), bin.center()],
                                         [bin.value.value - bin.value.error, bin.value.value + bin.value.error], *args, **kwargs))
                elif hasattr(bin.value, "value"):
                    lines.extend(ax.plot([bin.low, bin.high], [bin.value.value, bin.value.value], *args, **kwargs))
                else:
                    lines.extend(ax.plot([bin.low, bin.high], [bin.value, bin.value], *args, **kwargs))

            return lines
        elif self.ndim == 2:
            from matplotlib.patches import Rectangle
            patches = []
            for bin in self:
                patches.append(Rectangle(bin.low, width=bin.high[0] - bin.low[0] - 0.05, height=bin.high[1] - bin.low[1] - 0.05, ec="black", fc="lightgrey", alpha=0.8))
                ax.add_artist(patches[-1])
            return patches

    def nice_plot(self, ax=None, marker=None, label=None, *args, **kwargs):
        import matplotlib.pyplot as plt
        import matplotlib.patches
        from matplotlib.collections import PatchCollection
        from matplotlib.collections import RegularPolyCollection
        from matplotlib.patches import Patch
        from matplotlib.collections import LineCollection

        import numpy as np
        if ax is None:
            ax = plt.gca()

        if self.ndim == 1:
            lines      = []
            points     = []
            rectangles = []
            for bin in self:
                low,    = bin.low
                high,   = bin.high
                center, = bin.center()

                if hasattr(bin.value, "error") and np.isfinite(bin.value.error) and bin.value.error != 0:
                    rectangles.append(matplotlib.patches.Rectangle((low, bin.value.value-bin.value.error), width=high-low, height=2*bin.value.error))
                elif hasattr(bin.value, "value"):
                    lines.append(((low, bin.value.value), (high, bin.value.value)))
                    points.append((center, bin.value.value))
                else:
                    lines.append(((low, bin.value), (high, bin.value)))
                    points.append((center, bin.value))



            ax.plot(*np.array(points).T, linestyle='none', marker=marker, label=label, **kwargs)
            ax.add_collection(LineCollection(lines, *args, **kwargs))

            ax.add_collection(PatchCollection(rectangles, **kwargs))
            self.plot(alpha=0)

            if np.any([hasattr(bin.value, "error") for bin in self]) and  \
               np.any([np.isfinite(bin.value.error) for bin in self]) and \
               np.all([bin.value.error!=0 for bin in self]) :
                rec = matplotlib.patches.Rectangle((float('nan'),float('nan')), width=0, height=0, **kwargs)
                rec.set_label(label)
                ax.add_patch(rec)

            return ax

    def compatible(self, other):
        if not self.ndim == other.ndim:
            return False

        if not len(self) == len(other):
            return False

        for self_bin, other_bin in zip(self.bins, other.bins):
            if not self_bin.compatible(other_bin):
                return False

        return True

    def __eq__(self, other):
        if not self.ndim == other.ndim:
            raise TypeError("Cannot compare binned objects with different dimensions!")
        if not self.compatible(other):
            return False
        res = BinnedObject(self.ndim)
        for sbin, obin in zip(self, other):
            res.add_bin(Bin(sbin.low, sbin.high, sbin == obin))
        return res

    def __bool__(self):
        # Python 3
        if len(self) > 1:
            raise TypeError("The truth of a binned object with more than one bin is ambigous! Use self.any() or self.all()")
        else:
            return bool(self.bins[0])

    def __nonzero__(self):
        # Python 2
        return self.__bool__()

    def all(self):
        return all(bool(b) for b in self.bins)

    def any(self):
        return any(bool(b) for b in self.bins)

    def sum(self, axis=None):
        from numpy import isscalar
        if axis is None or axis == 0 and self.ndim == 1:
            return sum(b.value for b in self)
        elif isscalar(axis):
            if axis >= self.ndim:
                raise Exception("Cannot sum over non-existent axis {0} for {1}d BinnedObject".format(axis, self.ndim))
            res = BinnedObject(self.ndim - 1)
            for bin in self.bins:
                res.iadd_bin(bin.projected(axis))
        else:
            res = self
            for axis in reversed(sorted(axis)):
                res = res.sum(axis)

        return res

    def volume(self):
        res = BinnedObject(self.ndim)
        for b in self:
            res.add_bin(b.volume())
        return res

    def rebin(self, other, intrinsic=False):
        """
        Rebin self with the bin-structure of the BinnedObject other.

        Parameters
        ----------

        intrinsic : If False, assume bin values are extrinsic in (scale with)
          bin size. That is, the new bin values are obtained by volume-weighted
          linear-combination of old bins overlapping into the new bin.

          If True, assume bin values are intrinsic in (do not scale with) bin size,
          new bin values are obtained by treating the bin values as the density of
          an extrinsic quantity

        """
        if other.ndim != self.ndim:
            raise Exception("Incompatible dimensions")

        res = BinnedObject(self.ndim)
        for ob in other:
            res.add_bin(Bin(ob.low, ob.high, ob.value * 0.0))

        if intrinsic:
            res_conf, self_conf = conforming_rebin(res, self * self.volume())
        else:
            res_conf, self_conf = conforming_rebin(res, self)

        for splitted_bin in self_conf:
            try:
                res_bin = res.find_bin(splitted_bin.center())
                res_bin.value += splitted_bin.value
            except:
                pass

        if intrinsic:
            res = res / res.volume()

        return res

    def find_overlap(self, bins):
        """
        ov_a, ov_b = a.find_overlap(b)

        Efficiently calculate the overlap of the bins of self and a list of bins

        Parameters
        ----------

        bins : a collection of Bin() objects

        Returns
        -------

        ov_a, ov_b : lists

        Two lists of lists, ov_a contains for each bin in "a" the list of bins
        of "b" that overlap with this bin.
        """
        from .interval_overlap import interval_overlap

        if hasattr(bins, "ndim") and self.ndim != bins.ndim:
            raise Exception("Cannot compute overlapping bins for quantities with different dimensions: {0} != {1}".format(
                self.ndim, bins.ndim))

        def getitem(ndim):
            def get(q, n):
                if n == 0:
                    return q.low[ndim]
                elif n == 1:
                    return q.high[ndim]
                else:
                    raise Exception("This should not happen!")

            return get

        ov_self, ov_other = interval_overlap(list(self), list(bins), getitem=getitem(0))

        for ndim in range(1, self.ndim):
            ov_self_n, ov_other_n = interval_overlap(list(self), list(bins), getitem=getitem(ndim))
            for key in ov_self_n:
                ov_self[key].intersection_update(ov_self_n[key])
            for key in ov_other_n:
                ov_other[key].intersection_update(ov_other_n[key])

        return ov_self, ov_other

    def bounding_box(self):
        """
        Return the bounding box of self, a self.ndim long tuple
        of low and high coordinates that enclose all bins of self

        Examples
        --------

        >>> bin1 = Bin([0., 0.], [1., 2.], 0.25)
        >>> bin2 = Bin([1., 0.], [2., 1.], 0.75)
        >>> a = BinnedObject(2, (bin1, bin2))
        >>> a.bounding_box()
        ((0.0, 2.0), (0.0, 2.0))
        """
        return tuple((self.lows[dim][0].low[dim],
                      self.highs[dim][-1].high[dim]) for dim in range(self.ndim))

    def __reduce__(self):
        return (self.__class__,(self.ndim,tuple(self),))

def conforming_rebin(a, b, inplace=False, return_overlap=False):
    """
    Take two BinnedObjects and return two new ones with a compatible binning.
    This is achieved by splitting all bins on any vertices of any overlapping other bin.

    New bin values are obtained as volume-weighted fractions of the original bin value,
    i. e. the values are assumed to be extrinsic.

    Parameters
    ----------
    inplace : bool
      If True, modify the objects a and b instead of producing new BinnedObjects

    return_overlap : bool
      If True, return the overlap as one would get from a.find_overlap(b) as well

    Example
    -------
    >>> bin_a1 = Bin([0.], [0.25], 0.25)
    >>> bin_a2 = Bin([0.25], [1.0], 0.75)
    >>> a = BinnedObject(1, (bin_a1, bin_a2))
    >>>
    >>> bin_b1 = Bin([0.], [0.75], 1.5)
    >>> bin_b2 = Bin([0.75], [1.0], 0.5)
    >>> b = BinnedObject(1, (bin_b1, bin_b2))
    >>>
    >>> an, bn = conforming_rebin(a, b)
    >>> print(an)
    BinnedObject(
     [0.0 - 0.25] = 0.25
     [0.25 - 0.75] = 0.5
     [0.75 - 1.0] = 0.25
    )

    >>> print(bn)
    BinnedObject(
     [0.0 - 0.25] = 0.5
     [0.25 - 0.75] = 1.0
     [0.75 - 1.0] = 0.5
    )
    >>>
    >>> bin_a1 = Bin([0.], [1.], 1.0)
    >>> bin_b1 = Bin([0.25], [0.75], 0.5)
    >>>
    >>> a = BinnedObject(1, (bin_a1,))
    >>> b = BinnedObject(1, (bin_b1,))
    >>>
    >>> an, bn = conforming_rebin(a, b)
    >>> print(an)
    BinnedObject(
     [0.0 - 0.25] = 0.25
     [0.25 - 0.75] = 0.5
     [0.75 - 1.0] = 0.25
    )

    >>> print(bn)
    BinnedObject(
     [0.25 - 0.75] = 0.5
    )
    """
    if not a.ndim == b.ndim:
        raise Exception("Incompatible dimensions")

    ov_a, ov_b = a.find_overlap(b)

    bins_a = list(a)
    bins_b = list(b)

    def split_bin(bin, splitpoints, overlap_self, overlap_other):
        splitbins = [bin]
        for p in splitpoints:
            new_splitbins = []
            for sb in splitbins:
                if p in sb:
                    new_splitbins.extend(sb.split(p))
                else:
                    new_splitbins.append(sb)
            splitbins = new_splitbins
        if splitbins == [bin]:
            return []
        for sb in splitbins:
            overlap_self[sb] = set()
        for otherbin in overlap_self[bin]:
            overlap_other[otherbin].remove(bin)
            for sb in splitbins:
                if sb.overlaps(otherbin):
                    overlap_other[otherbin].add(sb)
                    overlap_self[sb].add(otherbin)
        return splitbins

    for self, other, binlist_self, binlist_other, overlap_self, overlap_other in \
            ((a, b, bins_a, bins_b, ov_a, ov_b),
             (b, a, bins_b, bins_a, ov_b, ov_a)):
        to_remove = []
        for bin in binlist_self:
            split = False
            points = set(p for otherbin in overlap_self[bin] for p in otherbin.overlap_points(bin))
            if points:
                splitbins = split_bin(bin, points, overlap_self, overlap_other)
                if splitbins:
                    binlist_self.extend(splitbins)
                    to_remove.append(bin)
                    if inplace:
                        self.remove_bin(bin)
                        self.add_bins(splitbins)

                # Now split other bins
                for otherbin in overlap_self[bin]:
                    splitbins = split_bin(otherbin, points, overlap_other, overlap_self)
                    if splitbins:
                        binlist_other.extend(splitbins)
                        binlist_other.remove(otherbin)
                        del overlap_other[otherbin]
                        if inplace:
                            other.remove_bin(otherbin)
                            other.add_bins(splitbins)
                del overlap_self[bin]

        for r in to_remove:
            binlist_self.remove(r)

    if not inplace:
        a, b = BinnedObject(a.ndim, bins_a), BinnedObject(a.ndim, bins_b)

    if return_overlap:
        return a, b, ov_a, ov_b
    else:
        return a, b