#!/usr/bin/python

# Copyright (C) 2012 Duncan Macleod
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""Plot zero-lag triggers from a coherent inspiral analysis.
"""

from __future__ import division

import sys
import os
import optparse
import datetime
import time
import numpy
import re
import warnings
warnings.filterwarnings("ignore", message="Module dap was already imported")
warnings.filterwarnings("ignore", message="column name (.*) is not lower case")

# set backend for non-interactive plotting
if not os.getenv("DISPLAY", None):
    import matplotlib
    matplotlib.use("agg", warn=False)


from pylal import (antenna, date, git_version, plotutils, MultiInspiralUtils,
                   coh_PTF_pyutils, htmlutils, ligolw_tisi)
from pylal.dq import dqSegmentUtils
from lal import GPSTimeNow
from pylal.xlal.datatypes.ligotimegps import LIGOTimeGPS

from glue import iterutils, markup, segments, lal as cache
from glue.ligolw import ilwd, ligolw, table, lsctables, utils as ligolw_utils
from glue.ligolw.utils import (process as ligolw_process, ligolw_add,
                               search_summary as ligolw_search_summary)
lsctables.MultiInspiralTable.veto = coh_PTF_pyutils.veto
lsctables.MultiInspiralTable.vetoed = coh_PTF_pyutils.vetoed

# set up metadata
__author__ = "Duncan Macleod <duncan.macleod@ligo.org>"
__version__ = git_version.id
__date__ = git_version.date

# set up timer
start = time.time()
elapsed_time = lambda: time.time()-start

# globals
VERBOSE = False
PROFILE = False
FOREGROUND = "foreground"
BACKGROUND = "background"
SIMULATIONS = "simulations"
found_inj = "found_injections"
missed_inj = "missed_injections"
SEARCHES = [BACKGROUND, FOREGROUND, found_inj]
MARKERS = ["o", "x", "+"]
MARKER_EC = ["k", "g", "r"]
MARKER_FC = ["b", "g", "r"]
MARKER_LW = [1, 2, 2]
GRAY = "#dddddd"
SNR_LIM = [5, 100]
SNGL_SNR_LIM = [1, 200]
BESTNR_LIM = [5, 30]
CHISQ_LIM = [8, 2000]
BANK_CHISQ_LIM = [8, 2000]
CONT_CHISQ_LIM = [10, 5000]
NULL_SNR_LIM = [0, 30]

# conveniences
latex = plotutils.display_name
asarray = numpy.asarray
if sys.version_info[0] <= 2:
    range = xrange


def print_verbose(message, verbose=True, stream=sys.stdout, profile=True):
    """Print verbose messages to a file stream.

    @param message
        text to print
    @param verbose
        flag to print or not, default: False (don"t print)
    @param stream
        file object stream in which to print
    @param profile
        flag to print timestamp, default: False
    """
    if stream != sys.stderr:
        profile &= PROFILE
        verbose &= VERBOSE
    if profile and message.endswith("\n"):
        message = "%s (%.2f)\n" % (message.rstrip("\n"), elapsed_time())
    if verbose:
        stream.write(message)
        stream.flush()


def read_ligolw(lalcache):
    """Read a cache of LIGO_LW files into a single xmldoc

    @param lalcache
        LAL-format cache object (GLUE) from which to read
    """
    xmldoc = ligolw.Document()
    ligolw_add.ligolw_add(xmldoc, lalcache.pfnlist())
    time_slide_table = table.get_table(xmldoc,
                                       lsctables.TimeSlideTable.tableName)
    time_slide_mapping = ligolw_tisi.time_slides_vacuum(
                             time_slide_table.as_dict())
    iterutils.inplace_filter(lambda row: row.time_slide_id not in
                             time_slide_mapping.keys(), time_slide_table)
    for tbl in xmldoc.getElementsByTagName(ligolw.Table.tagName):
        tbl.applyKeyMapping(time_slide_mapping)
    return xmldoc


def read_found_missed(xmldoc, loudest_by="snr"):
    """Read MultiInspiralTables and SimInspiralTables from a cache object

    @param lalcache
        LAL-format cache object (GLUE) from which to read
    @param loudest_by
        column name by which to select most significant event in the case
        of multiple MultiInspiral events mapped to one SimInspiral
    """
    # get coinc_table
    coinc_table = table.get_table(xmldoc, lsctables.CoincTable.tableName)
    coinc_map_table = table.get_table(xmldoc, lsctables.CoincMapTable.tableName)
    all_sims = table.get_table(xmldoc, lsctables.SimInspiralTable.tableName)
    all_multis = table.get_table(xmldoc, lsctables.MultiInspiralTable.tableName)
    found_sims = table.new_from_template(all_sims)
    found_multis = table.new_from_template(all_multis)
    missed_sims = table.new_from_template(all_sims)
    for coinc in coinc_table:
        sim_ids = [row.event_id for row in coinc_map_table if
                      row.coinc_event_id == coinc.coinc_event_id and
                      row.table_name == "sim_inspiral"]
        if len(sim_ids) > 1:
            raise ValueError("More than one SimInspiral associated with this "
                             "found injection coinc.")
        found_sims.append([row for row in all_sims if
                           row.simulation_id == sim_ids[0]][0])
        multi_ids = [row.event_id for row in coinc_map_table if
                     row.coinc_event_id == coinc.coinc_event_id and
                     row.table_name == "multi_inspiral"]
        multis = [row for row in all_multis if row.event_id in multi_ids]
        if len(multis) > 1:
            if hasattr(lsctables.MultiInspiral, "get_%s" % loudest_by):
                rank = getattr(lsctables.MultiInspiral, "get_%s" % loudest_by)
            else:
                rank = lambda x: getattr(x, loudest_by)
            multis.sort(key=lambda row: rank(row), reverse=True)
        found_multis.append(multis[0])
    found_ids = [row.simulation_id for row in found_sims]
    missed_sims.extend(filter(lambda row: row.simulation_id not in found_ids,
                              all_sims))
    if len(found_sims) + len(missed_sims) != len(all_sims):
        raise RuntimeError("%d found and %d missed identified from a total "
                           "of %d. Oops." % (len(found_sims), len(missed_sims),
                                             len(all_sims)))
    print_verbose("%d found injections identified.\n" % len(found_sims),
                  profile=False)
    print_verbose("%d missed injections identified.\n" % len(missed_sims),
                  profile=False)
    return found_multis, found_sims, missed_sims
    return mi_found,sim_found,sim_missed


def separate_time_slides(mi_table, return_index=False):
    """Returns a list of MultiInspiralTables representing the events for
    each in a list of time slides.

    If return_index=True, the return is a list of boolean arrays
    identifying the elements in mi_table for each slide.
    """
    # find time slides
    slide_ids = list(map(str, set(mi.time_slide_id for mi in mi_table)))
    if return_index:
        slides = []
        mi_time_slide_ids = numpy.asarray([str(t.time_slide_id) for
                                           t in mi_table])
        for id_ in slide_ids:
            slides.append(mi_time_slide_ids == id_)
        return slides
    else:
        slides = []
        for id_ in slide_ids:
            slide_table = table.new_from_template(mi_table)
            slide_table.extend(t for t in mi_table if
                               str(t.time_slide_id) == id_)
            slides.append(slide_table)
        return slides


def write_multi_inspiral_table(mi_table, outfile, search):
    """Write MultiInspiralTable to file
    """
    xmldoc = ligolw.Document()
    xmldoc.appendChild(ligolw.LIGO_LW())
    process = ligolw_process.append_process(xmldoc, program=__file__,
                                            version=__version__,
                                            comment=("Loudest events in "
                                                     "coh_PTF inspiral %s "
                                                     "analysis" % search),
                                            ifos=mi_table[0].get_ifos())
    xmldoc.childNodes[-1].appendChild(
        lsctables.New(lsctables.SearchSummaryTable))
    ligolw_search_summary.append_search_summary(xmldoc, process)
    process.end_time = int(GPSTimeNow())
    xmldoc.childNodes[-1].appendChild(mi_table)
    ligolw_utils.write_filename(xmldoc, outfile, gz=outfile.endswith(".gz"))


def write_sim_inspiral_table(sim_table, outfile, search):
    """Write SimInspiralTable to file
    """
    xmldoc = ligolw.Document()
    xmldoc.appendChild(ligolw.LIGO_LW())
    process = ligolw_process.append_process(xmldoc, program=__file__,
                                            version=__version__,
                                            comment=("Loudest events in "
                                                     "coh_PTF inspiral %s "
                                                     "analysis" % search))
    xmldoc.childNodes[-1].appendChild(
        lsctables.New(lsctables.SearchSummaryTable))
    ligolw_search_summary.append_search_summary(xmldoc, process)
    process.end_time = int(GPSTimeNow())
    xmldoc.childNodes[-1].appendChild(sim_table)
    ligolw_utils.write_filename(xmldoc, outfile, gz=outfile.endswith(".gz"))


def plot_xy(outfile, xdata, ydata, xlabel, ylabel, xlim=None, ylim=None,
           logx=False, logy=False, fill=None, plot_contours=None,
           fill_above=None, fill_below=None, line=None):
    plot = plotutils.ScatterPlot(xlabel, ylabel)
    for key,marker,lc,lw in zip(SEARCHES, MARKERS, MARKER_EC, MARKER_LW):
        if not xdata.has_key(key):
           continue
        plot.add_content(xdata[key], ydata[key], label=latex(key),
                         marker=marker, edgecolor=lc, linewidth=lw)
    if plot_contours:
        coh_PTF_pyutils.plot_contours(plot.ax, contours["snr"],
                                      contours[plot_contours],
                                      contours["color"])

    if fill_above:
        x, y = fill_above
        plot.ax.plot(x, y, "k-")
        ymax = ylim and ylim[1] or plot.ax.get_ylim()[1]
        polyx = numpy.concatenate((x, [max(x), min(x)]))
        polyy = numpy.concatenate((y, [ymax, ymax]))
        plot.ax.fill(polyx, polyy, color='#dddddd', alpha=0.8)
    if fill_below:
        x, y = fill_below
        plot.ax.plot(x, y, "k-")
        ymin = ylim and ylim[0] or plot.ax.get_ylim()[0]
        polyx = numpy.concatenate((x, [max(x), min(x)]))
        polyy = numpy.concatenate((y, [ymin, ymin]))
        plot.ax.fill(polyx, polyy, color='#dddddd', alpha=0.5)
    if line:
        x, y = line
        plot.ax.plot(x, y, "k--")

    plot.finalize()
    plotutils.add_colorbar(plot.ax, visible=False)
    if logx:
        plot.ax.set_xscale("log")
    if logy:
        plot.ax.set_yscale("log")
    if xlim:
        plot.ax.set_xlim(*xlim)
    if ylim:
        plot.ax.set_ylim(*ylim)
    plot.savefig(outfile)
    plot.close()


def plot_hist(outfile, data, xlabel, ylabel, xlim=None, ylim=None, logx=False,
             logy=False, num_bins=50):
    # make plot and add background
    plot = plotutils.CumulativeHistogramPlot(xlabel, ylabel)
    ymax = []
    for key,color in zip(SEARCHES, MARKER_FC):
       if not data.has_key(key):
           continue
       if key == BACKGROUND:
           if isinstance(data[key], numpy.ndarray):
               data[key] = [data[key]]
           plot.add_background(data[key], label=latex(key))
           ymax.extend(map(len, data[key]))
       else:
           plot.add_content(data[key], label=latex(key), linestyle="",
                            marker="^")
           ymax.append(len(data[key]))
    plot.finalize(num_bins=num_bins)
    if xlim:
        plot.ax.set_xlim(*xlim)
    plot.savefig(outfile)
    plot.close()


def plot_found_missed(outfile, found_xy, missed_xy, nonzerofap_xyz,
                      xlabel, ylabel, xlim=None, ylim=None,
                      logx=False, logy=True):
    """Plot injections as found/missed/coloured based on FAP.
    """
    plot = plotutils.ColorbarScatterPlot(xlabel, ylabel,
                                         "False alarm probability (FAP)")
    plot.ax.scatter(*missed_xy, edgecolor="r", marker="x")#, label="Missed")
    plot.ax.scatter(*found_xy, edgecolor="k", marker="x")#, label="Found")
    if len(nonzerofap_xyz[0]):
        plot.add_content(*nonzerofap_xyz,
                         edgecolors='none')#, label="Non-zero FAP")
    else:
        plot.add_content([1e-100], [1e-100], [0.5], edgecolors="none",
                         visible=True)#, label="Non-zero FAP")
    if not xlim:
        xlim = [0.8, 2.7]
    ylim = [0.8,500]
    plot.finalize(clim=[0,1])
    if logx:
        plot.ax.set_xscale("log")
    if logy:
        plot.ax.set_yscale("log")
    if xlim:
        plot.ax.set_xlim(*xlim)
    if ylim:
        plot.ax.set_ylim(*ylim)
    plot.savefig(outfile)
    plot.savefig(outfile.replace(".png", ".pdf"), bbox_inches="tight")
    plot.close()


def parse_threshold(option, opt_str, value, parser):
    """Verify threshold argument is valid on callback.
    """
    if value <= 0:
        parser.error("%s must be positive" % opt_str)
    if opt_str == "--sngl-snr":
        getattr(parser.values, option.dest).append(float(value))
    else:
        setattr(parser.values, option.dest, float(value))


if __name__ == "__main__":

    parser = optparse.OptionParser(description=__doc__,
                                   epilog="For help, ask.ligo.org",
                                   formatter=optparse.IndentedHelpFormatter(4))
    parser.add_option("-p", "--profile", action="store_true", default=False,
                      help="timestamp output, default: %default")
    parser.add_option("-v", "--verbose", action="store_true", default=False,
                      help="show verbose output, default: %default")
    parser.add_option("-V", "--version", action="version",\
                      help="show program's version number and exit")
    parser.version = git_version.verbose_msg

    # input options
    inputopts = parser.add_option_group("Input options")
    inputopts.add_option("--gps-start-time", action="store", type="float",
                         help="GPS start time for plots.")
    inputopts.add_option("--gps-end-time", action="store", type="float",
                         help="GPS end time for plots.")
    inputopts.add_option("-b", "--background-cache", action="append",
                         type="string", default=None, metavar="FILE",
                         help=("path to LAL-format cache file of XML files "
                               "containing MultiInspiralTables from time-slide"
                               " time-slide analysis (no injections). "
                               "Can be given multiple times."))
    inputopts.add_option("-f", "--foreground-cache", action="append",
                         type="string", default=None, metavar="FILE",
                         help=("path to LAL-format cache file of XML files "
                               "containing MultiInspiralTables from zero-lag"
                               "analysis (no injections). "
                               "Can be given multiple times."))
    inputopts.add_option("-y", "--simulation-cache", action="append",
                         type="string", default=None, metavar="FILE",
                         help=("path to LAL-format cache file of XML files "
                               "containing MultiInspiralTables representing "
                               "the results of software injections. "
                               "Can be given multiple times."))

    # sbv options
    sbvopts = parser.add_option_group("Signal-based veto options")
    sbvopts.add_option("-S", "--snr-threshold", action="store",
                       type="float", default=6.0,
                       help=("veto threshold on (new) SNR, default: %default"))
    sbvopts.add_option("-N", "--null-snr-threshold", action="store",
                       type="float", default=6.0,
                       help=("veto threshold on null SNR, default: %default"))
    sbvopts.add_option("-W", "--null-weight-threshold", action="store",
                       type="float", default=4.25,
                       help=("threshold on null SNR above which events get "
                             "down-ranked, default: %default"))
    sbvopts.add_option("-g", "--sngl-snr-threshold", action="store",
                       type="string", default="4.0,4.0",
                       help=("comma-separated list of veto thresholds on "
                             "single-detector SNR, default: %default"))

    # efficiency options
    effopts = parser.add_option_group("Efficiency options",
                                      ("Tunable parameters for efficiency"
                                       "calculation"))
    effopts.add_option("-U", "--upper-inj-dist", action="store", type="float",
                       default=100,
                       help=("upper limit on injection efficiency distance, "
                             "default: %default"))
    effopts.add_option("-L", "--lower-inj-dist", action="store", type="float",
                       default=0, help=("lower limit on injection efficiency "
                                        "distance, default: %default"))
    effopts.add_option("-n", "--num-distance-bins", action="store",
                       type="int", default=20,
                       help=("The number of bins used to calculate"
                             "injection efficiency. default: %default"))

    # output options
    outopts = parser.add_option_group("Output options")
    outopts.add_option("-o", "--output-directory", action="store",
                       type="string", default=os.getcwd(), metavar="DIR",
                       help="output directory for all plots, default: %default")
    outopts.add_option("-t", "--output-tag", action="store", type="string",
                       default="COH_PTF_INSPIRAL", metavar="TAG",
                       help="output tag for all plots, default: %default")

    # parse args, printing help if nothing given
    opts,args = parser.parse_args()
    if len(vars(opts).keys()) == 0:
        parser.print_help()
        sys.exit(0)

    start_time = opts.gps_start_time
    end_time = opts.gps_end_time
    VERBOSE = opts.verbose
    PROFILE = opts.profile
    opts.sngl_snr_threshold = map(float, opts.sngl_snr_threshold.split(","))

    #
    # load multi_inspirals
    #

    full_cache = cache.Cache()

    # list files
    if opts.background_cache is None:
        background = False
    else:
        background = BACKGROUND
        bg_cache = cache.Cache.fromfilenames(opts.background_cache)
        bg_cache.sort(key=lambda e: e.segment[0])
        full_cache.extend(bg_cache)
    if opts.foreground_cache is None:
        foreground = False
    else:
        foreground = FOREGROUND
        fg_cache = cache.Cache.fromfilenames(opts.foreground_cache)
        fg_cache.sort(key=lambda e: e.segment[0])
        full_cache.extend(fg_cache)
    if opts.simulation_cache is None:
        simulations = False
        found_inj = False
        missed_inj = False
    else:
        simulations = SIMULATIONS
        found_inj = found_inj
        missed_inj = missed_inj
        sim_cache = cache.Cache.fromfilenames(opts.simulation_cache)
        sim_cache.sort(key=lambda e: e.segment[0])
        full_cache.extend(sim_cache)

    if len(full_cache) == 0:
        raise optparse.OptionValueError("No files were loaded via "
                                        "--background-cache, "
                                        "--foreground-cache, or "
                                        "--simulation-cache. "
                                        "Please double check.")

    # get start and end times of cache
    if not start_time or not end_time:
        seglist = segments.segmentlist([e.segment for e in
                                        full_cache]).coalesce()
        s,e = seglist.extent()
    if not start_time:
        start_time = int(s)
    if not end_time:
        end_time = int(e)
    duration = end_time - start_time

    # load triggers and injections
    num_slides = 0
    if background:
        print_verbose("Reading background events:\n", profile=False)
        xmldoc = read_ligolw(bg_cache)
        time_slide_table = table.get_table(xmldoc,
                                           lsctables.TimeSlideTable.tableName)
        bg_mi_table = table.get_table(
                              xmldoc, lsctables.MultiInspiralTable.tableName)
        slide_dict = time_slide_table.as_dict()
        num_slides = len(slide_dict)
        print_verbose("%d time slides identified.\n" % num_slides)
    else:
        bg_mi_table = None
    if foreground:
        print_verbose("Reading foreground events:\n", profile=False)
        xmldoc = read_ligolw(fg_cache)
        fg_mi_table = table.get_table(
                             xmldoc, lsctables.MultiInspiralTable.tableName)
    else:
        fg_mi_table = None
    if simulations:
        print_verbose("Reading simulation events:\n", profile=False)
        xmldoc = read_ligolw(sim_cache)
        sim_mi_table, sim_found_inj, sim_missed_inj = read_found_missed(
                                                             xmldoc,
                                                             loudest_by="snr")
    else:
        sim_mi_table = None

    # find ifos
    ifos = (bg_mi_table and bg_mi_table[0].ifos or
            fg_mi_table and fg_mi_table[0].ifos or
            sim_mi_table and sim_mi_table[0].ifos)
    ifos = sorted(lsctables.instrument_set_from_ifos(ifos))
    print_verbose("Interferometer list: %s\n" % (", ".join(ifos)),
                  profile=False)

    #
    # set output directory
    #

    if not os.path.isdir(opts.output_directory):
        os.makedirs(opts.output_directory)
    os.chdir(opts.output_directory)

    #
    # extract data
    #

    # pack MultiInspiralTables
    mi_tables = {}
    sim_tables = {}
    for x,tab in zip([foreground, background, found_inj],
                     [fg_mi_table, bg_mi_table, sim_mi_table]):
        if x:
            mi_tables[x] = tab
    if simulations:
        sim_tables[found_inj] = sim_found_inj
        sim_tables[missed_inj] = sim_missed_inj
    # trigger data
    mi_time = dict((key, asarray(tab.get_end()).astype(float)) for
                   key, tab in mi_tables.iteritems())
    mi_snr = dict((key, asarray(tab.get_column("snr"))) for
                   key, tab in mi_tables.iteritems())
    mi_bestnr = dict((key, asarray(tab.get_bestnr())) for
                     key, tab in mi_tables.iteritems())
    mi_sngl_snr = dict((ifo, dict((key, asarray(tab.get_sngl_snr(ifo))) for
                                  key, tab in mi_tables.iteritems())) for
                       ifo in ifos)
    mi_new_snr = dict((key, asarray(tab.get_new_snr(column="chisq"))) for
                      key, tab in mi_tables.iteritems())
    mi_bank_new_snr = dict((key, asarray(tab.get_new_snr(column="bank_chisq")))
                           for key, tab in mi_tables.iteritems())
    mi_cont_new_snr = dict((key, asarray(tab.get_new_snr(column="cont_chisq")))
                           for key, tab in mi_tables.iteritems())
    mi_chisq = dict((key, asarray(tab.get_column("chisq"))) for
                    key, tab in mi_tables.iteritems())
    mi_bank_chisq = dict((key, asarray(tab.get_column("bank_chisq"))) for
                         key, tab in mi_tables.iteritems())
    mi_cont_chisq = dict((key, asarray(tab.get_column("cont_chisq"))) for
                         key, tab in mi_tables.iteritems())
    mi_null_snr = dict((key, asarray(tab.get_null_snr())) for
                       key, tab in mi_tables.iteritems())
    print_verbose("Columns extracted from MultiInspiralTable(s).\n")

    # injection data
    sites = set(ifo[0] for ifo in ifos)
    sim_time = dict((key, asarray(tab.get_column("geocent_end_time")) +
                         asarray(tab.get_column("geocent_end_time_ns")*1e-9))
                   for key, tab in sim_tables.iteritems())
    sim_mchirp = dict((key, asarray(tab.get_column("mchirp"))) for
                     key, tab in sim_tables.iteritems())
    sim_ra = dict((key, asarray(tab.get_column("longitude"))) for
                 key, tab in sim_tables.iteritems())
    sim_dec = dict((key, asarray(tab.get_column("latitude"))) for
                 key, tab in sim_tables.iteritems())
    sim_dist = dict((key, asarray(tab.get_column("distance"))) for
                   key, tab in sim_tables.iteritems())
    sim_sngl_eff_dist = dict((key, asarray([tab.get_column("eff_dist_%s"
                                                     % (ifo.lower())) for
                                       ifo in sites])) for
                        key,tab in sim_tables.iteritems())
    sim_eff_dist = dict((key, numpy.power(numpy.power(sim_sngl_eff_dist[key],
                                                     -1).sum(0), -1)) for
                        key, tab in sim_tables.iteritems())
    sim_dec_dist = dict()
    for key,tab in sim_tables.iteritems():
        sim_dec_dist[key] = numpy.zeros(len(tab))
        for k in range(len(tab)):
            sim_dec_dist[key][k] = sorted([sim_sngl_eff_dist[key][i][k] for
                                           i in range(len(sites))])[-2]
        print sim_dec_dist[key]

    #
    # get time slide values
    #

    mi_slide_snr = dict()
    mi_slide_bestnr = dict()
    if background:
        slides = separate_time_slides(bg_mi_table, return_index=True)
        mi_slide_snr[background] = [mi_snr[background][slide] for
                                    slide in slides]
        mi_slide_bestnr[background] = [mi_bestnr[background][slide]
                                       for slide in slides]
        print_verbose("Time slide SNRs separated.\n")
        print_verbose("Events per slide:\n%s\n"
                      % (", ".join(map(str,map(len,mi_slide_snr[background])))),
                      profile=False)
    if foreground:
        mi_slide_snr[foreground] = mi_snr[foreground]
        mi_slide_bestnr[foreground] = mi_bestnr[foreground]
    if simulations:
        mi_slide_snr[found_inj] = mi_snr[found_inj]
        mi_slide_bestnr[found_inj] = mi_bestnr[found_inj]

    #
    # print loudest events
    #

    if simulations:
        # build new table of closest missed injections
        num_loudest = min(10, len(sim_tables[missed_inj]))
        close_missed_idx = sim_dist[missed_inj].argsort()[:10]
        close_missed_events = table.new_from_template(sim_tables[missed_inj])
        close_missed_events.extend(
            numpy.asarray(sim_tables[missed_inj])[close_missed_idx])
        close_missed_xml = ("%s-%s_CLOSE_MISSED-%d-%d.xml.gz"
                       % ("".join(ifos), opts.output_tag.upper(),
                          int(start_time), int(round(duration))))
        write_sim_inspiral_table(close_missed_events, close_missed_xml,
                                 missed_inj)
        print_verbose("Closest missed %d %s events written to\n%s\n"
                      % (num_loudest, missed_inj, close_missed_xml))
    elif foreground or background:
        key = foreground or background
        num_loudest = min(10, len(mi_tables[key]))
        loudest_idx = mi_bestnr[key].argsort()[::-1][:10]
        # build new table of loudest events
        loudest_events = table.new_from_template(mi_tables[key])
        loudest_events.extend(numpy.asarray(mi_tables[key])[loudest_idx])
        loudest_xml = ("%s-%s_LOUDEST-%d-%d.xml.gz"
                       % ("".join(ifos), opts.output_tag.upper(),
                          int(start_time), int(round(duration))))
        write_multi_inspiral_table(loudest_events, loudest_xml, key)
        print_verbose("Loudest %d %s events written to\n%s\n"
                      % (num_loudest, key, loudest_xml))

    #
    # calculate injection efficiency
    #

    if simulations:
        if opts.lower_inj_dist == 0:
            efficiency_bins = numpy.linspace(opts.lower_inj_dist,
                                             opts.upper_inj_dist,
                                             opts.num_distance_bins+1,
                                             endpoint=True)
        else:
            efficiency_bins = numpy.logspace(numpy.log10(opts.lower_inj_dist),
                                             numpy.log10(opts.upper_inj_dist),
                                             opts.num_distance_bins+1,
                                             endpoint=True)
        deltaD = numpy.diff(efficiency_bins)
        efficiency_distance = efficiency_bins[:-1] + deltaD/2
        if background:
            fap = asarray([(bestnr <= mi_bestnr[background]).sum() for bestnr in
                           mi_bestnr[found_inj]]) / mi_bestnr[background].size
        else:
            fap = numpy.zeros(mi_bestnr[found_inj].size)
        found_by_dist = numpy.histogram(sim_dist[found_inj][fap==0],
                                        bins=efficiency_bins)[0]
        missed_by_dist = numpy.histogram(list(sim_dist[found_inj][fap!=0]) +
                                         list(sim_dist[missed_inj]),
                                         bins=efficiency_bins)[0]
        efficiency_by_dist = found_by_dist / (found_by_dist + missed_by_dist)
        sngl_efficiency_by_dist = dict()
        for i,ifo in enumerate(sites):
            found_by_dist = numpy.histogram(
                                sim_sngl_eff_dist[found_inj][i,:][fap==0],
                                bins=efficiency_bins)[0]
            sim_sngl_eff_dist[found_inj][i,:]
            missed_by_dist = numpy.histogram(
                               list(sim_sngl_eff_dist[found_inj][i,:][fap!=0]) +
                               list(sim_sngl_eff_dist[missed_inj][i,:]),
                               bins=efficiency_bins)[0]
            sngl_efficiency_by_dist[ifo] = (found_by_dist /
                                            (found_by_dist + missed_by_dist))

    #
    # setup plots
    #

    plotdir = "plots"
    if not os.path.isdir(plotdir):
        os.mkdir(plotdir)

    # set plot params
    plotutils.set_rcParams()
    plotutils.pylab.rcParams.update({
        "font.family": "serif",
        "font.serif": ["Computer Modern Roman"],
        "axes.labelsize" : 21,
        "xtick.labelsize" : 18,
        "ytick.labelsize" : 18,
        "axes.axisbelow" : True})

    # set chisq DOF
    chisq_dof = (bg_mi_table and bg_mi_table[0].chisq_dof or
                 fg_mi_table and fg_mi_table[0].chisq_dof or
                 sim_mi_table and sim_mi_table[0].chisq_dof)
    bank_chisq_dof = (bg_mi_table and bg_mi_table[0].bank_chisq_dof or
                      fg_mi_table and fg_mi_table[0].bank_chisq_dof or
                      sim_mi_table and sim_mi_table[0].bank_chisq_dof)
    cont_chisq_dof = (bg_mi_table and bg_mi_table[0].cont_chisq_dof or
                      fg_mi_table and fg_mi_table[0].cont_chisq_dof or
                      sim_mi_table and sim_mi_table[0].cont_chisq_dof)

    # get )contours
    new_snr_contours = list(set([5.5, 6, 6.5, 7, 8, 9, 10, 11] +
                                [opts.snr_threshold]))
    veto_contour_index = new_snr_contours.index(opts.snr_threshold)
    contours = coh_PTF_pyutils.calculate_contours(
                   new_snrs=new_snr_contours,
                   null_thresh=opts.null_snr_threshold,
                   new_snr_thresh=opts.snr_threshold,
                   chisq_dof=chisq_dof, bank_chisq_dof=bank_chisq_dof,
                   cont_chisq_dof=cont_chisq_dof)
    contours = dict(zip(("new_snr", "bank_chisq", "cont_chisq", "chisq",
                         "null", "snr", "color"),
                        [new_snr_contours]+list(contours)))
    contours["null_weight"] = (contours["null"] / opts.null_snr_threshold *
                               opts.null_weight_threshold)
    veto_contours = dict((key, contours[key][veto_contour_index]) for
                         key in ["new_snr", "bank_chisq", "cont_chisq",
                                 "chisq"])

    # get plot time
    time_unit, time_str = plotutils.time_axis_unit(end_time - start_time)
    mi_plot_time = dict()
    sim_plot_time = dict()
    for key,val in mi_time.iteritems():
        mi_plot_time[key] = (val - start_time)/time_unit
    for key,val in sim_time.iteritems():
        sim_plot_time[key] = (val - start_time)/time_unit
    plot_duration = float(duration)/time_unit
    start_date = datetime.datetime(
                     *date.XLALGPSToUTC(LIGOTimeGPS(start_time))[:6])
    time_label = "Time (%s) since %s (%s)"\
                 % (time_str, start_date.strftime("%B %d %Y, %H:%M:%S %ZUTC"),
                    int(start_time))

    # set plot tag
    plotname = os.path.join(plotdir, "%s-%s_%s-%d-%d.png"\
                                     % ("".join(sorted(ifos)),
                                        opts.output_tag, "%s",
                                        int(start_time), int(round(duration))))

    #
    # plot versus time
    #

    # plot time versus SNR
    outfile = plotname % "TIME_SNR"
    plot_xy(outfile, mi_plot_time, mi_snr, time_label,
            "Coherent signal-to-noise ratio (SNR)", xlim=[0, plot_duration],
            ylim=SNR_LIM, logy=True)
    print_verbose("%s written.\n" % outfile)

    # plot time versus Null SNR
    outfile = plotname % "TIME_NULL_SNR"
    plot_xy(outfile, mi_plot_time, mi_null_snr, time_label,
            "Null signal-to-noise ratio (SNR)", xlim=[0, plot_duration],
            ylim=NULL_SNR_LIM,
            fill_below=([0, plot_duration],[opts.null_snr_threshold]*2))
    print_verbose("%s written.\n" % outfile)

    # plot time versus BestNR
    outfile = plotname % "TIME_BESTNR"
    plot_xy(outfile, mi_plot_time, mi_bestnr, time_label,
            "$\\chi^2$ re-weighted signal-to-noise ratio (SNR)",
            xlim=[0, plot_duration], ylim=BESTNR_LIM, logy=True,
            fill_below=([0, plot_duration], [opts.snr_threshold]*2))
    print_verbose("%s written.\n" % outfile)

    # plot time versus sngl SNR
    for i,ifo in enumerate(ifos):
        outfile = plotname % ("TIME_SNR%s" % ifo.upper())
        plot_xy(outfile, mi_plot_time, mi_sngl_snr[ifo], time_label,
                "%s signal-to-noise ratio (SNR)" % ifo.upper(),
                xlim=[0, plot_duration], ylim=SNGL_SNR_LIM, logy=True,
                fill_below=([0, plot_duration], [opts.snr_threshold]*2))
        print_verbose("%s written.\n" % outfile)

    #
    # plot num slides
    #

    if background and not simulations:
        plot = plotutils.BarPlot("Slide number", "Number of events")
        xdata = numpy.arange(num_slides+1) - num_slides//2
        xdata = xdata[xdata!=0]
        ydata = numpy.asarray(list(map(len, mi_slide_snr[background])))
        plot.add_content(xdata, ydata, color=MARKER_FC[0], width=1)
        if foreground:
            plot.add_content([0], [len(mi_slide_snr[foreground])],
                             color=MARKER_FC[1], width=1)
            xdata = sorted(list(xdata) + [0])
        plot.finalize(alpha=1.0)
        plot.ax.xaxis.set_ticks(xdata)
        outfile = plotname % "NUM_EVENTS"
        plot.savefig(outfile)
        plot.close()
        print_verbose("%s written.\n" % outfile)

    plotutils.pylab.rcParams.update({"figure.figsize":[8,6]})

    #
    # plot SNR histograms
    #

    try:
        outfile = plotname % "BESTNR_HISTOGRAM"
        plot_hist(outfile, mi_slide_bestnr,
                  "$\\chi^2$ re-weighted signal-to-noise ratio (SNR)",
                  "Cumulative number of events", xlim=[5, 20])
        print_verbose("%s written.\n" % outfile)
    except RuntimeError as e:
        if re.search("maximum recursion", str(e)):
            warnings.warn(str(e), RuntimeWarning)
        else:
            raise

    #
    # plot versus snr
    #

    # plot BestNR
    outfile = plotname % ("SNR_BESTNR")
    plot_xy(outfile, mi_snr, mi_bestnr, "Coherent signal-to-noise ratio (SNR)",
            "$\\chi^2$ re-weighted signal-to-noise ratio (SNR)",
            xlim=SNR_LIM, ylim=BESTNR_LIM, logx=True, logy=True)
    print_verbose("%s written.\n" % outfile)

    # plot sngl SNR
    for i,ifo in enumerate(ifos):
        outfile = plotname % ("SNR_SNR%s" % ifo.upper())
        line = dict()
        second_veto = (len(opts.sngl_snr_threshold) > 1 and
                       {"line": (SNR_LIM, [opts.sngl_snr_threshold[1]]*2)} or
                       {})
        try:
            plot_xy(outfile, mi_snr, mi_sngl_snr[ifo],
                    "Coherent signal-to-noise ratio (SNR)",
                    "%s SNR" % ifo.upper(), xlim=SNR_LIM, ylim=SNGL_SNR_LIM,
                    logx=True, logy=True,
                    fill_below=(SNR_LIM, [opts.sngl_snr_threshold[0]]*2),
                    **second_veto)
        except RuntimeError as e:
            if re.search("maximum recursion", str(e)):
                warnings.warn("Maximum recursion error caught (bug in "
                              "matplotlib < 1), plot will not be shaded")
            plot_xy(outfile, mi_snr, mi_sngl_snr[ifo],
                    "Coherent signal-to-noise ratio (SNR)",
                    "%s SNR" % ifo.upper(), xlim=SNR_LIM, ylim=SNGL_SNR_LIM,
                    logx=True, logy=True,
                    **second_veto)
        print_verbose("%s written.\n" % outfile)

    # plot null snr
    outfile = plotname % "SNR_NULL"
    try:
        plot_xy(outfile, mi_snr, mi_null_snr, "Coherent signal-to-noise ratio (SNR)",
                "Null SNR", xlim=SNR_LIM, ylim=NULL_SNR_LIM, logx=True,
                logy=False, fill_above=(contours["snr"], contours["null"]),
                line=(contours["snr"], contours["null_weight"]))
    except RuntimeError as e:
        plot_xy(outfile, mi_snr, mi_null_snr, "Coherent signal-to-noise ratio (SNR)",
                "Null SNR", xlim=SNR_LIM, ylim=NULL_SNR_LIM, logx=True,
                logy=False, plot_contours="null",
                line=(contours["snr"], contours["null_weight"]))
    print_verbose("%s written.\n" % outfile)

    # plot chisqs
    outfile = plotname % "SNR_CHISQ"
    try:
        plot_xy(outfile, mi_snr, mi_chisq, "Coherent signal-to-noise ratio (SNR)",
                "$\\chi^2$", xlim=SNR_LIM, ylim=CHISQ_LIM, logx=True, logy=True,
                plot_contours="chisq",
                fill_above=(contours["snr"], veto_contours["chisq"]))
    except RuntimeError as e:
        if re.search("maximum recursion", str(e)):
            warnings.warn("Maximum recursion error caught (bug in "
                          "matplotlib < 1), plot will not be shaded")
            plot_xy(outfile, mi_snr, mi_chisq,
                    "Coherent signal-to-noise ratio (SNR)", "$\\chi^2$",
                    xlim=SNR_LIM, ylim=CHISQ_LIM, logx=True, logy=True,
                    plot_contours="chisq")

        else:
            raise
    print_verbose("%s written.\n" % outfile)


    outfile = plotname % "SNR_BANK_CHISQ"
    try:
        plot_xy(outfile, mi_snr, mi_bank_chisq,
                "Coherent signal-to-noise ratio (SNR)", "Bank $\\chi^2$",
                xlim=SNR_LIM, ylim=BANK_CHISQ_LIM,
                logx=True, logy=True, plot_contours="bank_chisq",
                fill_above=(contours["snr"], veto_contours["bank_chisq"]))
    except RuntimeError as e:
        if re.search("maximum recursion", str(e)):
            plot_xy(outfile, mi_snr, mi_bank_chisq,
                    "Coherent signal-to-noise ratio (SNR)",
                    "Bank $\\chi^2$", xlim=SNR_LIM, ylim=BANK_CHISQ_LIM,
                    logx=True, logy=True, plot_contours="bank_chisq")
        else:
            raise
    print_verbose("%s written.\n" % outfile)

    outfile = plotname % "SNR_CONT_CHISQ"
    try:
        plot_xy(outfile, mi_snr, mi_cont_chisq,
                "Coherent signal-to-noise ratio (SNR)",
                "Auto-correlation $\\chi^2$", xlim=SNR_LIM, ylim=CONT_CHISQ_LIM,
                logx=True, logy=True, plot_contours="cont_chisq",
                fill_above=(contours["snr"], veto_contours["cont_chisq"]))
    except RuntimeError as e:
        if re.search("maximum recursion", str(e)):
            plot_xy(outfile, mi_snr, mi_cont_chisq,
                    "Coherent signal-to-noise ratio (SNR)",
                    "Auto-correlation $\\chi^2$", xlim=SNR_LIM,
                    ylim=CONT_CHISQ_LIM, logx=True, logy=True,
                    plot_contours="cont_chisq")
        else:
            raise

    print_verbose("%s written.\n" % outfile)

    #
    # plot found/missed injections
    #

    if simulations:
        outfile = plotname % "FOUND_MISSED_MCHIRP_DEC_DIST"
        plot_found_missed(outfile,
                          (sim_mchirp[found_inj][fap==0],
                           sim_dec_dist[found_inj][fap==0]),
                          (sim_mchirp[missed_inj], sim_dec_dist[missed_inj]),
                          (sim_mchirp[found_inj][fap!=0],
                           sim_dec_dist[found_inj][fap!=0], fap[fap!=0]),
                          "$\mathcal{M}_c$ (${\\rm M_\\odot}$)",
                          "Injected decisive distance (${\\rm Mpc}$)",
                          ylim=[opts.lower_inj_dist, opts.upper_inj_dist])
        print_verbose("%s written.\n" % outfile)

        outfile = plotname % "FOUND_MISSED_TIME_DEC_DIST"
        plot_found_missed(outfile,
                          (sim_plot_time[found_inj][fap==0],
                           sim_dec_dist[found_inj][fap==0]),
                          (sim_plot_time[missed_inj], sim_dec_dist[missed_inj]),
                          (sim_plot_time[found_inj][fap!=0],
                           sim_dec_dist[found_inj][fap!=0], fap[fap!=0]),
                          time_label,
                          "Injected decisive distance (${\\rm Mpc}$)",
                          xlim=[0, plot_duration],
                          ylim=[opts.lower_inj_dist, opts.upper_inj_dist])
        print_verbose("%s written.\n" % outfile)

        outfile = plotname % "EFFICIENCY"
        plot = plotutils.SimplePlot("Distance (${\\rm Mpc}$)", "Efficiency")
        for d,e in zip(efficiency_distance, efficiency_by_dist):
            print d,e
        plot.add_content(efficiency_distance, efficiency_by_dist)
        plot.finalize()
        plot.ax.set_xscale("log")
        plot.ax.set_xlim(1, 1e3)
        plot.ax.set_ylim(0, 1.1)
        plot.savefig(outfile)
        plot.close()
        print_verbose("%s written.\n" % outfile)

        for i,ifo in enumerate(sites):
            outfile = plotname % ("FOUND_MISSED_MCHIRP_%s_EFF_DIST"
                                  % ifo.upper())
            plot_found_missed(outfile,
                              (sim_mchirp[found_inj][fap==0],
                               sim_sngl_eff_dist[found_inj][i,:][fap==0]),
                              (sim_mchirp[missed_inj],
                               sim_sngl_eff_dist[missed_inj][i,:]),
                              (sim_mchirp[found_inj][fap!=0],
                               sim_sngl_eff_dist[found_inj][i,:][fap!=0],
                               fap[fap!=0]),
                              "Chirp mass ($\mathrm{M}_{\\odot}$)",
                              "%s effective distance (Mpc)" % ifo,
                              ylim=[opts.lower_inj_dist, opts.upper_inj_dist])
            print_verbose("%s written.\n" % outfile)

            outfile = plotname % ("FOUND_MISSED_TIME_%s_EFF_DIST"
                                  % ifo.upper())
            plot_found_missed(outfile,
                              (sim_plot_time[found_inj][fap==0],
                               sim_sngl_eff_dist[found_inj][i,:][fap==0]),
                              (sim_plot_time[missed_inj],
                               sim_sngl_eff_dist[missed_inj][i,:]),
                              (sim_plot_time[found_inj][fap!=0],
                               sim_sngl_eff_dist[found_inj][i,:][fap!=0],
                               fap[fap!=0]),
                              time_label,
                              "%s effective distance (${\\rm Mpc}$)" % ifo,
                              xlim=[0, plot_duration],
                              ylim=[opts.lower_inj_dist, opts.upper_inj_dist])
            print_verbose("%s written.\n" % outfile)

            outfile = plotname % "%s_EFFICIENCY" % ifo
            plot = plotutils.SimplePlot("%s effective distance (${\\rm Mpc}$)"
                                        % ifo, "Efficiency")
            plot.add_content(efficiency_distance, sngl_efficiency_by_dist[ifo])
            plot.finalize()
            plot.ax.set_xscale("log")
            plot.ax.set_ylim(0, 1.1)
            plot.ax.set_xlim(1, 1e3)
            plot.savefig(outfile)
            plot.close()
            print_verbose("%s written.\n" % outfile)
