#!/usr/bin/python
#
# Copyright (C) 2009  Kipp Cannon
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#


import matplotlib
matplotlib.rcParams.update({
	"font.size": 8.0,
	"axes.titlesize": 10.0,
	"axes.labelsize": 10.0,
	"xtick.labelsize": 8.0,
	"ytick.labelsize": 8.0,
	"legend.fontsize": 8.0,
	"figure.dpi": 300,
	"savefig.dpi": 300,
	"text.usetex": True
})
from matplotlib import figure
from matplotlib import patches
from matplotlib import colorbar
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
import numpy
from optparse import OptionParser
import sys


from glue.ligolw import ligolw
from glue.ligolw import lsctables
from glue.ligolw import utils
from pylal import git_version
from pylal import rate


lsctables.use_in(ligolw.LIGOLWContentHandler)


__author__ = "Kipp Cannon <kipp.cannon@ligo.org>"
__version__ = "git id %s" % git_version.id
__date__ = git_version.date


#
# =============================================================================
#
#                                 Command Line
#
# =============================================================================
#


def parse_command_line():
	parser = OptionParser(
		version = "Name: %%prog\n%s" % git_version.verbose_msg
	)
	parser.add_option("-v", "--verbose", action = "store_true", help = "Be verbose.")
	parser.add_option("-s", "--smooth-bins", action="store_true", help="Perform smoothing.")
	options, filenames = parser.parse_args()

	return options, (filenames or [None])


#
# =============================================================================
#
#                                     Blah
#
# =============================================================================
#


def binned_ratios_extractor(xmldoc):
	for name in [elem.Name.replace(u":pylal_rate_binnedratios", u"") for elem in xmldoc.getElementsByTagName(ligolw.LIGO_LW.tagName) if elem.hasAttribute("Name") and "pylal_rate_binnedratios" in elem.Name]:
		yield name, rate.BinnedRatios.from_xml(xmldoc, name)


def binned_ratios_plot(title, xcoords, ycoords, data, ncontours = 9, xlims = None, ylims = None, datamin = -200, datamax = 200):
	fig = figure.Figure(figsize=(4,3))
	FigureCanvas(fig)
	axes = fig.add_axes((.15,.15,.55,.75))
	# FIXME: should extract the binned array type some other way
	if 'mtotal' in title:
		axes.semilogx()
	else:
		axes.loglog()

	data[data < datamin] = datamin
	data[data > datamax] = datamax

	cset = axes.contour(xcoords, ycoords, data, ncontours)
	cbar = fig.add_axes((.75,.15,.1,.75))
	colorbar.Colorbar(cbar, cset)
	axes.set_xlim(xlims)
	axes.set_ylim(ylims)
	axes.set_title(title)
	# FIXME: should extract the name of the dimensions some other way
	axes.set_ylabel(title.split('-vs-')[0].split('1-')[-1])
	axes.set_xlabel(title.split('-vs-')[1].split(' ')[0])
	return fig


#
# =============================================================================
#
#                                     Main
#
# =============================================================================
#


options, filenames = parse_command_line()


saved_ratios = {}
for n, filename in enumerate(filenames):
	if options.verbose:
		print >>sys.stderr, "%d/%d:" % (n + 1, len(filenames)),
	xmldoc = utils.load_filename(filename, verbose = options.verbose, contenthandler = ligolw.LIGOLWContentHandler)
	for name, ratios in binned_ratios_extractor(xmldoc):
		if name not in saved_ratios.keys():
			saved_ratios[name] = ratios
		else:
			saved_ratios[name] += ratios

for name, ratios in saved_ratios.items():
	# FIXME: currently only plots 2D binned ratios
	if '-vs-' not in name:
		continue

	if options.smooth_bins:
		if options.verbose:
			print >>sys.stderr, "smoothing %s ..." % name
		window_length = 30
		rate.filter_binned_ratios(ratios, rate.gaussian_window(ratios.numerator.bins.shape[0]/window_length, ratios.numerator.bins.shape[1]/window_length, sigma=2*window_length))

	ratios.to_pdf()

	xcoords, ycoords = ratios.numerator.bins.centres()

	outname = "%s_num.png" % name
	if options.verbose:
		print >>sys.stderr, "writing %s ..." % outname
	fig = binned_ratios_plot("%s Numerator" % name.replace("_", "-"), xcoords, ycoords, numpy.transpose(numpy.log10(ratios.numerator.array)), ncontours=49, xlims=(ratios.numerator.bins.min[0],ratios.numerator.bins.max[0]), ylims=(ratios.numerator.bins.min[1],ratios.numerator.bins.max[1]), datamin=-20, datamax=0)
	fig.savefig(outname)

	outname = "%s_den.png" % name
	if options.verbose:
		print >>sys.stderr, "writing %s ..." % outname
	fig = binned_ratios_plot("%s Denominator" % name.replace("_", "-"), xcoords, ycoords, numpy.transpose(numpy.log10(ratios.denominator.array)), ncontours=49, xlims=(ratios.numerator.bins.min[0],ratios.numerator.bins.max[0]), ylims=(ratios.numerator.bins.min[1],ratios.numerator.bins.max[1]), datamin=-20, datamax=0)
	fig.savefig(outname)

	ratios.logregularize()

	outname = "%s_ratio.png" % name
	if options.verbose:
		print >>sys.stderr, "writing %s ..." % outname
	ratios.logregularize()
	fig = binned_ratios_plot("%s Likelihood Ratio" % name.replace("_", "-"), xcoords, ycoords, numpy.transpose(numpy.log10(ratios.ratio())), ncontours = 49, xlims=(ratios.numerator.bins.min[0],ratios.numerator.bins.max[0]), ylims=(ratios.numerator.bins.min[1],ratios.numerator.bins.max[1]), datamin=-20, datamax=20)
	fig.savefig(outname)
