#! /usr/bin/python
""" This program calculates an extended background
estimate from ihope results where chisq is calculated at the first stage.
Results are output in the current directory
"""
import optparse
import glob
import numpy
import copy

from numpy import sqrt
from itertools import chain, combinations
from operator import attrgetter
from glue.ligolw import ligolw, lsctables, utils
from glue import segments

from pylal import gatherloudtrigs
from pylal import extended_background_utils as loud_event_coincs

class DefaultContentHandler(lsctables.ligolw.LIGOLWContentHandler):
    pass
lsctables.use_in(DefaultContentHandler)

name_from_cat = {2:'CAT_2', 3:'CAT_3', 4:'CAT_4'}

def get_cat_from_name(fname):
    if fname.find("CAT4") != -1 or fname.find("CAT_4") != -1 or fname.find("CATEGORY_4") != -1:
        return 4
    elif fname.find("CAT3") != -1 or fname.find("CAT_3") != -1 or fname.find("CATEGORY_3") != -1:
        return 3
    elif fname.find("CAT2") != -1 or fname.find("CAT_2") != -1 or fname.find("CATEGORY_2") != -1:
        return 2
    else:
        return None

def gather_zerofar(file_glob):
    """
    Check for zero far events in the loudest events summaries

    This function looks through the loudest event summaries for zero far
    events. This includes doubles and triples. For triples it takes the
    loudest two component single triggers, as we calculate the extended
    background estimate only from double coincidences.

    A loudest_event table is returned with only the zero far events. Only the
    data for the two loudest single detector events is kept.
    """

    input_files = glob.glob(file_glob)

    output_doc=ligolw.Document()
    output_doc.appendChild(ligolw.LIGO_LW())

    events = {}
    cats = []

    # Read through the glob of loudest event summaries and look for zero far
    # events.
    have_zerofar_table = False
    for fname in input_files:
        cat_level = get_cat_from_name(fname)

        xmldoc = utils.load_filename(fname, contenthandler=DefaultContentHandler, gz=fname.endswith("gz") )
        loudest = lsctables.table.get_table(xmldoc, "loudest_events")

        # hack to create an empty table of this custom table type
        if have_zerofar_table is False:
            zerofar_table = copy.deepcopy(loudest)
            have_zerofar_table = True
            while len(zerofar_table) > 0:
                del zerofar_table[0]

        # If we find a zero far event make a list of entries (one for 
        # each sngl trigger), store the lists in a dictionary indexed by 
        # coinc event id
        for event in loudest:
            if event.combined_far == 0:
                if event.coinc_event_id not in events:
                    events[event.coinc_event_id] = [event]
                else:
                    events[event.coinc_event_id].append(event)
                cats.append(cat_level)

    print "There are %s zero far events" % len(events)

    # We currently only calculate coincs from double events so for 
    # triples (quadruples, ..) we remove all but the 2 loudest detectors
    for event_trigs in events.itervalues():
        # sort in descending order of SNR, take the last 2
        event_trigs.sort(key=attrgetter('sngl_snr'), reverse=True)
        #print 'Sorted trigger SNRs :', [trg.sngl_snr for trg in event_trigs]
        for trg in event_trigs[0:2]:
            zerofar_table.append(trg)
    return zerofar_table, cats


def create_combined_trig_files(base, ifos, cats=[], new_snr_cut=0.):
    """
    Return a dictionary containing, for each ifo, the single detector
    full data triggers above the specified newsnr cut, and the single
    detector livetime, both after applying vetoes at each cumulative CAT
    level specified in cats.
    If cats is empty, apply all CAT vetoes for which a
    {ifo}-COMBINED_CAT_*_VETO_SEGS-*.txt file exists.
    """
    trig_data = {}
    for ifo in ifos:
        trig_data[ifo] = {}
        fnames =  glob.glob(base + "/full_data/" + str(ifo) + "-INSPIRAL_FULL_DATA*.xml.gz")
        veto_files = glob.glob(base + "/segments/" + str(ifo) + "-LIGOLW_COMBINE_SEGMENTS_VETO_CAT*_CUMULATIVE-*-*.xml")
        trigs, livetime = gatherloudtrigs.get_loud_trigs(fnames, veto_files, new_snr_cut)
        for trig, lv, vf in zip(trigs, livetime, veto_files):
            cat_level = get_cat_from_name(vf)
            if (len(cats) == 0) or (cat_level in cats):
                trig_data[ifo][cat_level] = (trig, lv)
            else: continue
    return trig_data


def calc_background_coincs(ztrig, cats, single_trigs, ethinca, slide,
                           veto_window, param_range, coinc_threshold):
    """
    Get the background coincidences for each zero-FAR event.  Each event
    has two corresponding triggers in the lists ztrig and cats which must 
    be consecutive entries.
    """
    trig_csnrs = []
    trig_vcsnrs = []
    for ind in range(len(ztrig)/2):
        t1 = ztrig[ind*2]
        t2 = ztrig[ind*2+1]
        cat = cats[ind*2]

        d1_trigs, lv1 = single_trigs[ztrig[0].sngl_ifo][cat]
        d2_trigs, lv2 = single_trigs[ztrig[1].sngl_ifo][cat]
        print 'ifos', ztrig[0].sngl_ifo, ztrig[1].sngl_ifo, 'at cat', cat, 'veto'

        mchirp = 0.5 * (t1.sngl_mchirp + t2.sngl_mchirp)
        mchirp_bins = loud_event_coincs.parse_mchirp_bins(param_range)
        min_mchirp, max_mchirp = loud_event_coincs.get_mchirp_bin(mchirp, mchirp_bins)

        # We only consider coincs above this threshold
        threshold = t1.snr * coinc_threshold

        print "Trigger at snr %s " % t1.snr
        print "Using threshold at snr %s" % threshold

        # Get the list of coincs
        coincs = loud_event_coincs.find_slide_coincs(d1_trigs, d2_trigs,
                            min_mchirp, max_mchirp, ethinca, slide, threshold)

        # Get the list of coincs after removing those caused by
        # events around the zero far trigger
        veto_start = t1.end_time - veto_window
        veto_end = t1.end_time + veto_window
        veto_coincs = loud_event_coincs.get_veto_coincs(coincs, veto_start, veto_end)

        coinc_snr = coincs.get_new_snr()
        veto_coinc_snr = veto_coincs.get_new_snr()

        # Now that we have the coincs, calculate the new snr values
        csnr = []
        vcsnr = []
        for i in range(len(coinc_snr)/2):
            csnr.append(sqrt(coinc_snr[i*2]**2 + coinc_snr[i*2+1]**2))
        for i in range(len(veto_coinc_snr)/2):
            vcsnr.append(sqrt(veto_coinc_snr[i*2]**2 + veto_coinc_snr[i*2+1]**2))

        # Get the livetime before and after considering the time removed
        # around the trigger
        bkgtime = loud_event_coincs.get_background_livetime(lv1, lv2, slide)
        vbkgtime = loud_event_coincs.get_background_livetime(lv1, lv2, slide, veto_window)
        print 'background livetime for slide step', slide, 'is', bkgtime, vbkgtime, 'before/after vetoing candidate time'

        trig_csnrs.append((csnr, bkgtime))
        trig_vcsnrs.append((vcsnr, vbkgtime))

    return trig_csnrs, trig_vcsnrs

def all_subsets(ss):
  return chain(*map(lambda x: combinations(ss, x), range(0, len(ss)+1)))

def find_in_list(l, s):
    """
    For each string in a sequence of strings l, identify the
    first that contains the string s and return that entry.
    If no string contains s, return None.
    """
    outstr = None
    for str in l:
        if str.find(s) != -1:
            outstr = str
            break
    return outstr

def make_result_pages(opts, result_database, zerofar_trigs, cats,
                      bgtrigs, vbgtrigs):
    """Create a result page for each zero far event
    """
    from pylal import lalapps_cbc_plotrates
    from itertools import combinations

    #Standard Args passed to lalapps_cbc_plotrates
    sargs = []
    sargs += ['--data-table', 'coinc_inspiral']
    sargs += ['--single-table', 'sngl_inspiral']
    sargs += ['--output-path', './']
    sargs += ['--statistic', 'snr:Threshold $\\rho_c$']
    sargs += ['--foreground-datatype', 'all_data']
    sargs += ['--param-name', 'mchirp']  #FIXME do not hardcode
    sargs += ['--lin-x']
    sargs += ['--nbins', '50']
    sargs += ['--plot-significance-bands']
    sargs += ['--max-sigma', '5']
    sargs += ['--dpi', '400', '--verbose']

    for i in range(len(zerofar_trigs)/2):
        print i, len(zerofar_trigs)/2 , "HI", range(len(zerofar_trigs)/2)
        t1 = zerofar_trigs[i*2]
        t2 = zerofar_trigs[i*2+1]
        cat_level = cats[i*2]

        # Store the background new snr values into text files that the
        # plotting program understands. The first line is the livetime
        # in seconds. The remaining lines are the new snr values of
        # coincs.
        veto_event_tag = name_from_cat[cat_level] + '_' + str(i)

        bg_file = 'LOUD_EVENT_BACKGROUND_' + veto_event_tag + '.txt'
        vbg_file = 'LOUD_EVENT_VETOBACKGROUND_' + veto_event_tag + '.txt'
        background, livetime = bgtrigs[i]
        vbackground, vlivetime = vbgtrigs[i]
        numpy.savetxt(bg_file, numpy.append(numpy.array(livetime), background))
        numpy.savetxt(vbg_file, numpy.append(numpy.array(vlivetime), vbackground))

        # Get the database that contains the trigger results
        dat_name = find_in_list(glob.glob(result_database), name_from_cat[cat_level])

        ifos = t1.ifos.split(',')
        extra_instr = t1.ifos.split(',') + ['V1', 'L1', 'H1']

        ifo1 = t1.sngl_ifo
        ifo2 = t2.sngl_ifo

        extra_instr.remove(ifo1)
        extra_instr.remove(ifo2)

        # We need to grab coincs from all combinations of times, so we
        # construct the instrument sets
        instrument_set = all_subsets(extra_instr)
        coinc_set = ""
        for on in instrument_set:
            print on, instrument_set
            on_ifos = [ifo1,ifo2] + list(on)
            on_str = ','.join(on_ifos)

            if coinc_set == "":
                pass
            else:
                coinc_set += ";"

            coinc_set += "[" + ifo1 + "," + ifo2 + " in " + on_str + "]"

        mchirp = 0.5 * (t1.sngl_mchirp + t2.sngl_mchirp)
        mchirp_bins = loud_event_coincs.parse_mchirp_bins(opts.param_ranges)
        min_mchirp, max_mchirp = loud_event_coincs.get_mchirp_bin(mchirp, mchirp_bins)
        mchirp_bin_str = '[' + str(min_mchirp) + ',' + str(max_mchirp) + ')'

        # Event specific arguments
        args = [dat_name] + sargs 
        args += ['--param-ranges', mchirp_bin_str]
        args += ['--include-only-coincs', coinc_set]
        args += ['--plot-special-time', str(t1.end_time)]
        args += ['--plot-special-window', str(opts.veto_window)]
        args += ['--xmin', opts.xmin]
        args += ['--user-tag', veto_event_tag]
        args += ['--add-background-file', bg_file]
        args += ['--add-special-background-file', vbg_file]

        # Make plot
        print args
        lalapps_cbc_plotrates.main(args)

def main():
    parser = optparse.OptionParser()
    parser.add_option("--loudest-event-glob",
                    action="store", type="string",
                    help="Glob of xml files with the loudest event summaries. REQUIRED")
    parser.add_option("-c","--new-snr-cut",
                    action="store", type="float", default=6., 
                    help="newSNR threshold to retain single-ifo triggers. Default=6.0")
    parser.add_option("-e","--e-thinca-parameter", 
                    dest="ethinca", action="store", type="float", default=0.5, 
                    help="Ethinca threshold for coincidence test. Default=0.5")
    parser.add_option("-a","--xmin", action="store", type="float", default=8.,
                    help="Set a minimum value for the x-axis. Typically will be "
                         "close to the combined SNR value at the trigger generation "
                         "threshold. Default=8.")
    parser.add_option("--coinc-threshold",
                    action="store", type="float",
                    help="Fraction of the zero-FAR event combined SNR above which to"
                         " start looking for coinc triggers. REQUIRED")
    parser.add_option("-v","--veto-window",
                    action="store", type="float",
                    help="The amount of time before and after a zero-FAR trigger to veto. REQUIRED")
    parser.add_option("--param-ranges",
                    action="store", type="string", default=0, 
                    help="Chirp mass bins used in determining the combined FAR, e.g. "
                    "'[0,1.3);[1.3,3]'. Default=0")
    parser.add_option("-s","--slide-step",
                    action="store", type="float", default=5.,
                    help="Time slide step in seconds. Default=5.0")
    parser.add_option("--ihope-base-dir",
                    action="store", type="string",
                    help="The ihope base directory to obtain trigger files for the analysis. "
                         "Should be of the form '/path/to/dir/GPSSTART-GPSEND'. REQUIRED")

    (opts,args) = parser.parse_args()

    # Check options
    if opts.loudest_event_glob is None:
        raise ValueError("Cannot proceed without a loudest-event glob!")

    print "Getting zero-FAR triggers"
    zerofar_trigs, cats = gather_zerofar(opts.loudest_event_glob)
    if zerofar_trigs is None:
        print "No zero-FAR triggers found, exiting ..."
        exit()

    print "Collecting the sngl ifo triggers above threshold"
    ifos = set(list(zerofar_trigs.getColumnByName('sngl_ifo')))
    single_trigs = create_combined_trig_files(opts.ihope_base_dir, ifos, set(cats), opts.new_snr_cut)

    print "Calculating the background coincidences"
    bgtrigs, vbgtrigs = calc_background_coincs(zerofar_trigs, cats,
                           single_trigs, opts.ethinca, opts.slide_step,
                           opts.veto_window, opts.param_ranges,
                           opts.coinc_threshold)

    print "Making the result pages"
    results_glob = opts.ihope_base_dir + "/workflow_postproc/*PYCBCCFAR_CUMULATIVE_CAT_*sqlite"
    make_result_pages(opts, results_glob, zerofar_trigs, cats,
                      bgtrigs, vbgtrigs)

if __name__ == "__main__":
    main()

