#!/usr/bin/env python
#
# Copyright (C) 2013  Kipp Cannon
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#


import logging
import math
import matplotlib
matplotlib.rcParams.update({
	"font.size": 8.0,
	"axes.titlesize": 10.0,
	"axes.labelsize": 10.0,
	"xtick.labelsize": 8.0,
	"ytick.labelsize": 8.0,
	"legend.fontsize": 8.0,
	"figure.dpi": 100,
	"savefig.dpi": 100,
	"text.usetex": True,
	"path.simplify": True
})
from matplotlib import figure
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
import numpy
from optparse import OptionParser
import os.path
import StringIO
import sys
import time
import urlparse


from glue.ligolw import ligolw
from glue.ligolw import array
from glue.ligolw import param
from glue.ligolw import lsctables
from glue.ligolw import utils
from gstlal.reference_psd import horizon_distance
try:
	from ligo.gracedb import cli as gracedb
except ImportError:
	from ligo import gracedb
from ligo.lvalert.utils import get_LVAdata_from_stdin
from pylal import series as lal_series


golden_ratio = (1 + math.sqrt(5)) / 2


class ContentHandler(ligolw.LIGOLWContentHandler):
	pass
array.use_in(ContentHandler)
param.use_in(ContentHandler)
lsctables.use_in(ContentHandler)


#
# =============================================================================
#
#                                   Library
#
# =============================================================================
#


def get_filename(gracedb_client, graceid, filename, retries = 3, retry_delay = 10.0, ignore_404 = False):
	for i in range(retries):
		logging.info("retrieving \"%s\" for %s" % (filename, graceid))
		response = gracedb_client.rest("/event/%s/files/%s" % (graceid, filename))
		if response.status == 200:
			return response
		if response.status == 404 and ignore_404:
			logging.warning("retrieving \"%s\" for %s: (%d) %s.  skipping ..." % (filename, graceid, response.status, response.reason))
			return None
		logging.warning("retrieving \"%s\" for %s: (%d) %s.  pausing ..." % (filename, graceid, response.status, response.reason))
		time.sleep(retry_delay)
	raise Exception("retrieving \"%s\" for %s: (%d) %s" % (filename, graceid, response.status, response.reason))


def get_psds(gracedb_client, graceid, filename = "psd.xml.gz", ignore_404 = False):
	response = get_filename(gracedb_client, graceid, filename = filename, ignore_404 = ignore_404)
	if response is None:
		return None
	return lal_series.read_psd_xmldoc(utils.load_fileobj(response, ContentHandler)[0])


def get_coinc_xmldoc(gracedb_client, graceid, filename = "coinc.xml"):
	return utils.load_fileobj(get_filename(gracedb_client, graceid, filename = filename), ContentHandler)[0]


def plot_psds(psds, coinc_xmldoc, plot_width = 640, colours = {"H1": "r", "H2": "b", "L1": "g", "V1": "m"}):
	coinc_event, = lsctables.table.get_table(coinc_xmldoc, lsctables.CoincTable.tableName)
	coinc_inspiral, = lsctables.table.get_table(coinc_xmldoc, lsctables.CoincInspiralTable.tableName)
	offset_vector = lsctables.table.get_table(coinc_xmldoc, lsctables.TimeSlideTable.tableName).as_dict()[coinc_event.time_slide_id]
	process, = lsctables.table.get_table(coinc_xmldoc, lsctables.ProcessTable.tableName)
	sngl_inspirals = dict((row.ifo, row) for row in lsctables.table.get_table(coinc_xmldoc, lsctables.SnglInspiralTable.tableName))

	mass1 = sngl_inspirals.values()[0].mass1
	mass2 = sngl_inspirals.values()[0].mass2
	end_time = coinc_inspiral.get_end()
	logging.info("%g Msun -- %g Msun event in %s at %.2f GPS" % (mass1, mass2, ", ".join(sorted(sngl_inspirals)), float(end_time)))

	fig = figure.Figure()
	FigureCanvas(fig)
	fig.set_size_inches(plot_width / float(fig.get_dpi()), int(round(plot_width / golden_ratio)) / float(fig.get_dpi()))
	axes = fig.gca()
	axes.grid(True)

	min_psds, max_psds = [], []
	for instrument, psd in sorted(psds.items()):
		if psd is None:
			continue
		psd_data = psd.data
		f = psd.f0 + numpy.arange(len(psd_data)) * psd.deltaF
		logging.info("found PSD for %s spanning [%g Hz, %g Hz]" % (instrument, f[0], f[-1]))
		axes.loglog(f, psd_data, color = colours[instrument], alpha = 0.8, label = "%s (%.4g Mpc)" % (instrument, horizon_distance(psd, mass1, mass2, 8, 10)))
		if instrument in sngl_inspirals:
			logging.info("found %s event with SNR %g" % (instrument, sngl_inspirals[instrument].snr))
			inspiral_spectrum = [None, None]
			horizon_distance(psd, mass1, mass2, sngl_inspirals[instrument].snr, 10, inspiral_spectrum = inspiral_spectrum)
			axes.loglog(inspiral_spectrum[0], inspiral_spectrum[1], color = colours[instrument], dashes = (5, 2), alpha = 0.8, label = "SNR = %.3g" % sngl_inspirals[instrument].snr)
		# record the minimum from within the rage 10 Hz -- 1 kHz
		min_psds.append(psd_data[int((10.0 - psd.f0) / psd.deltaF) : int((1000 - psd.f0) / psd.deltaF)].min())
		# record the maximum from within the rage 1 Hz -- 1 kHz
		max_psds.append(psd_data[int((1.0 - psd.f0) / psd.deltaF) : int((1000 - psd.f0) / psd.deltaF)].max())

	axes.set_xlim((1.0, 3000.0))
	if min_psds:
		axes.set_ylim((10**math.floor(math.log10(min(min_psds))), 10**math.ceil(math.log10(max(max_psds)))))
	axes.set_title(r"Strain Noise Spectral Density for $%.3g\,\mathrm{M}_{\odot}$--$%.3g\,\mathrm{M}_{\odot}$ Merger at %.2f GPS" % (mass1, mass2, float(end_time)))
	axes.set_xlabel(r"Frequency (Hz)")
	axes.set_ylabel(r"Spectral Density ($\mathrm{strain}^2 / \mathrm{Hz}$)")
	axes.legend(loc = "lower left")

	return fig


def upload_fig(fig, gracedb_client, graceid, filename = "psd.png"):
	plotfile = StringIO.StringIO()
	fig.savefig(plotfile, format = os.path.splitext(filename)[-1][1:])
	logging.info("uploading \"%s\" for %s" % (filename, graceid))
	response = gracedb_client.upload(str(graceid), filename, filecontents = plotfile.getvalue(), comment = "strain spectral densities")
	if "error" in response:
		raise Exception("upload of \"%s\" for %s failed: %s" % (filename, graceid, response["error"]))


#
# =============================================================================
#
#                                 Command Line
#
# =============================================================================
#


def parse_command_line():
	parser = OptionParser()
	parser.add_option("--no-upload", action = "store_true", help = "Write plots to disk.")
	parser.add_option("--skip-404", action = "store_true", help = "Skip events that give 404 (file not found) errors (default is to abort).")
	parser.add_option("-v", "--verbose", action = "store_true", help = "Be verbose.")
	options, graceids = parser.parse_args()

	if not graceids:
		# FIXME:  lvalert_listen doesn't allow command-line options
		options.verbose = True

	# can only call basicConfig once (otherwise need to switch to more
	# complex logging configuration)
	if options.verbose:
		logging.basicConfig(format = "%(asctime)s:%(message)s", level = logging.INFO)
	else:
		logging.basicConfig(format = "%(asctime)s:%(message)s")

	return options, graceids


#
# =============================================================================
#
#                                     Main
#
# =============================================================================
#


options, graceids = parse_command_line()


if not graceids:
	lvalert_data = get_LVAdata_from_stdin(sys.stdin, as_dict = True)
	logging.info("%(alert_type)s-type alert for event %(uid)s" % lvalert_data)
	logging.info("lvalert data: %s" % repr(lvalert_data))
	filename = os.path.split(urlparse.urlparse(lvalert_data["file"]).path)[-1]
	if filename not in (u"coinc.xml",) and "_CBC_" not in filename:
		logging.info("filename is not 'coinc.xml'.  skipping")
		sys.exit()
	graceids = [str(lvalert_data["uid"])]
	# pause to give the psd file a chance to get uploaded.
	time.sleep(8)


gracedb_client = gracedb.Client()


for graceid in graceids:
	psds = get_psds(gracedb_client, graceid, ignore_404 = options.skip_404)
	if psds is None:
		continue
	fig = plot_psds(psds, get_coinc_xmldoc(gracedb_client, graceid))
	if options.no_upload:
		filename = "psd_%s.png" % graceid
		logging.info("writing %s ..." % filename)
		fig.savefig(filename)
	else:
		upload_fig(fig, gracedb_client, graceid)
	logging.info("finished processing %s" % graceid)
