#!/usr/bin/python
try:
	import sqlite3
except ImportError:
	# pre 2.5.x
	from pysqlite2 import dbapi2 as sqlite3
from glue.ligolw import dbtables 
from glue.ligolw import table
from glue.ligolw import ilwd
from glue import segments
from pylal.xlal import tools as xlaltools
from pylal.xlal.datatypes import snglinspiraltable
from pylal import SnglInspiralUtils
from pylal import db_thinca_rings
from pylal import git_version
from pylal import mvsc_queries
from time import clock,time
import matplotlib
matplotlib.use('Agg')
import pylab
from optparse import *
import glob
import sys
import random
import bisect
import pickle
from pylal import ligolw_thinca

usage="""
example command line:
%prog --instruments=H1,L1 database1.sqlite database2.sqlite etc.
or, if you are brave:
%prog --instruments=H1,L1 *.sqlite
this code turns sqlite databases into .pat files for MVSC
each row in the .pat file contains a vector of parameters that characterize the double coincident trigger
"""

__author__ = "Kari Hodge <khodge@ligo.caltech.edu>"

sqlite3.enable_callback_tracebacks(True)

parser=OptionParser(usage=usage,version=git_version.verbose_msg)
parser.add_option("", "--number", default=10, type="int", help="number for round robin")
parser.add_option("", "--factor", default=50.0, type="float", help="the value of the magic number factor in the effective snr formula, should be 50 for highmass and 250 for lowmass")
parser.add_option("", "--instruments", help="pair that you want to get like H1,L1")
parser.add_option("", "--output-tag", default="CBC", help="a string added to all filenames to help you keep track of things")
parser.add_option("", "--apply-weights", action="store_true", default=False, help="calculates weight for all found injections, saves in .pat file (all bkg events get weight=1), if this option is not supplied, all events get a weight of 1, so still use -a 4 in the Spr* executables")
parser.add_option("", "--check-weights", action="store_true", default=False,help="turn on if you want a plot of cumulative weights v. distance")
parser.add_option("", "--exact-tag", default="ring_exact", help="this is the dbinjfind tag stored in the sqlite database for the exactly found injections - the ones you want to use for training")
parser.add_option("-s","--tmp-space",help="necessary for sqlite calls, for example /usr1/khodge")
#parser.add_option("", "--nearby-tag", default="ring_nearby", help="this is the dbinjfind tag stored in the sqlite for the nearby injections - we will still rank all of these")

(opts,databases)=parser.parse_args()
ifos=opts.instruments.strip().split(',')
ifos.sort()

time1=time()


class SnglInspiral(snglinspiraltable.SnglInspiralTable):
	"""
	You need to make this subclass of xlaltools.SnglInspiralTable because the C version doesn't have the methods 
	you need to slide the triggers on the ring, which you need to do for a correct calculation of ethinca
	"""
	__slots__ = ()

	def get_end(self):
		return dbtables.lsctables.LIGOTimeGPS(self.end_time, self.end_time_ns)

	def set_end(self, gps):
		self.end_time, self.end_time_ns = gps.seconds, gps.nanoseconds

dbtables.lsctables.SnglInspiralTable.RowType = SnglInspiral

parameters = mvsc_queries.CandidateEventQuery.parameters
print "your MVSC analysis will use the following dimensions: "+parameters
 
exact_injections = []
exact_injections_info = []
exact_injections_distance = []
exact_injections_sngl_gps_times = []
all_injections = []
all_injections_info = []
all_injections_distance = []
all_injections_sngl_gps_times = []
normalization = []
zerolag = []
zerolag_info = []
zerolag_sngl_gps_times = []
timeslides = []
timeslides_info = []
timeslides_sngl_gps_times = []

for database in databases:
	local_disk = opts.tmp_space
	working_filename = dbtables.get_connection_filename(database, tmp_path = local_disk, verbose = True)
	connection = sqlite3.connect(working_filename)
	dbtables.DBTable_set_connection(connection)
	xmldoc = dbtables.get_xml(connection)
	cursor = connection.cursor()
	num_sngl_cols = len(table.get_table(xmldoc, dbtables.lsctables.SnglInspiralTable.tableName).dbcolumnnames)
	rings = db_thinca_rings.get_thinca_rings_by_available_instruments(connection)
	offset_vectors = dbtables.lsctables.table.get_table(dbtables.get_xml(connection), dbtables.lsctables.TimeSlideTable.tableName).as_dict()
	sngl_inspiral_row_from_cols = table.get_table(xmldoc, dbtables.lsctables.SnglInspiralTable.tableName).row_from_cols
	def calc_effective_snr(snr, chisq, chisq_dof, fac=opts.factor):
		return snr/ (1 + snr**2/fac)**(0.25) / (chisq/(2*chisq_dof - 2) )**(0.25)
	
	def calc_ethinca(rowA,rowB,time_slide_id,rings=rings,offset_vectors=offset_vectors):
		flatrings = segments.segmentlist()
		for value in rings.values():
			flatrings.extend(value)
		rowA = sngl_inspiral_row_from_cols(rowA)
		SnglInspiralUtils.slideTriggersOnRings([rowA],flatrings,offset_vectors[time_slide_id])
		rowB = sngl_inspiral_row_from_cols(rowB)
		SnglInspiralUtils.slideTriggersOnRings([rowB],flatrings,offset_vectors[time_slide_id])
		return xlaltools.XLALCalculateEThincaParameter(rowA,rowB)
	
	def calc_delta_t(trigger1_ifo, trigger1_end_time, trigger1_end_time_ns, trigger2_ifo, trigger2_end_time, trigger2_end_time_ns, time_slide_id, rings = rings, offset_vectors = offset_vectors):
		time_slide_id = ilwd.ilwdchar(time_slide_id)
		trigger1_true_end_time = dbtables.lsctables.LIGOTimeGPS(trigger1_end_time, trigger1_end_time_ns)
		trigger2_true_end_time = dbtables.lsctables.LIGOTimeGPS(trigger2_end_time, trigger2_end_time_ns)
		# find the instruments that were on at trigger 1's end time and
		# find the ring that contains this trigger
		try:
			[ring] = [segs[segs.find(trigger1_end_time)] for segs in rings.values() if trigger1_end_time in segs]
		except ValueError:
			# FIXME THERE SEEMS TO BE A BUG IN	THINCA!	Occasionally thinca records a trigger on the upper boundary
			# of its ring.	This would make it outside the ring which is very problematic.	It needs to be fixed in thinca
			# for now we'll allow the additional check that the other trigger is in the ring and use it.
				print >>sys.stderr, "trigger1 found not on a ring, trying trigger2"
				[ring] = [segs[segs.find(trigger2_end_time)] for segs in rings.values() if trigger2_end_time in segs]
		# now we can unslide the triggers on the ring
		try:
			trigger1_true_end_time = SnglInspiralUtils.slideTimeOnRing(trigger1_true_end_time, offset_vectors[time_slide_id][trigger1_ifo], ring)
			trigger2_true_end_time = SnglInspiralUtils.slideTimeOnRing(trigger2_true_end_time, offset_vectors[time_slide_id][trigger2_ifo], ring)
			out = abs(trigger1_true_end_time - trigger2_true_end_time)
			return float(out)
		except:
			print >> sys.stderr, "calc delta t failed because one of the trigger's true end times landed on the upper boundary of the thinca ring. See: trigger 1: ", trigger1_true_end_time, "trigger 2: ", trigger2_true_end_time, "ring: ", ring
			out = float(abs(trigger1_true_end_time - trigger2_true_end_time)) % 1
			if out > 0.5:
				out = 1.0 - out
			print >> sys.stderr, "SO...delta t has been set to: ", out, "in accordance with the mod 1 hack"
			return out
	connection.create_function("calc_delta_t", 7, calc_delta_t)
	connection.create_function("calc_effective_snr", 3, calc_effective_snr)
# in S6, the timeslides, zerolag, and injections are all stored in the same sqlite database, thus this database must include a sim inspiral table
# if you provide a database that does not include injections, the code will still run as long as one of the databases you provide includes injections 
	try:
		sim_inspiral_table = table.get_table(xmldoc, dbtables.lsctables.SimInspiralTable.tableName)
		is_injections = True
	except ValueError:
		is_injections = False

# please note that the third to last entry in exact_injections and all_injections is the gps time of the coinc_inspiral. We will only be using this for bookkeeping, it will not stay in the array
	if is_injections:
		for values in connection.cursor().execute(''.join([mvsc_queries.CandidateEventQuery.select_dimensions,mvsc_queries.CandidateEventQuery.add_select_injections,mvsc_queries.CandidateEventQuery.add_from_injections,mvsc_queries.CandidateEventQuery.add_where_all,]), (ifos[0],ifos[1],) ):
			all_injections.append((calc_ethinca(values[1:num_sngl_cols+1],values[num_sngl_cols+1:2*num_sngl_cols+1],ilwd.ilwdchar(values[2*num_sngl_cols+1]),rings,offset_vectors,),) + values[2*num_sngl_cols+2:] + (1,))
			all_injections_info.append([values[0], database])
			all_injections_distance.append([values[-1], database])
		#for values in connection.cursor().execute(''.join([mvsc_queries.CandidateEventQuery.select_dimensions,mvsc_queries.CandidateEventQuery.add_select_injections,mvsc_queries.
#CandidateEventQuery.add_where_exact,mvsc_queries.CandidateEventQuery.add_from_injections]), (ifos[0],ifos[1],opts.start_time,opts.end_time,opts.exact_tag,) ):
		for values in connection.cursor().execute(''.join([mvsc_queries.CandidateEventQuery.select_dimensions,mvsc_queries.CandidateEventQuery.add_select_injections,mvsc_queries.CandidateEventQuery.add_from_injections,mvsc_queries.CandidateEventQuery.add_where_exact,]), (ifos[0],ifos[1],opts.exact_tag,) ):
			exact_injections.append((calc_ethinca(values[1:num_sngl_cols+1],values[num_sngl_cols+1:2*num_sngl_cols+1],ilwd.ilwdchar(values[2*num_sngl_cols+1]),rings,offset_vectors,),) + values[2*num_sngl_cols+2:] + (1,))
			exact_injections_info.append([values[0], database])
			exact_injections_distance.append([values[-1], database])

	#FIXME: look up coinc_definer_id from definition in pylal
# get the timeslide/full_data triggers
	for values in connection.cursor().execute(''.join([mvsc_queries.CandidateEventQuery.select_dimensions,mvsc_queries.CandidateEventQuery.add_select_fulldata,mvsc_queries.CandidateEventQuery.add_from_fulldata]), (ifos[0],ifos[1],) ):
		if values[-1] == 'slide':
			timeslides.append((calc_ethinca(values[1:num_sngl_cols+1],values[num_sngl_cols+1:2*num_sngl_cols+1],ilwd.ilwdchar(values[2*num_sngl_cols+1]),rings,offset_vectors),) + values[2*num_sngl_cols+2:-1] + (1,) + (0,))
			timeslides_info.append([values[0], database])
		if values[-1] == 'all_data':
			zerolag.append((calc_ethinca(values[1:num_sngl_cols+1],values[num_sngl_cols+1:2*num_sngl_cols+1],ilwd.ilwdchar(values[2*num_sngl_cols+1]),rings,offset_vectors),) + values[2*num_sngl_cols+2:-1] + (1,) + (0,))
			zerolag_info.append([values[0], database])
	dbtables.put_connection_filename(database, working_filename, verbose = True)

#let's remove the sngl gps times, these are in entries -5 and -4 in the injection lists, and -2 and -3 in the timeslide/zerolag lists, and put them in their own tables
print len(all_injections)
print len(exact_injections)
print len(timeslides)
print len(zerolag)
print len(all_injections[0])
print len(exact_injections[0])
print len(timeslides[0])
print len(zerolag[0])

print all_injections[0]
print timeslides[0]
tmplist=[]
for i,row in enumerate(exact_injections):
	tmprow=(row)
	exact_injections_sngl_gps_times.append([tmprow[-6],tmprow[-5]])
	tmplist.append((tmprow[0:-6]+tmprow[-4:]))
exact_injections=tmplist
tmplist=[]
for i,row in enumerate(all_injections):
	tmprow=list(row)
	all_injections_sngl_gps_times.append([tmprow[-6],tmprow[-5]])
	tmplist.append((tmprow[0:-6]+tmprow[-4:]))
all_injections=tmplist
tmplist=[]
for i,row in enumerate(timeslides):
	tmprow=list(row)
	timeslides_sngl_gps_times.append([tmprow[-4],tmprow[-3]])
	tmplist.append((tmprow[0:-4]+tmprow[-2:]))
timeslides=tmplist
tmplist=[]
for i,row in enumerate(zerolag):
	tmprow=list(row)
	zerolag_sngl_gps_times.append([tmprow[-4],tmprow[-3]])
	tmplist.append((tmprow[0:-4]+tmprow[-2:]))
zerolag=tmplist
print all_injections[0]
print timeslides[0]

print len(all_injections)
print len(exact_injections)
print len(timeslides)
print len(zerolag)
print len(all_injections[0])
print len(exact_injections[0])
print len(timeslides[0])
print len(zerolag[0])

# the weight given to each injection will be equal to 1/sqrt(snr_a^2+snr_b^2)
#FIXME: maybe there are better ways to implement the weighting, please think about it before applying 
newexact_injections=[]
exact_injections_vol=[]
exact_injections_lin=[]
exact_injections_log=[]
newall_injections=[]
if opts.apply_weights:
	print "applying weights for exact_injections"
	for i,row in enumerate(exact_injections):
		# so here we're using sngl SNRs to re-weight, but I had been storing the distane ...
		# just don't use weights right now!  
		injtmp = list(row)
		injtmp[-3]=str(injtmp[-3])
		if injtmp[-3] == 'uniform':
			injtmp[-2]=((8**2+8**2)**0.5/(injtmp[4]**2+injtmp[5]**2)**0.5)**2
			exact_injections_lin.append([exact_injections_distance[i][0],injtmp[4],injtmp[5],injtmp[-2]])
		if injtmp[-3] == 'log10':
			injtmp[-2]=((8**2+8**2)**0.5/(injtmp[4]**2+injtmp[5]**2)**0.5)**3
			exact_injections_log.append([exact_injections_distance[i][0],injtmp[4],injtmp[5],injtmp[-2]])
		if injtmp[-3] == 'volume':
			injtmp[-2]=1.0
			exact_injections_vol.append([exact_injections_distance[i][0],injtmp[4],injtmp[5],injtmp[-2]])
		newexact_injections.append(tuple(injtmp))
	for row in all_injections:
		injtmp = list(row)
		injtmp[-3]=str(injtmp[-3])
		injtmp[-2]=3*((8**2+8**2)**0.5/(injtmp[4]**2+injtmp[5]**2)**0.5)**2
		if injtmp[-3] == 'log10':
			injtmp[-2]=3*((8**2+8**2)**0.5/(injtmp[4]**2+injtmp[5]**2)**0.5)**3
		if injtmp[-3] == 'volume':
			injtmp[-2]=1.0
		newall_injections.append(tuple(injtmp))
else:
	for row in exact_injections:
		injtmp = list(row)
		injtmp[-2]=1
		newexact_injections.append(tuple(injtmp))
	for row in all_injections:
		injtmp = list(row)
		injtmp[-2]=1
		newall_injections.append(tuple(injtmp))
exact_injections=newexact_injections
all_injections=newall_injections

def open_file_write_pickle(filetype, datatype, ifos, object):
	f=open(''.join(ifos) + '_' + opts.output_tag + '_' + str(datatype) + '_' + str(filetype) +  '.pickle', 'w')
	pickle.dump(object,f)
	f.close()
def open_file_write_text(filetype, datatype, ifos, object):
	f=open(''.join(ifos) + '_' + opts.output_tag + '_' + str(datatype) + '_' + str(filetype) +  '.txt', 'w')
	f.write('distance snr_a snr_b weight \n')
	for row in object:
		f.write("%s\n" % " ".join(map(str,row)))
	f.close()


if opts.check_weights:
	open_file_write_text('weights_vol','exact_injections',ifos,exact_injections_vol)
	open_file_write_text('weights_lin','exact_injections',ifos,exact_injections_lin)
	open_file_write_text('weights_log','exact_injections',ifos,exact_injections_log)

#now let's get rid of the injection distribution tags so they're not in our table anymore
tmp = zip(*exact_injections)
del(tmp[-3]) # we have to remove the GPS times from the array before writing to file
exact_injections = zip(*tmp)
tmp = zip(*all_injections)
del(tmp[-3]) # we have to remove the GPS times from the array before writing to file
all_injections = zip(*tmp)
print all_injections[0]

##save all injections, timeslides, zerolag, gps files, and infofiles to pickles
#open_file_write_pickle('data','all_injections',ifos,all_injections)
#open_file_write_pickle('data','exact_injections',ifos,exact_injections)
#open_file_write_pickle('data','timeslides',ifos,timeslides)
#open_file_write_pickle('data','zerolag',ifos,zerolag)
#open_file_write_pickle('sngl_gps_times','all_injections',ifos,all_injections_sngl_gps_times)
#open_file_write_pickle('sngl_gps_times','exact_injections',ifos,exact_injections_sngl_gps_times)
#open_file_write_pickle('sngl_gps_times','timeslides',ifos,timeslides_sngl_gps_times)
#open_file_write_pickle('sngl_gps_times','zerolag',ifos,zerolag_sngl_gps_times)
#open_file_write_pickle('info','all_injections',ifos,all_injections_info)
#open_file_write_pickle('info','exact_injections',ifos,exact_injections_info)
#open_file_write_pickle('info','timeslides',ifos,timeslides_info)
#open_file_write_pickle('info','zerolag',ifos,zerolag_info)
#pickle.dump(parameters,open(''.join(ifos) + '_' + opts.output_tag + '_header' + '.pickle', 'w'))

#we will put the rest of this in another code to be called after I get all the aux info

random.seed(2)
random.shuffle(timeslides)
random.seed(2)
random.shuffle(timeslides_info)

# this part of the code writes the triggers' information into .pat files, in the format needed for SprBaggerDecisionTreeApp
# to get the MVSC rank for each timeslide and injection, we do a round-robin of training and testing, with the number of rounds determined by opts.number
# for example,	if opts.number is 10, each round will train a random forest of bagged decision trees on 90% of the timeslides and injections
# then we'd run the remaining 10% through the trained forest to get their MVSC rank
# in this case, we'd do this 10 times, ensuring that every timeslide and injection gets ranked 
Nrounds = opts.number
Ninj = len(all_injections)
Nslide = len(timeslides)

trstr = 'training' 
testr = 'evaluation'
zlstr = 'zerolag'

print "there are ", len(timeslides), " timeslide doubles in ", ''.join(ifos), " and triple coincidences"
print "there are ", len(all_injections), " injection doubles in ", ''.join(ifos), " and triple coincidences"
print "there are ", len(zerolag), " zerolag doubles in ", ''.join(ifos), " and triple coincidences"

if len(exact_injections) > Nrounds and len(timeslides) > Nrounds:
	Nparams = len(exact_injections[0]) - 3
	Nrounds = opts.number
	Ninj_exact = len(exact_injections)
	Nslide = len(timeslides)
	gps_times_for_all_injections = zip(*all_injections)[-3]
	gps_times_for_exact_injections = zip(*exact_injections)[-3]
	print min(gps_times_for_all_injections), max(gps_times_for_all_injections)
	print min(gps_times_for_exact_injections), max(gps_times_for_exact_injections)
	trstr = 'training'
	testr = 'evaluation'
	zlstr = 'zerolag'

	def open_file_write_headers(filetype, set_num, ifos, Nparams=Nparams):
		f = open(''.join(ifos) + '_' + opts.output_tag + '_set' + str(set_num) + '_' + str(filetype) +	'.pat', 'w')
		f.write(str(Nparams) + '\n')
		f.write(parameters + "\n")
		return f
	
# first put in the header information
	for i in range(Nrounds):
		f_training = open_file_write_headers(trstr, i, ifos)
		f_testing = open_file_write_headers(testr, i, ifos)
		f_testing_info=open(''.join(ifos) + '_' + opts.output_tag + '_set' + str(i) + '_' + str(testr) + '_info.pat', 'w')
# now let's do the injections - each training set will have (Nrounds-1)/Nrounds fraction of all exactly found injections
# we need to rig the evaluation sets (which include all exact AND nearby injections) to not be evaluated on a forest that was trained on themselves
# so, we divide up the set of exactly found injections (which is sorted by GPS time) into Nrounds, then use the GPS times at the boundaries to properly divide our evaluation sets
		exact_injections_tmp = list(exact_injections)
		set_i_exact_inj = exact_injections_tmp[i*Ninj_exact/Nrounds : (i+1)*Ninj_exact/Nrounds]
		divisions = [set_i_exact_inj[0][-3], set_i_exact_inj[-1][-3]] #these are the gps boundaries for what should go in the evaluation set
		del(exact_injections_tmp[i*Ninj_exact/Nrounds : (i+1)*Ninj_exact/Nrounds])
# exact_injections_tmp now contains our the exact injections we want to include in the ith training set, let's write them to file
		tmp = zip(*exact_injections_tmp)
		del(tmp[-3]) # we have to remove the GPS times from the array before writing to file
		exact_injections_tmp = zip(*tmp)
		for row in exact_injections_tmp:
			f_training.write("%s\n" % " ".join(map(str,row)))
# now we need to construct the evaluation (aka testing) set (all injections, not just exact anymore) that pairs with this training set
		set_i_all_inj = []
		left_index = bisect.bisect_left(gps_times_for_all_injections,divisions[0])
		right_index = bisect.bisect_right(gps_times_for_all_injections,divisions[1])
		set_i_all_inj = all_injections[left_index:right_index]
		set_i_all_inj_info = all_injections_info[left_index:right_index]
		print "to add:", len(set_i_all_inj)
		tmp = zip(*set_i_all_inj)
		del(tmp[-3]) # we have to remove the GPS times from the array before writing to file
		set_i_all_inj = zip(*tmp)
		for row in set_i_all_inj:
			f_testing.write("%s\n" % " ".join(map(str,row)))
		for row in set_i_all_inj_info:
			f_testing_info.write("%s\n" % " ".join(map(str,row)))
# now let's do the timeslides
		timeslides_tmp = list(timeslides)
		timeslides_info_tmp = list(timeslides_info)
# get (say) 10% of the timeslides and injections, which you will run through the forest that you've trained on the other 90%
		set_i_slide = timeslides_tmp[i*Nslide/Nrounds : (i+1)*Nslide/Nrounds]
		set_i_slide_info = timeslides_info_tmp[i*Nslide/Nrounds : (i+1)*Nslide/Nrounds]
		for row in set_i_slide:
			f_testing.write("%s\n" % " ".join(map(str,row)))
		for row in set_i_slide_info:
			f_testing_info.write("%s\n" % " ".join(map(str,row)))
# delete the 10%, and save the remaining 90% into the training file
		del(timeslides_tmp[i*Nslide/Nrounds : (i+1)*Nslide/Nrounds])
		for row in timeslides_tmp:
			f_training.write("%s\n" % " ".join(map(str,row)))
		if len(zerolag) != 0:
			f_zerolag=open(''.join(ifos) + '_' + opts.output_tag + '_set' + str(i) + '_'+ str(zlstr) + '.pat','w')
			f_zerolag.write(str(Nparams) + '\n')
			f_zerolag.write(parameters + "\n")
			for row in zerolag:
				f_zerolag.write("%s\n" % " ".join(map(str,row)))
			f_zerolag_info=open(''.join(ifos) + '_' + opts.output_tag + '_set' + str(i) + '_' + str(zlstr) + '_info.pat', 'w')
			for row in zerolag_info:
				f_zerolag_info.write("%s\n" % " ".join(map(str,row)))
else: raise Exception, "There were no injections found for the specified ifo combination %s" % ifos


time2=time()
elapsed_time=time2-time1
print "elapsed time:", elapsed_time
