from __future__ import absolute_import

import unittest
from .interval_overlap import interval_overlap

class TestIntervals(unittest.TestCase):
    def setUp(self):
        pass

    def test_edge_case(self):
        bin1 = (0, 1)
        bin2 = (1, 2)
        ov1, ov2 = interval_overlap([bin1], [bin2])
        return self.assertEqual(ov1, {bin1 : set([])})
        return self.assertEqual(ov2, {bin2 : set([])})

    def test_simple(self):
        bin1 = (0, 2)
        bin2 = (1, 3)
        ov1, ov2 = interval_overlap([bin1], [bin2])
        return self.assertEqual(ov1, {bin1 : set([bin2])})
        return self.assertEqual(ov2, {bin2 : set([bin1])})

    # assertItemsEqual has been renamed to assertCountEqual in Python 3
    # provide it under the Python 3 name also for Python 2
    if not hasattr(unittest.TestCase, "assertCountEqual"):
        assertCountEqual = unittest.TestCase.assertItemsEqual

    def test_general(self):
        from itertools import combinations, product, product

        # do not enlarge this :)
        N = 4

        def powerset():
            for length in range(N * (N - 1) // 2 + 1):
                for x in combinations(combinations(range(N), 2), length):
                    yield x

        def overlap(ai, bi):
            if ai[0] <= bi[0] < ai[1] or bi[0] <= ai[0] < bi[1]:
                return True
            else:
                return False

        for a, b in product(powerset(), repeat=2):
            ov_a, ov_b = interval_overlap(a, b)

            # compare result with a dumb brute force approach
            naive_ov_a = {}
            naive_ov_b = {}
            for ai in a:
                naive_ov_a[ai] = set()
            for bi in b:
                naive_ov_b[bi] = set()
            for ai, bi in product(a, b):
                if overlap(ai, bi):
                    if bi not in naive_ov_a[ai]:
                        naive_ov_a[ai].add(bi)
                    if ai not in naive_ov_b[bi]:
                        naive_ov_b[bi].add(ai)

            for naive, ov in zip([naive_ov_a, naive_ov_b], [ov_a, ov_b]):
                self.assertCountEqual(naive.keys(), ov.keys())
                for k in ov.keys():
                    self.assertCountEqual(naive[k], ov[k],
                        msg="\nTest with: a = {0}, b = {1}\nnaive: {2}\ninterval_overlap: {3}\n{4}\n{5}".format(
                                    a, b, naive, ov, sorted(naive.items()), sorted(ov.items())))