#!/usr/bin/python

import subprocess
import sys
import glob
import os
from glue import lal

from optparse import OptionParser

from glue.ligolw import ligolw
from glue.ligolw import table
from glue.ligolw import lsctables
from glue.ligolw import utils
from pylal import git_version
from pylal import ligolw_cafe
from glue import pipeline
import ConfigParser
import tempfile
import string
from glue import iterutils

parser = OptionParser(version = git_version.verbose_msg, usage = "%prog [options] [databases]")
parser.add_option("-v", "--verbose", action = "store_true", help = "Be verbose.")
parser.add_option("-n", "--numbers-of-neurons", type="string")
parser.add_option("-m", "--max-epochs", default=1000, type="int")
parser.add_option("-w", "--weights-min", default=-0.1, type="float")
parser.add_option("-x", "--weights-max", default=0.1, type="float")
parser.add_option("-y", "--increase-factor", default=1.2, type="float")
parser.add_option("-z", "--decrease-factor", default=0.5, type="float")
parser.add_option("-a", "--delta-min", default=0.0, type="float")
parser.add_option("-b", "--delta-max", default=50.0, type="float")
parser.add_option("-f", "--steep-hidden", default=0.5, type="float")
parser.add_option("-g", "--steep-out", default=0.5, type="float")
parser.add_option("-i", "--ini-file")
parser.add_option("-k", "--skip-file-generation", action = "store_true", help = "provide this option if you already have your .pat files in your directory and don't need to generate them again")
parser.add_option("-K", "--link-generated-files", help = "supply path to a directory that already has the .pat files you want to use")
parser.add_option("-c", "--skip-file-conversion", action = "store_true", help = "provide this option if you already have your .ann files in your directory and don't need to generate them again")
parser.add_option("-C", "--link-converted-files", help = "supply path to a directory that already has the .ann files you want to use")
parser.add_option("-s", "--subsets", default=1, type="int")
parser.add_option("-p","--log-path", help = "set dagman log path")
(opts, trainingfiles) = parser.parse_args()

run_tag = '_n'+str('n'.join(opts.numbers_of_neurons.split(",")))+'_y'+''.join(str(opts.increase_factor).split("."))+'_f'+''.join(str(opts.steep_hidden).split("."))+'g'+''.join(str(opts.steep_out).split("."))+'_m'+str(opts.max_epochs)

class auxmvc_DAG(pipeline.CondorDAG):
	def __init__(self, config_file, log_path):
		self.config_file = str(config_file)
		self.basename = self.config_file.replace('.ini','')+run_tag
		tempfile.tempdir = log_path
		tempfile.template = self.basename + '.dag.log.'
		logfile = tempfile.mktemp()
		fh = open( logfile, "w" )
		fh.close()
		pipeline.CondorDAG.__init__(self,logfile)
		self.set_dag_file(self.basename)
		self.jobsDict = {}
		self.id = 0
	def add_node(self, node):
		self.id+=1
		pipeline.CondorDAG.add_node(self, node)
		
class generate_files_job(pipeline.CondorDAGJob):
	def __init__(self, cp, tag_base='GENERATE_FILES'):
		"""
		"""
		self.__prog__ = 'generate_spr_files'
		self.__executable = string.strip(cp.get('condor','generate_spr_files'))
		self.__universe = "vanilla"
		pipeline.CondorDAGJob.__init__(self,self.__universe,self.__executable)
		self.add_condor_cmd('getenv','True')
		self.tag_base = tag_base
		self.add_condor_cmd('environment',"KMP_LIBRARY=serial;MKL_SERIAL=yes")
		self.set_sub_file(tag_base+'.sub')
		self.set_stdout_file('logs/'+tag_base+run_tag+'-$(macroid)-$(process).out')
		self.set_stderr_file('logs/'+tag_base+run_tag+'-$(macroid)-$(process).err')

class generate_files_node(pipeline.CondorDAGNode):
	def __init__(self, job, dag, options, outputfiles, p_node=[]):
		pipeline.CondorDAGNode.__init__(self,job)
		self.add_macro("macroid", dag.id)
		for opt in options:
			self.add_var_opt(opt[0],opt[1])
		for file in outputfiles:
			self.add_output_file(file)
		for p in p_node:
			self.add_parent(p)
		dag.add_node(self)

class convert_files_job(pipeline.CondorDAGJob):
	def __init__(self, cp, tag_base='CONVERT_FILES'):
		"""
		"""
		self.__prog__ = 'ConvertSprToFann'
		self.__executable = string.strip(cp.get('condor','ConvertSprToFann'))
		self.__universe = "vanilla"
		pipeline.CondorDAGJob.__init__(self,self.__universe,self.__executable)
		self.add_condor_cmd('getenv','True')
		self.tag_base = tag_base
		self.add_condor_cmd('environment',"KMP_LIBRARY=serial;MKL_SERIAL=yes")
		self.set_sub_file(tag_base+'.sub')
		self.set_stdout_file('logs/'+tag_base+run_tag+'-$(macroid)-$(process).out')
		self.set_stderr_file('logs/'+tag_base+run_tag+'-$(macroid)-$(process).err')

class convert_files_node(pipeline.CondorDAGNode):
	def __init__(self, job, dag, trainingfile, p_node=[]):
		pipeline.CondorDAGNode.__init__(self,job)
		self.add_macro("macroid", dag.id)
		self.add_input_file(trainingfile)
		self.trainingfile = self.get_input_files()[0]
		self.add_file_arg("-t %s --subsets %i" % (self.trainingfile, opts.subsets))
		self.add_output_file(file)
		for p in p_node:
			self.add_parent(p)
		dag.add_node(self)

class train_network_job(pipeline.CondorDAGJob):
	"""
	"""
	def __init__(self, cp, tag_base='TRAIN_NETWORK'):
		"""
		"""
		self.__prog__ = 'FannTrainNeuralNetwork'
		self.__executable = string.strip(cp.get('condor','FannTrainNeuralNetwork'))
		self.__universe = "vanilla"
		pipeline.CondorDAGJob.__init__(self,self.__universe,self.__executable)
		self.add_condor_cmd('getenv','True')
		self.tag_base = tag_base
		self.add_condor_cmd('environment',"KMP_LIBRARY=serial;MKL_SERIAL=yes")
		self.set_sub_file(tag_base+'.sub')
		self.set_stdout_file('logs/'+tag_base+run_tag+'-$(macroid)-$(process).out')
		self.set_stderr_file('logs/'+tag_base+run_tag+'-$(macroid)-$(process).err')

class train_network_node(pipeline.CondorDAGNode):
	"""
	"""
	def __init__(self, job, dag, trainingfile, p_node=[]):
		pipeline.CondorDAGNode.__init__(self,job)
		self.add_macro("macroid", dag.id)
		self.add_input_file(trainingfile)
		self.trainingfile = self.get_input_files()[0]
		self.trainednetwork = self.trainingfile.replace('.pat',run_tag+".net")
		self.add_file_arg("-t %s -s %s -n %s -m %i -y %f -f %f -g %f --subsets %i" % (self.trainingfile, self.trainednetwork, opts.numbers_of_neurons, opts.max_epochs, opts.increase_factor, opts.steep_hidden, opts.steep_out, opts.subsets))
		self.add_output_file(self.trainednetwork)
		for p in p_node:
			self.add_parent(p)
		dag.add_node(self)

class use_network_job(pipeline.CondorDAGJob):
	"""
	"""
	def __init__(self, cp, tag_base='USE_NETWORK'):
		"""
		"""
		self.__prog__ = 'FannEvaluation'
		self.__executable = string.strip(cp.get('condor','FannEvaluation'))
		self.__universe = "vanilla"
		pipeline.CondorDAGJob.__init__(self,self.__universe,self.__executable)
		self.add_condor_cmd('getenv','True')
		self.tag_base = tag_base
		self.add_condor_cmd('environment',"KMP_LIBRARY=serial;MKL_SERIAL=yes")
		self.set_sub_file(tag_base+'.sub')
		self.set_stdout_file('logs/'+tag_base+run_tag+'-$(macroid)-$(process).out')
		self.set_stderr_file('logs/'+tag_base+run_tag+'-$(macroid)-$(process).err')

class use_network_node(pipeline.CondorDAGNode):
	"""
	"""
	def __init__(self, job, dag, trainednetwork, file_to_rank,	p_node=[]):
		pipeline.CondorDAGNode.__init__(self,job)
		#FIXME add tmp file space
		self.add_macro("macroid", dag.id)
		self.add_input_file(trainednetwork)
		self.add_input_file(file_to_rank)
		self.trainednetwork = self.get_input_files()[0]
		self.file_to_rank = self.get_input_files()[1]
		self.ranked_file = self.file_to_rank.replace('.pat',run_tag+'.dat')
		self.add_file_arg("-e %s -n %s -s %s" % (self.file_to_rank, self.trainednetwork, self.ranked_file))
		self.add_output_file(self.ranked_file)
		for p in p_node:
			self.add_parent(p)
		dag.add_node(self)

class result_plots_job(pipeline.CondorDAGJob):
	"""
	"""
	def __init__(self, cp, tag_base='RESULT_PLOTS'+run_tag):
		"""
		"""
		self.__prog__ = 'auxmvc_result_plots'
		self.__executable = string.strip(cp.get('condor','auxmvc_result_plots'))
		self.__universe = "vanilla"
		pipeline.CondorDAGJob.__init__(self,self.__universe,self.__executable)
		self.add_condor_cmd('getenv','True')
		self.tag_base = tag_base
		self.add_condor_cmd('environment',"KMP_LIBRARY=serial;MKL_SERIAL=yes")
		self.set_sub_file(tag_base+'.sub')
		self.set_stdout_file('logs/'+tag_base+run_tag+'-$(macroid)-$(process).out')
		self.set_stderr_file('logs/'+tag_base+run_tag+'-$(macroid)-$(process).err')

class result_plots_node(pipeline.CondorDAGNode):
	def __init__(self, job, dag, datfiles, tag, options, p_node=[]):
		pipeline.CondorDAGNode.__init__(self,job)
		self.add_macro("macroid", dag.id)
		for file in datfiles:
			self.add_file_arg(file[0])
		self.add_var_opt("tag",tag)
		for opt in options:
			self.add_var_opt(opt[0],opt[1])
		for p in p_node:
			self.add_parent(p)
		dag.add_node(self)

class combined_plot_job(pipeline.CondorDAGJob):
	"""
	"""
	def __init__(self, cp, tag_base='COMBINED_PLOT'+run_tag):
		"""
		"""
		self.__prog__ = 'auxmvc_ROC_combiner'
		self.__executable = string.strip(cp.get('condor','auxmvc_ROC_combiner'))
		self.__universe = "vanilla"
		pipeline.CondorDAGJob.__init__(self,self.__universe,self.__executable)
		self.add_condor_cmd('getenv','True')
		self.tag_base = tag_base
		self.add_condor_cmd('environment',"KMP_LIBRARY=serial;MKL_SERIAL=yes")
		self.set_sub_file(tag_base+'.sub')
		self.set_stdout_file('logs/'+tag_base+run_tag+'-$(macroid)-$(process).out')
		self.set_stderr_file('logs/'+tag_base+run_tag+'-$(macroid)-$(process).err')

class combined_plot_node(pipeline.CondorDAGNode):
	def __init__(self, job, dag, picklefiles, run_tag, options, p_node=[]):
		pipeline.CondorDAGNode.__init__(self,job)
		self.add_macro("macroid", dag.id)
		for file in picklefiles:
			self.add_file_arg(str(file))
		self.add_var_opt("tag",run_tag)
		for opt in options:
			self.add_var_opt(opt[0],opt[1])
		for p in p_node:
			self.add_parent(p)
		dag.add_node(self)

###############################################################################
## MAIN #######################################################################
###############################################################################


### SET UP THE DAG

try: os.mkdir("logs")
except: pass

cp = ConfigParser.ConfigParser()
#FIXME don't assume file name
ininame = opts.ini_file
cp.read(ininame)
dag = auxmvc_DAG(ininame, opts.log_path)

#generate_files
generate_job = generate_files_job(cp)
training_files = {} #dictionary keyed on DQ category
evaluation_files = {} #dictionary keyed on DQ category
for cat in cp.get("generate_spr_files","DQ-cats").split(','):
	training_files[cat] = []
	evaluation_files[cat] = []
	for i in range(int(cp.get("generate_spr_files","roundrobin-number"))):
		training_files[cat].append(cat+'_'+cp.get("generate_spr_files","output-tag")+'_set_'+str(i)+'_'+'training.pat')
		evaluation_files[cat].append(cat+'_'+cp.get("generate_spr_files","output-tag")+'_set_'+str(i)+'_'+'evaluation.pat')

if opts.link_generated_files:
	dir=opts.link_generated_files
	for file in os.listdir(dir):
		if os.path.splitext(file)[1] == '.pat':
                        if not os.path.isfile(os.path.join(os.getcwd(),file)):
				os.symlink(os.path.abspath(os.path.join(dir,file)),os.path.join(os.getcwd(),file))
	
#convert_files
if opts.link_converted_files:
	dir=opts.link_converted_files
	for file in os.listdir(dir):
		if os.path.splitext(file)[1] == '.ann':
                        if not os.path.isfile(os.path.join(os.getcwd(),file)):
				os.symlink(os.path.abspath(os.path.join(dir,file)),os.path.join(os.getcwd(),file))
	
#ConvertSprToFann
convert_job = {}
convert_node = {}

#FannTrainNeuralNetwork
train_job = {}
train_node = {}

#FannEvaluation
rank_job = {}
rank_node = {}
zl_rank_job = use_network_job(cp)
zl_rank_node = {}

#result_plots
results_job = result_plots_job(cp)
results_node = {}
picklenodesdone=[]
picklefiles = []

#combined_plot
combined_job = combined_plot_job(cp)
combined_node = {}

#set up the nodes
if opts.skip_file_generation or opts.link_generated_files:
	generate_node = None
else:
	generate_node = generate_files_node(generate_job, dag, cp.items("generate_spr_files"), training_files.values()+evaluation_files.values())
for cat in cp.get("generate_spr_files","DQ-cats").split(','):
	print cat
	catdatfiles=[]
	train_node[cat]={}
	rank_node[cat]={}
	convert_node[cat]={}
	convert_job[cat] = convert_files_job(cp, tag_base="CONVERT_FILES_"+cat+run_tag)
	train_job[cat] = train_network_job(cp, tag_base="TRAIN_NETWORK_"+cat+run_tag)
	rank_job[cat] = use_network_job(cp, tag_base="USE_NETWORK_"+cat+run_tag)
	for i,file in enumerate(training_files[cat]):
		if opts.skip_file_generation or opts.link_generated_files: 
			if not (opts.skip_file_conversion or opts.link_converted_files): 
				convert_node[cat][file] = convert_files_node(convert_job[cat], dag, file, p_node=[])
		else: 
			if not (opts.skip_file_conversion or opts.link_converted_files): 
				convert_node[cat][file] = convert_files_node(convert_job[cat], dag, file, p_node=[generate_node])
		if opts.skip_file_conversion or opts.link_converted_files: 
			train_node[cat][file] = train_network_node(train_job[cat], dag, file, p_node=[])
		else: 
			train_node[cat][file] = train_network_node(train_job[cat], dag, file, p_node=[convert_node[cat][file]])
		rank_node[cat][file] = use_network_node(rank_job[cat], dag, train_node[cat][file].trainednetwork, file.replace('_training','_evaluation'), p_node=[train_node[cat][file]])
		catdatfiles.append(rank_node[cat][file].get_output_files())
	roctag = cat+'_'+cp.get("generate_spr_files","output-tag")+run_tag
	results_node[cat] = result_plots_node(results_job, dag, catdatfiles, roctag, cp.items("auxmvc_result_plots"), p_node=rank_node[cat].values())
	picklefiles.append(cp.get("auxmvc_result_plots","output-dir")+'/ROC_'+roctag+'.pickle')
	picklenodesdone.append(results_node[cat])
combined_node = combined_plot_node(combined_job, dag, picklefiles, run_tag, cp.items("auxmvc_ROC_combiner"), p_node=picklenodesdone)


dag.write_sub_files()
dag.write_dag()
dag.write_script()
## test
