#!/usr/bin/python
from math import *
import glue
from glue.ligolw import ligolw
from glue.ligolw import lsctables
from glue.ligolw import table
from glue.ligolw.utils import ligolw_add
from glue.ligolw import utils
from glue import segments
import sys

from optparse import OptionParser
from glue.lal import LIGOTimeGPS

from pylal import git_version
__author__ = "Chad Hanna <channa@ligo.caltech.edu>, Satya Mohapatra <satyanarayan.raypitambarmohapatra@ligo.org>"
__version__ = "git id %s" % git_version.id
__date__ = git_version.date


def create_tables(xmldoc, verbose=False):
	if verbose: print >> sys.stderr, "Creating tables..."

	try: t = table.get_table(xmldoc,"sngl_burst")
	except: 
		sngl_burst = lsctables.New(lsctables.SnglBurstTable)
		xmldoc.childNodes[0].appendChild(sngl_burst)

	try: t = table.get_table(xmldoc,"search_summary")
	except:	
		search_summary = lsctables.New(lsctables.SearchSummaryTable)
		xmldoc.childNodes[0].appendChild(search_summary)

	try: t = table.get_table(xmldoc,"multi_burst")
	except:
		multi_burst = lsctables.New(lsctables.MultiBurstTable)
		xmldoc.childNodes[0].appendChild(multi_burst)

	try: t = table.get_table(xmldoc,"coinc_event_map")
	except:
		coinc_event_map = lsctables.New(lsctables.CoincMapTable)
		xmldoc.childNodes[0].appendChild(coinc_event_map)

	try: t = table.get_table(xmldoc,"coinc_event")
	except:
		coinc_event = lsctables.New(lsctables.CoincTable)
		xmldoc.childNodes[0].appendChild(coinc_event)

	try: t = table.get_table(xmldoc,"process")
	except:
		process = lsctables.New(lsctables.ProcessTable)
		xmldoc.childNodes[0].appendChild(process)

	try: t = table.get_table(xmldoc,"coinc_definer")
	except:
		coinc_definer = lsctables.New(lsctables.CoincDefTable)
		xmldoc.childNodes[0].appendChild(coinc_definer)

	try: t = table.get_table(xmldoc,"time_slide")
	except:
		time_slide = lsctables.New(lsctables.TimeSlideTable)
		xmldoc.childNodes[0].appendChild(time_slide)
	
def get_instruments_from_segments(xmldoc):

	#TODO make sure that segments have sensible *on* instruments
	try: st = table.get_table(xmldoc,"segment_definer")	
	except: 
		print sys.stderr, """
segment definer table not found, cannot continue.  
Hint you need to run ligolw_segments on the omega segments file
to make an xml file that you can put on the command line
		"""
		sys.exit(1)
	if len(st) > 1:
		print sys.stderr, "Too many segment definer rows, I cannot figure out the intsruments analyzed"
		sys.exit(1)
	return st[0].ifos

def populate_search_summary_from_segments(xmldoc,ifos=None, verbose=False):
	t = table.get_table(xmldoc,"search_summary")
	try: st = table.get_table(xmldoc,"segment")	
	except: 
		print sys.stderr, """
segment table not found, cannot continue.  
Hint you need to run ligolw_segments on the omega segments file
to make an xml file that you can put on the command line
		"""
		sys.exit(1)
	for i,s in enumerate(st):
		if verbose: print >>sys.stderr, "reconstructing search summary: %.1f%%\r" % (100.0 * float(i+1) / len(st),),
		row = t.RowType()
		row.process_id = add_process(xmldoc, ifos)
		row.shared_object = None
		row.lalwrapper_cvs_tag = None
		row.lal_cvs_tag = None
		row.comment = "Awesome"
		row.ifos = ifos
		seg = s.get()
		#FIXME adjust for what omega actually analyzed 
		row.set_in(seg)
		row.set_out(seg)
		row.nevents = None
		row.nnodes = None
		t.append(row)	
	if verbose: print >> sys.stderr, ""

def match_sngl_burst_to_search_summary(xmldoc, verbose=False):
	search_summary = table.get_table(xmldoc,"search_summary")
        proc_ids = [r.process_id for r in search_summary]
	sngl_burst = table.get_table(xmldoc,"sngl_burst")
	seg_list = search_summary.get_outlist()
	
	for j, brow in enumerate(sngl_burst):
		if verbose: print >>sys.stderr, "matching sngl burst to search summary: %.1f%%\r" % (100.0 * float(j + 1) / len(sngl_burst), ),
		try: 
			idx = seg_list.find(brow.get_peak())
			brow.process_id = proc_ids[idx]
		except: 
			print sys.stderr, "Trigger %d found outside of segment!! Aborting!" % (j,)
			sys.exit(1)

	if verbose: print >>sys.stderr, ""

def create_time_slide_row(xmldoc, offset):
	t = table.get_table(xmldoc, "time_slide")
	#FIXME FOR S6 we need different scheme
	offset = {'H1':0.0, 'H2':0.0, 'L1':offset}
	#offset = {'H1':0.0, 'L1':offset}
	time_slide_id = t.get_next_id()
	for k in offset.keys():
		time_slide_id = t.get_next_id()
		row = t.RowType()
		row.time_slide_id = time_slide_id
		row.instrument = k
		row.offset = offset[k]
		row.process_id = None
		t.append(row)
	return time_slide_id 

def create_coinc_def_row(xmldoc):
	try: 
		t = lsctables.New(lsctables.CoincDefinerTable)
		xmldoc.childNodes[0].appendChild(t)
	except: 
		t = table.get_table(xmldoc,"coinc_definer")
	CoincDef = lsctables.CoincDef(search = u"omega", search_coinc_type = 0, description = u"sngl_burst<-->sngl_burst coincidences")
	CoincDef.coinc_def_id = t.get_next_id()
	t.append(CoincDef)
	return CoincDef.coinc_def_id

class Coinc(object):

	def __init__(self, row, xmldoc, ifos, on_instruments, time_slide_id, coinc_def_id):
		self.xmldoc = xmldoc
		self.time = {}
		self.duration = {}
		self.frequency = {}
		self.z = {}
		self.bandwidth = {}
		self.zinc = {}
		self.offset = {}
		self.time_slide_id = None
		self.on_instruments = on_instruments
		#FIXME use utilities in glue to sort out instrument keys
		self.instrument_combos = ifos
		self.instruments = ifos.split(',')
		self._add_coinc(row, ifos)
		self.time_slide_id = time_slide_id
		self.coinc_def_id = coinc_def_id
		

	def add_to_coinc_event_table(self):
		t = table.get_table(self.xmldoc,"coinc_event")
		row = t.RowType()
		row.coinc_event_id = t.get_next_id()
		row.set_instruments(self.on_instruments.split(','))
		row.likelihood = 1.0
		row.nevents = len(self.instruments)
		t.append(row)	
		#FIXME add the rest of the colums
		row.process_id = None #table.get_table(self.xmldoc,"process")[0].process_id
		row.coinc_def_id = self.coinc_def_id
		self.add_to_multi_burst_table(self.instrument_combos, row.coinc_event_id)
		row.time_slide_id = self.time_slide_id
		for i in self.instruments:
			event_id = self.add_to_sngl_burst_table(i)
			self.add_to_coinc_event_map(row.coinc_event_id, event_id)

	def add_to_multi_burst_table(self,instrument_combos, coinc_event_id):
		c = table.get_table(self.xmldoc,'multi_burst')
		row = c.RowType()
		row.coinc_event_id = coinc_event_id
		row.creator_db = None
		row.process_id = None
		row.ifos = instrument_combos
		row.start_time = None
		row.start_time_ns = None
		row.duration = self.duration[instrument_combos]
		row.set_peak(self.time[instrument_combos])
		row.central_freq = self.frequency[instrument_combos]
		row.bandwidth = self.bandwidth[instrument_combos]
		row.amplitude = None
		row.snr = (2.0 * self.z[instrument_combos])**(0.5)
		row.confidence = None
		row.ligo_axis_ra = None
		row.ligo_axis_dec = None
		row.ligo_angle = None
		row.ligo_angle_sig = None
		row.filter_id = None
		row.false_alarm_rate = None
		c.append(row)
		
	def add_to_coinc_event_map(self, coinc_event_id, event_id, table_name='sngl_burst'):
		t = table.get_table(self.xmldoc,"coinc_event_map")
		row = t.RowType()
		row.coinc_event_id = coinc_event_id
		row.table_name = table_name
		row.event_id = event_id
		t.append(row)
		
		
	def add_to_sngl_burst_table(self, instrument):
		b = table.get_table(self.xmldoc, 'sngl_burst')
		row = b.RowType()
		#hint event_id have a table_name attribute, use this for coinc_map
		row.event_id = b.get_next_id()
		row.ifo = instrument
		row.creator_db = None
		row.filter_id = None
		row.search = "omega"
		#FIXME figure out the channel???
		row.channel = None
		row.start_time = None
		row.start_time_ns = None
		row.stop_time = None
		row.stop_time_ns = None
		row.duration = self.duration[instrument]
		#FIXME figure this out?
		row.flow = None
		row.fhigh = None
		row.central_freq = self.frequency[instrument]
		row.bandwidth = self.bandwidth[instrument]
		row.amplitude = None
		row.snr = (2.0 * self.z[instrument])**(0.5)
		row.confidence = None
                row.chisq = None
                row.chisq_dof = None
		#FIXME is this right?
		row.tfvolume = row.bandwidth * row.duration
		row.hrss = None
		row.time_lag = None
		row.set_peak(self.time[instrument])
		row.peak_frequency = None
		row.peak_strain = None
		row.peak_time_error = None
		row.peak_frequency_error = None
		row.peak_strain_error = None
		row.ms_start_time = None
		row.ms_start_time_ns = None
		row.ms_stop_time = None
		row.ms_stop_time_ns = None
		row.ms_duration = None
		row.ms_flow = None
		row.ms_fhigh = None	
		row.ms_bandwidth = None
		row.ms_snr = None
		row.ms_confidence = None
		row.ms_hrss = None
		row.param_one_name = None
		row.param_one_value = None
		row.param_two_name = None
		row.param_two_value = None
		row.param_three_name = None
		row.param_three_value = None		
		#FIXME WE NEED PROCESS TABLE INFO
		row.process_id = None
 		b.append(row)
		return row.event_id

	def add_row_dict(self, key, row_part):
		self.time[key] = LIGOTimeGPS(float(row_part[0]))
		self.frequency[key] = float(row_part[1])
		self.duration[key] = float(row_part[2])
		self.bandwidth[key] = float(row_part[3])
		self.z[key] = float(row_part[4])

	def _add_coinc(self, row, ifocombos=None):
		#FIXME ASSUMES TXT FILES HAVE FIXED COLUMNS
		if ifocombos == "H1,H2,L1": self._add_H1H2L1(row.split())	
		if ifocombos == "H1,L1": self._add_H1L1(row.split())
		if ifocombos == "H2,L1": self._add_H2L1(row.split())
		if ifocombos == "H1,H2": self._add_H1H2(row.split())

	def _add_H1H2(self, row):
		for i,k in enumerate(['H1,H2','H1','H2']):
			self.add_row_dict(k, row[i*5:i*5+5])

	def _add_H1L1(self, row):
		for i,k in enumerate(['H1,L1','H1','L1']):
			self.add_row_dict(k, row[i*5:i*5+5])

	def _add_H2L1(self, row):
		for i,k in enumerate(['H2,L1','H2','L1']):
			self.add_row_dict(k, row[i*5:i*5+5])

	def _add_H1H2L1(self, row):
		for i,k in enumerate(['H1,H2,L1','H1','H2','L1']):
			self.add_row_dict(k, row[i*5:i*5+5])  

def add_process(xmldoc, ifos):
	t = table.get_table(xmldoc,"process")
	row = t.RowType()
	row.process_id = t.get_next_id()
	row.program = "omega_to_coinc"
	row.version = None
	row.cvs_repository = None
	row.cvs_entry_time = None
	row.comment = "a simple tool to convert omega output to coinc tables"
	row.is_online = None
	row.node = None
	row.username = None
	row.unix_procid = None
	row.start_time = None
	row.end_time = None
	row.jobid = None
	row.domain = None
	row.ifos = ifos
	t.append(row)
	return row.process_id

def parse_command_line():
	parser = OptionParser(
		version = "%prog CVS $Id$",
		usage = '%prog --instruments="instrument1,instrument2,.." [options] [<trigger_file.txt> <sim_inspiral>.xml <segments>.xml ...]',
		description = """
%prog is a format converter from omega output to coinc tables useful for IMR 
searches.  It assumes a standard Omega txt file with entries 
time, frequency, duration, bandwidth, z white space delimited.
It should have 5 columns per ifo combination and single ifo involved. Hint, 
you should give it the sim_inspiral table used for injections if injections 
were done and you MUST give a segment xml file for the live times....
	""")
	parser.add_option("-v", "--verbose", action = "store_true", help = "Be verbose.")
	parser.add_option("-o", "--offset", type=float, default = 0, help = "Set offset value for L1 default 0 (FIXME for S6)")
	parser.add_option("-f", "--output-file", default = "omega_to_coinc.xml.gz", help = "Set the output file name")
	parser.add_option("-i", "--instruments", help="Set the instruments involved in coinc, example H1,H2,L1")
	options, files = parser.parse_args()

	if not options.instruments:
		raise ValueError, "--instruments required"

	urls = [f for f in files if ".xml" in f]
	trigger_files = [f for f in files if '.txt' in f]

	return options, urls, trigger_files


### MAIN ###
###############################################################################

opts, xmls, trigger_files = parse_command_line()

#FIXME this needs to be a command line argument
ifos = opts.instruments

xmldoc = ligolw_add.ligolw_add(ligolw.Document(), xmls, verbose=opts.verbose)

#Check for valid segments before continuing
on_instruments = get_instruments_from_segments(xmldoc)
print on_instruments

create_tables(xmldoc, opts.verbose)

coinc_def_id = create_coinc_def_row(xmldoc)

populate_search_summary_from_segments(xmldoc, ifos, opts.verbose)

time_slide_id = create_time_slide_row(xmldoc, opts.offset)

#add_process(xmldoc, ifos)

for f in trigger_files:
	for i,row in enumerate(open(f).readlines()):
		if '#' in row: continue
		if opts.verbose: print>> sys.stderr, "processing trigger:\t %d\r" % (i,),
		#FIXME instruments and offset need to be command line arguments
		coinc = Coinc(row, xmldoc, ifos, on_instruments, time_slide_id, coinc_def_id)
		coinc.add_to_coinc_event_table()



#match sngls to search summary table
match_sngl_burst_to_search_summary(xmldoc, opts.verbose)  

utils.write_filename(xmldoc, opts.output_file, verbose = opts.verbose, gz = (opts.output_file or "stdout").endswith(".gz"))

xmldoc.unlink()

