bin.py
f3dedd8d
 # coding=utf-8
 #
 # The MIT License (MIT)
 #
 # Copyright (c) 2014 Lorenz Hüdepohl, Emmanuel Stamou
 #
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to deal
 # in the Software without restriction, including without limitation the rights
 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 # copies of the Software, and to permit persons to whom the Software is
 # furnished to do so, subject to the following conditions:
 #
 # The above copyright notice and this permission notice shall be included in
 # all copies or substantial portions of the Software.
 #
 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 # THE SOFTWARE.
 
22348233
 from __future__ import absolute_import
 from __future__ import print_function
 
f3dedd8d
 from .arithmetic_decorators import add_operators
 
 def projected(t, axis):
     t = tuple(t)
     return t[:axis] + t[axis+1:]
 
 def bin_operator(func):
     def wrapped_bin_operator(self, other):
         if hasattr(other, "low"):
             if not self.compatible(other):
                 raise Exception("Bins not compatible")
             other_value = other.value
         else:
             other_value = other
         return Bin(self.low, self.high, func(self.value, other_value))
     return wrapped_bin_operator
 
22348233
 def bin_hash(abin):
     """
     Two bins are equal if they are compatible and have the same value,
     that is if all the attributes low, high and value are the same
     """
     return hash((abin.low, abin.high, abin.value))
 
 @add_operators(bin_operator, bin_hash)
f3dedd8d
 class Bin(object):
3f7b7964
     def __init__(self, *args):
         """
         Construct a new bin
 
           bin = Bin(low, high, value)
           bin = Bin(otherbin)
 
         """
         if len(args) == 1:
             otherbin, = args
             low, high, value = otherbin.low, otherbin.high, otherbin.value
         elif len(args) == 3:
             low, high, value = args
f3dedd8d
 
         if not len(low) == len(high):
             raise Exception("len(low) != len(high)")
 
         if not all(l < h for l, h in zip(low, high)):
             raise Exception("low < high not fullfilled")
 
         self.low = tuple(low)
         self.high = tuple(high)
         self.value = value
         self.ndim = len(self.low)
 
     def center(self):
22348233
         """
         Returns the center of the bin
 
         Example
         -------
         >>> b = Bin([0., 0.], [10., 10.], 1.0)
         >>> b.center()
         (5.0, 5.0)
         """
         return tuple(map(lambda p : 0.5 * (p[1] + p[0]), zip(self.low, self.high)))
f3dedd8d
 
     def half_width(self):
22348233
         """
         Returns the half-width of the bin
 
         Example
         -------
         >>> b = Bin([0., 0.], [20., 40.], 1.0)
         >>> b.half_width()
         (10.0, 20.0)
         """
         return tuple(map(lambda p : 0.5 * (p[1] - p[0]), zip(self.low, self.high)))
f3dedd8d
 
     def __str__(self):
         ranges = ", ".join("[{0} - {1}]".format(a,b) for a,b in zip(self.low, self.high))
         return "{0} = {1}".format(ranges, self.value)
 
     def __repr__(self):
         return "Bin({0.low}, {0.high}, {0.value})".format(self)
 
     def compatible(self, other):
         if self.low != other.low:
             return False
         if self.high != other.high:
             return False
         return True
 
795ed4c7
     def __eq__(self, other):
         if not self.ndim == other.ndim:
             raise TypeError("Cannot compare bins with different dimensions!")
         if not self.compatible(other):
             return False
         return self.value == other.value
 
     def __bool__(self):
         # Python 3
         return bool(self.value)
 
     def __nonzero__(self):
         # Python 2
         return self.__bool__()
 
f3dedd8d
     def projected(self, axis):
         return Bin(projected(self.low, axis), projected(self.high, axis), self.value)
 
     def __contains__(self, point):
         """
         Returns True if point is inside or on the surface of the bin's domain
 
         Example
         -------
         >>> b = Bin([0., 0.], [1., 1.], 1.0)
22348233
         >>> print(b)
f3dedd8d
         [0.0 - 1.0], [0.0 - 1.0] = 1.0
 
22348233
         >>> (0.5, 0.5) in b
f3dedd8d
         True
 
22348233
         >>> (0.5, 1.0) in b
f3dedd8d
         True
 
22348233
         >>> (0.5, 1.1) in b
f3dedd8d
         False
 
22348233
         >>> (0.01, 0.99) in b
f3dedd8d
         True
         """
 
         if not len(point) == self.ndim:
             raise Exception("Must specify a {0}-dimensional point for this bin!".format(self.ndim))
         return all(l <= point[n] <= h for (n, (l,h)) in enumerate(zip(self.low, self.high)))
 
983ecc0a
     def point_inside(self, point):
         """
         Returns True if point is strictly inside of the bin's domain
 
         Example
         -------
         >>> b = Bin([0., 0.], [1., 1.], 1.0)
22348233
         >>> print(b)
983ecc0a
         [0.0 - 1.0], [0.0 - 1.0] = 1.0
 
22348233
         >>> b.point_inside((0.5, 0.5))
983ecc0a
         True
 
22348233
         >>> b.point_inside((0.5, 1.0))
983ecc0a
         False
 
22348233
         >>> b.point_inside((1e-5, 0.9999))
983ecc0a
         True
         """
 
         if not len(point) == self.ndim:
             raise Exception("Must specify a {0}-dimensional point for this bin!".format(self.ndim))
         return all(l < point[n] < h for (n, (l,h)) in enumerate(zip(self.low, self.high)))
 
f3dedd8d
     def split_on_axis(self, axis, coord):
         """
         Split-up a bin into two, along given value "coord" on axis "axis".
 
         New bin values are obtained as volume-weighted fractions of the original value
 
         Example
         -------
         >>> b = Bin([0.], [1.], 1.0)
22348233
         >>> print(b)
f3dedd8d
         [0.0 - 1.0] = 1.0
 
         >>> b1, b2 = b.split_on_axis(0, 0.25)
22348233
         >>> print(b1)
f3dedd8d
         [0.0 - 0.25] = 0.25
 
22348233
         >>> print(b2)
f3dedd8d
         [0.25 - 1.0] = 0.75
 
         >>> b = Bin([0., 0.], [1., 1.], 1.0)
22348233
         >>> print(b)
f3dedd8d
         [0.0 - 1.0], [0.0 - 1.0] = 1.0
 
         >>> b1, b2 = b.split_on_axis(1, 0.25)
22348233
         >>> print(b1)
f3dedd8d
         [0.0 - 1.0], [0.0 - 0.25] = 0.25
 
22348233
         >>> print(b2)
f3dedd8d
         [0.0 - 1.0], [0.25 - 1.0] = 0.75
 
         """
         if not self.low[axis] <= coord <= self.high[axis]:
             raise Exception("Value {0} is outside bin in axis {1}".format(coord, axis))
 
         w = self.high[axis] - self.low[axis]
         l1 = (coord -  self.low[axis]) / w
         l2 = (self.high[axis] - coord) / w
 
         if 0 < l1 < 1e-7 or 0 < l2 < 1e-7:
             import warnings
3f7b7964
             warnings.warn(("Split into extremely degenerate bins along dimension {0}, "
                            "check your bin boundaries:\n"
                            "  {1:20} {2:20} {3:20}\n  {4!r:20} {5!r:20} {6!r:20}\n").format(
                                 axis,
                                 "left", "split", "right",
                                 self.low[axis], coord, self.high[axis]),
f3dedd8d
                           stacklevel=2)
 
         if not abs(l1 + l2 - 1.0) < 1e-11:
             raise Exception("Something wrong here")
 
         v1 = float(l1) * self.value
         v2 = float(l2) * self.value
 
         if l1 > 0:
             b1 = Bin(self.low,
                      self.high[:axis] + (coord,) + self.high[axis+1:],
                      v1)
         if l2 > 0:
             b2 = Bin(self.low[:axis] + (coord,) + self.low[axis+1:],
                      self.high,
                      v2)
         if l1 > 0 and l2 > 0:
             return b1, b2
         elif l1 == 1.0 or l2 == 1.0:
             return self,
         else:
             raise Exception("Something wrong in split_in_axis")
 
     def split(self, point):
         """
         Divide a bin into 2^self.ndim bins around the given point within the bin
 
         New bin values are obtained as volume-weighted fractions of the original value
 
         Example
         -------
         >>> b = Bin([0., 0.], [1., 1.], 1.0)
22348233
         >>> print(b)
f3dedd8d
         [0.0 - 1.0], [0.0 - 1.0] = 1.0
 
         >>> b1, b2, b3, b4 = b.split([0.25, 0.25])
22348233
         >>> print(b1)
f3dedd8d
         [0.0 - 0.25], [0.0 - 0.25] = 0.0625
 
22348233
         >>> print(b2)
f3dedd8d
         [0.0 - 0.25], [0.25 - 1.0] = 0.1875
 
22348233
         >>> print(b3)
f3dedd8d
         [0.25 - 1.0], [0.0 - 0.25] = 0.1875
 
22348233
         >>> print(b4)
f3dedd8d
         [0.25 - 1.0], [0.25 - 1.0] = 0.5625
 
         >>> b.split([0.25, 1.5])
         Traceback (most recent call last):
           ...
         Exception: Point not contained in bin!
         """
         if not point in self:
             raise Exception("Point not contained in bin!".format(self.ndim))
 
         subbins = [self]
         for (axis, coord) in enumerate(point):
             new_subbins = []
             for bin in subbins:
                 new_subbins.extend(bin.split_on_axis(axis, coord))
             subbins = new_subbins
 
         return tuple(subbins)
 
     def vertices(self):
         from itertools import product
         return tuple(product(*zip(self.low, self.high)))
 
     def volume(self):
         from operator import mul
         return Bin(self.low, self.high, reduce(mul, (h - l for (l,h) in zip(self.low, self.high))))
 
     def overlaps(self, other):
         """
         Return true if the domain of self overlaps with the domain of other
 
         Example
         -------
         >>> bin1 = Bin([0.], [0.25], 0.25)
         >>> bin2 = Bin([0.25], [1.0], 0.75)
         >>> bin3 = Bin([0.0], [1.0], 1.0)
         >>> bin1.overlaps(bin2)
         False
         >>> bin1.overlaps(bin3)
         True
         """
         if not self.ndim == other.ndim:
             raise Exception("Cannot operate on Bins with different dimension!")
         for n in range(self.ndim):
             if not(self.low[n] <= other.low[n] < self.high[n] or other.low[n] <= self.low[n] < other.high[n]):
                 return False
         return True
 
795ed4c7
     def overlap_fraction(self, other):
         """
         Return the volume ratio between the overlap of self and other and self itself.
 
         Examples
         --------
 
         >>> bin1 = Bin([0.], [1.0], 0.5)
         >>> bin2 = Bin([0.5], [4.5], 1.)
         >>> bin3 = Bin([1.0], [2.0], 1.)
         >>> bin1.overlap_fraction(bin2)
         0.5
         >>> bin2.overlap_fraction(bin1)
         0.125
         >>> bin3.overlap_fraction(bin1)
         0.0
         """
         if not self.ndim == other.ndim:
             raise Exception("Cannot operate on Bins with different dimension!")
 
         def clip(a):
             if a < self.low[n]:
                 return self.low[n]
             elif self.low[n] <= a <= self.high[n]:
                 return a
             elif self.high[n] < a:
                 return self.high[n]
             else:
                 raise Exception("Tertium datur")
 
         r = 1.0
 
         for n in range(self.ndim):
             r *= (clip(other.high[n]) - clip(other.low[n])) / (self.high[n] - self.low[n])
 
         return r
 
f3dedd8d
     def overlap_points(self, other):
         """
         Return the edge points of the intersection hypercube
         of self and other -- provided self and other overlap --
         that are not coincident with both an edge point of self
         and other at the same time
         """
         from itertools import product
         vertices = set(self.vertices()).intersection(set(other.vertices()))
         low  = map(max, zip(self.low, other.low))
         high = map(min, zip(self.high, other.high))
         return filter(lambda p : p not in vertices, product(*zip(low, high)))