#!/usr/bin/env python
#
# Copyright (C) 2011 Chad Hanna
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
This program makes a dag to recolor frames
"""

__author__ = 'Chad Hanna <chad.hanna@ligo.org>'


##############################################################################
# import standard modules and append the lalapps prefix to the python path
import sys, os, copy, math
import subprocess, socket, tempfile

##############################################################################
# import the modules we need to build the pipeline
from glue import iterutils
from glue import pipeline
from glue import lal
from glue.ligolw import lsctables
from glue import segments
import glue.ligolw.utils as utils
import glue.ligolw.utils.segments as ligolw_segments
from optparse import OptionParser
from gstlal import datasource

def which(prog):
	which = subprocess.Popen(['which',prog], stdout=subprocess.PIPE)
	out = which.stdout.read().strip()
	if not out: 
		print >>sys.stderr, "ERROR: could not find %s in your path, have you built the proper software and source the proper env. scripts?" % (prog,prog)
		raise ValueError 
	return out

def log_path():
	host = socket.getfqdn()
	#FIXME add more hosts as you need them
	if 'caltech.edu' in host: return '/usr1/' + os.environ['USER']
	if 'phys.uwm.edu' in host: return '/localscratch/' + os.environ['USER']
	if 'aei.uni-hannover.de' in host: return '/local/user/' + os.environ['USER']
	if 'phy.syr.edu' in host: return '/usr1/' + os.environ['USER']


class bank_DAG(pipeline.CondorDAG):

	def __init__(self, name, logpath = log_path()):
		self.basename = name
		tempfile.tempdir = logpath
		tempfile.template = self.basename + '.dag.log.'
		logfile = tempfile.mktemp()
		fh = open( logfile, "w" )
		fh.close()
		pipeline.CondorDAG.__init__(self,logfile)
		self.set_dag_file(self.basename)
		self.jobsDict = {}
		self.node_id = 0
		self.output_cache = []

	def add_node(self, node):
		node.set_retry(3)
		self.node_id += 1
		node.add_macro("macroid", self.node_id)
		pipeline.CondorDAG.add_node(self, node)

	def write_cache(self):
		out = self.basename + ".cache"
		f = open(out,"w")
		for c in self.output_cache:
			f.write(str(c)+"\n")
		f.close()

#
# Classes for generating reference psds
#

class gstlal_reference_psd_job(pipeline.CondorDAGJob):
	"""
	A gstlal_reference_psd job
	"""
	def __init__(self, executable=which('gstlal_reference_psd'), tag_base='gstlal_reference_psd'):
		"""
		"""
		self.__prog__ = 'gstlal_reference_psd'
		self.__executable = executable
		self.__universe = 'vanilla'
		pipeline.CondorDAGJob.__init__(self,self.__universe,self.__executable)
		self.add_condor_cmd('getenv','True')
		self.add_condor_cmd('requirements', 'Memory > 1999') #FIXME is this enough?
		self.tag_base = tag_base
		self.add_condor_cmd('environment',"KMP_LIBRARY=serial;MKL_SERIAL=yes")
		self.set_sub_file(tag_base+'.sub')
		self.set_stdout_file('logs/'+tag_base+'-$(macroid)-$(process).out')
		self.set_stderr_file('logs/'+tag_base+'-$(macroid)-$(process).err')


class gstlal_median_psd_job(pipeline.CondorDAGJob):
	"""
	A gstlal_median_psd job
	"""
	def __init__(self, executable=which('gstlal_ninja_median_of_psds'), tag_base='gstlal_ninja_median_of_psds'):
		"""
		"""
		self.__prog__ = 'gstlal_ninja_median_of_psds'
		self.__executable = executable
		self.__universe = 'vanilla'
		pipeline.CondorDAGJob.__init__(self,self.__universe,self.__executable)
		self.add_condor_cmd('getenv','True')
		self.tag_base = tag_base
		self.add_condor_cmd('environment',"KMP_LIBRARY=serial;MKL_SERIAL=yes")
		self.set_sub_file(tag_base+'.sub')
		self.set_stdout_file('logs/'+tag_base+'-$(macroid)-$(process).out')
		self.set_stderr_file('logs/'+tag_base+'-$(macroid)-$(process).err')


class gstlal_ninja_smooth_reference_psd_job(pipeline.CondorDAGJob):
	"""
	A gstlal_ninja_smooth_reference_psd job
	"""
	def __init__(self, executable=which('gstlal_ninja_smooth_reference_psd'), tag_base='gstlal_ninja_smooth_reference_psd'):
		"""
		"""
		self.__prog__ = 'gstlal_ninja_smooth_reference_psd'
		self.__executable = executable
		self.__universe = 'vanilla'
		pipeline.CondorDAGJob.__init__(self,self.__universe,self.__executable)
		self.add_condor_cmd('getenv','True')
		self.tag_base = tag_base
		self.add_condor_cmd('environment',"KMP_LIBRARY=serial;MKL_SERIAL=yes")
		self.set_sub_file(tag_base+'.sub')
		self.set_stdout_file('logs/'+tag_base+'-$(macroid)-$(process).out')
		self.set_stderr_file('logs/'+tag_base+'-$(macroid)-$(process).err')


class gstlal_reference_psd_node(pipeline.CondorDAGNode):
	"""
	A gstlal_reference_psd node
	"""
	def __init__(self, job, dag, frame_cache, gps_start_time, gps_end_time, instrument, channel, injections=None, p_node=[]):

		pipeline.CondorDAGNode.__init__(self,job)
		self.add_var_opt("frame-cache", frame_cache)
		self.add_var_opt("gps-start-time", gps_start_time)
		self.add_var_opt("gps-end-time", gps_end_time)
		self.add_var_opt("data-source", "frames")
		self.add_var_arg("--channel-name=%s=%s" % (instrument, channel))
		if injections:
			self.add_var_opt("injections", injections)
		path = os.getcwd()
		output_name = self.output_name = '%s/%s-%d-%d-reference_psd.xml.gz' % (path, instrument, gps_start_time, gps_end_time)
		self.add_var_opt("write-psd",output_name)
		dag.output_cache.append(lal.CacheEntry(instrument, "-", segments.segment(gps_start_time, gps_end_time), "file://localhost/%s" % (output_name,)))
		for p in p_node:
			self.add_parent(p)
		dag.add_node(self)


class gstlal_ninja_smooth_reference_psd_node(pipeline.CondorDAGNode):
	"""
	A gstlal_ninja_smooth_reference_psd node
	"""
	def __init__(self, job, dag, instrument, input_psd, p_node=[]):
		pipeline.CondorDAGNode.__init__(self,job)
		path = os.getcwd()
		#FIXME shouldn't be hardcoding stuff like this
		output_name = self.output_name = input_psd.replace('reference_psd', 'smoothed_reference_psd')
		self.add_var_opt("instrument", instrument)
		self.add_var_opt("input-psd", input_psd)
		self.add_var_opt("output-psd", output_name)
		for p in p_node:
			self.add_parent(p)
		dag.add_node(self)


class gstlal_median_psd_node(pipeline.CondorDAGNode):
	"""
	A gstlal_median_psd node
	"""
	def __init__(self, job, dag, instrument, input_psds, output, p_node=[]):
		pipeline.CondorDAGNode.__init__(self,job)
		path = os.getcwd()
		#FIXME shouldn't be hardcoding stuff like this
		output_name = self.output_name = output
		self.add_var_opt("instrument", instrument)
		self.add_var_opt("output-name", output_name)
		for psd in input_psds:
			self.add_file_arg(psd)
		for p in p_node:
			self.add_parent(p)
		dag.add_node(self)


#
# classes for generating recolored frames
#

class gstlal_recolor_frames_job(pipeline.CondorDAGJob):
	"""
	A gstlal_recolor_frames job
	"""
	def __init__(self, executable=which('gstlal_recolor_frames'), tag_base='gstlal_recolor_frames'):
		"""
		"""
		self.__prog__ = 'gstlal_recolor_frames'
		self.__executable = executable
		self.__universe = 'vanilla'
		pipeline.CondorDAGJob.__init__(self,self.__universe,self.__executable)
		self.add_condor_cmd('getenv','True')
		self.add_condor_cmd('requirements', 'Memory > 1999') #FIXME is this enough?
		self.tag_base = tag_base
		self.add_condor_cmd('environment',"KMP_LIBRARY=serial;MKL_SERIAL=yes")
		self.set_sub_file(tag_base+'.sub')
		self.set_stdout_file('logs/'+tag_base+'-$(macroid)-$(process).out')
		self.set_stderr_file('logs/'+tag_base+'-$(macroid)-$(process).err')


class gstlal_recolor_frames_node(pipeline.CondorDAGNode):
	"""
	A gstlal_recolor_frames node
	"""
	def __init__(self, job, dag, frame_cache, gps_start_time, gps_end_time, instrument, channel, reference_psd, recolor_psd, injections=None, output_channel_name = None, duration = 4096, output_path = None, frame_type = None, p_node=[]):

		pipeline.CondorDAGNode.__init__(self,job)
		self.add_var_opt("data-source", "frames")
		self.add_var_opt("frame-cache", frame_cache)
		self.add_var_opt("gps-start-time",gps_start_time)
		self.add_var_opt("gps-end-time",gps_end_time)
		self.add_var_opt("output-channel-name", output_channel_name)
		if injections is not None:
			self.add_var_opt("injections", injections)
		if output_path is not None:
			self.add_var_opt("output-path", output_path)
		self.add_var_opt("frame-type", frame_type)
		self.add_var_opt("recolor-psd",recolor_psd)
		self.add_var_opt("reference-psd",reference_psd)
		self.add_var_opt("duration", duration)
		self.add_var_arg("--channel-name=%s=%s" % (instrument, channel))
		for p in p_node:
			self.add_parent(p)
		dag.add_node(self)


def breakupsegs(seglists, min_segment_length):
	for instrument, seglist in seglists.iteritems():
		newseglist = segments.segmentlist()
		for seg in seglist:
			if abs(seg) > min_segment_length:
				newseglist.append(segments.segment(seg))
		seglists[instrument] = newseglist


def parse_command_line():
	parser = OptionParser(description = __doc__)
	parser.add_option("--segment-file", metavar = "filename", help = "Set the name of the xml file to get segments (required).")
	parser.add_option("--min-segment-length", metavar = "seconds", help = "Set the minimum segment length (required)", type="float")
	parser.add_option("--injections", metavar = "filename", help = "Set the name of the xml file to get a cache of svd banks (optional)")
	parser.add_option("--recolor-psd", metavar = "filename", help = "Set the name of the xml file to get the reference psd")
	parser.add_option("--frame-cache", metavar = "filenames", help = "Set the frame cache files in format H1=H1.cache,H2=H2.cache, etc..")
	#FIXME get this from the cache?
	parser.add_option("--input-channel", metavar = "name", action = "append", help = "Set the channel name like H1=LSC-STRAIN. can be given more than once and is required")
	parser.add_option("--output-channel", metavar = "name", action = "append", help = "Set the channel name like H1=LSC-STRAIN. can be given more than once and is required")
	parser.add_option("--output-path", metavar = "PATH", action = "append", help = "Set the instrument dependent output path for frames, defaults to current working directory. eg H1=/path/to/H1/frames. Can be given more than once.")
	parser.add_option("--frame-type", metavar = "name", action = "append", help = "Set the instrument dependent frame type, H1=TYPE. Can be given more than once and is required for each instrument processed.")
	
	options, filenames = parser.parse_args()

	fail = ""
	for option in ("segment_file","min_segment_length", "frame_cache", "frame_type"):
		if getattr(options, option) is None:
			fail += "must provide option %s\n" % (option)
	if fail: raise ValueError, fail

	framecache = {}
	for c in options.frame_cache.split(','):
		ifo = c.split("=")[0]
		cache = c.replace(ifo+"=","")
		framecache[ifo] = cache
	inchannels = datasource.channel_dict_from_channel_list(options.input_channel)
	outchannels = datasource.channel_dict_from_channel_list(options.output_channel)
	frametypes = datasource.channel_dict_from_channel_list(options.frame_type)
	if options.output_path is None:
		outpaths = {}
	else:
		outpaths = datasource.channel_dict_from_channel_list(options.output_path)

	if not (set(framecache) == set(inchannels) == set(outchannels)):
		raise ValueError('--frame-cache, --input-channels and --output-channels must contain same instruments')

	return options, filenames, framecache, inchannels, outchannels, outpaths, frametypes#, process_params


options, filenames, frame_cache, inchannels, outchannels, outpaths, frametypes = parse_command_line()

try: os.mkdir("logs")
except: pass
dag = bank_DAG("recolor_pipe")

seglists = ligolw_segments.segmenttable_get_by_name(utils.load_filename(options.segment_file), "datasegments").coalesce()
breakupsegs(seglists, options.min_segment_length)

psdJob = gstlal_reference_psd_job()
smoothJob = gstlal_ninja_smooth_reference_psd_job()
medianJob = gstlal_median_psd_job()
colorJob = gstlal_recolor_frames_job()

smoothnode = {}
mediannode = {}

for instrument, seglist in seglists.iteritems():
	mediannode[instrument] = {}
	smoothnode[instrument] = {}
	for seg in seglist:
		#FIXME if there are sements without frame caches this will barf
		psdnode = gstlal_reference_psd_node(psdJob, dag, frame_cache[instrument], int(seg[0]), int(seg[1]), instrument, inchannels[instrument], injections=options.injections, p_node=[])
		smoothnode[instrument][seg] = gstlal_ninja_smooth_reference_psd_node(smoothJob, dag, instrument, psdnode.output_name,  p_node=[psdnode])

	mediannode[instrument] = gstlal_median_psd_node(medianJob, dag, instrument, [v.output_name for v in smoothnode[instrument].values()], "%s_median_psd.xml.gz" % instrument, p_node=smoothnode[instrument].values())

for instrument, seglist in seglists.iteritems():
	try:
		output_path = outpaths[instrument]
	except KeyError:
		output_path = None
	for seg in seglist:
		gstlal_recolor_frames_node(colorJob, dag, frame_cache[instrument], int(seg[0]), int(seg[1]), instrument, inchannels[instrument], reference_psd=mediannode[instrument].output_name, recolor_psd = options.recolor_psd, injections=options.injections, output_channel_name = outchannels[instrument], output_path = output_path, frame_type = frametypes[instrument], p_node=[mediannode[instrument]])
		
dag.write_sub_files()
dag.write_dag()
dag.write_script()
dag.write_cache()



