#!/usr/bin/python

try:
	import sqlite3
except ImportError:
	# pre 2.5.x
	from pysqlite2 import dbapi2 as sqlite3
from glue.ligolw import dbtables
from glue.ligolw import table
from glue.ligolw import ilwd
from glue import segments
from pylal.xlal import tools as xlaltools
from pylal.xlal.datatypes import snglringdowntable
from pylal import SnglInspiralUtils
from pylal import db_rinca_rings
from pylal import git_version
from pylal import mvsc_queries_ringdown
from time import clock,time
import matplotlib
matplotlib.use('Agg')
import pylab
from optparse import *
import glob
import sys
import random
import bisect
import pickle

usage="""
example command line:
%prog --instruments=H1,L1 database1.sqlite database2.sqlite etc.
or, if you are brave:
%prog --instruments=H1,L1 *.sqlite
this code turns sqlite databases into .pat files for MVSC
each row in the .pat file contains a vector of parameters that characterize the double coincident trigger
"""

__author__ = "Kari Hodge <khodge@ligo.caltech.edu>, Paul T Baker <paul.baker@ligo.org>"

sqlite3.enable_callback_tracebacks(True)

parser=OptionParser(usage=usage,version=git_version.verbose_msg)
parser.add_option("", "--number", default=10, type="int", help="number for round robin")
parser.add_option("", "--instruments", help="pair that you want to get like H1,L1")
parser.add_option("", "--output-tag", default="CBC", help="a string added to all filenames to help you keep track of things")
parser.add_option("", "--apply-weights", action="store_true", default=False, help="calculates weight for all found injections, saves in .pat file (all bkg events get weight=1), if this option is not supplied, all events get a weight of 1, so still use -a 4 in the Spr* executables")
parser.add_option("", "--start-time", default=0, type="int", help="the gps start time for window of desired coincs")
parser.add_option("", "--end-time", default=999999999, type="int", help="the gps end time for window of desired coincs")
parser.add_option("", "--check-weights", action="store_true", default=False,help="turn on if you want a plot of cumulative weights v. distance")
parser.add_option("", "--exact-tag", default="ring_exact", help="this is the dbinjfind tag stored in the sqlite database for the exactly found injections - the ones you want to use for training")
parser.add_option("", "--nearby-tag", default="ring_nearby", help="this is the dbinjfind tag stored in the sqlite for the nearby injections - we will still rank all of these")

(opts,databases)=parser.parse_args()
ifos=opts.instruments.strip().split(',')
ifos.sort()

time1=time()

class SnglRingdown(snglringdowntable.SnglRingdownTable):
	"""
	You need to make this subclass of xlaltools.SnglRingdownTable because the C version doesn't have the methods
	you need to slide the triggers on the ring, which you need to do for a correct calculation of ds_sq
	"""
	__slots__ = ()

	def get_end(self):
		return dbtables.lsctables.LIGOTimeGPS(self.start_time, self.start_time_ns)

	def set_end(self, gps):
		self.start_time, self.start_time_ns = gps.seconds, gps.nanoseconds

dbtables.lsctables.SnglRingdownTable.RowType = SnglRingdown

parameters = mvsc_queries_ringdown.CandidateEventQuery.parameters
print "your MVSC analysis will use the following dimensions: "+parameters

exact_injections = []
exact_injections_info = []
exact_injections_distance = []
all_injections = []
all_injections_info = []
all_injections_distance = []
normalization = []
zerolag = []
zerolag_info = []
timeslides = []
timeslides_info = []

for database in databases:
	local_disk = None #"/tmp"
	working_filename = dbtables.get_connection_filename(database, tmp_path = local_disk, verbose = True)
	connection = sqlite3.connect(working_filename)
	dbtables.DBTable_set_connection(connection)
	xmldoc = dbtables.get_xml(connection)
	cursor = connection.cursor()
	num_sngl_cols = len(table.get_table(xmldoc, dbtables.lsctables.SnglRingdownTable.tableName).dbcolumnnames)
	rings = db_rinca_rings.get_rinca_rings_by_available_instruments(connection)
	offset_vectors = dbtables.lsctables.table.get_table(dbtables.get_xml(connection), dbtables.lsctables.TimeSlideTable.tableName).as_dict()
	sngl_ringdown_row_from_cols = table.get_table(xmldoc, dbtables.lsctables.SnglRingdownTable.tableName).row_from_cols

	def gQQ(Qa,Qb):
		Q = (Qa+Qb)/2.
		Q2 = Q*Q
		return ( 1. + 28.*Q2*Q2 + 128.*Q2*Q2*Q2 + 64.*Q2*Q2*Q2*Q2) / ( 4. * Q2 * ( 1. + 6.*Q2 + 8.*Q2*Q2 )*( 1. + 6.*Q2 + 8.*Q2*Q2 ) )

	def gff(fa,fb,Qa,Qb):
		f = (fa+fb)/2
		Q = (Qa+Qb)/2
		Q2 = Q*Q
		return ( 1. + 6.*Q2 + 16.*Q2*Q2) / ( 4. * f*f * ( 1. + 2.*Q2 ) )

	def gtt(fa,fb,Qa,Qb):
		f = (fa+fb)/2
		Q = (Qa+Qb)/2
		Q2 = Q*Q
		PI = 3.1415926535897932
		return ( PI*PI * f*f ) * ( 1. + 4.*Q2 ) / ( Q2 )
	def gQf(fa,fb,Qa,Qb):
		f = (fa+fb)/2
		Q = (Qa+Qb)/2
		Q2 = Q*Q
		return - ( 1. + 2.*Q2 + 8.*Q2*Q2 ) / ( 4.*Q*f * ( 1. + 6.*Q2 + 8.*Q2*Q2 ) )

	def gtf(Qa,Qb):
		Q = (Qa+Qb)/2.
		Q2 = Q*Q
		PI = 3.1415926535897932
		return - ( PI * Q ) * ( 1. + 4.*Q2) / ( 1. + 2.*Q2 )

	def gtQ(fa,fb,Qa,Qb):
		f = (fa+fb)/2
		Q = (Qa+Qb)/2
		Q2 = Q*Q
		PI = 3.1415926535897932
		return ( PI * f ) * ( 1. - 2.*Q2 ) / ( ( 1. + 2.*Q2 )*( 1. + 2.*Q2 ) )

	def calc_ds_sq(rowA,rowB,time_slide_id,rings=rings,offset_vectors=offset_vectors):
		flatrings = segments.segmentlist()
		for value in rings.values():
			flatrings.extend(value)
		rowA = sngl_ringdown_row_from_cols(rowA)
		SnglInspiralUtils.slideTriggersOnRings([rowA],flatrings,offset_vectors[time_slide_id])
		rowB = sngl_ringdown_row_from_cols(rowB)
		SnglInspiralUtils.slideTriggersOnRings([rowB],flatrings,offset_vectors[time_slide_id])
		return xlaltools.XLAL3DRinca(rowA,rowB)

	def calc_delta_t(trigger1_ifo, trigger1_time, trigger1_time_ns, trigger2_ifo, trigger2_time, trigger2_time_ns, time_slide_id, rings = rings, offset_vectors = offset_vectors):
		time_slide_id = ilwd.ilwdchar(time_slide_id)
		trigger1_true_time = dbtables.lsctables.LIGOTimeGPS(trigger1_time, trigger1_time_ns)
		trigger2_true_time = dbtables.lsctables.LIGOTimeGPS(trigger2_time, trigger2_time_ns)
		# find the instruments that were on at trigger 1's time and
		# find the ring that contains this trigger
		try:
			[ring] = [segs[segs.find(trigger1_time)] for segs in rings.values() if trigger1_time in segs]
		except ValueError:
			# FIXME THERE SEEMS TO BE A BUG IN THINCA! Occasionally thinca records a trigger on the upper boundary
			# of its ring.  This would make it outside the ring which is very problematic.  It needs to be fixed in thinca
			# for now we'll allow the additional check that the other trigger is in the ring and use it.
				print >>sys.stderr, "trigger1 found not on a ring, trying trigger2"
				[ring] = [segs[segs.find(trigger2_time)] for segs in rings.values() if trigger2_time in segs]
		# now we can unslide the triggers on the ring
		try:
			trigger1_true_time = SnglInspiralUtils.slideTimeOnRing(trigger1_true_time, offset_vectors[time_slide_id][trigger1_ifo], ring)
			trigger2_true_time = SnglInspiralUtils.slideTimeOnRing(trigger2_true_time, offset_vectors[time_slide_id][trigger2_ifo], ring)
			out = abs(trigger1_true_time - trigger2_true_time)
			return float(out)
		except:
			print >> sys.stderr, "calc delta t failed because one of the trigger's true times landed on the upper boundary of the thinca ring. See: trigger 1: ", trigger1_true_time, "trigger 2: ", trigger2_true_time, "ring: ", ring
			out = float(abs(trigger1_true_time - trigger2_true_time)) % 1
			if out > 0.5:
				out = 1.0 - out
			print >> sys.stderr, "SO...delta t has been set to: ", out, "in accordance with the mod 1 hack"
			return out
	connection.create_function("calc_delta_t", 7, calc_delta_t)
	connection.create_function("calc_ds_dq", 5, calc_ds_sq)
        connection.create_function("gQQ", 2, gQQ)
        connection.create_function("gff", 4, gff)
        connection.create_function("gtt", 4, gtt)
        connection.create_function("gQf", 4, gQf)
        connection.create_function("gtf", 2, gtf)
        connection.create_function("gtQ", 4, gtQ)
# in S6, the timeslides, zerolag, and injections are all stored in the same sqlite database, thus this database must include a sim inspiral table
# if you provide a database that does not include injections, the code will still run as long as one of the databases you provide includes injections
	try:
		sim_inspiral_table = table.get_table(xmldoc, dbtables.lsctables.SimInspiralTable.tableName)
		is_injections = True
	except ValueError:
		is_injections = False
# please note that the third to last entry in exact_injections and all_injections is the gps time of the coinc_inspiral. We will only be using this for bookkeeping, it will not stay in the array
	if is_injections:
		for values in connection.cursor().execute(''.join([mvsc_queries_ringdown.CandidateEventQuery.select_dimensions,mvsc_queries_ringdown.CandidateEventQuery.add_select_injections,mvsc_queries_ringdown.CandidateEventQuery.add_from_injections,mvsc_queries_ringdown.CandidateEventQuery.add_where_all,]), (ifos[0],ifos[1],opts.start_time,opts.end_time,opts.nearby_tag,) ):
			all_injections.append((calc_ds_sq(values[1:num_sngl_cols+1],values[num_sngl_cols+1:2*num_sngl_cols+1],ilwd.get_ilwdchar(values[2*num_sngl_cols+1]),rings,offset_vectors,),) + values[2*num_sngl_cols+2:-1] + (1,))
			all_injections_info.append([values[0], database])
			all_injections_distance.append([values[-1], database])
		#for values in connection.cursor().execute(''.join([mvsc_queries_ringdown.CandidateEventQuery.select_dimensions,mvsc_queries_ringdown.CandidateEventQuery.add_select_injections,mvsc_queries_ringdown.CandidateEventQuery.add_where_exact,mvsc_queries_ringdown.CandidateEventQuery.add_from_injections]), (ifos[0],ifos[1],opts.start_time,opts.end_time,opts.exact_tag,) ):
		for values in connection.cursor().execute(''.join([mvsc_queries_ringdown.CandidateEventQuery.select_dimensions,mvsc_queries_ringdown.CandidateEventQuery.add_select_injections,mvsc_queries_ringdown.CandidateEventQuery.add_from_injections,mvsc_queries_ringdown.CandidateEventQuery.add_where_exact,]), (ifos[0],ifos[1],opts.start_time,opts.end_time,opts.exact_tag,) ):
			exact_injections.append((calc_ds_sq(values[1:num_sngl_cols+1],values[num_sngl_cols+1:2*num_sngl_cols+1],ilwd.get_ilwdchar(values[2*num_sngl_cols+1]),rings,offset_vectors,),) + values[2*num_sngl_cols+2:-1] + (1,))
			exact_injections_info.append([values[0], database])
			exact_injections_distance.append([values[-1], database])

	#FIXME: look up coinc_definer_id from definition in pylal
# get the timeslide/full_data triggers
	for values in connection.cursor().execute(''.join([mvsc_queries_ringdown.CandidateEventQuery.select_dimensions,mvsc_queries_ringdown.CandidateEventQuery.add_select_fulldata,mvsc_queries_ringdown.CandidateEventQuery.add_from_fulldata]), (ifos[0],ifos[1],opts.start_time,opts.end_time,) ):
		if values[-1] == 'slide':
			timeslides.append((calc_ds_sq(values[1:num_sngl_cols+1],values[num_sngl_cols+1:2*num_sngl_cols+1],ilwd.ilwdchar(values[2*num_sngl_cols+1]),rings,offset_vectors),) + values[2*num_sngl_cols+2:-1] + (1,) + (0,))
			timeslides_info.append([values[0], database])
		if values[-1] == 'all_data':
			zerolag.append((calc_ds_sq(values[1:num_sngl_cols+1],values[num_sngl_cols+1:2*num_sngl_cols+1],ilwd.ilwdchar(values[2*num_sngl_cols+1]),rings,offset_vectors),) + values[2*num_sngl_cols+2:-1] + (1,) + (0,))
			zerolag_info.append([values[0], database])
	dbtables.put_connection_filename(database, working_filename, verbose = True)

# the weight given to each injection will be equal to 1/sqrt(snr_a^2+snr_b^2)
#FIXME: maybe there are better ways to implement the weighting, please think about it before applying 
newexact_injections=[]
exact_injections_vol=[]
exact_injections_lin=[]
exact_injections_log=[]
newall_injections=[]
if opts.apply_weights:
	print "applying weights for exact_injections"
	for i,row in enumerate(exact_injections):
		injtmp = list(row)
		injtmp[-2]=str(injtmp[-2])
		if injtmp[-2] == 'Ringdown':
			injtmp[-2]=((8**2+8**2)**0.5/(injtmp[10]**2+injtmp[11]**2)**0.5)**3
			exact_injections_log.append([exact_injections_distance[i][0],injtmp[10],injtmp[11],injtmp[-2]])
		if injtmp[-2] == 'uniform':
			injtmp[-2]=((8**2+8**2)**0.5/(injtmp[10]**2+injtmp[11]**2)**0.5)**2
			exact_injections_lin.append([exact_injections_distance[i][0],injtmp[10],injtmp[11],injtmp[-2]])
		if injtmp[-2] == 'log10':
			injtmp[-2]=((8**2+8**2)**0.5/(injtmp[10]**2+injtmp[11]**2)**0.5)**3
			exact_injections_log.append([exact_injections_distance[i][0],injtmp[10],injtmp[11],injtmp[-2]])
		if injtmp[-2] == 'volume':
			injtmp[-2]=1.0
			exact_injections_vol.append([exact_injections_distance[i][0],injtmp[10],injtmp[11],injtmp[-2]])
		newexact_injections.append(tuple(injtmp))
	for row in all_injections:
		injtmp = list(row)
		injtmp[-2]=str(injtmp[-2])
		if injtmp[-2] == 'Ringdown':
			injtmp[-2]=3*((8**2+8**2)**0.5/(injtmp[10]**2+injtmp[11]**2)**0.5)**3
		if injtmp[-2] == 'uniform':
			injtmp[-2]=3*((8**2+8**2)**0.5/(injtmp[10]**2+injtmp[11]**2)**0.5)**2
		if injtmp[-2] == 'log10':
			injtmp[-2]=3*((8**2+8**2)**0.5/(injtmp[10]**2+injtmp[11]**2)**0.5)**3
		if injtmp[-2] == 'volume':
			injtmp[-2]=1.0
		newall_injections.append(tuple(injtmp))
else:
	for row in exact_injections:
		injtmp = list(row)
		injtmp[-2]=1
		newexact_injections.append(tuple(injtmp))
	for row in all_injections:
		injtmp = list(row)
		injtmp[-2]=1
		newall_injections.append(tuple(injtmp))
exact_injections=newexact_injections
all_injections=newall_injections

pickle.dump(exact_injections_vol,open(''.join(ifos) + '_' + opts.output_tag + 'exact_injections_vol.p','wb'))
pickle.dump(exact_injections_lin,open(''.join(ifos) + '_' + opts.output_tag + 'exact_injections_lin.p','wb'))
pickle.dump(exact_injections_log,open(''.join(ifos) + '_' + opts.output_tag + 'exact_injections_log.p','wb'))

weight_dictionary={}
if opts.check_weights:
	for i in range(len(exact_injections_distance)):
		sum=0.0
		distance_i=exact_injections_distance[i][0]
		for j in range(len(exact_injections_distance)):
			if exact_injections_distance[j][0] <= distance_i:
				sum=sum+exact_injections[j][-2]
		weight_dictionary[distance_i]=[sum,exact_injections[i][10],exact_injections[i][11]]
		pylab.loglog(distance_i,sum,'b.')
		pylab.axis('equal')
		pylab.hold(1)
	pickle.dump(weight_dictionary,open(''.join(ifos) + '_' + opts.output_tag + 'weight_dictionary.p','wb'))
	pylab.savefig(''.join(ifos) + '_' + opts.output_tag + 'weight_v_distance_cumulative')
print "there are ", len(timeslides), " timeslide doubles in ", ''.join(ifos), " and triple coincidences"
print "there are ", len(exact_injections), " exactly found injection doubles in ", ''.join(ifos), " and triple coincidences"
print "there are ", len(all_injections), " total found injection doubles in ", ''.join(ifos), " and triple coincidences"
print "there are ", len(zerolag), " zerolag doubles in ", ''.join(ifos), " and triple coincidences"

random.seed(2)
random.shuffle(timeslides)
random.seed(2)
random.shuffle(timeslides_info)

# this part of the code writes the triggers' information into .pat files, in the format needed for SprBaggerDecisionTreeApp
# to get the MVSC rank for each timeslide and injection, we do a round-robin of training and evaluation, with the number of rounds determined by opts.number
# for example, if opts.number is 10, each round will train a random forest of bagged decision trees on 90% of the timeslides and exact_injections
# then we'd run the remaining 10% through the trained forest to get their MVSC rank
# in this case, we'd do this 10 times, ensuring that every timeslide and injection gets ranked
Nrounds = opts.number
Ninj = len(exact_injections)
Nslide = len(timeslides)

trstr = 'training'
evstr = 'evaluation'
zlstr = 'zerolag'

print "there are ", len(timeslides), " timeslide doubles in ", ''.join(ifos), " and triple coincidences"
print "there are ", len(exact_injections), "exactly found injection doubles in ", ''.join(ifos), " and triple coincidences"
print "there are ", len(zerolag), " zerolag doubles in ", ''.join(ifos), " and triple coincidences"

if len(exact_injections) > Nrounds and len(timeslides) > Nrounds:
	Nparams = len(exact_injections[0]) - 3
	Nrounds = opts.number
	Ninj_exact = len(exact_injections)
	Nslide = len(timeslides)
	gps_times_for_all_injections = zip(*all_injections)[-3]
	gps_times_for_exact_injections = zip(*exact_injections)[-3]
	print min(gps_times_for_all_injections), max(gps_times_for_all_injections)
	print min(gps_times_for_exact_injections), max(gps_times_for_exact_injections)
	trstr = 'training' 
	testr = 'evaluation'
	zlstr = 'zerolag'

	def open_file_write_headers(filetype, set_num, ifos, Nparams=Nparams):
		f = open(''.join(ifos) + '_' + opts.output_tag + '_set' + str(set_num) + '_' + str(filetype) +  '.pat', 'w')
		f.write(str(Nparams) + '\n')
		f.write(parameters + "\n")
		return f

# first put in the header information
	for i in range(Nrounds):
		f_training = open_file_write_headers(trstr, i, ifos)
		f_testing = open_file_write_headers(testr, i, ifos)
		f_testing_info=open(''.join(ifos) + '_' + opts.output_tag + '_set' + str(i) + '_' + str(testr) + '_info.pat', 'w')
# now let's do the injections - each training set will have (Nrounds-1)/Nrounds fraction of all exactly found injections
# we need to rig the evaluation sets (which include all exact AND nearby injections) to not be evaluated on a forest that was trained on themselves
# so, we divide up the set of exactly found injections (which is sorted by GPS time) into Nrounds, then use the GPS times at the boundaries to properly divide our evaluation sets
		exact_injections_tmp = list(exact_injections)
		set_i_exact_inj = exact_injections_tmp[i*Ninj_exact/Nrounds : (i+1)*Ninj_exact/Nrounds]
		divisions = [set_i_exact_inj[0][-3], set_i_exact_inj[-1][-3]] #these are the gps boundaries for what should go in the evaluation set
		del(exact_injections_tmp[i*Ninj_exact/Nrounds : (i+1)*Ninj_exact/Nrounds])
# exact_injections_tmp now contains our the exact injections we want to include in the ith training set, let's write them to file
		tmp = zip(*exact_injections_tmp)
		del(tmp[-3]) # we have to remove the GPS times from the array before writing to file
		exact_injections_tmp = zip(*tmp)
		for row in exact_injections_tmp:
			f_training.write("%s\n" % " ".join(map(str,row)))
# now we need to construct the evaluation (aka testing) set (all injections, not just exact anymore) that pairs with this training set
		set_i_all_inj = []
		left_index = bisect.bisect_left(gps_times_for_all_injections,divisions[0])
		right_index = bisect.bisect_right(gps_times_for_all_injections,divisions[1])
		set_i_all_inj = all_injections[left_index:right_index]
		set_i_all_inj_info = all_injections_info[left_index:right_index]
		print "to add:", len(set_i_all_inj)
		tmp = zip(*set_i_all_inj)
		del(tmp[-3]) # we have to remove the GPS times from the array before writing to file
		set_i_all_inj = zip(*tmp)
		for row in set_i_all_inj:
			f_testing.write("%s\n" % " ".join(map(str,row)))
		for row in set_i_all_inj_info:
			f_testing_info.write("%s\n" % " ".join(map(str,row)))
# now let's do the timeslides
		timeslides_tmp = list(timeslides)
		timeslides_info_tmp = list(timeslides_info)
# get (say) 10% of the timeslides and injections, which you will run through the forest that you've trained on the other 90%
		set_i_slide = timeslides_tmp[i*Nslide/Nrounds : (i+1)*Nslide/Nrounds]
		set_i_slide_info = timeslides_info_tmp[i*Nslide/Nrounds : (i+1)*Nslide/Nrounds]
		for row in set_i_slide:
			f_testing.write("%s\n" % " ".join(map(str,row)))
		for row in set_i_slide_info:
			f_testing_info.write("%s\n" % " ".join(map(str,row)))
# delete the 10%, and save the remaining 90% into the training file
		del(timeslides_tmp[i*Nslide/Nrounds : (i+1)*Nslide/Nrounds])
		for row in timeslides_tmp:
			f_training.write("%s\n" % " ".join(map(str,row)))
		if len(zerolag) != 0:
			f_zerolag=open(''.join(ifos) + '_' + opts.output_tag + '_set' + str(i) + '_'+ str(zlstr) + '.pat','w')
			f_zerolag.write(str(Nparams) + '\n')
			f_zerolag.write(parameters + "\n")
			for row in zerolag:
				f_zerolag.write("%s\n" % " ".join(map(str,row)))
			f_zerolag_info=open(''.join(ifos) + '_' + opts.output_tag + '_set' + str(i) + '_' + str(zlstr) + '_info.pat', 'w')
			for row in zerolag_info:
				f_zerolag_info.write("%s\n" % " ".join(map(str,row)))
else: raise Exception, "There were no injections found for the specified ifo combination %s" % ifos

time2=time()
elapsed_time=time2-time1
print "elapsed time:", elapsed_time
