#!/usr/bin/python
import sys, os
from optparse import OptionParser
from pylal import rate
from pylal import SimInspiralUtils
import scipy
import numpy
import matplotlib
matplotlib.use('Agg')
import pylab
from math import *
import sys
import glob
import copy
from glue.ligolw import ligolw, utils
from glue import lal

from pylal import git_version
__author__ = "Chad Hanna <channa@ligo.caltech.edu>, Kipp Cannon <kipp.cannon@ligo.org>, Satya Mohapatra <satya@physics.umass.edu>"
__version__ = "git id %s" % git_version.id
__date__ = git_version.date

# FIXME:  the plots should use LaTeX to generate the labels, but some don't
# work yet
matplotlib.rcParams.update({
	"text.usetex": False
})

# FIXME, I apparently don't know how to read XML as cleverly as I thought, how do I find the table
# without specifying the childnode number?

def get_combined_array(tablename, childnode, files):
	# FIXME assumes that all the xml files have the same binned array tables
	# Figure out the shape of the arrays in the file, make an array with one more
	# dimension, the number of files from sys.argv[1:]
	xmldoc = utils.load_filename(files[0], verbose=True, gz = (files[0] or "stdin").endswith(".gz"))
	xmldoc = xmldoc.childNodes[0]
	A  = rate.BinnedArray.from_xml(xmldoc.childNodes[childnode], tablename) 
	bins = rate.NDBins.from_xml(xmldoc.childNodes[childnode])
	out = numpy.zeros((len(files),)+A.array.shape,dtype="float")
	# Read the data
	for i, f in enumerate(files):
		xmldoc = utils.load_filename(f, verbose=True, gz = (f or "stdin").endswith(".gz"))
		xmldoc = xmldoc.childNodes[0]
		out[i] = rate.BinnedArray.from_xml(xmldoc.childNodes[childnode], tablename).array
	A.array = numpy.zeros(A.array.shape)
	return bins, out, A 


def istrue(arg):
	return True


def posterior(VT, sigmasq, Lambda):
	'''
	This function implements the analytic marginalization in 
	Biswas, Creighton, Brady, Fairhurst (25)
	Cause hey, why not?
	This takes arrays of VT, sigma and Lambda to combine.
	'''

	length = 100000
	#FIXME does this need to be an integer???
	K = VT**2 / sigmasq

	#FIXME, drew said this was cool?
	mu = numpy.arange(length) * 100.0 / VT.sum() / length

	post = numpy.ones(len(mu), dtype="float")

	for vt, k, lam in zip(VT, K, Lambda):
		# FIXME maybe this should be smarter?
		# a total lack of injections can screw this up.  If vt = 0 move on. 
		if vt <= 0 or k <= 0: continue
		post *= vt / (1.0 + lam) * ( (1.0 + mu * vt / k)**(-k-1) + (mu * vt * lam * (1.0 + 1.0/k) /(1.0 + mu * vt / k)**(k+2)) )

	return mu, post

def integrate_posterior(mu, post, conf):
	cumpost = post.cumsum()/post.sum()
	#if you can do it, maybe you can't cause the volume is zero
	try: val = [idx for idx in range(len(cumpost)) if cumpost[idx] >= conf][0]
	except: val = 0
	return mu[val]

def get_spin_ranges(bins, mbA):
	spin1z = []
	spin2z = []
	for i,s1 in enumerate(bins.lower()[0]):
		for j,s2 in enumerate(bins.lower()[1]):
			if not mbA[i][j]:
				spin1z.append(s1)
				spin1z.append(bins.upper()[0][i])	
				spin2z.append(s2)
				spin2z.append(bins.upper()[1][j])
	return [min(spin1z), max(spin1z)], [min(spin2z), max(spin2z)]

def parse_command_line():
	parser = OptionParser(
		version = "Name: %%prog\n%s" % git_version.verbose_msg
	)
	parser.add_option("","--input-cache", help="input cache containing only the databases you want to run on (you can also list them as arguments, but you should use a cache if you are afraid that the command line will be too long.)")
	options, filenames = parser.parse_args()
	if not filenames: filenames = []
	if options.input_cache: filenames.extend([lal.CacheEntry(l).path for l in open(options.input_cache).readlines()])

	return options, (filenames or [])
				

opts, files = parse_command_line()

# bins are the same for each call and ulA is an empty binnedArray that has the right shape
# that can hold the upperlimit when we get around to computing it later, so it is okay
# that bins, and ulA are overwritten in each call. vA, vA2 and dvA are the important ones
bins, vA, ulA = get_combined_array("2DsearchvolumeFirstMoment", 0, files)
#FIXME Hack to give volume that is zero a value = 0.01
vA[vA==0] = 0.01

bins, vA2, ulA = get_combined_array("2DsearchvolumeSecondMoment", 1, files)
bins, dvA, ulA = get_combined_array("2DsearchvolumeDerivative", 2, files)
bins, vAD, ulA = get_combined_array("2DsearchvolumeDistance", 3, files)
bins, ltA, tmp= get_combined_array("2DsearchvolumeLiveTime", 4, files)
numfiles = len(files)

# track livetime bins that are 0, they are outside of the search space
# this is a useful array for computing only meaningful things
mbA = numpy.zeros(ltA[0].shape)
mbA[ltA[0] == 0] = 1


s1range, s2range = get_spin_ranges(bins, mbA)

#bin edges Number of bins + 1 for pcolor
X = numpy.array( list(bins.lower()[0]) + [bins.upper()[0][-1]] )
Y = numpy.array( list(bins.lower()[1]) + [bins.upper()[1][-1]] )

f = pylab.figure(1)
###############################################################################
# loop over all the filenames and spin and compute the posteriors separately
###############################################################################

for i, f in enumerate(files):
	legend_str = []
	lines = []
	#FIXME we shouldn't get ifos from filenames, we should put it in the xml :(
	ifos = os.path.split(f)[1].replace('.xml','').replace("2Dsearchvolume-","")
	#FIXME it is stupid to pull out names this way
	combined_ifos = "_".join(ifos.split("_")[:-1])
	wiki = open(ifos+'_range_summary.txt',"w")
	wiki.write("||'''Spin'''||'''Range(Mpc)'''||'''Time'''||'''UL @ 90%'''||'''Fractional error'''||\n")
	# loop over spin bins
  
	ulA.array *= 0

	for j, s1 in enumerate(bins.centres()[0]):
		for k, s2 in enumerate(bins.centres()[1]):
			spins = bins[s1,s2]
			if mbA[spins[0], spins[1]]: continue
			legend_str.append("%.1f, %.1f" % (s1, s2))
			mu,post = posterior(vA[i:i+1,spins[0],spins[1]], vA2[i:i+1,spins[0],spins[1]], dvA[i:i+1,spins[0],spins[1]])
			lines.append(pylab.loglog(mu,post/post.max()))
			ulA.array[j][k] = integrate_posterior(mu, post, 0.90)
			wiki.write("||%.2f,%.2f||%.2f|| %.2f || %.2e || %.2f ||\n" % (s1, s2, vAD[i,spins[0],spins[1]], ltA[i, spins[0], spins[1]], ulA.array[j][k], vA2[i, spins[0], spins[1]]**0.5 / (vA[i, spins[0], spins[1]])) )


  
	fudge = 0.01 * min (ulA.array[ulA.array !=0])
	log_vol = pylab.log10(vA[i])
	distance = vAD[i]
	#HACKS FOR LOG PLOT :(
	log_ul = pylab.log10(ulA.array + fudge)
	vol_error = vA2[i]**0.5 / (vA[i] + 0.0001)
	der = dvA[i]
	log_der = pylab.log10(dvA[i] + 0.000000001)
	# set points outside spin space to 1 (log(1) = 0)
	log_ul[mbA == 1] = 0
	log_der[mbA == 1] = -3

	##
	# Make posterior plots
	##

	#f = pylab.figure(i)
	pylab.title("%s \nposteriors for a few spin bins" % (ifos,),fontsize=14)
	leg = pylab.figlegend(lines, legend_str, 'lower right')
	leg.prop.set_size(6)
	leg.pad = 0
	pylab.ylabel("Prob (unnormalized)",fontsize=14)
	pylab.xlabel("Rate",fontsize=14)
	pylab.ylim([0.0001, 1])
	pylab.grid()
	#FIXME hardcoded rate limits are bad for advanced ligo
	pylab.xlim([1e-8, 1])
	pylab.savefig(ifos + '_posterior.png')
	pylab.clf()

	##
	# Make log volume x time plot
	##

	#FIXME make it gray for pub?
	#pylab.gray()
	#FIXME setting these limits on the scale won't work in adv LIGO :)
	#print "printing X"
	#print numpy.shape(X)
	#print "printing Y"
	#print numpy.shape(Y)
	#print "printing lov_vol"
	#print numpy.shape(log_vol)
	pylab.pcolor(X,Y, log_vol, vmin=0, vmax=10)
	pylab.colorbar()
	pylab.ylim(s1range)
	pylab.xlim(s2range)
	pylab.title(ifos + " \nLog10[< Volume * Time>] in Mpc^3 yr",fontsize=14)
	pylab.xlabel("Spin 2",fontsize=14)
	pylab.ylabel("Spin 1",fontsize=14)
	pylab.gca().set_aspect(1)
	pylab.grid()
	pylab.savefig(ifos+'_volume_time.png')
	pylab.clf()

	##
	# Make log volume plot
	##

	#FIXME make it gray for pub?
	#pylab.gray()
	#FIXME setting these limits on the scale won't work in adv LIGO :)
	pylab.pcolor(X,Y, numpy.log10(distance) , vmin=0, vmax=3)
	pylab.colorbar()
	pylab.ylim(s1range)
	pylab.xlim(s2range)
	pylab.title(ifos + " \nlog10[< distance >] in Mpc",fontsize=14)
	pylab.xlabel("Spin 2",fontsize=14)
	pylab.ylabel("Spin 1",fontsize=14)
	pylab.gca().set_aspect(1)
	pylab.grid()
	pylab.savefig(ifos+'_distance.png')
	pylab.clf()

	## 
	# Make vol error plot	
	##
	#FIXME we don't show errors greater than 100%
	pylab.pcolor(X,Y, vol_error, vmin=0, vmax=1)
	pylab.colorbar()
	pylab.ylim(s1range)
	pylab.xlim(s2range)
	pylab.title(ifos + " \nFractional Error on Volume * Time [std/mean]",fontsize=14)
	pylab.xlabel("Spin 2",fontsize=14)
	pylab.ylabel("Spin 1",fontsize=14)
	pylab.gca().set_aspect(1)
	pylab.grid()
	pylab.savefig(ifos+'_fractional_error.png')
	pylab.clf()

	##
	# Make lambda plot
	##

	#FIXME hard limits
	pylab.pcolor(X,Y, log_der, vmin=-3, vmax=3)
	pylab.colorbar()
	pylab.ylim(s1range)
	pylab.xlim(s2range)
	pylab.title(ifos + " \nLog10[fg/bg likelihood ratio(Lambda)]",fontsize=14)
	pylab.xlabel("Spin 2",fontsize=14)
	pylab.ylabel("Spin 1",fontsize=14)
	pylab.gca().set_aspect(1)
	pylab.grid()
	pylab.savefig(ifos + '_lambda.png')
	pylab.clf()

	##
	# Make UL plot
	##
	#FIXME hard limits at log(rate) of -7, 0
	pylab.pcolor(X,Y, log_ul, vmin=-7, vmax=0)
	pylab.colorbar()
	pylab.ylim(s1range)
	pylab.xlim(s2range)
	pylab.title(ifos + " \nLog10[90% upper limit] in mergers/Mpc^3/yr",fontsize=14)
	pylab.xlabel("Spin 2",fontsize=14)
	pylab.ylabel("Spin 1",fontsize=14)
	pylab.gca().set_aspect(1)
	pylab.grid()
	pylab.savefig(ifos + '_upper_limit.png')
	pylab.clf()


if vA.shape[0] >= 2:
 for i in range(vA.shape[0]):
  for j in range(vA.shape[0]): 
   if j>i:
    pipelines = [os.path.split(files[i])[1].replace('.xml','').replace("2Dsearchvolume-","").replace("-H1H2L1","").replace("s1zs2z-",""), os.path.split(files[j])[1].replace('.xml','').replace("2Dsearchvolume-","").replace("-H1H2L1","").replace("s1zs2z-","")]
    vA0 = vA[i]
    print vA0.min()
    vA0[vA0 == vA0.min()] = 1.0
    vA1 = vA[j]
    vA1[vA1 == vA1.min()] = 1.0
    print vA1.min()
    #pylab.pcolor(X,Y, pylab.log10(vA0 /  vA1), vmin=0, vmax=3)
    pylab.pcolor(X,Y, vA0 /  vA1, vmin=0.1, vmax=2)
    pylab.colorbar()
    pylab.ylim(s1range)
    pylab.xlim(s2range)
    pylab.title( pipelines[0] + "/" + pipelines[1] ,fontsize=14)
    pylab.xlabel("Spin 2",fontsize=14)
    pylab.ylabel("Spin 1",fontsize=14)
    #pylab.gca().set_aspect(1)
    pylab.grid()
    pylab.savefig("s1zs2z" + "_" + pipelines[0] + "_" + pipelines[1] + "_" + 'distance_ratio.png')
    pylab.clf()




###############################################################################
# now write out the special combined case
###############################################################################
lines = []
legend_str = []
ulA.array *= 0

for j, s1 in enumerate(bins.centres()[0]):
	for k, s2 in enumerate(bins.centres()[1]):
		spins = bins[s1,s2]
		if mbA[spins[0], spins[1]]: continue
		legend_str.append("%.1f, %.1f" % (s1, s2))
		mu,post = posterior(vA[...,spins[0],spins[1]], vA2[...,spins[0],spins[1]], dvA[...,spins[0],spins[1]])
		lines.append(pylab.loglog(mu,post/post.max()))
		ulA.array[j][k] = integrate_posterior(mu, post, 0.90)

#HACKS FOR LOG PLOT :(
# fudge is 1% effect
fudge = 0.01 * min (ulA.array[ulA.array !=0])
log_ul = pylab.log10(ulA.array + fudge)

# set points outside spin space to 1 (log(1) = 0)
log_ul[mbA == 1] = 0

# Make posterior plots
pylab.title("Combined posteriors for a few spin bins",fontsize=14)
leg = pylab.figlegend(lines, legend_str, 'lower right')
leg.prop.set_size(6)
leg.pad = 0
leg.ncol = 2
pylab.ylabel("Prob (unnormalized)",fontsize=14)
pylab.xlabel("Rate",fontsize=14)
pylab.ylim([0.0001, 1])
pylab.grid()
#FIXME hardcoded rate limits are bad for advanced ligo
pylab.xlim([1e-8, 1])
#pylab.savefig(combined_ifos + 's1s2_posterior.png')
pylab.savefig(ifos + '_posterior.png')
pylab.clf()


##
# Make UL plot
##
#FIXME hard limits at log(rate) of -7, 0
pylab.pcolor(X,Y, log_ul, vmin=-7.0, vmax=0)
pylab.colorbar()
pylab.ylim(s1range)
pylab.xlim(s2range)
pylab.title("Combined Log10[90% upper limit] in mergers/Mpc^3/yr",fontsize=14)
pylab.xlabel("Spin 2",fontsize=14)
pylab.ylabel("Spin 1",fontsize=14)
pylab.gca().set_aspect(1)
pylab.grid()
#pylab.savefig(combined_ifos + 's1s2_upper_limit.png')
pylab.savefig(ifos + '_upper_limit.png')
pylab.clf()

f = open(ifos + '_upper_limit.txt','w')
for i, s1 in enumerate(bins.centres()[0]):
	for j, s2 in enumerate(bins.centres()[1]):
		spins = bins[s1,s2]
		if mbA[spins[0], spins[1]]: continue
		f.write("||%.2f||%.2f||%.2e||\n" % (s1,s2,log_ul[i,j]))
f.close()

print >> sys.stderr, "ALL FINNISH!"
sys.exit(0)
