#!/usr/bin/env python
#
# Copyright (C) 2011 Chad Hanna
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
This program makes a dag to run gstlal_inspiral offline
"""

__author__ = 'Chad Hanna <channa@caltech.edu>'

##############################################################################
# import standard modules and append the lalapps prefix to the python path
import sys, os, copy, math
import subprocess, socket, tempfile

##############################################################################
# import the modules we need to build the pipeline
from glue import iterutils
from glue import pipeline
from glue import lal
from glue.ligolw import lsctables
from glue import segments
from glue.ligolw import array
from glue.ligolw import param
import glue.ligolw.utils as utils
import glue.ligolw.utils.segments as ligolw_segments
from optparse import OptionParser
from gstlal.svd_bank import read_bank
from gstlal import inspiral, inspiral_pipe
from gstlal import dagparts as gstlaldagparts
import numpy
from pylal.datatypes import LIGOTimeGPS
from gstlal import datasource


#
# Utility functions
#


def T050017_filename(instruments, description, start, end, extension, path = None):
	if type(instruments) != type(str()):
		instruments = "".join(sorted(instruments))
	duration = end - start
	extension = extension.strip('.')
	if path is not None:
		return '%s/%s-%s-%d-%d.%s' % (path, instruments, description, start, duration, extension)
	else:
		return '%s-%s-%d-%d.%s' % (instruments, description, start, duration, extension)


#
# Classes for the svd banks
#


class gstlal_svd_bank_job(inspiral_pipe.InspiralJob):
	"""
	A gstlal_svd_bank job
	"""
	def __init__(self, executable=inspiral_pipe.which('gstlal_svd_bank'), tag_base='gstlal_svd_bank'):
		inspiral_pipe.InspiralJob.__init__(self, executable, tag_base)
		self.add_condor_cmd('request_memory', '1999') #FIXME is this enough?


class gstlal_svd_bank_node(inspiral_pipe.InspiralNode):
	"""
	A gstlal_svd_bank node
	"""
	def __init__(self, job, dag, template_bank, ifo, svd_bank_name, flow = 40, reference_psd = "psd.xml", tolerance = 0.9995, FAP = 0.5, snr = 4.0, clipleft = 0, clipright = 0, samples_min = 1024, samples_max_256 = 1024, samples_max_64 =  2048, samples_max = 4096, p_node=[], autocorrelation_length = None, bank_id = None, identity_transform = False):

		inspiral_pipe.InspiralNode.__init__(self, job, dag, p_node)
		
		self.add_var_opt("flow", flow)
		self.add_var_opt("snr-threshold", snr)
		self.add_var_opt("svd-tolerance", tolerance)
		self.add_var_opt("reference-psd", reference_psd)
		self.add_var_opt("template-bank", template_bank)
		self.add_var_opt("ortho-gate-fap", FAP)
		self.add_var_opt("samples-min", samples_min)
		self.add_var_opt("samples-max", samples_max)
		self.add_var_opt("samples-max-64", samples_max_64)
		self.add_var_opt("samples-max-256", samples_max_256)
		self.add_var_opt("clipleft", clipleft)
		self.add_var_opt("clipright", clipright)
		if bank_id is not None:
			self.add_var_opt("bank-id", bank_id)
		if autocorrelation_length is not None:
			self.add_var_opt("autocorrelation-length", autocorrelation_length)
		if identity_transform:
			self.add_var_arg("--identity-transform")
		self.add_var_opt("write-svd-bank", svd_bank_name)

#
# Classes to plot the horizon distance
#

class gstlal_plot_psd_horizon_job(inspiral_pipe.InspiralJob):
	"""
	A gstlal_plot_psd_horizon job
	"""
	def __init__(self, executable=inspiral_pipe.which('gstlal_plot_psd_horizon'), tag_base='gstlal_plot_psd_horizon'):
		"""
		gstlal_plot_psd_horizon job
		"""
		inspiral_pipe.InspiralJob.__init__(self, executable, tag_base)


class gstlal_plot_psd_horizon_node(inspiral_pipe.InspiralNode):
	"""
	A gstlal_plot_psd_horizon node
	"""
	def __init__(self, job, dag, files, output_name, p_node=[]):
		inspiral_pipe.InspiralNode.__init__(self, job, dag, p_node)
		path = os.getcwd()
		self.output_name = output_name
		self.add_var_arg(output_name)
		for f in files:
			self.add_file_arg(f)

#
# Classes for generating reference psds
#

class gstlal_reference_psd_job(inspiral_pipe.InspiralJob):
	"""
	A gstlal_reference_psd job
	"""
	def __init__(self, executable=inspiral_pipe.which('gstlal_reference_psd'), tag_base='gstlal_reference_psd'):
		"""
		A gstlal_reference_psd job
		"""
		inspiral_pipe.InspiralJob.__init__(self, executable, tag_base)


class gstlal_reference_psd_node(inspiral_pipe.InspiralNode):
	"""
	A gstlal_reference_psd node
	"""
	def __init__(self, job, dag, frame_cache, frame_segments_file, frame_segments_name, gps_start_time, gps_end_time, instruments, channel_dict, psd_fft_length = 16, injections=None, p_node=[]):
		inspiral_pipe.InspiralNode.__init__(self, job, dag, p_node)
		self.add_var_opt("frame-cache", frame_cache)
		self.add_var_opt("frame-segments-file", frame_segments_file)
		self.add_var_opt("frame-segments-name", frame_segments_name)
		self.add_var_opt("gps-start-time",gps_start_time)
		self.add_var_opt("gps-end-time",gps_end_time)
		self.add_var_opt("data-source", "frames")
		self.add_var_opt("channel-name", datasource.pipeline_channel_list_from_channel_dict(channel_dict, ifos = instruments))
		self.add_var_opt("psd-fft-length", psd_fft_length)
		if injections:
			self.add_var_opt("injections", injections)
		output_name = self.output_name = T050017_filename(instruments, "REFERENCE_PSD", gps_start_time, gps_end_time, '.xml.gz', path = job.output_path)
		self.add_var_opt("write-psd",output_name)


#
# gstlal_s5_pbh_summary_page
#


class gstlal_s5_pbh_summary_page_job(inspiral_pipe.InspiralJob):
	"""
	A gstlal_s5_pbh_summary_page job
	"""
	def __init__(self, executable=inspiral_pipe.which('gstlal_s5_pbh_summary_page'), tag_base='gstlal_s5_pbh_summary_page'):
		inspiral_pipe.InspiralJob.__init__(self, executable, tag_base)


class gstlal_s5_pbh_summary_page_node(inspiral_pipe.InspiralNode):
	"""
	A gstlal_s5_pbh_summary_page_node
	"""
	def __init__(self, job, dag, name_tag, web_dir, title, open_box=True, p_node=[]):
		inspiral_pipe.InspiralNode.__init__(self, job, dag, p_node)
		self.add_var_opt("output-name-tag", name_tag)
		self.add_var_opt("webserver-dir", web_dir)
		self.add_var_opt("title", title)
		if open_box: self.add_var_opt("open-box", "")


#
# gstlal_inspiral_plotsummary
#


class gstlal_inspiral_plotsummary_job(inspiral_pipe.InspiralJob):
	"""
	A gstlal_inspiral_plotsummary_job
	"""
	def __init__(self, executable=inspiral_pipe.which('gstlal_inspiral_plotsummary'), tag_base='gstlal_inspiral_plotsummary'):
		inspiral_pipe.InspiralJob.__init__(self, executable, tag_base)


class gstlal_inspiral_plotsummary_node(inspiral_pipe.InspiralNode):
	"""
	A gstlal_inspiral_plotsummary_node
	"""
	def __init__(self, job, dag, base, input=[], tmp_space=inspiral_pipe.log_path(), p_node=[], segments_name="datasegments"):
		inspiral_pipe.InspiralNode.__init__(self, job, dag, p_node)
		self.add_var_opt("segments-name",segments_name)
		self.add_var_opt("base", base)
		self.add_var_opt("tmp-space", tmp_space)
		for f in input:
			self.add_file_arg(f)

#
# gstlal_inspiral_plot_sensitivity
#

class gstlal_inspiral_plot_sensitivity_job(inspiral_pipe.InspiralJob):
	"""
	A gstlal_inspiral_plot_sensitivity_job
	"""
	def __init__(self, executable=inspiral_pipe.which('gstlal_inspiral_plot_sensitivity'), tag_base='gstlal_inspiral_plot_sensitivity'):
		inspiral_pipe.InspiralJob.__init__(self, executable, tag_base)


class gstlal_inspiral_plot_sensitivity_node(inspiral_pipe.InspiralNode):
	"""
	A gstlal_inspiral_plot_sensitivity_node
	"""
	# FIXME are all needed options here?
	def __init__(self, job, dag, base, injdbs=[], zldbs=[], tmp_space=inspiral_pipe.log_path(), p_node=[], bin_by_total_mass=True, bin_by_mass1_mass2=False, bin_by_mass_ratio=True, include_play=True):
		inspiral_pipe.InspiralNode.__init__(self, job, dag, p_node)
		self.add_var_opt("user-tag", base.split("/")[-1])
		if len(base.split("/")) > 1:
			self.add_var_opt("output-dir", base.rstrip(base.split("/")[-1]))
		self.add_var_opt("tmp-space", tmp_space)
		self.add_var_opt("veto-segments-name", "vetoes")
		if bin_by_total_mass:
			self.add_var_arg("--bin-by-total-mass")
		if bin_by_mass1_mass2:
			self.add_var_arg("--bin-by-mass1-mass2")
		if bin_by_mass_ratio:
			self.add_var_arg("--bin-by-mass-ratio")
		if include_play:
			self.add_var_arg("--include-play")
		for f in injdbs:
			self.add_file_arg(f)
		for f in zldbs:
			self.add_file_arg("--zero-lag-database %s"%f)


#
# gstlal_inspiral_plot_likelihoods
#


class gstlal_inspiral_plot_likelihoods_job(inspiral_pipe.InspiralJob):
	"""
	A gstlal_inspiral_plot_likelihoods job
	"""
	def __init__(self, executable=inspiral_pipe.which('gstlal_inspiral_plot_likelihoods'), tag_base='gstlal_inspiral_plot_likelihoods'):
		inspiral_pipe.InspiralJob.__init__(self, executable, tag_base)


class gstlal_inspiral_plot_likelihoods_node(inspiral_pipe.InspiralNode):
	"""
	A gstlal_inspiral_plot_likelihoods node
	"""
	def __init__(self, job, dag, base, url, output = None, p_node=[]):
		inspiral_pipe.InspiralNode.__init__(self, job, dag, p_node)
		if output is None:
			output = url+".html"
		self.add_var_arg("url=%s > %s" % (url, output))
		


#
# lalapps_run_sqlite
#


class lalapps_run_sqlite_job(inspiral_pipe.InspiralJob):
	"""
	A lalapps_run_sqlite
	"""
	def __init__(self, executable=inspiral_pipe.which('lalapps_run_sqlite'), tag_base='lalapps_run_sqlite'):
		inspiral_pipe.InspiralJob.__init__(self, executable, tag_base)


class lalapps_run_sqlite_node(inspiral_pipe.InspiralNode):
	"""
	A lalapps_run_sqlite node
	"""
	def __init__(self, job, dag, sql_file, input=[], tmp_space=inspiral_pipe.log_path(), p_node=[]):
		inspiral_pipe.InspiralNode.__init__(self, job, dag, p_node)
		self.add_var_opt("sql-file", sql_file)
		self.add_var_opt("tmp-space", tmp_space)
		if len(input) == 1:
			self.output_name = input[0]
		for f in input:
			self.add_file_arg(f)


#
# ligolw_sqlite
#


class ligolw_sqlite_from_xml_job(inspiral_pipe.InspiralJob):
	"""
	A ligolw_sqlite_job
	"""
	def __init__(self, executable=inspiral_pipe.which('ligolw_sqlite'), tag_base='ligolw_sqlite_from_xml'):
		inspiral_pipe.InspiralJob.__init__(self, executable, tag_base)


class ligolw_sqlite_to_xml_job(inspiral_pipe.InspiralJob):
	"""
	A ligolw_sqlite_job
	"""
	def __init__(self, executable=inspiral_pipe.which('ligolw_sqlite'), tag_base='ligolw_sqlite_to_xml'):
		inspiral_pipe.InspiralJob.__init__(self, executable, tag_base)


class ligolw_sqlite_node(inspiral_pipe.InspiralNode):
	"""
	A ligolw_sqlite node
	"""
	def __init__(self, job, dag, database, input=[], replace=True, tmp_space=inspiral_pipe.log_path(), extract=None, p_node=[]):
		inspiral_pipe.InspiralNode.__init__(self, job, dag, p_node)
		if extract is not None:
			self.add_var_opt("extract", extract)
		self.add_var_opt("database", database)
		if replace:
			self.add_var_opt("replace", "")
		self.add_var_opt("tmp-space", tmp_space)
		for f in input:
			if f is not None:
				self.add_file_arg(f)
		self.output_db_name = database
		self.output_xml_name = extract


#
# ligolw_inspinjfind
#


class ligolw_inspinjfind_job(inspiral_pipe.InspiralJob):
	"""
	A ligolw_inspinjfind_job
	"""
	def __init__(self, executable=inspiral_pipe.which('ligolw_inspinjfind'), tag_base='ligolw_inspinjfind'):
		inspiral_pipe.InspiralJob.__init__(self, executable, tag_base)


class ligolw_inspinjfind_node(inspiral_pipe.InspiralNode):
	"""
	A ligolw_inspinjfind node
	"""
	def __init__(self, job, dag, xml, p_node=[]):
		inspiral_pipe.InspiralNode.__init__(self, job, dag, p_node)
		#FIXME make a parameter?
		self.add_var_opt("time-window", 0.9)
		self.add_var_arg(xml)
		self.output_name = xml


#
# gstlal_inspiral
#


class gstlal_inspiral_job(inspiral_pipe.InspiralJob):
	"""
	A gstlal_inspiral job
	"""
	def __init__(self, executable=inspiral_pipe.which('gstlal_inspiral'), tag_base='gstlal_inspiral'):
		inspiral_pipe.InspiralJob.__init__(self, executable, tag_base)
		self.add_condor_cmd('requirements', '( CAN_RUN_MULTICORE )')
		self.add_condor_cmd('request_cpus', '8')
		self.add_condor_cmd('+RequiresMultipleCores', 'True')

class gstlal_inspiral_inj_job(inspiral_pipe.InspiralJob):
	"""
	A gstlal_inspiral job
	"""
	def __init__(self, executable=inspiral_pipe.which('gstlal_inspiral'), tag_base='gstlal_inspiral_inj'):
		inspiral_pipe.InspiralJob.__init__(self, executable, tag_base)
		self.add_condor_cmd('requirements', '( CAN_RUN_MULTICORE )')
		self.add_condor_cmd('request_cpus', '8')
		self.add_condor_cmd('+RequiresMultipleCores', 'True')


def sim_tag_from_inj_file(injections):
	return injections.replace('.xml', '').replace('.gz', '')


class gstlal_inspiral_node(inspiral_pipe.InspiralNode):
	"""
	A gstlal_inspiral node
	"""
	#FIXME add veto segments name
	def __init__(self, job, dag, frame_cache, frame_segments_file, frame_segments_name, gps_start_time, gps_end_time, channel_dict, reference_psd, svd_bank, tmp_space=inspiral_pipe.log_path(), ht_gate_thresh=None, injections=None, control_peak_time = 8, coincidence_threshold = 0.020, vetoes=None, time_slide_file=None, fir_stride = 8, instruments = "H1H2L1", number = 1, psd_fft_length = 16, blind_injections = None, p_node=[]):
		inspiral_pipe.InspiralNode.__init__(self, job, dag, p_node)

		self.set_category("INSPIRAL")

		if time_slide_file is not None:
			self.add_var_opt("time-slide-file", time_slide_file)
		if ht_gate_thresh is not None:
			self.add_var_opt("ht-gate-threshold", ht_gate_thresh)
		self.add_var_opt("psd-fft-length", psd_fft_length)
		self.add_var_opt("frame-cache", frame_cache)
		self.add_var_opt("frame-segments-file", frame_segments_file)
		self.add_var_opt("frame-segments-name", frame_segments_name)
		self.add_var_opt("gps-start-time",gps_start_time)
		self.add_var_opt("gps-end-time",gps_end_time)
		self.add_var_opt("channel-name", datasource.pipeline_channel_list_from_channel_dict(channel_dict))
		self.add_var_opt("reference-psd", reference_psd)
		self.add_var_opt("svd-bank", svd_bank)
		self.add_var_opt("tmp-space", tmp_space)
		self.add_var_opt("track-psd", "")
		self.add_var_opt("control-peak-time", control_peak_time)
		self.add_var_opt("coincidence-threshold", coincidence_threshold)		
		self.add_var_opt("fir-stride", fir_stride)
		self.add_var_opt("data-source", "frames")
		self.injections = injections
		if self.injections is not None:
			self.add_var_opt("injections", injections)
		if blind_injections is not None:
			self.add_var_opt("blind-injections", blind_injections)
		if vetoes is not None:
			self.add_var_opt("veto-segments-file", vetoes)
		svd_bank = os.path.split(svd_bank)[1].replace('.xml','')
		if self.injections is not None:
			sim_name = sim_tag_from_inj_file(self.injections)
			self.output_name = T050017_filename(instruments, '%04d_LLOID_%s' % (number, sim_name), gps_start_time, gps_end_time, '.sqlite', path = job.output_path)
		else:
			self.output_name = T050017_filename(instruments, '%04d_LLOID' % number, gps_start_time, gps_end_time, '.sqlite', path = job.output_path)
			self.background_name = T050017_filename(instruments, '%04d_LLOID_SNR_CHI' % number, gps_start_time, gps_end_time, '.xml.gz', path = job.output_path)
		job.number += 1
		self.add_var_opt("output", self.output_name)
		# FIXME either make this correct or remove it
		# dag.output_cache.append(lal.CacheEntry(instruments, "-", segments.segment(gps_start_time, gps_end_time), "file://localhost/%s" % (self.output_name,)))


#
# gstlal_compute_far_from_snr_chisq_histograms
#


class gstlal_inspiral_calc_likelihood_job(inspiral_pipe.InspiralJob):
	"""
	A gstlal_inspiral_calc_likelihood job
	"""
	def __init__(self, executable=inspiral_pipe.which('gstlal_inspiral_calc_likelihood'), tag_base='gstlal_inspiral_calc_likelihood'):
		inspiral_pipe.InspiralJob.__init__(self, executable, tag_base)


class gstlal_inspiral_calc_likelihood_node(inspiral_pipe.InspiralNode):
	"""
	A gstlal_inspiral_calc_likelihood node
	"""
	def __init__(self, job, dag, likelihood_files = [], synthesize_injections = 1000000, input = [], likelihood_output_name = "post_calc_likelihood_", background_prior = 1.0, p_node=[]):
		inspiral_pipe.InspiralNode.__init__(self, job, dag, p_node)
		self.add_var_opt("likelihood-file", pipeline_dot_py_append_opts_hack("likelihood-file", likelihood_files))
		self.add_var_opt("tmp-space", inspiral_pipe.log_path())
		for f in input:
			self.add_file_arg(f)
		fpath = os.path.split(likelihood_files[0])
		if likelihood_output_name is not None:
			self.background_name = os.path.join(fpath[0], likelihood_output_name + fpath[1])
			self.add_var_opt("write-likelihood", self.background_name)
		self.add_var_opt("background-prior", background_prior)
		self.add_var_opt("synthesize-injections", synthesize_injections)
		self.output_names = input
		if len(input) == 1:
			self.output_name = input[0]


#
# gstlal_compute_far_from_snr_chisq_histograms
#


class gstlal_compute_far_from_snr_chisq_histograms_job(inspiral_pipe.InspiralJob):
	"""
	A gstlal_compute_far_from_snr_chisq_histograms job
	"""
	def __init__(self, executable=inspiral_pipe.which('gstlal_compute_far_from_snr_chisq_histograms'), tag_base='gstlal_compute_far_from_snr_chisq_histograms'):
		inspiral_pipe.InspiralJob.__init__(self, executable, tag_base)


class gstlal_compute_far_from_snr_chisq_histograms_node(inspiral_pipe.InspiralNode):
	"""
	A gstlal_compute_far_from_snr_chisq_histograms_job node
	"""
	def __init__(self, job, dag, background_bins_files = [], noninj_input = [], inj_input = [], p_node=[]):
		inspiral_pipe.InspiralNode.__init__(self, job, dag, p_node)
		if background_bins_files is not None:
			self.add_var_opt("background-bins-file", pipeline_dot_py_append_opts_hack("background-bins-file", background_bins_files))
		self.add_var_opt("tmp-space", inspiral_pipe.log_path())
		if len(inj_input) > 0:
			self.add_var_opt("injection-dbs", pipeline_dot_py_append_opts_hack("injection-dbs", inj_input))
		self.add_var_opt("non-injection-db", pipeline_dot_py_append_opts_hack("non-injection-db", noninj_input))
		self.output_names = inj_input + noninj_input


#
# gstlal_inspiral_marginalize_likelihood
#


class gstlal_inspiral_marginalize_likelihood_job(inspiral_pipe.InspiralJob):
	"""
	A gstlal_inspiral_marginalize_likelihood job
	"""
	def __init__(self, executable=inspiral_pipe.which('gstlal_inspiral_marginalize_likelihood'), tag_base='gstlal_inspiral_marginalize_likelihood'):
		inspiral_pipe.InspiralJob.__init__(self, executable, tag_base)


class gstlal_inspiral_marginalize_likelihood_node(inspiral_pipe.InspiralNode):
	"""
	A gstlal_inspiral_marginalize_likelihood node
	"""
	def __init__(self, job, dag, output, background_bins_files, p_node=[]):
		inspiral_pipe.InspiralNode.__init__(self, job, dag, p_node)
		for f in background_bins_files:
			self.add_var_arg(f)
		self.add_var_opt("output", output)
		self.output_name = output


#
# Utility functions
#


# FIXME surely this is in glue
def parse_cache_str(instr):
	dictcache = {}
	if instr is None: return dictcache
	for c in instr.split(','):
		ifo = c.split("=")[0]
		cache = c.replace(ifo+"=","")
		dictcache[ifo] = cache
	return dictcache

def pipeline_dot_py_append_opts_hack(opt, vals):
	out = vals[0]
	for v in vals[1:]:
		out += " --%s %s" % (opt, v)
	return out

def extract_all_nodes_by_inj(nodes):
	out = {}

	for n in nodes:
		out.setdefault(n[0], []).append(n[1])
	return out

def parse_command_line():
	parser = OptionParser(description = __doc__)

	# generic data source options
	datasource.append_options(parser)
	parser.add_option("--psd-fft-length", metavar = "s", default = 16, type = "int", help = "FFT length, default 16s")
	
	# SVD bank construction options
	parser.add_option("--overlap", metavar = "num", type = "int", default = 0, help = "set the factor that describes the overlap of the sub banks, must be even!")
	parser.add_option("--autocorrelation-length", type = "int", default = 201, help = "The minimum number of samples to use for auto-chisquared, default 201 should be odd")
	parser.add_option("--samples-min", type = "int", default = 1024, help = "The minimum number of samples to use for time slices default 1024")
	parser.add_option("--samples-max-256", type = "int", default = 1024, help = "The maximum number of samples to use for time slices with frequencies above 256Hz, default 1024")
	parser.add_option("--samples-max-64", type = "int", default = 2048, help = "The maximum number of samples to use for time slices with frequencies above 64Hz, default 2048")
	parser.add_option("--samples-max", type = "int", default = 4096, help = "The maximum number of samples to use for time slices with frequencies below 64Hz, default 4096")
	parser.add_option("--bank-cache", metavar = "filenames", help = "Set the bank cache files in format H1=H1.cache,H2=H2.cache, etc..")
	parser.add_option("--tolerance", metavar = "float", type = "float", default = 0.9999, help = "set the SVD tolerance, default 0.9999")
	parser.add_option("--flow", metavar = "num", type = "float", default = 40, help = "set the low frequency cutoff, default 40 (Hz)")
	parser.add_option("--identity-transform", action = "store_true", help = "Use identity transform, i.e. no SVD")
	
	# trigger generation options
	parser.add_option("--vetoes", metavar = "filename", help = "Set the veto xml file.")
	parser.add_option("--time-slide-file", metavar = "filename", help = "Set the time slide table xml file")
	parser.add_option("--web-dir", metavar = "directory", help = "Set the web directory like /home/USER/public_html")
	parser.add_option("--fir-stride", type="int", metavar = "secs", default = 8, help = "Set the duration of the fft output blocks, default 8")
	parser.add_option("--control-peak-time", type="int", default = 8, metavar = "secs", help = "Set the peak finding time for the control signal, default 8")
	parser.add_option("--coincidence-threshold", metavar = "value", type = "float", default = 0.005, help = "Set the coincidence window in seconds (default = 0.005).  The light-travel time between instruments will be added automatically in the coincidence test.")
	parser.add_option("--max-segment-length", type="int", metavar = "dur", default = 30000, help = "Break up segments longer than dur seconds into shorter (contiguous, non-overlapping) segments. Default 30000 seconds.")
	parser.add_option("--num-banks", metavar = "str", help = "the number of banks per job. can be given as a list like 1,2,3,4 then it will split up the bank cache into N groups with M banks each.")
	parser.add_option("--max-inspiral-jobs", type="int", metavar = "jobs", help = "Set the maximum number of gstlal_inspiral jobs to run simultaneously, default no constraint.")
	parser.add_option("--ht-gate-threshold", type="float", help="set a threshold on whitened h(t) to veto glitches")
	parser.add_option("--do-iir-pipeline", action="store_true", help = "run the iir pipeline instead of lloid")
	parser.add_option("--blind-injections", metavar = "filename", help = "Set the name of an injection file that will be added to the data without saving the sim_inspiral table or otherwise processing the data differently.  Has the effect of having hidden signals in the input data. Separate injection runs using the --injections option will still occur.")
	parser.add_option("--verbose", action = "store_true", help = "Be verbose")

	# Override the datasource injection option
	parser.remove_option("--injections")
	parser.add_option("--injections", action = "append", help = "append injection files to analyze")

	options, filenames = parser.parse_args()
	options.num_banks = [int(v) for v in options.num_banks.split(",")]
	
	if options.overlap % 2:
		raise ValueError("overlap must be even")

	fail = ""
	for option in ("bank_cache",):
		if getattr(options, option) is None:
			fail += "must provide option %s\n" % (option)
	if fail: raise ValueError, fail

	return options, filenames

#
# MAIN
#

options, filenames = parse_command_line()
bank_cache = parse_cache_str(options.bank_cache)

detectors = datasource.GWDataSourceInfo(options)
channel_dict = detectors.channel_dict

instruments = "".join(sorted(bank_cache.keys()))

#FIXME a hack to find the sql paths
share_path = os.path.split(inspiral_pipe.which('gstlal_reference_psd'))[0].replace('bin', 'share/gstlal')
options.cluster_sql_file = os.path.join(share_path, 'simplify_and_cluster.sql')
options.injection_sql_file = os.path.join(share_path, 'inj_simplify_and_cluster.sql')


try: os.mkdir("logs")
except: pass
dag = inspiral_pipe.DAG("trigger_pipe")

if options.max_inspiral_jobs is not None:
	dag.add_maxjobs_category("INSPIRAL", options.max_inspiral_jobs)

#
# setup the job classes
#

refPSDJob = gstlal_reference_psd_job()
svdJob = gstlal_svd_bank_job(tag_base="gstlal_svd_bank")
horizonJob = gstlal_plot_psd_horizon_job()

if options.do_iir_pipeline is not None:
	gstlalInspiralJob = gstlal_inspiral_job(executable=inspiral_pipe.which('gstlal_iir_inspiral'))
	gstlalInspiralInjJob = gstlal_inspiral_inj_job(executable=inspiral_pipe.which('gstlal_iir_inspiral'))
else:
	gstlalInspiralJob = gstlal_inspiral_job()
	gstlalInspiralInjJob = gstlal_inspiral_inj_job()

calcLikelihoodJob = gstlal_inspiral_calc_likelihood_job()
calcLikelihoodJobInj = gstlal_inspiral_calc_likelihood_job(tag_base='gstlal_inspiral_calc_likelihood_inj')
gstlalInspiralComputeFarFromSnrChisqHistogramsJob = gstlal_compute_far_from_snr_chisq_histograms_job()
ligolwInspinjFindJob = ligolw_inspinjfind_job()
toSqliteJob = ligolw_sqlite_from_xml_job()
toXMLJob = ligolw_sqlite_to_xml_job()
lalappsRunSqliteJob = lalapps_run_sqlite_job()
plotSummaryJob = gstlal_inspiral_plotsummary_job()
plotSensitivityJob = gstlal_inspiral_plot_sensitivity_job()
openpageJob = gstlal_s5_pbh_summary_page_job(tag_base = 'gstlal_s5_pbh_summary_page_open')
pageJob = gstlal_s5_pbh_summary_page_job()
marginalizeJob = gstlal_inspiral_marginalize_likelihood_job()

#
# Setup analysis segments
#

allsegs = detectors.frame_segments
boundary_seg = detectors.seg

# define the analyzable instruments
analyzable_instruments_set = set(bank_cache.keys())

# get a dictionary of all the disjoint 2+ detector combination segments
segsdict = segments.segmentlistdict()
for n in range(2, 1 + len(analyzable_instruments_set)):
	for ifo_combos in iterutils.choices(list(analyzable_instruments_set), n):
		# never analyze H1H2 or H2L1 times
		if set(ifo_combos) == set(('H1', 'H2')) or set(ifo_combos) == set(('L1', 'H2')):
			print >> sys.stderr, "not analyzing: ", ifo_combos, " only time"
			continue
		segsdict[frozenset(ifo_combos)] = allsegs.intersection(ifo_combos) - allsegs.union(analyzable_instruments_set - set(ifo_combos))
		segsdict[frozenset(ifo_combos)] &= segments.segmentlist([boundary_seg])
		segsdict[frozenset(ifo_combos)] = segsdict[frozenset(ifo_combos)].protract(2048)
		segsdict[frozenset(ifo_combos)] = gstlaldagparts.breakupsegs(segsdict[frozenset(ifo_combos)], options.max_segment_length, 2048)


#
# Variables to track certain dag nodes
#

likelihood_nodes = {}
likelihood_nodes_inj = {}
noninj_nodes = []
inj_nodes = dict([(inj,[]) for inj in options.injections])

#
# FIXME FIXME FIXME!!! 
# From here down don't make explicit references to the opaque options provided
# by datasource for maintainability
#

#
# Precompute the PSDs and banks for each segment
#

def hash_seg(ifos, seg):
	# FIXME what is a good way to hash the segment?
	return (ifos, seg)

psd_nodes = {}
bank_nodes = {}
for ifos in segsdict:
	this_channel_dict = dict((k, channel_dict[k]) for k in ifos if k in channel_dict)
	for seg in segsdict[ifos]:
		psd_nodes[hash_seg(ifos, seg)] = gstlal_reference_psd_node(refPSDJob, dag, options.frame_cache, options.frame_segments_file, options.frame_segments_name, seg[0].seconds, seg[1].seconds, ifos, this_channel_dict, injections=None, psd_fft_length = options.psd_fft_length, p_node=[])

# plot the horizon distance
name_tag = "plots/gstlal-%d-%d_" % (int(boundary_seg[0]), int(boundary_seg[1]))
gstlal_plot_psd_horizon_node(horizonJob, dag, [node.output_name for node in psd_nodes.values()], name_tag + "horizon.png", p_node = psd_nodes.values())

#
# loop over banks to run gstlal inspiral pre clustering and far computation
#

bank_groups = inspiral_pipe.build_bank_groups(bank_cache, options.num_banks)

for i, bank_group in enumerate(bank_groups):
	
	inspiral_nodes = []

	# bank and inspiral jobs by segment
	for ifos in segsdict:
		for seg in segsdict[ifos]:
			svd_nodes = []
			s = "" # bank input file string for the inspiral jobs
			for ifo, files in bank_group.items():
				if ifo not in ifos:
					continue
				for n, f in enumerate(files):
					# handle template bank clipping
					if (n == 0) and (i == 0):
						clipleft = 0
					else:
						clipleft = options.overlap / 2
					if (i == len(bank_groups) - 1) and (n == len(files) -1):
						clipright = 0
					else:
						clipright = options.overlap / 2

					svd_bank_name = T050017_filename(ifo, '%d_%d_SVD' % (i, n), seg[0].seconds, seg[1].seconds, '.xml.gz', path = svdJob.output_path)
					s += '%s:%s,' % (ifo, svd_bank_name)
					# FIXME bank id should be global??
					svd_nodes.append(gstlal_svd_bank_node(svdJob, dag, f, ifo, svd_bank_name, tolerance = options.tolerance, reference_psd = psd_nodes[hash_seg(ifos, seg)].output_name, flow = options.flow, clipleft = clipleft, clipright = clipright, samples_min = options.samples_min, samples_max_256 = options.samples_max_256, samples_max_64 = options.samples_max, samples_max = options.samples_max, autocorrelation_length = options.autocorrelation_length, bank_id = n, p_node = [psd_nodes[hash_seg(ifos, seg)]], identity_transform = options.identity_transform))

			s = s.strip(',')

			# only use a channel dict with the relevant channels
			this_channel_dict = dict((k, channel_dict[k]) for k in ifos if k in channel_dict)

			# non injections
			noninjnode = gstlal_inspiral_node(gstlalInspiralJob, dag, options.frame_cache, options.frame_segments_file, options.frame_segments_name, seg[0].seconds, seg[1].seconds, this_channel_dict, reference_psd=psd_nodes[hash_seg(ifos, seg)].output_name, svd_bank=s, injections=None, vetoes=options.vetoes, time_slide_file=options.time_slide_file, control_peak_time = options.control_peak_time, fir_stride = options.fir_stride, coincidence_threshold=options.coincidence_threshold, number = i, instruments = "".join(sorted(ifos)), ht_gate_thresh = options.ht_gate_threshold, blind_injections = options.blind_injections, psd_fft_length = options.psd_fft_length, p_node = svd_nodes)
			inspiral_nodes.append((None, noninjnode))

			# injections
			for injections in options.injections:	
				injnode = gstlal_inspiral_node(gstlalInspiralInjJob, dag, options.frame_cache, options.frame_segments_file, options.frame_segments_name, seg[0].seconds, seg[1].seconds, this_channel_dict, reference_psd=psd_nodes[hash_seg(ifos, seg)].output_name, svd_bank=s, injections=injections, vetoes=options.vetoes, control_peak_time = options.control_peak_time, fir_stride = options.fir_stride, coincidence_threshold=options.coincidence_threshold, number = i, instruments = "".join(sorted(ifos)), ht_gate_thresh = options.ht_gate_threshold, psd_fft_length = options.psd_fft_length, p_node = svd_nodes)
				inspiral_nodes.append((injections, injnode))

	# likelihood jobs for non injections
	likelihood_nodes[i] = gstlal_inspiral_calc_likelihood_node(calcLikelihoodJob, dag, likelihood_files = [node[1].background_name for node in inspiral_nodes if node[0] is None], input = [node[1].output_name for node in inspiral_nodes if node[0] is None], p_node=[node[1] for node in inspiral_nodes if node[0] is None])
	
	# likelihood jobs for injections
	likelihood_nodes_inj[i] = []
	for injections in options.injections:
		likelihood_nodes_inj[i].append(gstlal_inspiral_calc_likelihood_node(calcLikelihoodJobInj, dag, likelihood_output_name = None, likelihood_files = [node[1].background_name for node in inspiral_nodes if node[0] is None], input = [node[1].output_name for node in inspiral_nodes if node[0] == injections], p_node=[node[1] for node in inspiral_nodes if node[0] == injections]+[node[1] for node in inspiral_nodes if node[0] is None]))

	# after assigning the likelihoods cluster
	merge_nodes = []
	files_to_group = 10
	for inj, nodes in extract_all_nodes_by_inj(inspiral_nodes).items():
		if inj is None:
			# 10 at a time so the jobs take a bit longer to run
			for n in range(0, len(nodes), files_to_group):
				merge_nodes.append((inj, lalapps_run_sqlite_node(lalappsRunSqliteJob, dag, options.cluster_sql_file, input=[node.output_name for node in nodes[n:n+files_to_group]], p_node=[likelihood_nodes[i]])))
		else:
			for n in range(0, len(nodes), files_to_group):
				merge_nodes.append((inj, lalapps_run_sqlite_node(lalappsRunSqliteJob, dag, options.injection_sql_file, input=[node.output_name for node in nodes[n:n+files_to_group]], p_node=likelihood_nodes_inj[i])))

	# merge over the segments and cluster again

	noninjdb = T050017_filename(instruments, '%04d_LLOID' % i, int(boundary_seg[0]), int(boundary_seg[1]), '.sqlite')
	sqlitenode = ligolw_sqlite_node(toSqliteJob, dag, noninjdb, input=[node[1].output_name for node in inspiral_nodes if node[0] is None], p_node = [node[1] for node in merge_nodes if node[0] is None])
	sqlitenode = lalapps_run_sqlite_node(lalappsRunSqliteJob, dag, options.cluster_sql_file, input=[sqlitenode.output_db_name], p_node=[sqlitenode])
	noninj_nodes.append(sqlitenode)

	for inj in options.injections:
		injdb = T050017_filename(instruments, '%04d_LLOID_%s' % (i, sim_tag_from_inj_file(inj)), int(boundary_seg[0]), int(boundary_seg[1]), '.sqlite')
		sqlitenode = ligolw_sqlite_node(toSqliteJob, dag, injdb, input=[node[1].output_name for node in inspiral_nodes if node[0] ==  inj], p_node = [node[1] for node in merge_nodes if node[0] == inj])
		sqlitenode = lalapps_run_sqlite_node(lalappsRunSqliteJob, dag, options.injection_sql_file, input=[sqlitenode.output_db_name], p_node=[sqlitenode])
		inj_nodes[inj].append(sqlitenode)
	

#
# after all of the likelihood ranking and preclustering is finished put everything into single databases based on the injection file (or lack thereof)
#

# setup the final output names
noninjdb = T050017_filename(instruments, 'ALL_LLOID', int(boundary_seg[0]), int(boundary_seg[1]), '.sqlite')

# merge
sqlitenode = ligolw_sqlite_node(toSqliteJob, dag, noninjdb, input=[f.output_name for f in noninj_nodes] + [options.vetoes, options.frame_segments_file], p_node=noninj_nodes)

# cluster
noninjsqlitenode = lalapps_run_sqlite_node(lalappsRunSqliteJob, dag, options.cluster_sql_file, input=[noninjdb], p_node=[sqlitenode])

#
# injection DBs
#

injdbs = []
p_nodes = [noninjsqlitenode]
for injections in options.injections:

	# Setup the final output names, etc.
	injdb = T050017_filename(instruments, 'ALL_LLOID_%s' % sim_tag_from_inj_file(injections), int(boundary_seg[0]), int(boundary_seg[1]), '.sqlite')
	injdbs.append(injdb)
	injxml = injdb+".xml.gz"
	
	# extract only the ndoes that were used for injections
	thisinjnodes = inj_nodes[injections]

	# merge
	sqlitenode = ligolw_sqlite_node(toSqliteJob, dag, injdb, input=[f.output_name for f in thisinjnodes] + [options.vetoes, options.frame_segments_file, injections], p_node = thisinjnodes)

	# cluster
	clusternode = lalapps_run_sqlite_node(lalappsRunSqliteJob, dag, options.cluster_sql_file, input=[injdb], p_node=[sqlitenode])

	# convert to XML
	clusternode = ligolw_sqlite_node(toXMLJob, dag, injdb, replace=False, extract=injxml, p_node=[clusternode])

	# find injections
	inspinjnode = ligolw_inspinjfind_node(ligolwInspinjFindJob, dag, injxml, p_node=[clusternode])

	# convert back to sqlite
	sqlitenode = ligolw_sqlite_node(toSqliteJob, dag, injdb, input=[injxml], p_node=[inspinjnode])
	p_nodes.append(sqlitenode)


# compute FAPs and FARs
# split up the marginilization into groups of 10
margin = [node.background_name for node in likelihood_nodes.values()]
margout = []
margnodes = []
margnum = 16
for i,n in enumerate(range(0, len(margin), margnum)):
	margout.append("%d_marginalized_likelihood.xml.gz" % (i,))
	margnodes.append(gstlal_inspiral_marginalize_likelihood_node(marginalizeJob, dag, margout[-1], margin[n:n+margnum], p_node = p_nodes))
margnode = gstlal_inspiral_marginalize_likelihood_node(marginalizeJob, dag, "marginalized_likelihood.xml.gz", margout, p_node = margnodes)
farnode = gstlal_compute_far_from_snr_chisq_histograms_node(gstlalInspiralComputeFarFromSnrChisqHistogramsJob, dag, background_bins_files = ["marginalized_likelihood.xml.gz"], noninj_input = [noninjdb], inj_input = injdbs, p_node = [margnode])

# make summary plots
plotnodes = []
plotnodes.append( gstlal_inspiral_plotsummary_node(plotSummaryJob, dag, name_tag, input=[noninjdb] + injdbs, p_node=[farnode], segments_name=options.frame_segments_name) )

# make sensitivity plots
plotnodes.append( gstlal_inspiral_plot_sensitivity_node(plotSensitivityJob, dag, name_tag, injdbs=injdbs, zldbs=[noninjdb], p_node=[farnode]) )

# make a web page
gstlal_s5_pbh_summary_page_node(pageJob, dag, name_tag, options.web_dir, title="gstlal-%d-%d-closed-box" % (int(boundary_seg[0]), int(boundary_seg[1])), open_box = False, p_node=plotnodes)
gstlal_s5_pbh_summary_page_node(openpageJob, dag, name_tag, options.web_dir, title="gstlal-%d-%d-open-box" % (int(boundary_seg[0]), int(boundary_seg[1])), p_node=plotnodes)

#
# all done
#

dag.write_sub_files()
dag.write_dag()
dag.write_script()
dag.write_cache()
