#!/usr/bin/python
import sys, os
from optparse import OptionParser
from pylal import rate
from pylal import SimInspiralUtils
import scipy
import numpy
import matplotlib
matplotlib.use('Agg')
import pylab
from math import *
import sys
import glob
import copy
from glue.ligolw import ligolw, utils
from glue import lal
from scipy import interpolate

from pylal import git_version
__author__ = "Satya Mohapatra <satyanarayan.raypitambarmohapatra@ligo.org>"
__version__ = "git id %s" % git_version.id
__date__ = git_version.date


matplotlib.rcParams.update({
	"text.usetex": False
})

def get_combined_array(tablename, childnode, files):
	xmldoc = utils.load_filename(files[0], verbose=True, gz = (files[0] or "stdin").endswith(".gz"))
	xmldoc = xmldoc.childNodes[0]
	A  = rate.BinnedArray.from_xml(xmldoc.childNodes[childnode], tablename) 
	bins = rate.NDBins.from_xml(xmldoc.childNodes[childnode])
	out = numpy.zeros((len(files),)+A.array.shape,dtype="float")
	# Read the data
	for i, f in enumerate(files):
		xmldoc = utils.load_filename(f, verbose=True, gz = (f or "stdin").endswith(".gz"))
		xmldoc = xmldoc.childNodes[0]
		out[i] = rate.BinnedArray.from_xml(xmldoc.childNodes[childnode], tablename).array
	A.array = numpy.zeros(A.array.shape)
	return bins, out, A 


def istrue(arg):
	return True


def posterior(VT, sigmasq, Lambda):
	'''
	This function implements the analytic marginalization in 
	Biswas, Creighton, Brady, Fairhurst.
	This takes arrays of VT, sigma and Lambda to combine.
	'''

	length = 100000
	K = VT**2 / sigmasq

	mu = numpy.arange(length) * 100.0 / VT.sum() / length

	post = numpy.ones(len(mu), dtype="float")

	for vt, k, lam in zip(VT, K, Lambda):
		if vt <= 0 or k <= 0: continue
		post *= vt / (1.0 + lam) * ( (1.0 + mu * vt / k)**(-k-1) + (mu * vt * lam * (1.0 + 1.0/k) /(1.0 + mu * vt / k)**(k+2)) )

	return mu, post

def integrate_posterior(mu, post, conf):
	cumpost = post.cumsum()/post.sum()
	try: val = [idx for idx in range(len(cumpost)) if cumpost[idx] >= conf][0]
	except: val = 0
	return mu[val]

def get_mass_ranges(bins, mbA):
	mass1 = []
	mass2 = []
	for i,m1 in enumerate(bins.lower()[0]):
		for j,m2 in enumerate(bins.lower()[1]):
			if not mbA[i][j]:
				mass1.append(m1)
				mass1.append(bins.upper()[0][i])	
				mass2.append(m2)
				mass2.append(bins.upper()[1][j])
	return [min(mass1), max(mass1)], [min(mass2), max(mass2)]

def parse_command_line():
	parser = OptionParser(
		version = "Name: %%prog\n%s" % git_version.verbose_msg
	)
	parser.add_option("","--input-cache", help="input cache containing only the databases you want to run on (you can also list them as arguments, but you should use a cache if you are afraid that the command line will be too long.)")
	options, filenames = parser.parse_args()
	if not filenames: filenames = []
	if options.input_cache: filenames.extend([lal.CacheEntry(l).path for l in open(options.input_cache).readlines()])

	return options, (filenames or [])
				

opts, files = parse_command_line()


bins, vA, ulA = get_combined_array("2DsearchvolumeFirstMoment", 0, files)
vA[vA==0] = 0.01

bins, vA2, ulA = get_combined_array("2DsearchvolumeSecondMoment", 1, files)
bins, dvA, ulA = get_combined_array("2DsearchvolumeDerivative", 2, files)
bins, vAD, ulA = get_combined_array("2DsearchvolumeDistance", 3, files)
bins, ltA, tmp= get_combined_array("2DsearchvolumeLiveTime", 4, files)
numfiles = len(files)


mbA = numpy.zeros(ltA[0].shape)
mbA[ltA[0] == 0] = 1


m1range, m2range = get_mass_ranges(bins, mbA)

#bin edges Number of bins + 1 for pcolor
X = numpy.array( list(bins.lower()[0]) + [bins.upper()[0][-1]] )
Y = numpy.array( list(bins.lower()[1]) + [bins.upper()[1][-1]] )

totalMass = numpy.empty(len(X)-1)
for i in range(len(X)-1): totalMass[i]=(X[i]+X[i+1])/2.0
mass = numpy.meshgrid(totalMass,totalMass)

f = pylab.figure(1)
###############################################################################
# loop over all the filenames and masses and compute the posteriors separately
###############################################################################

AllDistance = numpy.empty([len(files)])



for i, f in enumerate(files):
	legend_str = []
	lines = []
	ifos = os.path.split(f)[1].replace('.xml','').replace("2Dsearchvolume-","")
	combined_ifos = "_".join(ifos.split("_")[:-1])
	# loop over mass bins
  
	ulA.array *= 0

	for j, m1 in enumerate(bins.centres()[0]):
		for k, m2 in enumerate(bins.centres()[1]):
			masses = bins[m1,m2]
			if mbA[masses[0], masses[1]]: continue
			legend_str.append("%.1f, %.1f" % (m1, m2))
			mu,post = posterior(vA[i:i+1,masses[0],masses[1]], vA2[i:i+1,masses[0],masses[1]], dvA[i:i+1,masses[0],masses[1]])
			#lines.append(pylab.loglog(mu,post/post.max()))
			ulA.array[j][k] = integrate_posterior(mu, post, 0.90)

  
	fudge = 0.01 * min (ulA.array[ulA.array !=0])
	log_vol = numpy.log10(vA[i])
	distance = vAD[i]
	log_ul = numpy.log10(ulA.array + fudge)
	vol_error = vA2[i]**0.5 / (vA[i] + 0.0001)
	der = dvA[i]
	log_der = numpy.log10(dvA[i] + 0.000000001)
	# set points outside mass space to 1 (log(1) = 0)
	log_ul[mbA == 1] = 0
	log_der[mbA == 1] = -3

	# bar
	massbar = numpy.array([])
	distancebar = numpy.array([])
	volumebar = numpy.array([])
	distanceerrorbar = numpy.array([])
	volumeerrorbar = numpy.array([])
	dist_error = ((3./4.)*vol_error)**(1/3.)
	width = X[1]-X[0]
	for i in range(len(X)-1):
		for j in range(len(X)-1):
			if j>=i:
				massbar= numpy.append(massbar,mass[0][i][j]);distancebar = numpy.append(distancebar,distance[i][j]);volumebar = numpy.append(volumebar,log_vol[i][j]);volumeerrorbar = numpy.append(volumeerrorbar,vol_error[i][j]);distanceerrorbar = numpy.append(distanceerrorbar,dist_error[i][j]);

	massbarunique = numpy.unique(massbar)
	distancebaraveraged = numpy.empty(len(massbarunique))
	distanceerrorbaraveraged = numpy.empty(len(massbarunique))

	for i in range(len(massbarunique)):
		condlist=[massbar==massbarunique[i]]
		distancebaraveraged[i]=numpy.average(distancebar[condlist])
		distanceerrorbaraveraged[i]=max(distanceerrorbar[condlist])

	#do interpolation
	interpolation = interpolate.interp1d(massbarunique, distancebaraveraged)
	mass = numpy.arange(min(massbarunique), massbarunique[-2], 0.1)
	#apply interpolation
	interpolationAPP = numpy.empty(len(mass))
	for i in range(len(mass)):interpolationAPP[i] = interpolation(mass[i])
        # bar plot
	#pylab.bar(massbarunique,distancebaraveraged, width, yerr=distanceerrorbaraveraged,color='b',log='true',edgecolor='k')
	pylab.errorbar(massbarunique,distancebaraveraged,yerr=distanceerrorbaraveraged,xerr=width/4.,fmt='.')
	pylab.plot(mass,interpolationAPP,'r-')
	pylab.yscale('log')
	pylab.xlim([min(massbar),max(massbar)])
	pylab.ylim([1,1e3])
	pylab.xlabel("Mass [$M_\odot$]",fontsize=14)
	pylab.ylabel("Sensitivity Distance [Mpc]",fontsize=14)
	pylab.savefig(ifos+'_SensitivityDistance_curve.png')
	pylab.clf()

print >> sys.stderr, "ALL FINNISH!"
sys.exit(0)
