#!/usr/bin/python
import scipy
from scipy import interpolate
import numpy
try:
	import sqlite3
except ImportError:
	# pre 2.5.x
	from pysqlite2 import dbapi2 as sqlite3

sqlite3.enable_callback_tracebacks(True)

import sys
import glob
import copy
from optparse import OptionParser
import traceback

from glue import segments
from glue.ligolw import ligolw
from glue.ligolw import array
from glue.ligolw import param
array.use_in(ligolw.LIGOLWContentHandler)
param.use_in(ligolw.LIGOLWContentHandler)
from glue.ligolw import lsctables
from glue.ligolw import dbtables
from glue.ligolw import utils
from glue.ligolw import table
from glue import segmentsUtils
from glue.ligolw.utils import process
from glue.ligolw.utils import segments as ligolw_segments

from pylal import db_thinca_rings
from pylal import llwapp
from pylal import rate
from pylal import series as lalseries
from pylal import SimInspiralUtils
from pylal.xlal.datatypes.ligotimegps import LIGOTimeGPS
from pylal import imr_utils
import matplotlib
matplotlib.use('Agg')
goldenratio = 2 / (1 + 5**.5)
matplotlib.rcParams.update({
        "font.size": 8.0,
        "axes.titlesize": 8.0,
        "axes.labelsize": 8.0,
        "xtick.labelsize": 8.0,
        "ytick.labelsize": 8.0,
        "legend.fontsize": 6.0,
        "figure.figsize": (3.3,3.3*goldenratio),
        "figure.dpi": 200,
        "subplots.left": 0.2,
        "subplots.right": 0.75,
        "subplots.bottom": 0.15,
        "subplots.top": 0.75,
        "savefig.dpi": 600,
        "text.usetex": True     # render all text with TeX
})
from scipy.interpolate import interp1d
from matplotlib import pyplot
from gstlal import reference_psd

from pylal import git_version
__author__ = "Chad Hanna <channa@ligo.caltech.edu>, Satya Mohapatra <satya@physics.umass.edu>"
__version__ = "git id %s" % git_version.id
__date__ = git_version.date

def compute_horizon_interpolator(psd_files, verbose = False):
	sites = ('H1', 'L1', 'V1')
	horizons = dict( (s, []) for s in sites)
	times = dict( (s, []) for s in sites)
	for f in psd_files:
		psds = lalseries.read_psd_xmldoc(utils.load_filename(f, verbose = verbose, contenthandler = ligolw.LIGOLWContentHandler))
		for ifo, psd in psds.items():
			if psd is not None:
				times[ifo].append(int(psd.epoch))
				horizons[ifo].append(reference_psd.horizon_distance(psd, 1.4, 1.4, 8, 10, 2048))
	interps = {}
	for ifo in times:
		if len(times[ifo]) < 1:
			continue
		# put extra points at the boundary
		t = times[ifo]
		h = horizons[ifo]
		t.insert(0, 0)
		h.insert(0, h[0])
		t.append(2000000000)
		h.append(h[-1])
		interps[ifo] = interp1d(t, h)
	return interps

def parse_command_line():
	parser = OptionParser(version = git_version.verbose_msg, usage = "%prog [options] [file ...]", description = "%prog computes mass/mass upperlimit")
	parser.add_option("--output-tag", default = "", metavar = "name", help = "Set the file output name file")
	parser.add_option("--far", action = "append", type = "float", metavar = "Hz", help = "FAR to use for injection finding - can be given multiple times")
	parser.add_option("--live-time-program", default = "gstlal_inspiral", help = "Set the name of the live time program to use to get segments from the search summary table")
	parser.add_option("--veto-segments-name", default = "vetoes", help = "Set the name of the veto segments to use from the XML document.")
	parser.add_option("-t", "--tmp-space", metavar = "path", help = "Path to a directory suitable for use as a work area while manipulating the database file.  The database file will be worked on in this directory, and then moved to the final location when complete.  This option is intended to improve performance when running in a networked environment, where there might be a local disk with higher bandwidth than is available to the filesystem on which the final output will reside.")
	parser.add_option("--l-snr", action = "append", type = "float", metavar = "snr", help = "L1 snr - can be given multiple times")
	parser.add_option("--h-snr",  action = "append", type = "float", metavar = "snr", help = "H1 snr - can be given multiple times")
	parser.add_option("--verbose", action = "store_true", help = "Be verbose.")

	opts, filenames = parser.parse_args()

	print opts.far, opts.l_snr, opts.h_snr

	if not filenames:
		raise ValueError("must provide at least one database")

	return opts, filenames


options, filenames = parse_command_line()

IMR = imr_utils.DataBaseSummary([f for f in filenames if f.endswith('.sqlite')], tmp_path = options.tmp_space, veto_segments_name = options.veto_segments_name, live_time_program = options.live_time_program, verbose = options.verbose)

if options.h_snr is not None and options.l_snr is not None:
	horizon_interp = compute_horizon_interpolator([f for f in filenames if f.endswith('.xml.gz')], options.verbose)

for instruments, total_injections in IMR.total_injections_by_instrument_set.items():
	colors = ['b', 'g', 'r', 'c', 'm', 'y']

	# make a figure
	fig = pyplot.figure()
	ax = fig.add_axes((.2,.15,.75,.8))
	dist_bins = imr_utils.guess_nd_bins(total_injections, {"distance": (15, rate.LinearBins)})

	for i, far in enumerate(options.far):
		found_below_far = [s for f, s in IMR.found_injections_by_instrument_set[instruments] if f < far]
		if options.verbose:
			print >> sys.stderr, "Found %d injections in %s below FAR %e" % (len(found_below_far), "".join(instruments), far)
		eff, yerr = imr_utils.compute_search_efficiency_in_bins(found_below_far, total_injections, dist_bins)
		ax.errorbar(dist_bins.centres()[0], eff.array, yerr = yerr.array, capsize = 3, color = colors[i], label = "$\mathrm{Efficiency\,at\,%.0f\\times 10^{%d}\,FAR}$" % (far / 10.**(numpy.floor(numpy.log10(far))), numpy.floor(numpy.log10(far))))

	if options.h_snr is not None and options.l_snr is not None:
		for i, (h_snr, l_snr) in enumerate(zip(options.h_snr, options.l_snr)):
			# FIXME always assumes that H1 and L1 define the coincidence, probably roughly true unless you happen to be doing an H2 L1 search
			found_below_horizon = [sim for sim in total_injections if sim.get_chirp_dist('h') < horizon_interp['H1'](sim.h_end_time) * (8. / h_snr) and sim.get_chirp_dist('l') < horizon_interp['L1'](sim.l_end_time) * (8. / l_snr)]
			eff_horizon, y_horizon_err = imr_utils.compute_search_efficiency_in_bins(found_below_horizon, total_injections, dist_bins)
			ax.errorbar(dist_bins.centres()[0], eff_horizon.array, yerr = y_horizon_err.array, ls="--", capsize = 3, color = colors[i], label = "$\mathrm{GN\,-\,H1/L1\,SNR\,%.2f/%.2f}$" % (h_snr, l_snr))

	pyplot.grid()
	pyplot.legend()
	pyplot.xlabel('$\mathrm{Distance\,(Mpc)}$')
	pyplot.ylabel('$\mathrm{Efficiency}$')
	pyplot.savefig("%s-%s.png" % ("".join(instruments), options.output_tag))

