# coding=utf-8
#
# The MIT License (MIT)
#
# Copyright (c) 2014 Lorenz Hüdepohl, Emmanuel Stamou
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

from __future__ import absolute_import
from __future__ import print_function

from .arithmetic_decorators import add_operators

def projected(t, axis):
    t = tuple(t)
    return t[:axis] + t[axis+1:]

def bin_operator(func):
    def wrapped_bin_operator(self, other):
        if hasattr(other, "low"):
            if not self.compatible(other):
                raise Exception("Bins not compatible")
            other_value = other.value
        else:
            other_value = other
        return Bin(self.low, self.high, func(self.value, other_value))
    return wrapped_bin_operator

def bin_hash(abin):
    """
    Two bins are equal if they are compatible and have the same value,
    that is if all the attributes low, high and value are the same
    """
    return hash((abin.low, abin.high, abin.value))

@add_operators(bin_operator, bin_hash)
class Bin(object):
    def __init__(self, *args):
        """
        Construct a new bin

          bin = Bin(low, high, value)
          bin = Bin(otherbin)

        """
        if len(args) == 1:
            otherbin, = args
            low, high, value = otherbin.low, otherbin.high, otherbin.value
        elif len(args) == 3:
            low, high, value = args

        if not len(low) == len(high):
            raise Exception("len(low) != len(high)")

        if not all(l < h for l, h in zip(low, high)):
            raise Exception("low < high not fullfilled")

        self.low = tuple(low)
        self.high = tuple(high)
        self.value = value
        self.ndim = len(self.low)

    def center(self):
        """
        Returns the center of the bin

        Example
        -------
        >>> b = Bin([0., 0.], [10., 10.], 1.0)
        >>> b.center()
        (5.0, 5.0)
        """
        return tuple(map(lambda p : 0.5 * (p[1] + p[0]), zip(self.low, self.high)))

    def half_width(self):
        """
        Returns the half-width of the bin

        Example
        -------
        >>> b = Bin([0., 0.], [20., 40.], 1.0)
        >>> b.half_width()
        (10.0, 20.0)
        """
        return tuple(map(lambda p : 0.5 * (p[1] - p[0]), zip(self.low, self.high)))

    def __str__(self):
        ranges = ", ".join("[{0} - {1}]".format(a,b) for a,b in zip(self.low, self.high))
        return "{0} = {1}".format(ranges, self.value)

    def __repr__(self):
        return "Bin({0.low}, {0.high}, {0.value})".format(self)

    def compatible(self, other):
        if self.low != other.low:
            return False
        if self.high != other.high:
            return False
        return True

    def __eq__(self, other):
        if not self.ndim == other.ndim:
            raise TypeError("Cannot compare bins with different dimensions!")
        if not self.compatible(other):
            return False
        return self.value == other.value

    def __bool__(self):
        # Python 3
        return bool(self.value)

    def __nonzero__(self):
        # Python 2
        return self.__bool__()

    def projected(self, axis):
        return Bin(projected(self.low, axis), projected(self.high, axis), self.value)

    def __contains__(self, point):
        """
        Returns True if point is inside or on the surface of the bin's domain

        Example
        -------
        >>> b = Bin([0., 0.], [1., 1.], 1.0)
        >>> print(b)
        [0.0 - 1.0], [0.0 - 1.0] = 1.0

        >>> (0.5, 0.5) in b
        True

        >>> (0.5, 1.0) in b
        True

        >>> (0.5, 1.1) in b
        False

        >>> (0.01, 0.99) in b
        True
        """

        if not len(point) == self.ndim:
            raise Exception("Must specify a {0}-dimensional point for this bin!".format(self.ndim))
        return all(l <= point[n] <= h for (n, (l,h)) in enumerate(zip(self.low, self.high)))

    def point_inside(self, point):
        """
        Returns True if point is strictly inside of the bin's domain

        Example
        -------
        >>> b = Bin([0., 0.], [1., 1.], 1.0)
        >>> print(b)
        [0.0 - 1.0], [0.0 - 1.0] = 1.0

        >>> b.point_inside((0.5, 0.5))
        True

        >>> b.point_inside((0.5, 1.0))
        False

        >>> b.point_inside((1e-5, 0.9999))
        True
        """

        if not len(point) == self.ndim:
            raise Exception("Must specify a {0}-dimensional point for this bin!".format(self.ndim))
        return all(l < point[n] < h for (n, (l,h)) in enumerate(zip(self.low, self.high)))

    def split_on_axis(self, axis, coord):
        """
        Split-up a bin into two, along given value "coord" on axis "axis".

        New bin values are obtained as volume-weighted fractions of the original value

        Example
        -------
        >>> b = Bin([0.], [1.], 1.0)
        >>> print(b)
        [0.0 - 1.0] = 1.0

        >>> b1, b2 = b.split_on_axis(0, 0.25)
        >>> print(b1)
        [0.0 - 0.25] = 0.25

        >>> print(b2)
        [0.25 - 1.0] = 0.75

        >>> b = Bin([0., 0.], [1., 1.], 1.0)
        >>> print(b)
        [0.0 - 1.0], [0.0 - 1.0] = 1.0

        >>> b1, b2 = b.split_on_axis(1, 0.25)
        >>> print(b1)
        [0.0 - 1.0], [0.0 - 0.25] = 0.25

        >>> print(b2)
        [0.0 - 1.0], [0.25 - 1.0] = 0.75

        """
        if not self.low[axis] <= coord <= self.high[axis]:
            raise Exception("Value {0} is outside bin in axis {1}".format(coord, axis))

        w = self.high[axis] - self.low[axis]
        l1 = (coord -  self.low[axis]) / w
        l2 = (self.high[axis] - coord) / w

        if 0 < l1 < 1e-7 or 0 < l2 < 1e-7:
            import warnings
            warnings.warn(("Split into extremely degenerate bins along dimension {0}, "
                           "check your bin boundaries:\n"
                           "  {1:20} {2:20} {3:20}\n  {4!r:20} {5!r:20} {6!r:20}\n").format(
                                axis,
                                "left", "split", "right",
                                self.low[axis], coord, self.high[axis]),
                          stacklevel=2)

        if not abs(l1 + l2 - 1.0) < 1e-11:
            raise Exception("Something wrong here")

        v1 = float(l1) * self.value
        v2 = float(l2) * self.value

        if l1 > 0:
            b1 = Bin(self.low,
                     self.high[:axis] + (coord,) + self.high[axis+1:],
                     v1)
        if l2 > 0:
            b2 = Bin(self.low[:axis] + (coord,) + self.low[axis+1:],
                     self.high,
                     v2)
        if l1 > 0 and l2 > 0:
            return b1, b2
        elif l1 == 1.0 or l2 == 1.0:
            return self,
        else:
            raise Exception("Something wrong in split_in_axis")

    def split(self, point):
        """
        Divide a bin into 2^self.ndim bins around the given point within the bin

        New bin values are obtained as volume-weighted fractions of the original value

        Example
        -------
        >>> b = Bin([0., 0.], [1., 1.], 1.0)
        >>> print(b)
        [0.0 - 1.0], [0.0 - 1.0] = 1.0

        >>> b1, b2, b3, b4 = b.split([0.25, 0.25])
        >>> print(b1)
        [0.0 - 0.25], [0.0 - 0.25] = 0.0625

        >>> print(b2)
        [0.0 - 0.25], [0.25 - 1.0] = 0.1875

        >>> print(b3)
        [0.25 - 1.0], [0.0 - 0.25] = 0.1875

        >>> print(b4)
        [0.25 - 1.0], [0.25 - 1.0] = 0.5625

        >>> b.split([0.25, 1.5])
        Traceback (most recent call last):
          ...
        Exception: Point not contained in bin!
        """
        if not point in self:
            raise Exception("Point not contained in bin!".format(self.ndim))

        subbins = [self]
        for (axis, coord) in enumerate(point):
            new_subbins = []
            for bin in subbins:
                new_subbins.extend(bin.split_on_axis(axis, coord))
            subbins = new_subbins

        return tuple(subbins)

    def vertices(self):
        from itertools import product
        return tuple(product(*zip(self.low, self.high)))

    def volume(self):
        from operator import mul
        return Bin(self.low, self.high, reduce(mul, (h - l for (l,h) in zip(self.low, self.high))))

    def overlaps(self, other):
        """
        Return true if the domain of self overlaps with the domain of other

        Example
        -------
        >>> bin1 = Bin([0.], [0.25], 0.25)
        >>> bin2 = Bin([0.25], [1.0], 0.75)
        >>> bin3 = Bin([0.0], [1.0], 1.0)
        >>> bin1.overlaps(bin2)
        False
        >>> bin1.overlaps(bin3)
        True
        """
        if not self.ndim == other.ndim:
            raise Exception("Cannot operate on Bins with different dimension!")
        for n in range(self.ndim):
            if not(self.low[n] <= other.low[n] < self.high[n] or other.low[n] <= self.low[n] < other.high[n]):
                return False
        return True

    def overlap_fraction(self, other):
        """
        Return the volume ratio between the overlap of self and other and self itself.

        Examples
        --------

        >>> bin1 = Bin([0.], [1.0], 0.5)
        >>> bin2 = Bin([0.5], [4.5], 1.)
        >>> bin3 = Bin([1.0], [2.0], 1.)
        >>> bin1.overlap_fraction(bin2)
        0.5
        >>> bin2.overlap_fraction(bin1)
        0.125
        >>> bin3.overlap_fraction(bin1)
        0.0
        """
        if not self.ndim == other.ndim:
            raise Exception("Cannot operate on Bins with different dimension!")

        def clip(a):
            if a < self.low[n]:
                return self.low[n]
            elif self.low[n] <= a <= self.high[n]:
                return a
            elif self.high[n] < a:
                return self.high[n]
            else:
                raise Exception("Tertium datur")

        r = 1.0

        for n in range(self.ndim):
            r *= (clip(other.high[n]) - clip(other.low[n])) / (self.high[n] - self.low[n])

        return r

    def overlap_points(self, other):
        """
        Return the edge points of the intersection hypercube
        of self and other -- provided self and other overlap --
        that are not coincident with both an edge point of self
        and other at the same time
        """
        from itertools import product
        vertices = set(self.vertices()).intersection(set(other.vertices()))
        low  = map(max, zip(self.low, other.low))
        high = map(min, zip(self.high, other.high))
        return filter(lambda p : p not in vertices, product(*zip(low, high)))