from __future__ import absolute_import

import unittest
from . import Bin, BinnedObject, conforming_rebin

def powerset(M):
    from itertools import combinations
    for length in range(len(M) + 1):
        for x in combinations(M, length):
            yield x

def exact_cover(X, Y, solution=[]):
    if not X:
        yield list(solution)
    else:
        c = min(X, key=lambda c: len(X[c]))
        for r in list(X[c]):
            solution.append(r)
            cols = select(X, Y, r)
            for s in exact_cover(X, Y, solution):
                yield s
            deselect(X, Y, r, cols)
            solution.pop()

def select(X, Y, r):
    cols = []
    for j in Y[r]:
        for i in X[j]:
            for k in Y[i]:
                if k != j:
                    X[k].remove(i)
        cols.append(X.pop(j))
    return cols

def deselect(X, Y, r, cols):
    for j in reversed(Y[r]):
        X[j] = cols.pop()
        for i in X[j]:
            for k in Y[i]:
                if k != j:
                    X[k].add(i)

def hypercube_partitions(N, ndim, f=None):
    from itertools import product
    from numpy import arange

    if f is None:
        def f(x0, x1):
            return 0.

    def possible_bins():
        for lengths in product(range(N), repeat=ndim):
            for startpoint in product(*(range(N - l) for l in lengths)):
                yield startpoint, tuple(s + l for s, l in zip(startpoint, lengths))

    def gridnumber(point):
        return sum(p * N**n for n, p in enumerate(point))

    def iterate_points(quad):
        for p in product(*(range(l, h + 1) for l, h in zip(*quad))):
            yield gridnumber(p)

    quads = {}
    Y = {}
    bins = {}
    from operator import mul
    for bin_name, quad in enumerate(possible_bins()):
        Y[bin_name] = list(iterate_points(quad))
        quads[bin_name] = quad

    X = {}
    for point_number in range(N**ndim):
        X[point_number] = set()

    for bin_name in Y:
        for point_number in Y[bin_name]:
            X[point_number].add(bin_name)

    boundaries = arange(float(N+1))

    def create_bin(bin_name):
        quad = quads[bin_name]
        return Bin(list(boundaries[q] for q in quad[0]),
                   list(boundaries[q+1] for q in quad[1]), 0.0)

    for sol in exact_cover(X, Y):
        binlist = []
        for bin_name in sol:
            bin = create_bin(bin_name)
            bin.value = f(bin.low, bin.high)
            binlist.append(bin)
        assert(binlist)
        yield BinnedObject(ndim, binlist)

class TestBinningModules(unittest.TestCase):
    def setUp(self):
        pass

    def test_find_bin(self):
        Dmax = 3
        Nmax = 2

        for dim in range(1, Dmax + 1):
            for N in range(1, Nmax + 1):
                for bo in hypercube_partitions(N, dim):
                    for bin in bo:
                        self.assertEqual(bin, bo.find_bin(bin.center()))

    def test_conforming_rebin_1D(self):
        from itertools import product
        N = 5

        M = [float(i) for i in range(1, N)]

        def int1(x0, x1):
            """
            Integrate 1
            """
            return x1 - x0

        def int2(x0, x1):
            """
            Integrate x^3
            """
            return x1**4/4. - x0**4/4.

        for subset1, subset2 in product(powerset(M), repeat=2):
            subset1 = subset1 + (N,)
            subset2 = subset2 + (N,)
            bins_1 = [Bin([0.], [subset1[0]], int1(0., subset1[0]))]
            bins_2 = [Bin([0.], [subset2[0]], int2(0., subset2[0]))]

            for bins, subset, f in ((bins_1, subset1, int1), (bins_2, subset2, int2)):
                for x in subset[1:]:
                    x0 = bins[-1].high[0]
                    bins.append(Bin([x0], [x], f(x0, x)))

            bo1 = BinnedObject(1, bins_1)
            bo2 = BinnedObject(1, bins_2)

            self.assertEqual(bo1.sum(), int1(0, N))
            self.assertEqual(bo2.sum(), int2(0, N))

            bo1, bo2 = conforming_rebin(bo1, bo2)

            self.assertEqual(bo1.sum(), int1(0, N))
            self.assertEqual(bo2.sum(), int2(0, N))
            self.assertEqual((bo1 + bo2).sum(), int1(0, N) + int2(0, N))

    def test_conforming_rebin_multi_d(self):
        from itertools import product

        def f1(low, high):
            """
            Integrate \int x0 x1 .. xN dx0 ... dxN
            """
            res = 1
            for l, h in zip(low, high):
                res *= (h - l)
            return res

        def f2(low, high):
            """
            Integrate \int x0^3 x1^3 .. xN^3 dx0 ... dxN
            """
            res = 1
            for l, h in zip(low, high):
                res *= (h**4/4. - l**4/4.)
            return res

        Dmax = 3
        Nmax = 2

        for dim in range(1, Dmax + 1):
            for N in range(1, Nmax + 1):
                for a, b in product(hypercube_partitions(N, dim, f1), hypercube_partitions(N, dim, f2)):
                    self.assertTrue(a.lows)
                    self.assertTrue(b.lows)

                    expected_1 = f1([0.] * dim, [float(N)] * dim)
                    expected_2 = f2([0.] * dim, [float(N)] * dim)

                    self.assertEqual(a.sum(), expected_1)
                    self.assertEqual(b.sum(), expected_2)

                    a_rebin, b_rebin = conforming_rebin(a, b)

                    self.assertEqual(a_rebin.sum(), expected_1)
                    self.assertEqual(b_rebin.sum(), expected_2)
                    self.assertEqual((a_rebin + b_rebin).sum(), expected_1 + expected_2)

                    conforming_rebin(a, b, inplace=True)

                    self.assertTrue((a == a_rebin).all())
                    self.assertTrue((b == b_rebin).all())

    def test_sum(self):
        from itertools import combinations
        Dmax = 3
        Nmax = 2

        def f1(low, high):
            """
            Integrate \int x0 x1 .. xN dx0 ... dxN
            """
            res = 1
            for l, h in zip(low, high):
                res *= (h - l)
            return res

        def f1_partial(sumdims, N, lowdim_low, lowdim_high):
            j = 0
            low = []
            high = []
            for i in range(dim):
                if i in sumdims:
                    low.append(0.0)
                    high.append(float(N))
                else:
                    low.append(lowdim_low[j])
                    high.append(lowdim_high[j])
                    j = j + 1
            return f1(low, high)

        for dim in range(1, Dmax + 1):
            dims = range(dim)
            for N in range(1, Nmax + 1):
                for bo in hypercube_partitions(N, dim, f1):
                    # total sum
                    self.assertEqual(bo.sum(), f1((0.,) * bo.ndim, (N,) * bo.ndim))

                    # partial sums
                    for length in range(1, len(dims)):
                        for sumdims in combinations(dims, length):
                            summed = bo.sum(sumdims)

                            recreated = BinnedObject(summed.ndim)
                            for bin in summed:
                                recreated_bin = Bin(bin)
                                recreated_bin.value = f1_partial(sumdims, N, bin.low, bin.high)
                                recreated.add_bin(recreated_bin)

                            self.assertTrue((summed == recreated).all())
                            self.assertEqual(bo.sum(), summed.sum())
                            self.assertEqual(bo.sum(), recreated.sum())

    def test_pickle(self):
        try:
            from cPickle import dumps, loads
        except ImportError:
            from pickle import dumps, loads
        Dmax = 2
        Nmax = 2

        for dim in range(1, Dmax + 1):
            for N in range(1, Nmax + 1):
                for bo in hypercube_partitions(N, dim):
                    serial = dumps(bo)
                    self.assertTrue((bo == loads(serial)).all())