#!/usr/bin/python
#
# Copyright (C) 2012 Matthew West
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


#
# =============================================================================
#
#                                 Preamble
#
# =============================================================================
#

from optparse import OptionParser
import sqlite3
import sys
import os

import numpy as np
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot
from matplotlib import rcParams

from glue import iterutils
from glue.ligolw import dbtables
from pylal import git_version
from pylal import sngl_tmplt_far
from pylal import InspiralUtils

__author__ = "Matthew West <matthew.west@ligo.org>"
__version__ = "git id %s" % git_version.id
__date__ = git_version.date

description = """
    %prog generates a plot comparing the all-possible-coincs
    method of background estimation to that of slides.
    """


#
# =============================================================================
#
#                                Command Line
#
# =============================================================================
#


def parse_command_line():
    """
    Parse the command line.
    """
    parser = OptionParser(
        version = "Name: %%prog\n%s" % git_version.verbose_msg,
        usage = "%prog [options]",
        description = description 
        )
    # The available command line options
    parser.add_option("-d", "--database", action="store", type="string", default=None,
        help =
            "Input database to read. Can only input one at a time."
        )
    parser.add_option( "-t", "--tmp-space", action="store", type="string", default=None,
        metavar = "PATH",
        help =
            "Location of local disk on which to do work. This is optional; " +
            "it is only used to enhance performance in a networked " +
            "environment. "
        )
    parser.add_option( "-P", "--output-path", action = "store", type = "string",
        default = os.getcwd(), metavar = "PATH",
        help =
            "Optional. Path where the figures should be stored. Default is current directory."
        )
    parser.add_option("--snr-stat", action="store", type="string", default=None,
        metavar = "{rawsnr|snroverchi|effsnr|newsnr}",
        help =
            "Select the desired coincident ChiSq weighted SNR (required). " +
            "\"rawsnr\" is just the matched filter SNR and lacks any weighting " +
            "by a ChiSq statistic. The statistic \"snroverchi\" is used by gstlal. " +
            "\"effsnr\" uses the Effective SNR algorithm to calculate a ChiSq " +
            "weighted detection SNR.  \"newsnr\" uses the New SNR algorithm to " +
            "calculate a ChiSq weighted SNR."
        )
    parser.add_option("--sngls-threshold", type="float", default=5.5,
        metavar = "float",
        help = 
            "The minimum chisq-weighted snr for singles triggers"
        )
    parser.add_option("--sngls-bin-width", type="float", default=0.01,
        metavar = "float", 
        help = 
            "The width of bins in the single-ifo snr histograms"
        )
    parser.add_option("--coinc-bin-width", type="float", default=0.05,
        metavar = "float", 
        help = 
            "The width of bins in the coinc snr histograms"
        )
    parser.add_option( "-T", "--time-units", action = "store", default = 'yr',
        metavar = "s, min, hr, days, OR yr",
        help =
            "Time units to use when calculating FARs (the units of the FARs will be the inverse of this). " +
            "Options are s, min, hr, days, or yr. Default is yr."
        )
    parser.add_option( "-v", "--verbose", action = "store_true", default = False,
        help =
            "Be verbose."
        )

    options, arguments =  parser.parse_args()
    required_options = ["database", "snr_stat"]
    missing_options = [option for option in required_options if not getattr(options, option)]

    if missing_options:
        missing_options = ', '.join([
            "--%s" % option.replace("_", "-")
            for option in missing_options
        ])
        raise ValueError, "missing required option(s) %s" % missing_options

    if options.snr_stat not in ("rawsnr", "snroverchi", "effsnr", "newsnr"):
        raise ValueError, "unrecognized --snr-stat %s" % options.snr_stat

    return options


#
# =============================================================================
#
#                                    Main
#
# =============================================================================
#


#
# Command line
#

opts = parse_command_line()

# Setup working databases and connections
working_filename = dbtables.get_connection_filename(
    opts.database,
    tmp_path=opts.tmp_space,
    verbose=opts.verbose)
connection = sqlite3.connect( working_filename )
if opts.tmp_space:
    dbtables.set_temp_store_directory(
        connection,
        opts.tmp_space,
        verbose=opts.verbose)
dbtables.DBTable_set_connection( connection )

# create needed indices on tables if they don't already exist
current_indices = [index[0] for index in
    connection.execute('SELECT name FROM sqlite_master WHERE type == "index"').fetchall()]

sqlscript = ''
if 'si_idx' not in current_indices:
    sqlscript += 'CREATE INDEX si_idx ON sngl_inspiral (event_id, snr, chisq, chisq_dof, mchirp, eta);\n'
if 'ci_idx' not in current_indices:
    sqlscript += 'CREATE INDEX ci_idx ON coinc_inspiral (coinc_event_id);\n'
if 'cem_idx' not in current_indices:
    sqlscript += 'CREATE INDEX cem_idx ON coinc_event_map (coinc_event_id, event_id);\n'
if 'ts_idx' not in current_indices:
    sqlscript += 'CREATE INDEX ts_idx ON time_slide (instrument, offset);\n'
if 'e_idx' not in current_indices:
    sqlscript += 'CREATE INDEX e_idx ON experiment (experiment_id, instruments);\n'
if 'es_idx' not in current_indices:
    sqlscript += 'CREATE INDEX es_idx ON experiment_summary (experiment_id, time_slide_id, datatype);\n'
if 'em_idx' not in current_indices:
    sqlscript += 'CREATE INDEX em_idx ON experiment_map (coinc_event_id, experiment_summ_id);\n'

connection.executescript( sqlscript )

# get the template mass parameters
sqlquery = """
SELECT DISTINCT mchirp, eta, mass1, mass2
FROM sngl_inspiral
"""
mchirp, eta, m1, m2 = connection.execute( sqlquery ).fetchall()[0]
m1 = round(m1, 2)
m2 = round(m2, 2)

# get the set of ifos searched
sqlquery = """
SELECT DISTINCT ifos
FROM search_summary
"""
all_ifos = [ifo[0] for ifo in connection.execute( sqlquery )]

# get the single-ifo snr histograms
if opts.verbose:
    print >> sys.stdout, 'Making single-ifo snr histograms'
sngl_ifo_hist, sngl_ifo_midbins = sngl_tmplt_far.all_sngl_snr_hist(
    connection,
    mchirp, eta,
    all_ifos,
    min_snr=opts.sngls_threshold,
    sngls_width=opts.sngls_bin_width,
    snr_stat=opts.snr_stat)

# get single-ifo analyzed times 
T_i = sngl_tmplt_far.get_singles_times(connection, verbose=opts.verbose)

# get list of time_slide_ids
sqlquery = """
SELECT time_slide_id, group_concat(offset)
FROM time_slide GROUP BY time_slide_id
"""
time_shifts = connection.execute(sqlquery).fetchall()
tsid_list = [slide[0] for slide in time_shifts]
zerolag_tsid = [slide[0] for slide in time_shifts if not any(map(float, slide[1].split(',')))][0]

# create plot basename
tmplt_id = [piece for piece in opts.database.split('-') if 'TMPLT' in piece][0]
vetoCat = connection.execute(
    'SELECT DISTINCT name FROM segment_definer'
    ).fetchone()[0].split('_')[:2]
start_time, end_time = connection.execute(
    'SELECT DISTINCT gps_start_time, gps_end_time FROM experiment'
    ).fetchone()
plot_basename = '-'.join([
    '%s', 'all_possible_coincs',
    tmplt_id,
    '_'.join(sorted(vetoCat)),
    str(start_time), str(end_time - start_time) ])

for num_ifos in range(2, len(all_ifos)+1):
    for ifos in iterutils.choices(all_ifos, num_ifos):
        # use slides to determine the coincidence window
        tau = sngl_tmplt_far.get_coinc_window(connection, ifos)
        # make combined snr bins
        max_comb_snr = sngl_tmplt_far.quadrature_sum([max(sngl_ifo_midbins[ifo]) for ifo in ifos])
        combined_bins = np.arange(
            2**(1./2)*opts.sngls_threshold,
            max_comb_snr + 2*opts.coinc_bin_width,
            opts.coinc_bin_width)
        mid_bins = 0.5*( combined_bins[1:] + combined_bins[:-1] )
 
        # ------------------------ Foreground Coincs ------------------------ #
        if opts.verbose:
            print >> sys.stdout, 'Creating rates distribution for foreground coincs'

        zerolag_coincs = sngl_tmplt_far.coinc_snr_hist(
            connection,
            ifos,
            mchirp, eta,
            min_snr=opts.sngls_threshold,
            datatype="all_data",
            slide_id=zerolag_tsid,
            combined_bins=combined_bins,
            snr_stat=opts.snr_stat)['snr_hist']
        # get the amount of inclusive coinc time
        T_z = sngl_tmplt_far.coinc_time(connection, 'all_data', ifos, opts.time_units, tsid=zerolag_tsid)
        # create rates dict for zerolag coincs
        rate_zerolag = sngl_tmplt_far.compute_cumrate(zerolag_coincs, T_z)

        # ----------------------- All Possible Coincs ----------------------- #
        if opts.verbose:
            print >> sys.stdout, 'Creating rates distribution from all-possible-coincs'

        apc_hist = sngl_tmplt_far.all_possible_coincs(
            sngl_ifo_hist,
            sngl_ifo_midbins,
            combined_bins,
            ifos
        )
        # compute effective background time 
        T_b = sngl_tmplt_far.eff_bkgd_time(T_i, tau, ifos, opts.time_units)
        # create rates dict for all-possible-coincs
        rate_apc = sngl_tmplt_far.compute_cumrate(apc_hist-zerolag_coincs, T_b)

        # ----------------------- Lots of Time Slides ----------------------- #
        if opts.verbose:
            print >> sys.stdout, 'Creating rates distribution from time-slides'

        slide_coinc_hist = sngl_tmplt_far.coinc_snr_hist(
            connection,
            ifos,
            mchirp, eta,
            min_snr=opts.sngls_threshold,
            datatype="slide",
            combined_bins=combined_bins,
            snr_stat=opts.snr_stat)['snr_hist']
        # total slide time
        T_s = sngl_tmplt_far.coinc_time(connection, 'slide', ifos, opts.time_units)
        # create rates dict for time-slides
        rate_slides = sngl_tmplt_far.compute_cumrate(slide_coinc_hist, T_s)

        # ------------------------- Plotting Commands ----------------------- #
        if opts.verbose:
            print >> sys.stdout, 'Making Comparison Plot'

        plotting_params = {
            'font.size': 16,
            'text.usetex': True,
            'xtick.labelsize': 'medium',
            'ytick.labelsize': 'medium',
            'axes.grid': False,
            'axes.titlesize': 'medium',
            'axes.labelsize': 'medium',
            'grid.color': 'k',
            'grid.linestyle': '-',
            'legend.fontsize': 'medium',
            'legend.loc': 'upper right'
        }
        rcParams.update(plotting_params)

        # make new figure
        fig = pyplot.figure()
        ax = fig.add_subplot(111)

        # making the plots with errorbars
        ax.errorbar(
            combined_bins[:-1], rate_apc["mode"],
            yerr=(rate_apc["lower_lim"],rate_apc["upper_lim"]),
            color='r', label='All Possible Coincs',
            linestyle='None', marker='.', markersize=12)
        ax.errorbar(
            combined_bins[:-1], rate_slides["mode"],
            yerr=(rate_slides["lower_lim"],rate_slides["upper_lim"]),
            color='b', label='100,000 5sec Slides',
            linestyle='None', marker='.', markersize=12)
        ax.errorbar(
            combined_bins[:-1], rate_zerolag["mode"],
            yerr=(rate_zerolag["lower_lim"],rate_zerolag["upper_lim"]),
            color='c', label='Zerolag Coincs',
            linestyle='None', marker='.', markersize=12)

        # set axes scales and tick marks
        ax.set_xscale('linear')
        ax.set_xlim(7.5, 11.5)
        ax.xaxis.set_minor_locator(pyplot.MultipleLocator(0.1))
        ax.set_yscale('log')
        ax.set_ylim(1e-7, 1e2)
        ax.grid(which='major', linewidth=0.75)
        ax.grid(which='minor', linewidth=0.25)

	# make plot labels, title, & legend
        ifos_str = ','.join(sorted(ifos))
        stat_name = "SNR"
        if opts.snr_stat == "newsnr":
            stat_name = ''.join(["New", stat_name])
        elif opts.snr_stat == "effsnr":
            stat_name = ''.join(["Eff", stat_name])
        elif opts.snr_stat == "snroverchi":
            stat_name = ''.join([stat_name, "overCHI"])
        quad_sum_string = '+'.join(["\\rho^2_{%s}" % ifo for ifo in ifos])

        ax.set_xlabel(r'Combined %s: $\sqrt{%s}$' % (stat_name, quad_sum_string))
        ax.set_ylabel(r'Measured FAR $(yr^{-1})$ for $\rho_c \geq X$')
        ax.set_title(
            r'Rate of %s Events $(%s M_{\odot}-%s M_{\odot})$' %(ifos_str, m1, m2)
            )
        ax.legend(numpoints=1)

        # save the figure
        plot_name = opts.output_path + plot_basename % ifos_str.replace(',','')
        pyplot.savefig(plot_name + '.png', dpi=200, format='png')

        # close figure and clear memory
        pyplot.close(fig)

#
#   Finished cycling over experiments; exit
#
connection.close()
dbtables.discard_connection_filename(opts.database, working_filename, verbose=opts.verbose)
if opts.verbose:
    print >> sys.stdout, "Finished!"
sys.exit(0)

