#!/usr/bin/python
try:
	import sqlite3
except ImportError:
	# pre 2.5.x
	from pysqlite2 import dbapi2 as sqlite3
from glue.ligolw import dbtables 
from pylal import SnglInspiralUtils
from pylal import db_thinca_rings
from pylal import git_version
from time import clock,time
from optparse import *
import glob
import sys
import numpy
import pdb
from glue import iterutils

__author__ = "Kari Hodge <kari.hodge@ligo.org>"
__version__ = "git id %s" % git_version.id
__date__ = git_version.date

description="""
put the MVSC rank back into the SQL tables, after it's been turned into a likelihood, we put it in the likelihood column of the coinc_inspiral table
use it like:
mvsc_update_sql --all-instruments=H1,L1,V1 db1.sqlite db2.sqlite evaluation_file1.dat evalution_infofile1.pat etc
pretty much: feed it all your databases, *evaluation*.dat *info.pat *zerolag*.dat *zerolag_info*.pat files
please note: this code has only been tested in the case that all your triggers come from the same database (as is then norm in S6)
if you are working with more than one database, please do some sanity checks!
"""

time1=time()

parser=OptionParser(usage="%prog [options] [file ...]", description = description, version="Name: %%prog\n%s" % git_version.verbose_msg)
parser.add_option("-a","--all-instruments", help = "the list of all instruments from which you want to study the double-coincident triggers. e.g. H1,H2,L1,V1")
parser.add_option("-t","--output-tag", help = "tag to identify, for example, which category your zerolag average rank file is for")
parser.add_option("-o","--dont-reset-likelihoods", default = False, action="store_false", help="provide this option if you don't want to reset all the likelihoods to 1 before updating this time")
parser.add_option("-s","--tmp-space",help="necessary for sqlite calls, for example /usr1/khodge")
(opts,inputfiles)=parser.parse_args()

all_ifos = opts.all_instruments.strip().split(',')
ifo_combinations = list(iterutils.choices(all_ifos,2))
for i,comb in enumerate(ifo_combinations):
	ifo_combinations[i] = ''.join(comb)
print ifo_combinations

files = []
infofiles = []
zerolag_files = []
zerolag_infofiles = []
databases = []
for file in inputfiles:
	extension = file.split('.')[-1]
	name = file.split('.')[0]
	if extension == 'dat': 
		if 'zerolag' in name:
			zerolag_files.append(file)
		else:
			files.append(file)
	elif extension == 'pat': 
		if 'zerolag' in name:
			zerolag_infofiles.append(file)
		else:
			infofiles.append(file)
	elif extension == 'sqlite': 
		databases.append(file)
	else:
		print >> sys.stderr, "Error: args contain %s which is not a database, .dat, or .pat file" % file
		sys.exit(1)

print databases

# pair up each file with the appropriate info file
files.sort()
infofiles.sort()
if len(files) != len(infofiles): 
	print>> sys.stderr, "the number of infofiles does not match the number of files"
	sys.exit(1)
file_map=[]
for i in range(len(files)):
	file_map.append([files[i],infofiles[i]])

def files_to_dict(f1, f1_info, databases, ldict):
	for f in databases: 
		try: ldict[f]
		except: ldict[f] = {}
	f1lines = open(f1).readlines()
	#the first line of the first file is a header
	if f1lines: f1lines.pop(0)
	f1_infolines = open(f1_info).readlines()
	if len(f1lines) != len(f1_infolines): 
		print>> sys.stderr, "file " + f1 + " and " +f1_info + " have different lengths"
		sys.exit(1)
	for i, f1row in enumerate(f1lines):
		likelihood = float(f1row.split()[-1])
		if likelihood == 1: likelihood = float("inf")
		else: likelihood = likelihood / (1.0 - likelihood)
		cid, dfile = f1_infolines[i].split()[0:2]
		ldict[dfile][cid] = likelihood
	return ldict

#for zerolag files, we need to average the ranks gotten by the different trees
zerolag_files.sort()
zerolag_files_comb_dict={}
zerolag_infofiles_comb_dict={}
for comb in ifo_combinations:
	zerolag_files_comb_dict[comb]=[]
	zerolag_infofiles_comb_dict[comb]=[]
	for zlfile in zerolag_files:
		if comb in zlfile: 
			zerolag_files_comb_dict[comb].append(zlfile)
	for zlinfofile in zerolag_infofiles:
		if comb in zlinfofile: 
			zerolag_infofiles_comb_dict[comb].append(zlinfofile)
for comb in ifo_combinations:                                  
	tmp = None
	zerolag_header = []
	tmplist=zerolag_files_comb_dict[comb]
	if tmplist: 
		for zlfile in tmplist:
			if tmp != None:
				data = numpy.vstack((data,numpy.loadtxt(zlfile,dtype=float,skiprows=1,usecols=[-1])))
				zerolag_header.append(str(zlfile))
			else:
				tmp = numpy.loadtxt(zlfile,dtype=float,skiprows=1,usecols=[-1])
				data = tmp
				zerolag_header.append(str(zlfile))
		stddevs=numpy.std(data,axis=0)
		zerolag_header.append('stddev')
		means=numpy.mean(data,axis=0)
		zerolag_header.append('mean_rank')
		#make a new array where each row has the ranks given by the different forests, then the mean, then the standard deviation
		zerolag_header=numpy.array(zerolag_header)
		zerolag_array=numpy.vstack((data,stddevs,means)).transpose()
		prefix_string = zlfile.split('_')[0]+'_'
		with open(prefix_string+opts.output_tag+'_zerolag_ranks_stddev_mean.txt','w') as fzl:
			for word in zerolag_header:
				fzl.write(word+' ')
			fzl.write('\n')
			numpy.savetxt(fzl,zerolag_array)
		# add zerolag to the file map
		file_map.append([prefix_string+opts.output_tag+'_zerolag_ranks_stddev_mean.txt',zerolag_infofiles_comb_dict[comb][0]])
print file_map


# first, initialize the databases, so that all of the likelihood values are set to 1
if opts.dont_reset_likelihoods==False:
	print "presetting likelihoods to 1"
	for database in databases:
		filename = database
		local_disk = opts.tmp_space
		working_filename = dbtables.get_connection_filename(filename, tmp_path = local_disk, verbose = True)
		connection = sqlite3.connect(working_filename)
		xmldoc = dbtables.get_xml(connection)
		cursor = connection.cursor()
		connection.cursor().execute("""
		UPDATE coinc_event
			SET likelihood = 1
		""")
		connection.commit()
		dbtables.put_connection_filename(filename, working_filename, verbose = True)

ldict = {}
for pair in file_map:
	files_to_dict(str(pair[0]),str(pair[1]),databases, ldict)

for f in databases:
	filename = f
	local_disk = opts.tmp_space
	working_filename = dbtables.get_connection_filename(filename, tmp_path = local_disk, verbose = True)
	connection = sqlite3.connect(working_filename)
	xmldoc = dbtables.get_xml(connection)
	cursor = connection.cursor()
	def likelihood_from_coinc_id(id, ldict_f = ldict[f]):
		try: 
			return ldict_f[id]
		except: return 1
	if ldict[f]:	
		connection.create_function("likelihood_from_coinc_id", 1, likelihood_from_coinc_id)		
		cursor.execute("""
		UPDATE coinc_event
			SET likelihood = likelihood * likelihood_from_coinc_id(coinc_event.coinc_event_id)
		""")
		connection.commit() 
	dbtables.put_connection_filename(filename, working_filename, verbose = True)

time2=time()
elapsed_time=time2-time1
print elapsed_time


