#!/usr/bin/env python
#
# Copyright (C) 2010  Kipp Cannon, Chad Hanna, Drew Keppel
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
This program makes a dag to generate svd banks
"""

__author__ = 'Chad Hanna <channa@caltech.edu>'

##############################################################################
# import standard modules and append the lalapps prefix to the python path
import sys, os, copy, math
import subprocess, socket, tempfile

##############################################################################
# import the modules we need to build the pipeline
from glue import iterutils
from glue import pipeline
from glue import lal
from glue.ligolw import lsctables
from glue import segments
from optparse import OptionParser
from gstlal import dagparts

## @file gstlal_inspiral_svd_bank_pipe
# This program will make a HTCondor DAG to automate the creation of svd bank files; see gstlal_inspiral_svd_bank_pipe for more information

## @package gstlal_inspiral_svd_bank_pipe
#
# ### Graph of the condor DAG
#
# @dot
# digraph G {
#	// graph properties
#
#	rankdir=LR;
#	compound=true;
#	node [shape=record fontsize=10 fontname="Verdana"];     
#	edge [fontsize=8 fontname="Verdana"];
#
#	// nodes
#
#	"gstlal_svd_bank" [URL="\ref gstlal_svd_bank"];
# }
# @enddot
#
# This DAG implements only a single job type; gstlal_svd_bank
#
# ### Usage cases
#
# - Typical usage case for H1
#
#		$ gstlal_inspiral_svd_bank_pipe --autocorrelation-length 351 --instrument H1 --reference-psd reference_psd.xml --bank-cache H1_split_bank.cache --overlap 10 --flow 15 --output-name H1_bank
#
# - Please add more!
#
# ### Command line options
# 
#	+ `--instrument` [ifo]: set the name of the instrument, required
#	+ `--reference-psd` [file]: Set the name of the reference psd file, required
#	+ `--bank-cache` [file]: Set the name of the bank cache, required
#	+ `--overlap` [int]: Set the factor that describes the overlap of the sub banks, must be even!
#	+ `--identity-transform`: Turn off the SVD and use the identity reconstruction matrix
#	+ `--autocorrelation-length` [int]: The number of samples to use for auto-chisquared, default 201 should be odd
#	+ `--samples-min` [int]: The minimum number of samples to use for time slices default 1024
#	+ `--samples-max-256` [int]: The maximum number of samples to use for time slices with frequencies above 256Hz, default 1024
#	+ `--samples-max-64` [int]: The maximum number of samples to use for time slices with frequencies between 64Hz and 256 Hz, default 2048
#	+ `--samples-max` [int]: The maximum number of samples to use for time slices with frequencies below 64Hz, default 4096
#	+ `--stagger`: A hacky way to stagger the number of samples used in the SVD if you have a broad mass range and the cache is organized by chirp mass from small to big.  The first half gets 1024 the second gets 2048. This helps balance the number of slices with the number of principle components when the template duration varies wildly through the cache.  Using this option disables the samples-* options.  FIXME this will be removed once we know how to automate this in a better way.")
#	+ `--tolerance` [float]: Set the SVD tolerance, default 0.9995
#	+ `--flow` [float]: Set the low frequency cutoff, default 40 (Hz)
#	+ `--output-name` [file]: Set the base name of the output, required
#	+ `--verbose`: Be verbose.

class bank_DAG(pipeline.CondorDAG):

	def __init__(self, name, logpath = dagparts.log_path()):
		self.basename = name
		tempfile.tempdir = logpath
		tempfile.template = self.basename + '.dag.log.'
		logfile = tempfile.mktemp()
		fh = open( logfile, "w" )
		fh.close()
		pipeline.CondorDAG.__init__(self,logfile)
		self.set_dag_file(self.basename)
		self.jobsDict = {}
		self.node_id = 0
		self.output_cache = []

	def add_node(self, node):
		node.set_retry(3)
		self.node_id += 1
		node.add_macro("macroid", self.node_id)
		pipeline.CondorDAG.add_node(self, node)

	def write_cache(self):
		out = self.basename + ".cache"
		f = open(out,"w")
		for c in self.output_cache:
			f.write(str(c)+"\n")
		f.close()

class gstlal_svd_bank_job(pipeline.CondorDAGJob):
	"""
	A gstlal_svd_bank job
	"""
	def __init__(self, executable=dagparts.which('gstlal_svd_bank'), tag_base='gstlal_svd_bank'):
		"""
		"""
		self.__prog__ = 'gstlal_svd_bank'
		self.__executable = executable
		self.__universe = 'vanilla'
		pipeline.CondorDAGJob.__init__(self,self.__universe,self.__executable)
		self.add_condor_cmd('getenv','True')
		self.add_condor_cmd('requirements', 'Memory > 1999') #FIXME is this enough?
		self.tag_base = tag_base
		self.add_condor_cmd('environment',"KMP_LIBRARY=serial;MKL_SERIAL=yes")
		self.set_sub_file(tag_base+'.sub')
		self.set_stdout_file('logs/'+tag_base+'-$(macroid)-$(process).out')
		self.set_stderr_file('logs/'+tag_base+'-$(macroid)-$(process).err')


class gstlal_svd_bank_node(pipeline.CondorDAGNode):
	"""
	"""
	def __init__(self, job, dag, template_bank, ifo, flow = 40, reference_psd = "psd.xml", tolerance = 0.9995, FAP = 0.5, snr = 4.0, clipleft = 0, clipright = 0, samples_min = 1024, samples_max_256 = 1024, samples_max_64 =  2048, samples_max = 4096, p_node=[], autocorrelation_length = None, bank_id = None, identity_transform = False):

		pipeline.CondorDAGNode.__init__(self,job)
		self.add_var_opt("flow", flow)
		self.add_var_opt("snr-threshold", snr)
		self.add_var_opt("svd-tolerance", tolerance)
		self.add_var_opt("reference-psd", reference_psd)
		self.add_var_opt("template-bank", template_bank)
		self.add_var_opt("ortho-gate-fap", FAP)
		self.add_var_opt("samples-min", samples_min)
		self.add_var_opt("samples-max", samples_max)
		self.add_var_opt("samples-max-64", samples_max_64)
		self.add_var_opt("samples-max-256", samples_max_256)
		self.add_var_opt("clipleft", clipleft)
		self.add_var_opt("clipright", clipright)
		if bank_id is not None:
			self.add_var_opt("bank-id", bank_id)
		if autocorrelation_length is not None:
			self.add_var_opt("autocorrelation-length", autocorrelation_length)
		if identity_transform:
			self.add_var_arg("--identity-transform")
		svd_bank_name_path = os.path.split(template_bank)
		svd_bank_name = svd_bank_name_path[0] + "/svd_" + svd_bank_name_path[1]
		self.add_var_opt("write-svd-bank", svd_bank_name)
		dag.output_cache.append(lal.CacheEntry(ifo, "-", segments.segment(0, 999999999), "file://localhost%s" % (svd_bank_name,)))
		for p in p_node:
			self.add_parent(p)
		dag.add_node(self)

def parse_command_line():
	parser = OptionParser()
	parser.add_option("--instrument", help = "set the name of the instrument, required")
	parser.add_option("--reference-psd", metavar = "file", help = "Set the name of the reference psd file, required")
	parser.add_option("--bank-cache", metavar = "file", help = "Set the name of the bank cache, required")
	parser.add_option("--overlap", metavar = "num", type = "int", default = 0, help = "set the factor that describes the overlap of the sub banks, must be even!")
	parser.add_option("--identity-transform", default = False, action = "store_true", help = "turn off the SVD and use the identity reconstruction matrix")
	parser.add_option("--autocorrelation-length", type = "int", default = 201, help = "The number of samples to use for auto-chisquared, default 201 should be odd")
	parser.add_option("--samples-min", type = "int", default = 1024, help = "The minimum number of samples to use for time slices default 1024")
	parser.add_option("--samples-max-256", type = "int", default = 1024, help = "The maximum number of samples to use for time slices with frequencies above 256Hz, default 1024")
	parser.add_option("--samples-max-64", type = "int", default = 2048, help = "The maximum number of samples to use for time slices with frequencies above 64Hz, default 2048")
	parser.add_option("--samples-max", type = "int", default = 4096, help = "The maximum number of samples to use for time slices with frequencies below 64Hz, default 4096")
	parser.add_option("--stagger", action = "store_true", help = "A hacky way to stagger the number of samples used in the SVD if you have a broad mass range and the cache is organized by chirp mass from small to big.  The first half gets 1024 the second gets 2048. This helps balance the number of slices with the number of principle components when the template duration varies wildly through the cache.  Using this option disables the samples-* options.  FIXME this will be removed once we know how to automate this in a better way.")
	parser.add_option("--tolerance", metavar = "float", type = "float", default = 0.9995, help = "set the SVD tolerance, default 0.9995")
	parser.add_option("--flow", metavar = "num", type = "float", default = 40, help = "set the low frequency cutoff, default 40 (Hz)")
	parser.add_option("--output-name", help = "set the base name of the output, required")
	parser.add_option("-v", "--verbose", action = "store_true", help = "Be verbose.")
	options, filenames = parser.parse_args()

	if options.overlap % 2:
		raise ValueError("overlap must be even")

	return options, filenames

options, filenames = parse_command_line()


# get input arguments
ifo = options.instrument
ref_psd = options.reference_psd
input_cache = options.bank_cache

try: os.mkdir("logs")
except: pass
dag = bank_DAG(options.output_name)

svdJob = gstlal_svd_bank_job(tag_base="gstlal_svd_bank")
svdNode = {}

stagger = ((2048, 2048, 2048, 4096), (1024, 1024, 2048, 4096))

# assumes cache is sorted by chirpmass or whatever the SVD sorting algorithm that was chosen
files = [lal.CacheEntry(line).path for line in open(input_cache)]

for i, f in enumerate(files):
	# handle the edges by not clipping so you retain the template bank as intended.  
	clipleft = clipright = options.overlap / 2 # overlap must be even
	if i == 0:
		clipleft = 0
	if i == len(files) - 1:
		clipright = 0
	if options.stagger:
		samps = stagger[int(float(i) / len(files) * len(stagger))]
		svdNode[f] = gstlal_svd_bank_node(svdJob, dag, f, ifo, tolerance = options.tolerance, reference_psd = ref_psd, flow = options.flow, clipleft = clipleft, clipright = clipright, samples_min = samps[0], samples_max_256 = samps[1], samples_max_64 = samps[2], samples_max = samps[3], autocorrelation_length = options.autocorrelation_length, bank_id = i, identity_transform = options.identity_transform)
	else:
		svdNode[f] = gstlal_svd_bank_node(svdJob, dag, f, ifo, tolerance = options.tolerance, reference_psd = ref_psd, flow = options.flow, clipleft = clipleft, clipright = clipright, samples_min = options.samples_min, samples_max_256 = options.samples_max_256, samples_max_64 = options.samples_max, samples_max = options.samples_max, autocorrelation_length = options.autocorrelation_length, bank_id = i, identity_transform = options.identity_transform)

dag.write_sub_files()
dag.write_dag()
dag.write_script()
dag.write_cache()



