#!/usr/bin/env python
#
# Copyright (C) 2013,2014  Kipp Cannon, Chad Hanna, Jacob Peoples
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

## @file
# A program to comput the signal and noise rate posteriors of a gstlal inspiral analysis


#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#


import h5py
import math
import matplotlib
matplotlib.rcParams.update({
	"font.size": 8.0,
	"axes.titlesize": 10.0,
	"axes.labelsize": 10.0,
	"xtick.labelsize": 8.0,
	"ytick.labelsize": 8.0,
	"legend.fontsize": 8.0,
	"figure.dpi": 600,
	"savefig.dpi": 600,
	"text.usetex": True
})
from matplotlib import figure
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
import numpy
from optparse import OptionParser
import sqlite3
import sys


from glue.lal import CacheEntry
from glue.ligolw import dbtables
from glue.ligolw import lsctables
from glue.ligolw import utils as ligolw_utils
from glue.text_progress_bar import ProgressBar
from pylal import ligolw_thinca


from gstlal import far
from gstlal import rate_estimation


golden_ratio = (1. + math.sqrt(5.)) / 2.


__author__ = "Kipp Cannon <kipp.cannon@ligo.org>"
__version__ = "git id %s" % ""	# FIXME
__date__ = ""	# FIXME


#
# =============================================================================
#
#                                 Command Line
#
# =============================================================================
#


def parse_command_line():
	parser = OptionParser(
		version = "Name: %%prog\n%s" % "" # FIXME
	)
	parser.add_option("-c", "--confidence-intervals", metavar = "confidence[,...]", default = "0.95,0.999999", help = "Compute and report confidence intervals in the signal count for these confidences (default = \".95,.999999\", clear to disable).")
	parser.add_option("-i", "--input-cache", metavar = "filename", help = "Also process the files named in this LAL cache.  See lalapps_path2cache for information on how to produce a LAL cache file.")
	parser.add_option("--chain-file", metavar = "filename", help = "Read chain from this file, save chain to this file.")
	parser.add_option("--likelihood-file", metavar = "filename", action = "append", help = "Load ranking statistic PDFs from this file.  Can be given multiple times.")
	parser.add_option("--likelihood-cache", metavar = "filename", help = "Load ranking statistic PDFs from the files in this LAL cache.")
	parser.add_option("-t", "--tmp-space", metavar = "path", help = "Path to a directory suitable for use as a work area while manipulating the database file.  The database file will be worked on in this directory, and then moved to the final location when complete.  This option is intended to improve performance when running in a networked environment, where there might be a local disk with higher bandwidth than is available to the filesystem on which the final output will reside.")
	parser.add_option("--with-background", action = "store_true", help = "Show background posterior on plot.")
	parser.add_option("--restrict-to-instruments", metavar = "name,name[,...]", help = "Compute rate posteriors by looking only at events seen in precisely these instruments (default = use all events).  WARNING:  this is a diagnostic hack for testing the behaviour of the code, the output is not equivalent to inferring the event rate using only the given instruments (that would require discarding triggers and re-ranking the coincs).")
	parser.add_option("--samples", metavar = "count", type = "int", help = "Run this many samples.  Set to a negative number to load and plot the contents of a previously-recorded chain file without doing any additional samples.")
	parser.add_option("-v", "--verbose", action = "store_true", help = "Be verbose.")
	options, filenames = parser.parse_args()

	options.confidence_intervals = map(float, options.confidence_intervals.split(","))

	options.likelihood_filenames = []
	if options.likelihood_cache is not None:
		options.likelihood_filenames += [CacheEntry(line).path for line in open(options.likelihood_cache)]
	if options.likelihood_file is not None:
		options.likelihood_filenames += options.likelihood_file
	if not options.likelihood_filenames and options.samples > 0:
		raise ValueError("must provide --likelihood-cache and/or one or more --likelihood-file options")

	if options.input_cache:
		filenames += [CacheEntry(line).path for line in open(options.input_cache)]

	if options.restrict_to_instruments is not None:
		try:
			options.restrict_to_instruments = frozenset(lsctables.instrument_set_from_ifos(options.restrict_to_instruments))
		except:
			ValueError("invalid --restrict-to-instruments '%s'" % options.restrict_to_instruments)

	return options, filenames


#
# =============================================================================
#
#                              Support Functions
#
# =============================================================================
#


def load_ranking_data(filenames, ln_likelihood_ratio_threshold, verbose = False):
	if not filenames:
		raise ValueError("no likelihood files!")
	ranking_data = None
	for n, filename in enumerate(filenames, start = 1):
		if verbose:
			print >>sys.stderr, "%d/%d:" % (n, len(filenames)),
		xmldoc = ligolw_utils.load_filename(filename, contenthandler = far.ThincaCoincParamsDistributions.LIGOLWContentHandler, verbose = verbose)
		ignored, this_ranking_data, ignored = far.parse_likelihood_control_doc(xmldoc)
		xmldoc.unlink()
		if this_ranking_data is None:
			raise ValueError("'%s' does not contain ranking statistic PDFs" % filename)
		if ranking_data is None:
			ranking_data = this_ranking_data
		else:
			ranking_data += this_ranking_data
	# affect the zeroing of the PDFs below threshold by hacking the
	# histograms before running .finish().  do the indexing ourselves
	# to not 0 the bin @ threshold
	for binnedarray in ranking_data.background_likelihood_rates.values():
		binnedarray.array[:binnedarray.bins[0][threshold],] = 0.
	for binnedarray in ranking_data.signal_likelihood_rates.values():
		binnedarray.array[:binnedarray.bins[0][threshold],] = 0.
	for binnedarray in ranking_data.zero_lag_likelihood_rates.values():
		binnedarray.array[:binnedarray.bins[0][threshold],] = 0.

	ranking_data.finish(verbose = verbose)

	return ranking_data


def load_search_results(filenames, ln_likelihood_ratio_threshold, restrict_to_instruments = None, tmp_path = None, verbose = False):
	background_ln_likelihood_ratios = []
	zerolag_ln_likelihood_ratios = []

	for n, filename in enumerate(filenames, 1):
		if verbose:
			print >>sys.stderr, "%d/%d: %s" % (n, len(filenames), filename)
		working_filename = dbtables.get_connection_filename(filename, tmp_path = tmp_path, verbose = verbose)
		connection = sqlite3.connect(working_filename)

		xmldoc = dbtables.get_xml(connection)
		definer_id = lsctables.CoincDefTable.get_table(xmldoc).get_coinc_def_id(ligolw_thinca.InspiralCoincDef.search, ligolw_thinca.InspiralCoincDef.search_coinc_type, create_new = False)

		for ln_likelihood_ratio, is_background in connection.cursor().execute("""
SELECT
	coinc_event.likelihood,
	EXISTS (
		SELECT
			*
		FROM
			time_slide
		WHERE
			time_slide.time_slide_id == coinc_event.time_slide_id
			AND time_slide.offset != 0
	)
FROM
	coinc_event
	JOIN coinc_inspiral ON (
		coinc_inspiral.coinc_event_id == coinc_event.coinc_event_id
	)
WHERE
	coinc_event.coinc_def_id == ?
	AND coinc_event.likelihood >= ?""" + ("""
	AND coinc_inspiral.ifos == '%s'""" % lsctables.ifos_from_instrument_set(restrict_to_instruments) if restrict_to_instruments is not None else "")
	, (definer_id, ln_likelihood_ratio_threshold)):
			if is_background:
				background_ln_likelihood_ratios.append(ln_likelihood_ratio)
			else:
				zerolag_ln_likelihood_ratios.append(ln_likelihood_ratio)

		connection.close()
		dbtables.discard_connection_filename(filename, working_filename, verbose = verbose)

	return background_ln_likelihood_ratios, zerolag_ln_likelihood_ratios


def plot_rates(signal_rate, noise_rate = None, confidence_intervals = None, restrict_to_instruments = None):
	fig = figure.Figure()
	FigureCanvas(fig)
	fig.set_size_inches((4., (6. if noise_rate is not None else 4.) / golden_ratio))
	if restrict_to_instruments is None:
		fig.suptitle("Event Rate Posterior")
	else:
		fig.suptitle("Event Rate Posterior (from %s only)" % ", ".join(sorted(restrict_to_instruments)))
	if noise_rate is None:
		signal_axes = fig.gca()
	else:
		signal_axes = fig.add_subplot(2, 1, 2)

	x = signal_rate.bins[0].centres()
	y = signal_rate.array
	line1, = signal_axes.plot(x, y, color = "k", linestyle = "-", label = "Signal")
	signal_axes.set_xlabel("Event Rate ($\mathrm{signals} / \mathrm{experiment}$)")
	signal_axes.set_ylabel(r"$P(\mathrm{signals} / \mathrm{experiment})$")
	#signal_axes.loglog()

	#signal_axes.set_ylim((1e-8, 1.))

	if confidence_intervals is not None:
		table = signal_axes.table(cellText = [[r"$P(\mathrm{signals}/\mathrm{experiment} \in [%.4g, %.4g]) = %g$" % (lo, hi, conf)] for conf, (mode, lo, hi) in sorted(confidence_intervals.items())], cellLoc = "left", colWidths = [.7], loc = "upper right")
		for cell in table.get_celld().values():
			cell.set_edgecolor("w")

	if noise_rate is not None:
		background_axes = fig.add_subplot(2, 1, 1)
		x = noise_rate.bins[0].centres()
		y = noise_rate.array
		line2, = background_axes.plot(x, y, color = "r", linestyle = "-", label = "Background")
		background_axes.set_ylabel(r"$P(\mathrm{background\ count})$")
		fig.legend((line1, line2), (line1.get_label(), line2.get_label()), loc = "upper right")

	try:
		fig.tight_layout()
	except AttributeError:
		# old matplotlib. auto-layout not available
		pass
	return fig


#
# =============================================================================
#
#                                     Main
#
# =============================================================================
#


#
# command line
#


options, filenames = parse_command_line()


#
# load ranking statistic PDFs
#


if options.likelihood_filenames:
	ranking_data = load_ranking_data(options.likelihood_filenames, 0.0, verbose = options.verbose)
else:
	ranking_data = None


#
# load search results
#


background_ln_likelihood_ratios, zerolag_ln_likelihood_ratios = load_search_results(filenames, 0.0, restrict_to_instruments = options.restrict_to_instruments, tmp_path = options.tmp_space, verbose = options.verbose)


#
# calculate rate posteriors
#


if options.verbose:
	print >>sys.stderr, "calculating rate posteriors using %d likelihood ratios ..." % len(zerolag_ln_likelihood_ratios)
	progressbar = ProgressBar()
else:
	progressbar = None
signal_rate_pdf, noise_rate_pdf = rate_estimation.calculate_rate_posteriors(ranking_data, zerolag_ln_likelihood_ratios, restrict_to_instruments = options.restrict_to_instruments, progressbar = progressbar, chain_file = h5py.File(options.chain_file) if options.chain_file is not None else None)
del progressbar


#
# find confidence intervals
#


if options.confidence_intervals:
	if options.verbose:
		print >>sys.stderr, "determining confidence intervals ..."
	confidence_intervals = dict((conf, rate_estimation.confidence_interval_from_binnedarray(signal_rate_pdf, conf)) for conf in (0.95, 0.999999))
else:
	confidence_intervals = None
if options.verbose and confidence_intervals is not None:
	print >>sys.stderr, "rate posterior mean = %g signals/experiment" % rate_estimation.mean_from_pdf(signal_rate_pdf)
	# all modes are the same, pick one and report it
	print >>sys.stderr, "maximum-likelihood rate = %g signals/experiment" % confidence_intervals.values()[0][0]
	for conf, (mode, lo, hi) in sorted(confidence_intervals.items()):
		print >>sys.stderr, "%g%% confidence interval = [%g, %g] signals/experiment" % (conf * 100., lo, hi)


#
# save results
#


filename = "rate_posteriors.png"
if options.verbose:
	print >>sys.stderr, "writing %s ..." % filename
plot_rates(signal_rate_pdf, noise_rate = noise_rate_pdf if options.with_background else None, confidence_intervals = confidence_intervals, restrict_to_instruments = options.restrict_to_instruments).savefig(filename)

if options.verbose:
	print >>sys.stderr, "done"
