#!/usr/bin/python
#
# Copyright (C) 2007  Patrick Brady, Nickolas Fotopoulos
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
#
"""
This code computes the probabilities that go into the distance upper limits
for the CBC external trigger search.

There is a lot of advanced indexing technique here.  Reference:
http://scipy.org/Cookbook/Indexing
"""

from __future__ import division

__author__ = "Nickolas Fotopoulos <nvf@gravity.phys.uwm.edu>"
__prog__ = "pylal_grblikelihood"
__title__ = "GRB interpretation"

import itertools
import optparse
import os.path as op
import cPickle as pickle
import shutil
import sys

import numpy
numpy.seterr(all="raise")  # throw an exception on any funny business

from pylal import CoincInspiralUtils
from pylal import git_version
from pylal import grbsummary
from pylal import InspiralUtils
from pylal import rate

# set all empty trials to have likelihoods of -infinity
GRBL_EMPTY_TRIAL = -numpy.inf

################################################################################
# utility functions

class BackgroundEstimator(object):
    """
    Abstract class to define an interface for background estimation.
    After initializing with loudest (binned) background data, __call__
    becomes active and will return the ln(p(c|0)) (for each bin).
    """
    def __init__(self, loudest_in_bg_trials):
        self.loudest_in_bg_trials = loudest_in_bg_trials

    def __call__(self, loudest_in_trial):
        raise NotImplementedError

class ConstantBackgroundEstimator(BackgroundEstimator):
    """
    BackgroundEstimator that assumes a constant value for loudest_in_trial
    greater than the loudest_bg.
    """
    def __init__(self, loudest_in_bg_trials, eps):
        """
        Initialize ConstantBackgroundEstimator to return a ln(pc0) such that
        pc0[pc0 < eps] = eps. loudest_in_bg_trials can be any shape as long
        as the first axis (axis=0) is the trials axis.  If eps is a vector,
        then we require eps.shape == loudest_in_bg_trials.shape[1:]
        """
        if not numpy.isscalar(eps) \
            and eps.shape != loudest_in_bg_trials.shape[1:]:
            raise ValueError, "eps.shape != loudest_in_bg_trials.shape[1:]"
        self.eps = eps
        BackgroundEstimator.__init__(self, loudest_in_bg_trials)

    def __call__(self, loudest_in_trial):
        """
        Return ln(p(c|0)).  Where p(c|0) < eps, we return eps.
        """
        # the None puts an extra bracket around the list, and the ellipses 
        # don't really change anything (i.e. stand for whetever there is)
        num_louder = (self.loudest_in_bg_trials > loudest_in_trial[None, ...])\
            .sum(axis=0, dtype=float)
        pc0 = num_louder / self.loudest_in_bg_trials.shape[0]
        numpy.putmask(pc0, pc0 < self.eps, self.eps)
        return numpy.log(pc0)

class SimpleExpExtrapEstimator(BackgroundEstimator):
    """
    BackgroundEstimator that takes a fraction 'fac' of the loudest points 
    and extrapolates the p(c|0) distribution based on its value.
    """
    def __init__(self, loudest_in_bg_trials, offsource_loudest_by_trial_cat, fac):
        """
        Initialize SimpleExpExtrapolator with the extrapolation parameters.
        Use the (fac*N)th point from the end as a fiducial point, where
        N is the number of nonempty trials per mchirp bin.
        """
        sorted_loudest = loudest_in_bg_trials.copy()
        sorted_loudest.sort(axis=0)

        self.default_evaluator = \
            ConstantBackgroundEstimator(loudest_in_bg_trials, 1e-10)
        
        num_nonempty_offsource_trials = (offsource_loudest_by_trial_cat>0).sum(axis=0)

        self.fiducial_x = numpy.asarray([ sorted_loudest[-int(fac*N), i] \
                                          for i,N in enumerate(num_nonempty_offsource_trials)])
        self.fiducial_y = self.default_evaluator(self.fiducial_x)

    def __call__(self, loudest_in_trial):
        """
        Return ln(p(c|0)).  Where c is loudest than the Nth loudest, use the
        the simplest possible extrapolation based on the Nth loudest event.
        """
        # Call to ConstantBackgroundEstimator
        log_pc0 = self.default_evaluator(loudest_in_trial)
        # Computes the extrapolation
        extrap = self.fiducial_y + \
            self.fiducial_x**1.5 - loudest_in_trial**1.5
        # and replaces the computed pc0 with the extrapolation if a condition if fullfilled
        numpy.putmask(log_pc0, loudest_in_trial > self.fiducial_x, extrap)
        return log_pc0

class ForegroundEstimator(object):
    """
    Abstract class to define an interface for foreground estimation.
    After initializing with loudest (binned) injection data, __call__
    becomes active and will return the ln(p(c(mc)|h(m2)) (for each bin).
    """
    def __init__(self, inj_loudest_by_inj_mc_m2, num_sims_by_m2):
        raise NotImplementedError

    def __call__(self, trial_loudest_by_mc):
        raise NotImplementedError

class ConstantForegroundEstimator(ForegroundEstimator):
    """
    ConstantForegroundEstimator evaluates the probability of getting events
    louder than any of the given events within each m2 bin.
    """
    def __init__(self, inj_loudest_by_inj_mc, num_sims_by_m2, m2_bin_by_inj,
        weight_by_inj=None, eps=None):
        """
        Initialize ConstantForegroundEstimator to return a ln(p(c|h)) such that
        pch[pch < eps] = eps.  If eps is a vector, then we require
        eps.shape == loudest_in_bg_trials.shape[1:].  If eps
        is None, then clip at 1/2 of the smallest otherwise-possible value.
        """
        self.num_sims_by_m2 = num_sims_by_m2 + 1e-10  # prevent division by 0
        self.m2_bin_by_inj = m2_bin_by_inj
        self.dtype = inj_loudest_by_inj_mc.dtype
        self.shape = inj_loudest_by_inj_mc.shape
        self.num_inj = self.shape[0]
        self.num_m2 = num_sims_by_m2.shape[0]
        self.num_mc = self.shape[1]

        if weight_by_inj is not None:
            # normalize to N (the number of injections)
            self.weight_by_inj = numpy.array(weight_by_inj, copy=True)
            self.weight_by_inj *= self.num_inj / self.weight_by_inj.sum()
        else:
            self.weight_by_inj = numpy.ones(self.num_inj, dtype=int)

        if eps is not None:
            if numpy.isscalar(eps):
                self.eps = self.dtype.type(eps)
            elif eps.shape == self.shape[1:]:
                self.eps = eps
            else:
                raise ValueError, "eps must be None, a scalar, or else "\
                    "eps.shape must be equal to "\
                    "inj_loudest_by_inj_mc.shape[1:]."
        else:
            self.eps = 0.5 * \
                self.weight_by_inj[self.weight_by_inj.nonzero()].min() / \
                num_sims_by_m2[None, :].repeat(self.shape[1], axis=0)

        # repack to speed inner loop by allowing us to call dot()
        self.inj_loudest_by_inj_mc_m2 = \
            numpy.zeros(self.shape + (self.num_m2,), dtype=self.dtype)
        self.inj_loudest_by_inj_mc_m2[numpy.arange(self.num_inj, dtype=int),
                                      :, m2_bin_by_inj] = inj_loudest_by_inj_mc

    def __call__(self, trial_loudest_by_mc):
        """
        Return ln(p(c|h)).  Where p(c|h) < eps, we return eps.
        """
        louder_by_inj_mc_m2 = \
            self.inj_loudest_by_inj_mc_m2 > trial_loudest_by_mc[None, :, None]

        # Do a weighted count
        # NB: self.weight_by_inj is already normalized in __init__
        # NB: Use the BLAS call dot() to speed up this multiplication and sum:
        #     num_louder_by_mc_m2 = (louder_by_inj_mc_m2 \
        #        * self.weight_by_inj[:, None, None]).sum(axis=0, dtype=float)
        num_louder_by_mc_m2 = \
            numpy.dot(self.weight_by_inj,
                      louder_by_inj_mc_m2\
                      .reshape((self.num_inj, self.num_mc * self.num_m2))\
                      .astype(self.weight_by_inj.dtype))\
            .reshape((self.num_mc, self.num_m2))
        pch_by_mc_m2 = num_louder_by_mc_m2 / self.num_sims_by_m2[None, :]
        numpy.putmask(pch_by_mc_m2, pch_by_mc_m2 < self.eps, self.eps)

        return numpy.log(pch_by_mc_m2)

def amap(func, arr):
    """
    Array map.  Loop over the first dimension of an array and call func
    on each of the subarrays, returning the result as an array.
    """
    return numpy.array(map(func, arr))

def apply_along_axis(func1d,axis,arr,*args):
    """
    NB: Taken directly from numpy 1.3.0. This function is broken in 1.0 and
        1.0.4. The numpy imports and this note are the only modifications.

    Apply function to 1-D slices along the given axis.

    Execute `func1d(a[i],*args)` where `func1d` takes 1-D arrays, `a` is
    the input array, and `i` is an integer that varies in order to apply the
    function along the given axis for each 1-D subarray in `a`.

    Parameters
    ----------
    func1d : function
        This function should be able to take 1-D arrays. It is applied to 1-D
        slices of `a` along the specified axis.
    axis : integer
        Axis along which `func1d` is applied.
    a : ndarray
        Input array.
    args : any
        Additional arguments to `func1d`.

    Returns
    -------
    out : ndarray
        The output array. The shape of `out` is identical to the shape of `a`,
        except along the `axis` dimension, whose length is equal to the size
        of the return value of `func1d`.

    See Also
    --------
    apply_over_axes : Apply a function repeatedly over multiple axes.

    Examples
    --------
    >>> def my_func(a):
    ...     \"\"\"Average first and last element of a 1-D array\"\"\"
    ...     return (a[0] + a[-1]) * 0.5
    >>> b = np.array([[1,2,3], [4,5,6], [7,8,9]])
    >>> np.apply_along_axis(my_func, 0, b)
    array([4., 5., 6.])
    >>> np.apply_along_axis(my_func, 1, b)
    array([2., 5., 8.])

    """
    from numpy import asarray, zeros, product, isscalar
    arr = asarray(arr)
    nd = arr.ndim
    if axis < 0:
        axis += nd
    if (axis >= nd):
        raise ValueError("axis must be less than arr.ndim; axis=%d, rank=%d."
            % (axis,nd))
    ind = [0]*(nd-1)
    i = zeros(nd,'O')
    indlist = range(nd)
    indlist.remove(axis)
    i[axis] = slice(None,None)
    outshape = asarray(arr.shape).take(indlist)
    i.put(indlist, ind)
    res = func1d(arr[tuple(i.tolist())],*args)
    #  if res is a number, then we have a smaller output array
    if isscalar(res):
        outarr = zeros(outshape,asarray(res).dtype)
        outarr[tuple(ind)] = res
        Ntot = product(outshape)
        k = 1
        while k < Ntot:
            # increment the index
            ind[-1] += 1
            n = -1
            while (ind[n] >= outshape[n]) and (n > (1-nd)):
                ind[n-1] += 1
                ind[n] = 0
                n -= 1
            i.put(indlist,ind)
            res = func1d(arr[tuple(i.tolist())],*args)
            outarr[tuple(ind)] = res
            k += 1
        return outarr
    else:
        Ntot = product(outshape)
        holdshape = outshape
        outshape = list(arr.shape)
        outshape[axis] = len(res)
        outarr = zeros(outshape,asarray(res).dtype)
        outarr[tuple(i.tolist())] = res
        k = 1
        while k < Ntot:
            # increment the index
            ind[-1] += 1
            n = -1
            while (ind[n] >= holdshape[n]) and (n > (1-nd)):
                ind[n-1] += 1
                ind[n] = 0
                n -= 1
            i.put(indlist, ind)
            res = func1d(arr[tuple(i.tolist())],*args)
            outarr[tuple(i.tolist())] = res
            k += 1
        return outarr

def log_sum(trial_log_L_by_m2):
    """
    Return max ln(sum L), the detection statistic for the loudest event.

    NB: After a lot of experimentation, ln(sum L) turns out to optimize the ROC
    curve.  It is also the Bayesian answer.
    """
    if (trial_log_L_by_m2 == GRBL_EMPTY_TRIAL).all():
        return GRBL_EMPTY_TRIAL

    # to avoid overflow, factor out common part of exponential; underflows OK!
    # ln(sum exp(ln L_i)) = ln(exp(C) * sum exp(ln L_i - C))
    #                     = C + sum(exp(ln L_i - C))
    common = trial_log_L_by_m2.max()
    return numpy.log(numpy.exp(trial_log_L_by_m2 - common).sum()) + common

def max_L(trial_log_L_by_mc_m2):
    """
    Return the loudest L(m2) among the candidates that actually exist.  If no
    candidate is in the trial, return GRBL_EMPTY_TRIAL.
    """
    candidate_mask = trial_log_L_by_mc_m2[:, 0] != GRBL_EMPTY_TRIAL
    if not candidate_mask.any():
        return trial_log_L_by_mc_m2[0, :]  # all are empty; return any one
    else:
        return trial_log_L_by_mc_m2[candidate_mask, :].max(axis=0)

def parse_args():
    parser = optparse.OptionParser(version=git_version.verbose_msg)

    # cache input
    parser.add_option("--relic-onsource", help="output of pylal_relic "\
        "containing the onsource loudest coincs.")
    parser.add_option("--relic-offsource", help="output of pylal_relic "\
        "containing the offsource loudest coincs.")
    parser.add_option("--relic-injections", help="output of pylal_relic "\
        "containing the loudest injection coincs.")

    # odds and ends
    parser.add_option("--reweight-D", action="store_true", default=False,
        help="enable a reweighting of injections from a uniform distribution "\
        "in D to a uniform distribution in D^3.")
    parser.add_option("--user-tag", help="a tag with which to label outputs")
    parser.set_defaults(onmismatch="raise")
    parser.add_option("--version-mismatch-ok", action="store_const",
        const="warn", dest="onmismatch", help="by default, the program "\
        "will halt if the version id stored in the pickle does not match the "\
        "running version id; specify this option to reduce to a warning")
    parser.add_option("--verbose", action="store_true", default=False,
        help="extra information to the console")

    options, arguments = parser.parse_args()

    # check that mandatory switches are present
    for opt in ("relic_onsource", "relic_offsource", "relic_injections"):
        if getattr(options, opt) is None:
            raise ValueError, "--%s is required" % opt.replace("_", "-")

    return options, arguments

################################################################################
# parse arguments
opts, args = parse_args()

##############################################################################
# read everything into memory
# All vetoes have been taken into account

if opts.verbose:
    print "Reading in bin definitions and loudest statistics..."
id, statistic, mc_bins, mc_ifo_cats, onsource_loudest_by_cat \
    = pickle.load(open(opts.relic_onsource))
git_version.check_match(id, onmismatch=opts.onmismatch)
id, statistic, mc_bins, mc_ifo_cats, offsource_loudest_by_trial_cat \
    = pickle.load(open(opts.relic_offsource))
git_version.check_match(id, onmismatch=opts.onmismatch)
id, statistic, mc_bins, mc_ifo_cats, m2_bins, D_bins, m2_D_by_inj, inj_loudest_by_inj_cat \
    = pickle.load(open(opts.relic_injections))
git_version.check_match(id, onmismatch=opts.onmismatch)
if statistic not in ("effective_snr", "new_snr"):
    raise NotImplementedError

# compute useful quantities
num_inj = len(inj_loudest_by_inj_cat)
num_sims_by_m2 = numpy.zeros(len(m2_bins), dtype=int)
for m2_D in m2_D_by_inj:
    num_sims_by_m2[m2_bins[m2_D[0]]] += 1
m2_bin_by_inj = [m2_bins[m2_D[0]] for m2_D in m2_D_by_inj]

##############################################################################
# Initialize background and foreground estimation functions
# NB: start extrapolating 10% from the end of the tail for robustness.
num_offsource_trials = offsource_loudest_by_trial_cat.shape[0]
compute_log_pc0_by_cat = SimpleExpExtrapEstimator(offsource_loudest_by_trial_cat, \
                            offsource_loudest_by_trial_cat, 0.3)

# Example of how to switch to constant extrapolator
# compute_log_pc0_by_mc = ConstantBackgroundEstimator(\
#     offsource_loudest_by_trial_mc,
#     0.5 / offsource_loudest_by_trial_mc.shape[0])

# Compute weighting of uniform in distance to uniform in volume
if opts.reweight_D:
    if opts.verbose:
        print "Computing D to D^3 reweighting..."
    weight_by_inj = [m2_D[1]**2 for m2_D in m2_D_by_inj]
else:
    weight_by_inj = None

compute_log_pch_by_cat_m2 = ConstantForegroundEstimator(\
            inj_loudest_by_inj_cat, num_sims_by_m2, m2_bin_by_inj,
            weight_by_inj=weight_by_inj)

################################################################################
# compute the likelihoods of the loudest on-source coincidences
if opts.verbose:
    print "Computing on-source likelihoods..."
log_pc0_by_mc    = compute_log_pc0_by_cat(onsource_loudest_by_cat)
log_pch_by_mc_m2 = compute_log_pch_by_cat_m2(onsource_loudest_by_cat)

# NB: Here are examples of how to swap pure effective SNR or log IFAR.
# Just remember to do it for off-source and injection cases as well.

# log_L_by_mc_m2 = onsource_loudest_by_mc[:, None].repeat(len(m2_bins), axis=1)
# log_L_by_mc_m2 = -log_pc0_by_mc[:, None].repeat(len(m2_bins), axis=1)
log_L_by_mc_m2 = log_pch_by_mc_m2 - log_pc0_by_mc[:, None]

# mark non-candidates
actual_candidate_mask = onsource_loudest_by_cat > 0
log_L_by_mc_m2[~actual_candidate_mask, :] = GRBL_EMPTY_TRIAL

# for upper limit, get the max L in each m2 bin; only from actual candidates
log_L_by_m2 = max_L(log_L_by_mc_m2)

# for detection, find the log(sum L) of the loudest candidate; sum over m2
log_sum_L = apply_along_axis(log_sum, 1, log_L_by_mc_m2).max()

# write the on-source likelihoods to disk
onsource_outname = __prog__ + "_onsource"
if opts.user_tag is not None:
    onsource_outname += "_" + opts.user_tag
onsource_outname += ".pickle"
pickle.dump((git_version.id, log_pc0_by_mc, log_pch_by_mc_m2,
             actual_candidate_mask, log_L_by_m2, log_sum_L),
            open(onsource_outname, "wb"), -1)

################################################################################
# Compute off-source likelihoods
if opts.verbose:
    print "Computing off-source likelihoods..."

offsource_pc0_by_trial_cat = amap(\
    compute_log_pc0_by_cat, offsource_loudest_by_trial_cat)
offsource_pch_by_trial_cat_m2 = amap(\
    compute_log_pch_by_cat_m2, offsource_loudest_by_trial_cat)
offsource_log_L_by_trial_cat_m2 = offsource_pch_by_trial_cat_m2 \
    - offsource_pc0_by_trial_cat[:, :, None]

# mark non-candidates
offsource_log_L_by_trial_cat_m2[offsource_loudest_by_trial_cat == 0., :] = \
    GRBL_EMPTY_TRIAL

# for upper limit, get the max L in each m2 bin; only from actual candidates
offsource_log_L_by_trial_m2 = amap(max_L, offsource_log_L_by_trial_cat_m2)

# for detection, find the log(sum L) of the loudest candidate; sum over m2
offsource_log_sum_L_by_trial = \
    apply_along_axis(log_sum, 2, offsource_log_L_by_trial_cat_m2)\
    .max(axis=1)

# write the off-source likelihoods to disk
offsource_outname = onsource_outname.replace("_onsource", "_offsource", 1)
pickle.dump((git_version.id, offsource_pc0_by_trial_cat,
    offsource_pch_by_trial_cat_m2, offsource_log_L_by_trial_m2,
    offsource_log_sum_L_by_trial), open(offsource_outname, "wb"), -1)

################################################################################
# Compute injection likelihoods
# FIXME: This can be done in O(N_{inj}) time because L is monotonic with SNR
# within an mc bin, so a threshold in L can be converted to a threshold in SNR
# for each mc bin.
if opts.verbose:
    print "Computing injection likelihoods..."

inj_log_pc0_by_trial_cat = \
    amap(compute_log_pc0_by_cat, inj_loudest_by_inj_cat)
inj_log_pch_by_trial_cat_m2 = \
    amap(compute_log_pch_by_cat_m2, inj_loudest_by_inj_cat)
inj_log_L_by_trial_cat_m2 = \
    inj_log_pch_by_trial_cat_m2 - inj_log_pc0_by_trial_cat[:, :, None]

# mark non-candidates
inj_log_L_by_trial_cat_m2[inj_loudest_by_inj_cat == 0., :] = GRBL_EMPTY_TRIAL

# for upper limit, get the max L in each m2 bin; only from actual candidates
inj_log_L_by_trial_m2 = amap(max_L, inj_log_L_by_trial_cat_m2)

# for detection, find the log(sum L) of the loudest candidate; sum over m2
inj_log_sum_L_by_trial = \
    apply_along_axis(log_sum, 2, inj_log_L_by_trial_cat_m2)\
    .max(axis=1)

# write the injection likelihoods to disk
inj_outname = onsource_outname.replace("_onsource", "_injections", 1)
pickle.dump((git_version.id, inj_log_pc0_by_trial_cat,
             inj_log_pch_by_trial_cat_m2, inj_log_L_by_trial_m2,
             inj_log_sum_L_by_trial), open(inj_outname, "wb"), -1)

