#!/usr/bin/env python
#
# Copyright (C) 2013 Chad Hanna, Kipp Cannon
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

## @file gstlal_inspiral_plot_background
# A program to plot the likelihood distributions in noise of a gstlal inspiral analysis
#
# ### Command line interface
#
#	+ `--database` [filename]: Retrieve search results from this database (optional).  Can be given multiple times.
#	+ `--database-cache` [filename]: Retrieve search results from all databases in this LAL cache (optional).  See lalapps_path2cache.
#	+ `--max-snr` [value] (float): Plot SNR PDFs up to this value of SNR (default = 200).
#	+ `--max-log-lambda` [value] (float): Plot ranking statistic CDFs, etc., up to this value of the natural logarithm of the likelihood ratio (default = 40).
#	+ `--min-log-lambda` [value] (float): Plot ranking statistic CDFs, etc., down to this value of the natural logarithm of the likelihood ratio (default = -5).
#	+ `--output-dir` [path]: Write output to this directory (default = ".").
#	+ `--tmp-space` [path]: Path to a directory suitable for use as a work area while manipulating the database file.  The database file will be worked on in this directory, and then moved to the final location when complete.  This option is intended to improve performance when running in a networked environment, where there might be a local disk with higher bandwidth than is available to the filesystem on which the final output will reside.
#	+ `--user-tag` [tag]: Set the adjustable component of the description fields in the output filenames (default = "ALL").
#	+ `--verbose`: Be verbose.

import itertools
import math
import matplotlib
matplotlib.rcParams.update({
	"font.size": 10.0,
	"axes.titlesize": 10.0,
	"axes.labelsize": 10.0,
	"xtick.labelsize": 8.0,
	"ytick.labelsize": 8.0,
	"legend.fontsize": 8.0,
	"figure.dpi": 600,
	"savefig.dpi": 600,
	"text.usetex": True
})
from matplotlib import figure
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
import numpy
from optparse import OptionParser
import sqlite3
import sys


from glue.lal import CacheEntry
from glue.ligolw import dbtables
from glue.ligolw import lsctables
from glue.ligolw import utils as ligolw_utils
from glue import segmentsUtils
from pylal import ligolw_thinca


from gstlal import far
from gstlal import plotfar
from gstlal import inspiral_pipe


def parse_command_line():
	parser = OptionParser()
	parser.add_option("-d", "--database", metavar = "filename", action = "append", help = "Retrieve search results from this database (optional).  Can be given multiple times.")
	parser.add_option("-c", "--database-cache", metavar = "filename", help = "Retrieve search results from all databases in this LAL cache (optional).  See lalapps_path2cache.")
	parser.add_option("--max-snr", metavar = "SNR", default = 200., type = "float", help = "Plot SNR PDFs up to this value of SNR (default = 200).")
	parser.add_option("--max-log-lambda", metavar = "value", default = 40., type = "float", help = "Plot ranking statistic CCDFs, etc., up to this value of the natural logarithm of the likelihood ratio (default = 40).")
	parser.add_option("--min-log-lambda", metavar = "value", default = -5., type = "float", help = "Plot ranking statistic CCDFs, etc., down to this value of the natural logarithm of the likelihood ratio (default = -10).")
	parser.add_option("--scatter-log-lambdas", metavar = "[low]:[high][,[low]:[high]...]", help = "Overlay scatter plots of candidate parameter co-ordinates on various PDF plots, limiting the candidates to those whose log likelihood ratios are in the given range (default = don't overlay scatter plot).  Use a range of \":\" to plot all candidates.")
	parser.add_option("--output-dir", metavar = "output-dir", default = ".", help = "Write output to this directory (default = \".\").")
	parser.add_option("--output-format", metavar = "extension", default = ".png", help = "Select output format by choosen the filename extension (default = \".png\").")
	parser.add_option("-t", "--tmp-space", metavar = "path", help = "Path to a directory suitable for use as a work area while manipulating the database file.  The database file will be worked on in this directory, and then moved to the final location when complete.  This option is intended to improve performance when running in a networked environment, where there might be a local disk with higher bandwidth than is available to the filesystem on which the final output will reside.")
	parser.add_option("--threshold", metavar = "log likelihood ratio", type = "float", help = "When plotting ranking statistic PDFs and CCDFs, discard everything below this threshold and renormalize accordingly (default = plot PDFs and CCDFs as found).")
	parser.add_option("--user-tag", metavar = "user-tag", default = "ALL", help = "Set the adjustable component of the description fields in the output filenames (default = \"ALL\").")
	parser.add_option("--dont-apply-extinction", action = "store_true", help = "Don't apply the extinction model")
	parser.add_option("-v", "--verbose", action = "store_true", help = "Be verbose.")
	options, filenames = parser.parse_args()

	if options.database_cache is not None:
		if options.database is None:
			options.database = []
		options.database += [CacheEntry(line).path for line in open(options.database_cache)]

	if options.output_format not in (".png", ".pdf", ".svg"):
		raise ValueError("invalid --output-format \"%s\"" % options.output_format)

	if options.scatter_log_lambdas is not None:
		options.scatter_log_lambdas = segmentsUtils.from_range_strings(options.scatter_log_lambdas.split(","), boundtype = float).coalesce()

	options.user_tag = options.user_tag.upper()

	return options, filenames


def load_distributions(filenames, ln_likelihood_ratio_threshold = None, verbose = False):
	coinc_params_distributions, ranking_data, seglists = None, None, None
	for n, filename in enumerate(filenames, 1):
		if verbose:
			print >>sys.stderr, "%d/%d:" % (n, len(filenames)),
		this_coinc_params_distributions, this_ranking_data, this_seglists = far.parse_likelihood_control_doc(ligolw_utils.load_filename(filenames[0], contenthandler = far.ThincaCoincParamsDistributions.LIGOLWContentHandler, verbose = verbose))
		if this_coinc_params_distributions is None and this_ranking_data is None:
			raise ValueError("%s contains no parameter distribution data" % filename)
		if coinc_params_distributions is None:
			coinc_params_distributions = this_coinc_params_distributions
		elif this_coinc_params_distributions is not None:
			coinc_params_distributions += this_coinc_params_distributions
		if ranking_data is None:
			ranking_data = this_ranking_data
		elif this_ranking_data is not None:
			ranking_data += this_ranking_data
		if seglists is None:
			seglists = this_seglists
		else:
			seglists |= this_seglists
	if coinc_params_distributions is None and ranking_data is None:
		raise ValueError("no parameter distribution data loaded")
	if coinc_params_distributions is not None:
		coinc_params_distributions.finish(verbose = verbose)
	if ranking_data is not None and not options.dont_apply_extinction:
		# FIXME signal PDFs are still not right.
		ranking_data, _ = ranking_data.new_with_extinction(ranking_data)
	ranking_data.finish(verbose = verbose)
	# FIXME:  hack to fake some segments so we can generate file names
	# if the input files didn't provide segment information
	try:
		seglists.extent_all()
	except ValueError:
		seglists[None] = lsctables.segments.segmentlist([lsctables.segments.segment(0, 0)])
	return coinc_params_distributions, ranking_data, seglists


def load_search_results(filenames, ln_likelihood_ratio_threshold = None, tmp_path = None, verbose = False):
	background_ln_likelihood_ratios = []
	zerolag_ln_likelihood_ratios = []
	background_sngls = []
	zerolag_sngls = []

	for n, filename in enumerate(filenames, 1):
		if verbose:
			print >>sys.stderr, "%d/%d: %s" % (n, len(filenames), filename)
		working_filename = dbtables.get_connection_filename(filename, tmp_path = tmp_path, verbose = verbose)
		connection = sqlite3.connect(working_filename)

		xmldoc = dbtables.get_xml(connection)
		definer_id = lsctables.CoincDefTable.get_table(xmldoc).get_coinc_def_id(ligolw_thinca.InspiralCoincDef.search, ligolw_thinca.InspiralCoincDef.search_coinc_type, create_new = False)

		cursor = connection.cursor()

		# subclass to allow it to carry extra attributes
		class snglsdict(dict):
			__slots__ = ("ln_likelihood_ratio", "fap")

		for coinc_event_id, rows in itertools.groupby(cursor.execute("""
SELECT
	coinc_event_map.coinc_event_id,
	coinc_event.likelihood,
	coinc_inspiral.false_alarm_rate,
	EXISTS (
		SELECT
			*
		FROM
			time_slide
		WHERE
			time_slide.time_slide_id == coinc_event.time_slide_id
			AND time_slide.offset != 0
	),
	sngl_inspiral.ifo,
	sngl_inspiral.snr,
	sngl_inspiral.chisq
FROM
	sngl_inspiral
	JOIN coinc_event_map ON (
		coinc_event_map.table_name == 'sngl_inspiral'
		AND coinc_event_map.event_id == sngl_inspiral.event_id
	)
	JOIN coinc_event ON (
		coinc_event.coinc_event_id == coinc_event_map.coinc_event_id
	)
	JOIN coinc_inspiral ON (
		coinc_inspiral.coinc_event_id == coinc_event.coinc_event_id
	)
WHERE
	coinc_event.coinc_def_id == ?
	AND (? ISNULL OR coinc_event.likelihood >= ?)
ORDER BY
	coinc_event_map.coinc_event_id
		""", (definer_id, ln_likelihood_ratio_threshold, ln_likelihood_ratio_threshold)), lambda row: row[0]):
			rows = list(rows)
			is_background = rows[0][3]
			# {instrument: (snr, chisq)} dictionary
			sngls = snglsdict((row[4], (row[5], row[6])) for row in rows)
			sngls.ln_likelihood_ratio = rows[0][1]
			sngls.fap = rows[0][2]
			if is_background:
				background_sngls.append(sngls)
			else:
				zerolag_sngls.append(sngls)

		cursor.close()
		connection.close()
		dbtables.discard_connection_filename(filename, working_filename, verbose = verbose)

	background_ln_likelihood_ratios = sorted((sngls.ln_likelihood_ratio, sngls.fap) for sngls in background_sngls)
	zerolag_ln_likelihood_ratios = sorted((sngls.ln_likelihood_ratio, sngls.fap) for sngls in zerolag_sngls)

	return background_ln_likelihood_ratios, zerolag_ln_likelihood_ratios, background_sngls, zerolag_sngls


class SnrChiColourNorm(matplotlib.colors.Normalize):
	def __call__(self, value, clip = None):
		value, is_scalar = self.process_value(value)
		numpy.clip(value, 1e-200, 1e+200, value)

		self.autoscale_None(value)

		vmin = math.log(self.vmin)
		vmax = math.log(self.vmax)
		xbar = (vmax + vmin) / 2.
		delta = (vmax - vmin) / 2.
		pi_2 = math.pi / 2.

		value = (numpy.arctan((numpy.log(value) - xbar) * pi_2 / delta) + pi_2) / math.pi
		return value[0] if is_scalar else value

	def inverse(self, value):
		vmin = math.log(self.vmin)
		vmax = math.log(self.vmax)
		xbar = (vmax + vmin) / 2.
		delta = (vmax - vmin) / 2.
		pi_2 = math.pi / 2.

		value = numpy.exp(numpy.tan(value * math.pi - pi_2) * delta / pi_2 + xbar)
		numpy.clip(value, 1e-200, 1e+200, value)
		return value

	def _autoscale(self, vmin, vmax):
		vmin = math.log(vmin)
		vmax = math.log(vmax)
		xbar = (vmax + vmin) / 2.
		delta = (vmax - vmin) / 2.
		xbar += 2. * delta / 3.
		delta /= 4.
		vmin = math.exp(xbar - delta)
		vmax = math.exp(xbar + delta)
		return vmin, vmax

	def autoscale(self, A):
		self.vmin, self.vmax = self._autoscale(numpy.ma.min(A), numpy.ma.max(A))

	def autoscale_None(self, A):
		vmin, vmax = self._autoscale(numpy.ma.min(A), numpy.ma.max(A))
		if self.vmin is None:
			self.vmin = vmin
		if self.vmax is None:
			self.vmax = vmax


#
# command line
#


options, filenames = parse_command_line()


#
# load input
#


coinc_params_distributions, ranking_data, seglists = load_distributions(filenames, ln_likelihood_ratio_threshold = options.threshold, verbose = options.verbose)
if ranking_data is not None:
	fapfar = far.FAPFAR(ranking_data, livetime = far.get_live_time(seglists))
else:
	fapfar = None


if options.database:
	background_ln_likelihood_ratios, zerolag_ln_likelihood_ratios, background_sngls, zerolag_sngls = load_search_results(options.database, ln_likelihood_ratio_threshold = options.threshold, tmp_path = options.tmp_space, verbose = options.verbose)
	if options.scatter_log_lambdas is not None:
		sngls = [sngl for sngl in background_sngls + zerolag_sngls if sngl.ln_likelihood_ratio in options.scatter_log_lambdas]
	else:
		sngls = None
else:
	background_ln_likelihood_ratios, zerolag_ln_likelihood_ratios, sngls = None, None, None


#
# plots
#

# SNR and Chisquared
for instrument in ("H1","L1","V1"):
	for snr_chi_type in ("background_pdf", "injection_pdf", "zero_lag_pdf", "LR"):
		fig = plotfar.plot_snr_chi_pdf(coinc_params_distributions, instrument, snr_chi_type, options.max_snr, sngls = sngls)
		if fig is None:
			continue
		plotname = inspiral_pipe.T050017_filename(instrument, "GSTLAL_INSPIRAL_PLOT_BACKGROUND_%s_%s_SNRCHI2" % (options.user_tag, snr_chi_type.upper()), int(seglists.extent_all()[0]), int(seglists.extent_all()[1]), options.output_format, path = options.output_dir)
		if options.verbose:
			print >>sys.stderr, "writing %s" % plotname
		fig.savefig(plotname)


# Trigger and event rates
fig = plotfar.plot_rates(coinc_params_distributions, ranking_data)
plotname = inspiral_pipe.T050017_filename("H1L1V1", "GSTLAL_INSPIRAL_PLOT_BACKGROUND_%s_RATES" % options.user_tag, int(seglists.extent_all()[0]), int(seglists.extent_all()[1]), options.output_format, path = options.output_dir)
if options.verbose:
	print >>sys.stderr, "writing %s" % plotname
fig.savefig(plotname)


# SNR PDFs
for (instruments, horizon_distances) in sorted(coinc_params_distributions.snr_joint_pdf_cache.keys(), key = lambda (a, horizon_distances): sorted(horizon_distances)):
	horizon_distances = dict(horizon_distances)	# stored as frozen set of key/value pairs in cache key
	fig = plotfar.plot_snr_joint_pdf(coinc_params_distributions, instruments, horizon_distances, options.max_snr, sngls = sngls)
	if fig is not None:
		plotname = inspiral_pipe.T050017_filename(instruments, "GSTLAL_INSPIRAL_PLOT_BACKGROUND_%s_SNR_PDF_%s" % (options.user_tag, "_".join(["%s_%s" % (k, horizon_distances[k]) for k in sorted(horizon_distances)]) ), int(seglists.extent_all()[0]), int(seglists.extent_all()[1]), options.output_format, path = options.output_dir)
		if options.verbose:
			print >>sys.stderr, "writing %s" % plotname
		fig.savefig(plotname)


# ranking statistic PDFs and CCDFs
if ranking_data is not None:
	for instruments, binnedarray in ranking_data.background_likelihood_pdfs.items():
		fig = plotfar.plot_likelihood_ratio_pdf(ranking_data, instruments, (options.min_log_lambda, options.max_log_lambda), "Noise", binnedarray_string = "background_likelihood_pdfs")
		plotname = inspiral_pipe.T050017_filename(instruments or "COMBINED", "GSTLAL_INSPIRAL_PLOT_BACKGROUND_%s_NOISE_LIKELIHOOD_RATIO_PDF" % options.user_tag, int(seglists.extent_all()[0]), int(seglists.extent_all()[1]), options.output_format, path = options.output_dir)
		if options.verbose:
			print >>sys.stderr, "writing %s" % plotname
		fig.savefig(plotname)

	fig = plotfar.plot_likelihood_ratio_ccdf(fapfar, (options.min_log_lambda, options.max_log_lambda), zerolag_ln_likelihood_ratios = zerolag_ln_likelihood_ratios)
	plotname = inspiral_pipe.T050017_filename("COMBINED", "GSTLAL_INSPIRAL_PLOT_BACKGROUND_%s_NOISE_LIKELIHOOD_RATIO_CCDF_openbox" % options.user_tag, int(seglists.extent_all()[0]), int(seglists.extent_all()[1]), options.output_format, path = options.output_dir)
	if options.verbose:
		print >>sys.stderr, "writing %s" % plotname
	fig.savefig(plotname)

	fig = plotfar.plot_likelihood_ratio_ccdf(fapfar, (options.min_log_lambda, options.max_log_lambda), zerolag_ln_likelihood_ratios = background_ln_likelihood_ratios)
	plotname = inspiral_pipe.T050017_filename("COMBINED", "GSTLAL_INSPIRAL_PLOT_BACKGROUND_%s_NOISE_LIKELIHOOD_RATIO_CCDF_closedbox" % options.user_tag, int(seglists.extent_all()[0]), int(seglists.extent_all()[1]), options.output_format, path = options.output_dir)
	if options.verbose:
		print >>sys.stderr, "writing %s" % plotname
	fig.savefig(plotname)
