#!/usr/bin/python

import subprocess
import sys
import glob
import os
from glue import lal

from optparse import OptionParser

from glue.ligolw import ligolw
from glue.ligolw import table
from glue.ligolw import lsctables
from glue.ligolw import utils
from pylal import git_version
from pylal import ligolw_cafe
from glue import pipeline
import ConfigParser
import tempfile
import string
from glue import iterutils

parser = OptionParser(version = git_version.verbose_msg, usage = "%prog [options] [databases]")
parser.add_option("-v", "--verbose", action = "store_true", help = "Be verbose.")
parser.add_option("-n", "--number-of-trees", default=100, type="int")
parser.add_option("-c", "--criterion-for-optimization", default=5, type="int")
parser.add_option("-l", "--leaf-size", default=4, type="int")
parser.add_option("-s", "--sampled-parameters",default=6, type="int")
parser.add_option("-Z", "--exclude-variables",default="GPS", type="string", help = "Exclude input variables from the list, but put them in the output file.")
parser.add_option("-i", "--ini-file")
parser.add_option("-k", "--skip-file-generation", action = "store_true", help = "provide this option if you already have your .pat files in your directory and don't need to generate them again")
parser.add_option("-K", "--link-generated-files", help = "supply path to a directory that already has the .pat files you want to use")
parser.add_option("-p","--log-path", help = "set dagman log path")
(opts, trainingfiles) = parser.parse_args()

run_tag = '_n'+str(opts.number_of_trees)+'_l'+str(opts.leaf_size)+'_s'+str(opts.sampled_parameters)+'_c'+str(opts.criterion_for_optimization)

class auxmvc_DAG(pipeline.CondorDAG):
	def __init__(self, config_file, log_path):
		self.config_file = str(config_file)
		self.basename = self.config_file.replace('.ini','')+run_tag
		tempfile.tempdir = log_path
		tempfile.template = self.basename + '.dag.log.'
		logfile = tempfile.mktemp()
		fh = open( logfile, "w" )
		fh.close()
		pipeline.CondorDAG.__init__(self,logfile)
		self.set_dag_file(self.basename)
		self.jobsDict = {}
		self.id = 0
	def add_node(self, node):
		self.id+=1
		pipeline.CondorDAG.add_node(self, node)
		
class generate_files_job(pipeline.CondorDAGJob):
	def __init__(self, cp, tag_base='GENERATE_FILES'):
		"""
		"""
		self.__prog__ = 'generate_spr_files'
		self.__executable = string.strip(cp.get('condor','generate_spr_files'))
		self.__universe = "vanilla"
		pipeline.CondorDAGJob.__init__(self,self.__universe,self.__executable)
		self.add_condor_cmd('getenv','True')
		self.tag_base = tag_base
		self.add_condor_cmd('environment',"KMP_LIBRARY=serial;MKL_SERIAL=yes")
		self.set_sub_file(tag_base+'.sub')
		self.set_stdout_file('logs/'+tag_base+run_tag+'-$(macroid)-$(process).out')
		self.set_stderr_file('logs/'+tag_base+run_tag+'-$(macroid)-$(process).err')

class generate_files_node(pipeline.CondorDAGNode):
	def __init__(self, job, dag, options, outputfiles, p_node=[]):
		pipeline.CondorDAGNode.__init__(self,job)
		self.add_macro("macroid", dag.id)
		for opt in options:
			self.add_var_opt(opt[0],opt[1])
		for file in outputfiles:
			self.add_output_file(file)
		for p in p_node:
			self.add_parent(p)
		dag.add_node(self)

class train_forest_job(pipeline.CondorDAGJob):
	"""
	"""
	def __init__(self, cp, tag_base='TRAIN_FOREST'):
		"""
		"""
		self.__prog__ = 'SprBaggerDecisionTreeApp'
		self.__executable = string.strip(cp.get('condor','SprBaggerDecisionTreeApp'))
		self.__universe = "vanilla"
		pipeline.CondorDAGJob.__init__(self,self.__universe,self.__executable)
		self.add_condor_cmd('getenv','True')
		self.tag_base = tag_base
		self.add_condor_cmd('environment',"KMP_LIBRARY=serial;MKL_SERIAL=yes")
		self.set_sub_file(tag_base+'.sub')
		self.set_stdout_file('logs/'+tag_base+run_tag+'-$(macroid)-$(process).out')
		self.set_stderr_file('logs/'+tag_base+run_tag+'-$(macroid)-$(process).err')

class train_forest_node(pipeline.CondorDAGNode):
	"""
	"""
	def __init__(self, job, dag, trainingfile, p_node=[]):
		pipeline.CondorDAGNode.__init__(self,job)
		self.add_macro("macroid", dag.id)
		self.add_input_file(trainingfile)
		self.trainingfile = self.get_input_files()[0]
		self.trainedforest = self.trainingfile.replace('.pat',run_tag+'.spr')
                self.trainedout = self.trainingfile.replace('.pat',run_tag+'.out')
		self.add_file_arg("-a 1 -n %s -l %s -s %s -c %s -g 1 -i -d 1 -z %s -o %s -f %s %s" % (opts.number_of_trees, opts.leaf_size, opts.sampled_parameters, opts.criterion_for_optimization, opts.exclude_variables, self.trainedout,self.trainedforest, self.trainingfile))
		self.add_output_file(self.trainedforest)
		for p in p_node:
			self.add_parent(p)
		dag.add_node(self)

class use_forest_job(pipeline.CondorDAGJob):
	"""
	"""
	def __init__(self, cp, tag_base='USE_FOREST'):
		"""
		"""
		self.__prog__ = 'SprOutputWriterApp'
		self.__executable = string.strip(cp.get('condor','SprOutputWriterApp'))
		self.__universe = "vanilla"
		pipeline.CondorDAGJob.__init__(self,self.__universe,self.__executable)
		self.add_condor_cmd('getenv','True')
		self.tag_base = tag_base
		self.add_condor_cmd('environment',"KMP_LIBRARY=serial;MKL_SERIAL=yes")
		self.set_sub_file(tag_base+'.sub')
		self.set_stdout_file('logs/'+tag_base+run_tag+'-$(macroid)-$(process).out')
		self.set_stderr_file('logs/'+tag_base+run_tag+'-$(macroid)-$(process).err')

class use_forest_node(pipeline.CondorDAGNode):
	"""
	"""
	def __init__(self, job, dag, trainedforest, file_to_rank,	p_node=[]):
		pipeline.CondorDAGNode.__init__(self,job)
		#FIXME add tmp file space
		self.add_macro("macroid", dag.id)
		self.add_input_file(trainedforest)
		self.add_input_file(file_to_rank)
		self.trainedforest = self.get_input_files()[0]
		self.file_to_rank = self.get_input_files()[1]
		self.ranked_file = self.file_to_rank.replace('.pat','_n'+str(opts.number_of_trees)+'_l'+str(opts.leaf_size)+'_s'+str(opts.sampled_parameters)+'_c'+str(opts.criterion_for_optimization)+'.dat')
		self.add_file_arg("-A -a 1 -Z %s %s %s %s" % (opts.exclude_variables,self.trainedforest, self.file_to_rank, self.ranked_file))
		self.add_output_file(self.ranked_file)
		for p in p_node:
			self.add_parent(p)
		dag.add_node(self)

class channels_sig_job(pipeline.CondorDAGJob):
        """        """
        def __init__(self, cp, tag_base='CHSIG_PLOTS'):
                """
                """
                self.__prog__ = 'auxmvc_plot_mvsc_channels_significance'
                self.__executable = string.strip(cp.get('condor','auxmvc_plot_mvsc_channels_significance'))
                self.__universe = "vanilla"                
		pipeline.CondorDAGJob.__init__(self,self.__universe,self.__executable)
                self.add_condor_cmd('getenv','True')                
		self.tag_base = tag_base
                self.add_condor_cmd('environment',"KMP_LIBRARY=serial;MKL_SERIAL=yes")
                self.set_sub_file(tag_base+'.sub')
                self.set_stdout_file('logs/'+tag_base+run_tag+'-$(macroid)-$(process).out')
                self.set_stderr_file('logs/'+tag_base+run_tag+'-$(macroid)-$(process).err')

class channels_sig_node(pipeline.CondorDAGNode):
        def __init__(self, job, dag, input, usertag, options, p_node=[]):
                pipeline.CondorDAGNode.__init__(self,job)
                self.add_macro("macroid", dag.id)
                self.add_var_opt("input", input)
                self.add_var_opt("usertag",usertag)
                for opt in options:
                        self.add_var_opt(opt[0],opt[1])
                for p in p_node:                       
			 self.add_parent(p)
                dag.add_node(self)





class result_plots_job(pipeline.CondorDAGJob):
	"""
	"""
	def __init__(self, cp, tag_base='RESULT_PLOTS'):
		"""
		"""
		self.__prog__ = 'auxmvc_result_plots'
		self.__executable = string.strip(cp.get('condor','auxmvc_result_plots'))
		self.__universe = "vanilla"
		pipeline.CondorDAGJob.__init__(self,self.__universe,self.__executable)
		self.add_condor_cmd('getenv','True')
		self.tag_base = tag_base
		self.add_condor_cmd('environment',"KMP_LIBRARY=serial;MKL_SERIAL=yes")
		self.set_sub_file(tag_base+'.sub')
		self.set_stdout_file('logs/'+tag_base+run_tag+'-$(macroid)-$(process).out')
		self.set_stderr_file('logs/'+tag_base+run_tag+'-$(macroid)-$(process).err')

class result_plots_node(pipeline.CondorDAGNode):
	def __init__(self, job, dag, datfiles, tag, options, p_node=[]):
		pipeline.CondorDAGNode.__init__(self,job)
		self.add_macro("macroid", dag.id)
		for file in datfiles:
			self.add_file_arg(file[0])
		self.add_var_opt("tag",tag)
		for opt in options:
			self.add_var_opt(opt[0],opt[1])
		for p in p_node:
			self.add_parent(p)
		dag.add_node(self)

class combined_plot_job(pipeline.CondorDAGJob):
	"""
	"""
	def __init__(self, cp, tag_base='COMBINED_PLOT'):
		"""
		"""
		self.__prog__ = 'auxmvc_ROC_combiner'
		self.__executable = string.strip(cp.get('condor','auxmvc_ROC_combiner'))
		self.__universe = "vanilla"
		pipeline.CondorDAGJob.__init__(self,self.__universe,self.__executable)
		self.add_condor_cmd('getenv','True')
		self.tag_base = tag_base
		self.add_condor_cmd('environment',"KMP_LIBRARY=serial;MKL_SERIAL=yes")
		self.set_sub_file(tag_base+'.sub')
		self.set_stdout_file('logs/'+tag_base+run_tag+'-$(macroid)-$(process).out')
		self.set_stderr_file('logs/'+tag_base+run_tag+'-$(macroid)-$(process).err')

class combined_plot_node(pipeline.CondorDAGNode):
	def __init__(self, job, dag, picklefiles, run_tag, options, p_node=[]):
		pipeline.CondorDAGNode.__init__(self,job)
		self.add_macro("macroid", dag.id)
		for file in picklefiles:
			self.add_file_arg(str(file))
		self.add_var_opt("tag",run_tag)
		for opt in options:
			self.add_var_opt(opt[0],opt[1])
		for p in p_node:
			self.add_parent(p)
		dag.add_node(self)

###############################################################################
## MAIN #######################################################################
###############################################################################


### SET UP THE DAG

try: os.mkdir("logs")
except: pass

cp = ConfigParser.ConfigParser()
#FIXME don't assume file name
ininame = opts.ini_file
cp.read(ininame)
dag = auxmvc_DAG(ininame, opts.log_path)

#generate_files
generate_job = generate_files_job(cp)
training_files = {} #dictionary keyed on DQ category
evaluation_files = {} #dictionary keyed on DQ category
for cat in cp.get("generate_spr_files","DQ-cats").split(','):
	training_files[cat] = []
	evaluation_files[cat] = []
	for i in range(int(cp.get("generate_spr_files","roundrobin-number"))):
		training_files[cat].append(cat+'_'+cp.get("generate_spr_files","output-tag")+'_set_'+str(i)+'_'+'training.pat')
		evaluation_files[cat].append(cat+'_'+cp.get("generate_spr_files","output-tag")+'_set_'+str(i)+'_'+'evaluation.pat')

if opts.link_generated_files:
	dir=opts.link_generated_files
	for file in os.listdir(dir):
		if os.path.splitext(file)[1] == '.pat':
                        if not os.path.isfile(os.path.join(os.getcwd(),file)):
				os.symlink(os.path.abspath(os.path.join(dir,file)),os.path.join(os.getcwd(),file))
	
#SprBaggerDecisionTreeApp
train_job = {}
train_node = {}

#SprOutputWriterApp
rank_job = {}
rank_node = {}
zl_rank_job = use_forest_job(cp)
zl_rank_node = {}

#plot_mvsc_channels_significance
ch_sig_job = channels_sig_job(cp)
ch_sig_nodes = {}

#result_plots
results_job = result_plots_job(cp)
results_node = {}
picklenodesdone=[]
picklefiles = []

#combined_plot
combined_job = combined_plot_job(cp)
combined_node = {}

#set up the nodes
if opts.skip_file_generation or opts.link_generated_files:
	generate_node = None
else:
	generate_node = generate_files_node(generate_job, dag, cp.items("generate_spr_files"), training_files.values()+evaluation_files.values())
for cat in cp.get("generate_spr_files","DQ-cats").split(','):
	print cat
	catdatfiles=[]
	train_node[cat]={}
	rank_node[cat]={}
	train_job[cat] = train_forest_job(cp, tag_base="TRAIN_FOREST_"+cat)
	rank_job[cat] = use_forest_job(cp, tag_base="USE_FOREST_"+cat)
	for i,file in enumerate(training_files[cat]):
		if opts.skip_file_generation or opts.link_generated_files: train_node[cat][file] = train_forest_node(train_job[cat], dag, file, p_node=[])
		else: train_node[cat][file] = train_forest_node(train_job[cat], dag, file, p_node=[generate_node])
		rank_node[cat][file] = use_forest_node(rank_job[cat], dag, train_node[cat][file].trainedforest, file.replace('_training','_evaluation'), p_node=[train_node[cat][file]])
		catdatfiles.append(rank_node[cat][file].get_output_files())
	ch_sig_nodes[cat] = channels_sig_node(ch_sig_job, dag, "logs/"+train_job[cat].tag_base+run_tag+"*.out", cat+run_tag, cp.items("auxmvc_plot_mvsc_channels_significance"), p_node=rank_node[cat].values()) 
	roctag = cat+'_'+cp.get("generate_spr_files","output-tag")+run_tag
	results_node[cat] = result_plots_node(results_job, dag, catdatfiles, roctag, cp.items("auxmvc_result_plots"), p_node=rank_node[cat].values())
	picklefiles.append(cp.get("auxmvc_result_plots","output-dir")+'/ROC_'+roctag+'.pickle')
	picklenodesdone.append(results_node[cat])
combined_node = combined_plot_node(combined_job, dag, picklefiles, run_tag, cp.items("auxmvc_ROC_combiner"), p_node=picklenodesdone)


dag.write_sub_files()
dag.write_dag()
dag.write_script()
