from __future__ import absolute_import
from __future__ import print_function

def interval_overlap(a, b, getitem=None):
    """
    Given two list of intervals, return what interval in a
    overlaps what intervals of b

    Parameter
    ---------
    a, b : List of intervals (2-tuple like objects)

    getitem (optional) : A function
        If given, call getitem(ai, 0) and getitem(ai, 1) for each element
        ai of a instead of ai[0], ai[1].

    Returns
    -------

    Two dictionaries, giving for each interval in a the list
    of intervals in b that overlap with that one, and vice-versa.


    Example
    -------
    >>> interval_overlap([(0, 4), (2, 6)], [(1, 3), (5, 7)]) == \
    ({(0, 4) : set([(1, 3)]),            \
      (2, 6) : set([(1, 3), (5, 7)])},   \
                                         \
     {(1, 3) : set([(0, 4), (2, 6)]),    \
      (5, 7) : set([(2, 6)])})
    True
    """

    A_END = 0
    B_END = 1
    A_START = 2
    B_START = 3

    if getitem is None:
        getitem = lambda q, n : q[n]

    # O(n log(n))
    coords = sorted(sum(
        [((getitem(i,0), A_START, n), (getitem(i,1), A_END, n)) for n, i in enumerate(a)] + \
        [((getitem(i,0), B_START, n), (getitem(i,1), B_END, n)) for n, i in enumerate(b)],
        ()
    ))

    a_overlaps = {}
    b_overlaps = {}
    b_visible = set()
    a_visible = set()

    for x, kind, n in coords:
        if kind == A_END:
            item = a[n]
            a_visible.remove(item)

        if kind == B_END:
            item = b[n]
            b_visible.remove(item)

        if kind == A_START:
            item = a[n]
            a_visible.add(item)
            a_overlaps[item] = set(b_visible)
            for bi in b_visible:
                b_overlaps[bi].add(item)

        if kind == B_START:
            item = b[n]
            b_visible.add(item)
            b_overlaps[item] = set(a_visible)
            for ai in a_visible:
                a_overlaps[ai].add(item)

    return a_overlaps, b_overlaps