#!/usr/bin/python

try:
  import sqlite3
except ImportError:
  # pre 2.5.x
  from pysqlite2 import dbapi2 as sqlite3

import subprocess
import sys
import glob
import os
from glue import lal

from optparse import OptionParser

from glue.ligolw import dbtables
from glue.ligolw import ligolw
from glue.ligolw import table
from glue.ligolw import lsctables
from glue.ligolw import utils
from pylal import git_version
from pylal import ligolw_cafe
from pylal import mvsc_queries
from glue import pipeline
import ConfigParser
import tempfile
import string
from glue import iterutils

parser = OptionParser(version = git_version.verbose_msg, usage = "%prog [options] [databases]")
parser.add_option("-v", "--verbose", action = "store_true", help = "Be verbose.")
parser.add_option("-n", "--number-of-trees", default=100, type="int")
parser.add_option("-c", "--criterion-for-optimization", default=5, type="int")
parser.add_option("-l", "--leaf-size", default=4, type="int")
parser.add_option("-s", "--sampled-parameters",default=6, type="int")
parser.add_option("-i", "--ini-file")
parser.add_option("-k", "--skip-file-generation", action = "store_true", help = "provide this option if you already have your .pat files and don't need to generate them again")
parser.add_option("-p","--log-path", default='logs/', help = "set dagman log path")
parser.add_option("-a","--all-instruments", help = "the list of all instruments from which you want to study the double-coincident triggers. e.g. H1,H2,L1,V1")
parser.add_option("-u","--user-tag", default='CBC', help= "help you keep track of multiple runs in a single directory, i.e. CAT_3 or CAT_4")

(opts, databases) = parser.parse_args()

run_tag = opts.user_tag+'_n'+str(opts.number_of_trees)+'_l'+str(opts.leaf_size)+'_s'+str(opts.sampled_parameters)+'_c'+str(opts.criterion_for_optimization)

cp = ConfigParser.ConfigParser()
ininame = opts.ini_file
cp.read(ininame)

class mvsc_dag_DAG(pipeline.CondorDAG):
  def __init__(self, config_file, log_path):
    self.config_file = str(config_file)
    self.basename = self.config_file.replace('.ini','')+'_mvsc_'+run_tag
    tempfile.tempdir = log_path
    tempfile.template = self.basename + '.dag.log.'
    logfile = tempfile.mktemp()
    fh = open( logfile, "w" )
    fh.close()
    pipeline.CondorDAG.__init__(self,logfile)
    self.set_dag_file(self.basename)
    self.jobsDict = {}
    self.id = 0
  def add_node(self, node):
    self.id+=1
    pipeline.CondorDAG.add_node(self, node)
    
class mvsc_get_doubles_job(pipeline.CondorDAGJob):
  """
  A mvsc_get_doubles job: BLAH
  """
  def __init__(self, cp, tag_base='MVSC_GET_DOUBLES'):
    """
    """
    self.__prog__ = 'mvsc_get_doubles'
    self.__executable = string.strip(cp.get('condor','mvsc_get_doubles'))
    self.__universe = "vanilla"
    pipeline.CondorDAGJob.__init__(self,self.__universe,self.__executable)
    self.add_condor_cmd('getenv','True')
    self.tag_base = tag_base
    self.add_condor_cmd('environment',"KMP_LIBRARY=serial;MKL_SERIAL=yes")
    self.set_sub_file(tag_base+'_'+run_tag+'.sub')
    self.set_stdout_file(opts.log_path+'/'+tag_base+'_'+run_tag+'-$(macroid)-$(process).out')
    self.set_stderr_file(opts.log_path+'/'+tag_base+'_'+run_tag+'-$(macroid)-$(process).err')

class mvsc_get_doubles_node(pipeline.CondorDAGNode):
  """
  """
  def __init__(self, job, dag, options, instruments, databases, outputfiles, p_node=[]):
    pipeline.CondorDAGNode.__init__(self,job)
    #FIXME add tmp file space
    self.add_macro("macroid", dag.id)
    for opt in options:
      self.add_var_opt(opt[0],opt[1])
    for database in databases:
      self.add_file_arg(database)
    self.add_var_opt("instruments",instruments)
    self.add_var_opt("output-tag",opts.user_tag)
    for p in p_node:
      self.add_parent(p)
    dag.add_node(self)

class train_forest_job(pipeline.CondorDAGJob):
  """
  """
  def __init__(self, cp, tag_base='TRAIN_FOREST'):
    """
    """
    self.__prog__ = 'SprBaggerDecisionTreeApp'
    self.__executable = string.strip(cp.get('condor','mvsc_train_forest'))
    self.__universe = "vanilla"
    pipeline.CondorDAGJob.__init__(self,self.__universe,self.__executable)
    self.add_condor_cmd('getenv','True')
    self.tag_base = tag_base
    self.add_condor_cmd('environment',"KMP_LIBRARY=serial;MKL_SERIAL=yes")
    self.set_sub_file(tag_base+'_'+run_tag+'.sub')
    self.set_stdout_file(opts.log_path+'/'+tag_base+'_'+run_tag+'-$(macroid)-$(process).out')
    self.set_stderr_file(opts.log_path+'/'+tag_base+'_'+run_tag+'-$(macroid)-$(process).err')

class train_forest_node(pipeline.CondorDAGNode):
  """
  """
  def __init__(self, job, dag, trainingfile, p_node=[]):
    pipeline.CondorDAGNode.__init__(self,job)
    #FIXME add tmp file space
    self.add_macro("macroid", dag.id)
    self.add_input_file(trainingfile)
    self.trainingfile = self.get_input_files()[0]
    self.trainedforest = self.trainingfile.replace('_training.pat','_'+run_tag+'.spr')
    self.add_file_arg("-a 4 -z sngl_gps_time_a,sngl_gps_time_b -n %s -l %s -s %s -c %s -g 1 -i -d 1 -f %s %s" % (opts.number_of_trees, opts.leaf_size, opts.sampled_parameters, opts.criterion_for_optimization, self.trainedforest, self.trainingfile))
    self.add_output_file(self.trainedforest)
    for p in p_node:
      self.add_parent(p)
    dag.add_node(self)

class use_forest_job(pipeline.CondorDAGJob):
  """
  """
  def __init__(self, cp, tag_base='USE_FOREST'):
    """
    """
    self.__prog__ = 'SprOutputWriterApp'
    self.__executable = string.strip(cp.get('condor','mvsc_use_forest'))
    self.__universe = "vanilla"
    pipeline.CondorDAGJob.__init__(self,self.__universe,self.__executable)
    self.add_condor_cmd('getenv','True')
    self.tag_base = tag_base
    self.add_condor_cmd('environment',"KMP_LIBRARY=serial;MKL_SERIAL=yes")
    self.set_sub_file(tag_base+'_'+run_tag+'.sub')
    self.set_stdout_file(opts.log_path+'/'+tag_base+'_'+run_tag+'-$(macroid)-$(process).out')
    self.set_stderr_file(opts.log_path+'/'+tag_base+'_'+run_tag+'-$(macroid)-$(process).err')

class use_forest_node(pipeline.CondorDAGNode):
  """
  """
  def __init__(self, job, dag, trainedforest, file_to_rank,  p_node=[]):
    pipeline.CondorDAGNode.__init__(self,job)
    #FIXME add tmp file space
    self.add_macro("macroid", dag.id)
    self.add_input_file(trainedforest)
    self.add_input_file(file_to_rank)
    self.trainedforest = self.get_input_files()[0]
    self.file_to_rank = self.get_input_files()[1]
    self.ranked_file = self.file_to_rank.replace('.pat','.dat')
    self.add_file_arg("-A -a 4 -z sngl_gps_time_a,sngl_gps_time_b %s %s %s" % (self.trainedforest, self.file_to_rank, self.ranked_file))
    self.add_output_file(self.ranked_file)
    for p in p_node:
      self.add_parent(p)
    dag.add_node(self)

class mvsc_update_sql_job(pipeline.CondorDAGJob):
  """
  A mvsc_update_sql job: BLAH
  """
  def __init__(self, cp, tag_base='MVSC_UPDATE_SQL'):
    """
    """
    self.__prog__ = 'mvsc_update_sql'
    self.__executable = string.strip(cp.get('condor','mvsc_update_sql'))
    self.__universe = "vanilla"
    pipeline.CondorDAGJob.__init__(self,self.__universe,self.__executable)
    self.add_condor_cmd('getenv','True')
    self.tag_base = tag_base
    self.add_condor_cmd('environment',"KMP_LIBRARY=serial;MKL_SERIAL=yes")
    self.set_sub_file(tag_base+'_'+run_tag+'.sub')
    self.set_stdout_file(opts.log_path+'/'+tag_base+'_'+run_tag+'-$(macroid)-$(process).out')
    self.set_stderr_file(opts.log_path+'/'+tag_base+'_'+run_tag+'-$(macroid)-$(process).err')

class mvsc_update_sql_node(pipeline.CondorDAGNode):
  """
  """
  def __init__(self, job, dag, options, databases, p_node=[]):
    pipeline.CondorDAGNode.__init__(self,job)
    #FIXME add tmp file space
    self.add_macro("macroid", dag.id)
    for opt in options:
      self.add_var_opt(opt[0],opt[1])
    [self.add_var_arg(database) for database in databases]
    self.add_var_opt("all-instruments",opts.all_instruments)
    self.add_var_opt("output-tag",opts.user_tag)
    for p in p_node:
      self.add_parent(p)
    dag.add_node(self)



###############################################################################
## MAIN #######################################################################
###############################################################################


print databases

all_ifos = opts.all_instruments.strip().split(',')
ifo_combinations = list(iterutils.choices(all_ifos,2))
print ifo_combinations
ifos=ifo_combinations

### First, we are going to check that we have enough timeslide and injection events for training for each ifo combination, and that there are any zerolag 

count_all_injections={}
count_exact_injections={}
count_timeslides={}
count_zerolag={}
exact_tag = cp.get("mvsc_get_doubles","exact-tag")
for database in databases:
  local_disk = None #"/tmp"
  working_filename = dbtables.get_connection_filename(database, tmp_path = local_disk, verbose = True)
  connection = sqlite3.connect(working_filename)
  xmldoc = dbtables.get_xml(connection)
  cursor = connection.cursor()
  for comb in ifo_combinations:
    ifos = list(comb)
    combstr = ''.join(comb)
    comb = ','.join(comb)
    #print ''.join([mvsc_queries.CandidateEventQuery.select_count,mvsc_queries.CandidateEventQuery.add_from_injections,mvsc_queries.CandidateEventQuery.add_where_all,])
    count_all_injections[comb], = connection.cursor().execute(''.join([mvsc_queries.CandidateEventQuery.select_count,mvsc_queries.CandidateEventQuery.add_from_injections,mvsc_queries.CandidateEventQuery.add_where_all,]), (ifos[0],ifos[1],) )
    count_exact_injections[comb], = connection.cursor().execute(''.join([mvsc_queries.CandidateEventQuery.select_count,mvsc_queries.CandidateEventQuery.add_from_injections,mvsc_queries.CandidateEventQuery.add_where_exact,]), (ifos[0],ifos[1],exact_tag,) )
    count_timeslides[comb], = connection.cursor().execute(''.join([mvsc_queries.CandidateEventQuery.select_count,mvsc_queries.CandidateEventQuery.add_from_fulldata," AND experiment_summary.datatype == 'slide'"]), (ifos[0],ifos[1],) )
    count_zerolag[comb], = connection.cursor().execute(''.join([mvsc_queries.CandidateEventQuery.select_count,mvsc_queries.CandidateEventQuery.add_from_fulldata," AND experiment_summary.datatype == 'all_data'"]), (ifos[0],ifos[1],) )

print count_all_injections
for comb in ifo_combinations:
  comb = ','.join(comb)
  count_all_injections[comb]=count_all_injections[comb][0]
  count_exact_injections[comb]=count_exact_injections[comb][0]
  count_timeslides[comb]=count_timeslides[comb][0]
  count_zerolag[comb]=count_zerolag[comb][0]
  print "for the ifo combination "+str(comb)+" there are "+str(count_all_injections[comb])+" total found injections"
  print "for the ifo combination "+str(comb)+" there are "+str(count_exact_injections[comb])+" exactly found injections"
  print "for the ifo combination "+str(comb)+" there are "+str(count_timeslides[comb])+" timeslides"
  print "for the ifo combination "+str(comb)+" there are "+str(count_zerolag[comb])+" zerolag events"
  print "note, these counts include triples"
  # "there must be more than 0 zerolag events and at least as many exactly found injections and timeslides as the number of round-robins"

### SET UP THE DAG

try: os.mkdir("logs")
except: pass

#mvsc_get_doubles
get_job = mvsc_get_doubles_job(cp)
get_node = {}
training_files = {}
evaluation_files = {}
evaluation_info_files = {}
zerolag_files = {}
zerolag_info_files = {}

#SprBaggerDecisionTreeApp
train_job = train_forest_job(cp)
train_node = {}

#SprOutputWriterApp
rank_job = use_forest_job(cp)
rank_node = {}
zl_rank_job = use_forest_job(cp)
zl_rank_node = {}

#mvsc_update_sql
update_job = mvsc_update_sql_job(cp)
update_node = {}

update_files=[]

#Assemble the DAG
dag = mvsc_dag_DAG(ininame, opts.log_path+'/')
number_of_round_robins=int(cp.get("mvsc_get_doubles","number"))
comment={}
for comb in ifo_combinations:
  combstr = ''.join(comb)
  comb = ','.join(comb)
  training_files[comb] = []
  evaluation_files[comb] = []
  evaluation_info_files[comb] = []
  zerolag_files[comb] = []
  zerolag_info_files[comb] = []
  for i in range(number_of_round_robins):
    if number_of_round_robins < count_exact_injections[comb] and number_of_round_robins < count_timeslides[comb]:
      training_files[comb].append(combstr+'_'+opts.user_tag+'_set'+str(i)+'_'+'training.pat')
      evaluation_files[comb].append(combstr+'_'+opts.user_tag+'_set'+str(i)+'_'+'evaluation.pat')
      evaluation_info_files[comb].append(combstr+'_'+opts.user_tag+'_set'+str(i)+'_'+'evaluation_info.pat')
      comment[comb]=None
      if count_zerolag[comb] > 0:
        zerolag_files[comb].append(combstr+'_'+opts.user_tag+'_set'+str(i)+'_'+'zerolag.pat')
        zerolag_info_files[comb].append(combstr+'_'+opts.user_tag+'_set'+str(i)+'_'+'zerolag_info.pat')
      else: "there are no zerolag events for "+comb+". Note, this means there are also no triples"
    else: comment[comb]="there are not enough training events for "+comb+". This means we cannot rank zerolag events for this combination"
  if opts.skip_file_generation:
    get_node[comb] = None
  else:
    if comment[comb] == None: 
      get_node[comb] = mvsc_get_doubles_node(get_job, dag, cp.items("mvsc_get_doubles"), comb, databases, training_files[comb]+evaluation_files[comb]+evaluation_info_files[comb]+zerolag_files[comb]+zerolag_info_files[comb])
    else: 
      print comment[comb]
      get_node[comb]=None
  train_node[comb] = {}
  rank_node[comb] = {}
  zl_rank_node[comb] = {}
  for i,file in enumerate(training_files[comb]):
    update_files.append(evaluation_info_files[comb][i])
    if opts.skip_file_generation: train_node[comb][i] = train_forest_node(train_job, dag, file, p_node = [])
    else: train_node[comb][i] = train_forest_node(train_job, dag, file, p_node = [get_node[comb]])
    try: rank_node[comb]
    except: rank_node[comb]={}
    rank_node[comb][i] = use_forest_node(rank_job, dag, train_node[comb][i].trainedforest, str(evaluation_files[comb][i]), p_node=[train_node[comb][i]])
    update_files.append(rank_node[comb][i].ranked_file)
    try:
      update_files.append(zerolag_info_files[comb][i])
      zl_rank_node[comb][i] = use_forest_node(zl_rank_job, dag, train_node[comb][i].trainedforest, str(zerolag_files[comb][i]), p_node=[train_node[comb][i]])
      update_files.append(zl_rank_node[comb][i].ranked_file)
    except:
      print "omitting "+comb+" from zerolag update files"
finished_rank_nodes=[]
for key in rank_node:
  finished_rank_nodes.extend(rank_node[key].values())
  finished_rank_nodes.extend(zl_rank_node[key].values())

update_options=[]
if cp.has_section("mvsc_update_sql"):
  update_options = cp.items("mvsc_update_sql")

update_node['all'] = mvsc_update_sql_node(update_job, dag, cp.items("mvsc_update_sql"), databases+update_files, p_node=finished_rank_nodes)
print update_files
dag.write_sub_files()
dag.write_dag()
dag.write_script()
