#!/usr/bin/env python
#
# Copyright (C) 2013 Chad Hanna
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

import sys, os
import matplotlib
matplotlib.use('Agg')
import pylab
from gstlal import far
import copy
import numpy
from optparse import OptionParser
from glue.ligolw import ligolw
from glue.ligolw import utils
from glue.ligolw import lsctables
from glue.ligolw import array
from glue.ligolw import param
array.use_in(ligolw.LIGOLWContentHandler)
param.use_in(ligolw.LIGOLWContentHandler)
lsctables.use_in(ligolw.LIGOLWContentHandler)

def parse_command_line():
	parser = OptionParser()
	parser.add_option("-b", "--base", metavar = "base", default = "cbc_plotsummary_", help = "Set the prefix for output filenames (default = \"cbc_plotsummary_\").")
	parser.add_option("--marginalized-file", metavar = "file", help = "Set the marginalized likelihood input file")
	parser.add_option("-v", "--verbose", action = "store_true", help = "Be verbose.")
	options, filenames = parser.parse_args()

	return options, filenames

options, files = parse_command_line()

total_count = {'H1':0., 'L1':0., 'V1':0.}

for f in files:
	Far = far.LocalRankingData.from_filenames([f], verbose = options.verbose)
	dists, segs = Far.distribution_stats, Far.livetime_seg
	counts = dists.raw_distributions.background_rates
	inj = dists.raw_distributions.injection_rates
	likely = copy.deepcopy(inj)

	for i, ifo in enumerate(['H1','L1', 'V1']):	
		likely[ifo+"_snr_chi"].array /= counts[ifo+"_snr_chi"].array
		total_count[ifo] += counts[ifo+"_snr_chi"].array.sum()
		for name, obj in (("background", counts), ("injections", inj), ("likelihood", likely)):
			fig = pylab.figure(figsize=(6,5), facecolor = 'g')
			fig.patch.set_alpha(0.0)
			data = obj[ifo+"_snr_chi"].array
			snr = obj[ifo+"_snr_chi"].bins[0].centres()[1:-1]
			chi = obj[ifo+"_snr_chi"].bins[1].centres()[1:-1]
			chi[0] = 0 # not inf
			ax = pylab.subplot(111)
			pylab.pcolormesh(snr, chi, numpy.log10(data.T +1)[1:-1,1:-1])
			if "Log" in str(obj[ifo+"_snr_chi"].bins[0]):
				ax.set_xscale('log')
			if "Log" in str(obj[ifo+"_snr_chi"].bins[1]):
				ax.set_yscale('log')
			pylab.colorbar()
			pylab.xlabel('SNR')
			pylab.ylabel('reduced chi^2 / SNR^2')
			pylab.ylim([chi[1], chi[-1]])
			pylab.xlim([snr[1],snr[-1]])
			pylab.title('%s: %s log base 10 (number + 1)' % (ifo, name))
			pylab.grid(color=(0.1,0.4,0.5), linewidth=2)
			pylab.savefig("%s%s-%s-%s.png" % (options.base, ifo, name, os.path.split(f)[1]))


if options.marginalized_file:
	marg, procid = far.RankingData.from_xml(utils.load_filename(options.marginalized_file, contenthandler = ligolw.LIGOLWContentHandler, verbose = options.verbose))
	marg.compute_joint_cdfs()

	for k in marg.joint_likelihood_pdfs:
		N = sum(total_count[ifo] for ifo in k)
		#FIXME hardcoded num_slides to 2, don't do that!
		m = marg.trials_table[k].count * marg.trials_table[k].count_below_thresh / marg.trials_table[k].thresh / float(abs(marg.livetime_seg)) * marg.trials_table.num_nonzero_count() / 2
		pylab.figure()
		larray = numpy.logspace(max(numpy.log10(marg.minrank[k]), -3), 0.99 * numpy.log10(marg.maxrank[k]), 1000)
		FAP = 1.0 - marg.cdf_interpolator[k](larray)
		FAPM = [far.fap_after_trials(p, m) for p in  FAP]
		pylab.semilogy(numpy.log(larray), FAPM, 'k')
		trials_high_error = FAPM + FAPM / numpy.sqrt(N*FAP)
		trials_low_error = FAPM - FAPM / numpy.sqrt(N*FAP)
		trials_low_error[trials_low_error < 1e-8] = 1e-8
		pylab.fill_between(numpy.log(larray), trials_low_error, trials_high_error, alpha=0.25, color='k')	
		pylab.grid()
		pylab.title("%s FAP vs Likelihood" % "".join(sorted(list(k))))
		pylab.xlabel('ln(Likelihood)')
		pylab.ylabel('False Alarm Probability')
		pylab.ylim([1e-8, 1])
		pylab.savefig("%s%s-marginalized_likelihood_cdf.png" % (options.base, "".join(sorted(list(k)))))


