#!/usr/bin/env python
#
# Copyright (C) 2012  Madeline Wade
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================

"""
Generate frame files for LHO coherent data.  Generate a LIGO Light-Weight XML veto segment file from the LHO null stream.  If the null stream exceeds a threshold, this is marked as a veto segment.
"""

import os
import sys
import numpy

import gobject
gobject.threads_init()
import pygtk
pygtk.require('2.0')
import pygst
pygst.require('0.10')
import gst
from optparse import OptionParser

from gstlal import pipeparts
from gstlal import lloidparts
from gstlal import reference_psd
from gstlal import simplehandler
from gstlal import coherent_null
from gstlal import inspiral
from gstlal import datasource

from glue import segments
from glue.ligolw import ligolw
from glue.ligolw import array
from glue.ligolw import param
from glue.ligolw import lsctables
array.use_in(ligolw.LIGOLWContentHandler)
param.use_in(ligolw.LIGOLWContentHandler)
lsctables.use_in(ligolw.LIGOLWContentHandler)
from glue.ligolw import utils
from glue.ligolw.utils import segments as ligolw_segments
from glue.ligolw.utils import process as ligolw_process

from pylal.datatypes import LIGOTimeGPS
from pylal.xlal.datatypes.real8frequencyseries import REAL8FrequencySeries

__author__ = "Madeline Wade <madeline.wade@ligo.org>"

#
# =============================================================================
#
#                                 Command Line
#
# =============================================================================
#

parser = OptionParser()

datasource.append_options(parser)

parser.add_option("--write-pipeline", metavar = "filename", help = "Write a DOT graph description of the as-built pipeline to this file (optional).  The environment variable GST_DEBUG_DUMP_DOT_DIR must be set for this option to work.")
parser.add_option("--track-psd", action = "store_true", help = "Track the H1 and H2 psds.  The filters that create the coherent stream will be updated throughout the run.")
parser.add_option("--reference-psd", metavar = "filename", help = ".xml file containing the H1 and H2 psds. Use this psd as first guess psd or fixed psd.")
parser.add_option("--write-psd", metavar = "filename", help = "Measure psds and write to .xml file.  Use these psds as a first guess psd or fixed psd.")
parser.add_option("--null-output", metavar="filename", help = "Set the output filename for the null vetoes (default = write to stdout).  If the filename ends with \".gz\" it will be gzip-compressed.")
parser.add_option("--vetoes-name", metavar = "name", default = "vetoes_from_LHO_null", help = "Set the name of the vetoes segment lists in the output document (default = \"vetoes_from_LHO_null\").")
parser.add_option("--write-null-frames", action="store_true", help = "Write the null stream to frame files as well as out put a null veto segment list (optional).")
parser.add_option("--sample-rate", metavar = "Hz", default = 4096, type = "int", help = "Sample rate at which to generate the data, should be less than or equal to the sample rate of the measured psds provided, default = 4096 Hz")
parser.add_option("--verbose", action="store_true", help = "Be verbose.")


options, filenames = parser.parse_args()

if options.reference_psd is None and options.write_psd is None and options.track_psd is None:
	raise ValueError("must use --track-psd if --reference-psd or --write-psd are not given; you can use both simultaneously")

if len([option for option in ("reference_psd", "write_psd") if getattr(options, option) is not None]) != 1 and len([option for option in ("reference_psd", "write_psd") if getattr(options, option) is not None]) != 0:
	raise ValueError("must provide only one of --reference-psd or --write-psd")

gw_data_source_info = datasource.GWDataSourceInfo(options)

options.psd_fft_length = 8
options.zero_pad_length = 2
quality = 9

#
# =============================================================================
#
#                                   Handler Class
#
# =============================================================================
#

class COHhandler(simplehandler.Handler):
	
	def __init__(self, mainloop, pipeline):

		# set various properties for psd and fir filter		
		self.psd_fft_length = options.psd_fft_length
		self.zero_pad_length = options.zero_pad_length		
		self.srate = options.sample_rate
		self.filter_length = int(self.srate*self.psd_fft_length)

		# set default psds for H1 and H2
		self.psd1 = self.build_default_psd(self.srate, self.filter_length)
		self.psd2 = self.build_default_psd(self.srate, self.filter_length)

		# set default impulse response and latency for H1 and H2
		self.H1_impulse, self.H1_latency, self.H2_impulse, self.H2_latency, self.srate = coherent_null.psd_to_impulse_response(self.psd1, self.psd2)
		self.H1_impulse = numpy.zeros(len(self.H1_impulse))
		self.H2_impulse = numpy.zeros(len(self.H2_impulse))
		
		# psd1_change and psd2_change store when the psds have been updated
		self.psd1_change = 0
		self.psd2_change = 0

		# coherent null bin default
		self.cohnullbin = None

		self.segment_start = None
		self.segment_stop = None

		# regular handler pipeline and message processing
		self.mainloop = mainloop
		self.pipeline = pipeline
		bus = pipeline.get_bus()
		bus.add_signal_watch()
		bus.connect("message", self.on_message)
		
		# veto segment list
		self.vetoes = segments.segmentlist()

	def on_message(self, bus, message):
		if message.type == gst.MESSAGE_EOS:
			self.segment_EOS()
			self.pipeline.set_state(gst.STATE_NULL)
			self.mainloop.quit()
		elif message.type == gst.MESSAGE_INFO:
			gerr, dbgmsg = message.parse_info()
			print >>sys.stderr, "info (%s:%d '%s'): %s" % (gerr.domain, gerr.code, gerr.message, dbgmsg)
		elif message.type == gst.MESSAGE_WARNING:
			gerr, dbgmsg = message.parse_warning()
			print >>sys.stderr, "warning (%s:%d '%s'): %s" % (gerr.domain, gerr.code, gerr.message, dbgmsg)
		elif message.type == gst.MESSAGE_ERROR:
			gerr, dbgmsg = message.parse_error()
			self.pipeline.set_state(gst.STATE_NULL)
			self.mainloop.quit()
			sys.exit("error (%s:%d '%s'): %s" % (gerr.domain, gerr.code, gerr.message, dbgmsg))

	def build_default_psd(self, srate, filter_length):
		psd = REAL8FrequencySeries()
		psd.deltaF = float(srate)/filter_length
		psd.data = numpy.ones(filter_length/2 + 1)
		psd.f0 = 0.0
		return psd

	def add_cohnull_bin(self, elem):
		self.cohnullbin = elem
	
	def update_fir_filter(self):
		self.psd1_change = 0
		self.psd2_change = 0
		self.H1_impulse, self.H1_latency, self.H2_impulse, self.H2_latency, self.srate = coherent_null.psd_to_impulse_response(self.psd1, self.psd2)
		self.cohnullbin.set_property("block-stride", self.srate)
		self.cohnullbin.set_property("H1-impulse", self.H1_impulse)
		self.cohnullbin.set_property("H2-impulse", self.H2_impulse)
		self.cohnullbin.set_property("H1-latency", self.H1_latency)
		self.cohnullbin.set_property("H2-latency", self.H2_latency)		
		
	def update_psd1(self, elem):
		self.psd1 = REAL8FrequencySeries(
			name = "PSD1",
			f0 = 0.0,
			deltaF = elem.get_property("delta-f"),
			data = numpy.array(elem.get_property("mean-psd"))
		)
		self.psd1_change = 1

	def update_psd2(self, elem):
		self.psd2 = REAL8FrequencySeries(
			name = "PSD2",
			f0 = 0.0,
			deltaF = elem.get_property("delta-f"),
			data = numpy.array(elem.get_property("mean-psd"))
		)
		self.psd2_change = 1
	
	def fixed_filters(self, psd1, psd2):
		self.psd1 = REAL8FrequencySeries(
			name = "PSD1",
			f0 = 0.0,
			deltaF = psd1.deltaF,
			data = psd1.data)
		self.psd2 = REAL8FrequencySeries(
			name = "PSD2",
			f0 = 0.0,
			deltaF = psd2.deltaF,
			data = psd2.data)
		self.H1_impulse, self.H1_latency, self.H2_impulse, self.H2_latency, self.srate = coherent_null.psd_to_impulse_response(self.psd1, self.psd2)

	def segment_sighandler(self, elem, timestamp, segment_type):
		if (segment_type == "on"):
			self.segment_start = timestamp
		elif (segment_type == "off"):
			self.segment_stop = timestamp
		else:
			raise ValueError("unrecognized message from mkhtgate signal handler")
		if (self.segment_start is not None and segment_type == "off"):
			self.vetoes.append(segments.segment(lsctables.LIGOTimeGPS(0, self.segment_start), lsctables.LIGOTimeGPS(0, self.segment_stop)))

	def segment_EOS(self):
		if self.segment_start >= self.segment_stop or self.segment_stop is None:
			self.segment_stop = numpy.float(options.gps_end_time) * 1e9
			self.vetoes.append(segments.segment(lsctables.LIGOTimeGPS(0, self.segment_start), lsctables.LIGOTimeGPS(0, self.segment_stop)))

		vetoes = self.vetoes
		vetoes.coalesce()
		vetoes = segments.segmentlistdict.fromkeys(("H1H2",), vetoes)

		xmldoc = ligolw.Document()
		xmldoc.childNodes.append(ligolw.LIGO_LW())
		process = ligolw_process.register_to_xmldoc(xmldoc, "vetoes_from_LHO_null", options.__dict__)
		llwsegments = ligolw_segments.LigolwSegments(xmldoc)
		llwsegments.insert_from_segmentlistdict(vetoes, options.vetoes_name, comment = "Null vetoes")
		llwsegments.finalize(process)
		
		utils.write_filename(xmldoc, options.null_output, gz = (options.null_output or "stdout").endswith(".gz"), verbose = options.verbose)


#
# =============================================================================
#
#                                             Main
#
# =============================================================================
#

if options.reference_psd is not None:
	psd = reference_psd.read_psd_xmldoc(utils.load_filename(options.reference_psd, verbose = options.verbose, contenthandler = ligolw.LIGOLWContentHandler))
	psd1 = psd["H1"]
	psd2 = psd["H2"]
if options.write_psd is not None:
	psd1 = reference_psd.measure_psd(gw_data_source_info, instrument = "H1", rate = options.sample_rate, psd_fft_length = options.psd_fft_length, verbose = options.verbose)
	psd2 = reference_psd.measure_psd(gw_data_source_info, instrument = "H2", rate = options.sample_rate, psd_fft_length = options.psd_fft_length, verbose = options.verbose)
	reference_psd.write_psd(options.write_psd, { "H1" : psd1, "H2" : psd2 }, verbose = options.verbose)

#
# Functions called by signal handler
#	

#FIXME: Probably only want to update psds (and filters) when they have changed "significantly enough"
def update_psd(elem, pspec, hand):
	name = elem.get_property("name")
	if name == "H1_whitener":
		hand.update_psd1(elem)
	if name == "H2_whitener":
		hand.update_psd2(elem)
	if (hand.psd1_change == 1 and hand.psd2_change == 1):
		hand.update_fir_filter()

#
# begin pipline
#

pipeline = gst.Pipeline("coh_null_h1h2")
mainloop = gobject.MainLoop()
handler = COHhandler(mainloop, pipeline)

if options.reference_psd is not None or options.write_psd is not None:
	handler.fixed_filters(psd1, psd2)

#
# H1 branch
#

H1head = datasource.mkbasicsrc(pipeline, gw_data_source_info, "H1", verbose = options.verbose)

H1head = pipeparts.mkreblock(pipeline, H1head)
H1head = pipeparts.mkcapsfilter(pipeline, H1head, "audio/x-raw-float, width=64, rate=[%d,MAX]" % handler.srate)
H1head = pipeparts.mkresample(pipeline, H1head, quality = quality)
H1head = pipeparts.mkcapsfilter(pipeline, H1head, "audio/x-raw-float, width=64, rate=%d" % handler.srate)

# track psd
if options.track_psd is not None:
	H1psdtee = pipeparts.mktee(pipeline, H1head)
	H1psd = pipeparts.mkwhiten(pipeline, H1psd, zero_pad = handler.zero_pad_length, fft_length = handler.psd_fft_length, name = "H1_whitener") 
	H1psd.connect_after("notify::mean-psd", update_psd, handler)
	pipeparts.mkfakesink(pipeline, H1psd)

	H1head = H1psdtee

#
# H2 branch
#

H2head = datasource.mkbasicsrc(pipeline, gw_data_source_info, "H2", verbose = options.verbose)

H2head = pipeparts.mkreblock(pipeline, H2head)
H2head = pipeparts.mkcapsfilter(pipeline, H2head, "audio/x-raw-float, rate=[%d,MAX]" % handler.srate)
H2head = pipeparts.mkresample(pipeline, H2head, quality = quality)
H2head = pipeparts.mkcapsfilter(pipeline, H2head, "audio/x-raw-float, rate=%d" % handler.srate)

# track psd
if options.track_psd is not None:
	H2psdtee = pipeparts.mktee(pipeline, H2head)
	H2psd = pipeparts.mkwhiten(pipeline, H2psd, zero_pad = handler.zero_pad_length, fft_length = handler.psd_fft_length, name = "H2_whitener")
	H2psd.connect_after("notify::mean-psd", update_psd, handler)
	H2psd = pipeparts.mkfakesink(pipeline, H2psd)
	
	H2head = H2psdtee

#
# Create coherent and null streams
#

coherent_null_bin = pipeparts.mklhocoherentnull(pipeline, H1head, H2head, handler.H1_impulse, handler.H1_latency, handler.H2_impulse, handler.H2_latency, handler.srate)
if options.track_psd is not None:
	handler.add_cohnull_bin(coherent_null_bin)
cohhead = coherent_null_bin.get_pad("COHsrc")
nullhead = coherent_null_bin.get_pad("NULLsrc")

#
# Create coherent and null frames
#

cohhead = pipeparts.mkprogressreport(pipeline, cohhead, "progress_coherent")
cohhead = pipeparts.mktaginject(pipeline, cohhead, "instrument=H1H2, channel-name=LSC-STRAIN_HPLUS")
pipeparts.mkframecppfilesink(pipeline, pipeparts.mkframecppchannelmux(pipeline, {"H1H2:LSC-STRAIN_HPLUS" : cohhead}, frame_duration = 1, frames_per_file = 4096), frame_type = "H1H2")

# null veto
nullhead = pipeparts.mkprogressreport(pipeline, nullhead, "progress_null")
nullhead = pipeparts.mkwhiten(pipeline, nullhead, zero_pad = handler.zero_pad_length, fft_length = handler.psd_fft_length)

if options.write_null_frames is not None:
	nullhead = pipeparts.mktee(pipeline, nullhead)
	nullframehead = pipeparts.mktaginject(pipeline, pipeparts.mkqueue(pipeline, nullhead), "instrument=H1H2, channel-name=LSC-STRAIN_HNULL")
	pipeparts.mkframecppfilesink(pipeline, pipeparts.mkframecppchannelmux(pipeline, {"H1H2:LSC-STRAIN_HNULL" : nullframehead}, frame_duration = 1, frames_per_file = 4096), frame_type = "H1H2_NULL")

nullhead = datasource.mkhtgate(pipeline, nullhead, threshold = 8.0)
nullhead.set_property("emit-signals", True)
nullhead.connect("start", handler.segment_sighandler, "off") # start means the start of non-gap output
nullhead.connect("stop", handler.segment_sighandler, "on") # stop means the end of non-gap output
pipeparts.mkfakesink(pipeline, nullhead)

#
# Running the pipeline messages and pipeline graph
#

if options.write_pipeline is not None:
	pipeparts.write_dump_dot(pipeline, "%s.%s" % (options.write_pipeline, "NULL"), verbose = options.verbose)

if options.verbose:
	print >>sys.stderr, "setting pipeline state to playing ..."
if pipeline.set_state(gst.STATE_PLAYING) == gst.STATE_CHANGE_FAILURE:
	raise RuntimeError("pipeline failed to enter PLAYING state")

if options.write_pipeline is not None:
	pipeparts.write_dump_dot(pipeline, "%s.%s" % (options.write_pipeline, "PLAYING"), verbose = options.verbose)

if options.verbose:
	print >>sys.stderr, "running pipeline ..."

mainloop.run()
