#!/usr/bin/python
from __future__ import division

__author__ = "Thomas Dent <thomas.dent@ligo.org>"
__prog__ = "pylal_cbc_ulmc"
__title__ = "Loudest Event UL Monte Carlo simulator"

import sys
import warnings

import math
import numpy
with warnings.catch_warnings():
  warnings.simplefilter("ignore")
  import matplotlib as mpl
  mpl.use("Agg")
  #from matplotlib import pyplot as plt
import random

from optparse import *

from pylal import rate
from pylal import git_version
from pylal import upper_limit_utils as ululs

usage = """
Program to investigate the behaviour of the Loudest Event upper limit procedure
when splitting up the total analysis time into many smaller times and when Lambda
estimation in each has a statistical and/or systematic error.
"""

def parse_command_line():
  parser = OptionParser(usage=usage, version=git_version.verbose_msg)

  parser.add_option( "-m", "--n-montecarlo", action="store", type="int", default=10,\
      help="Number of MC trials to do" )

  parser.add_option( "-n", "--ntimes", action="store", type="int", default=20,\
      help="Number of times to split total analysis into: must be >1" )

  parser.add_option( "-i", "--far-index", action="store", type="float", default=numpy.log(100),\
      help="Index of rhoc in exponential model for FAR vs. SNR: default is ln(100)" )

  parser.add_option( "-l", "--lambda-error", action="store", type="float", default=0,\
      help="Simulated fractional error on Lambda estimation (log-normal distribution)" )

  parser.add_option( "-b", "--lambda-bias", action="store", type="float", default=1,\
      help="Simulated multiplicative bias on Lambda estimation" )

  parser.add_option( "-r", "--print-trials", action="store_true", default=False,\
      help="Print summary UL values to screen for each trial" )

  parser.add_option( "-P", "--posterior-plots", action="store_true", default=False,\
      help="Output plots of the posterior pdfs for each MC trial. Output-path is required. Don't use this if nMC is large!" )

  parser.add_option( "-v", "--verbose", action="store_true", default=False,\
      help="Print FAR, Lambda, volume and effective numerator values to screen" )

  parser.add_option( "-p", "--make-plots", action="store_true", default=False,\
      help="Scatter plots of UL values versus various parameters. Output-path is required." )

  parser.add_option( "-T", "--user-tag", action="store", default="",\
      help="Tag string for output plots" )

  parser.add_option( "-o", "--output-path", action="store", type="string", default=False,\
      help="Where to output plots to (path ending in a backslash)" )

  (options,args) = parser.parse_args()

  if (options.make_plots or options.posterior_plots) and not options.output_path:
    raise ValueError, "You must specify an output-path to get plots."

  return options, sys.argv[1:]

##################################
# convenience functions

def rhoC_from_FAR(lambdaF, index):
  # lambdaF is the FAR in 1/yr
  # formula is a fit to the S6 BIC lowmass background estimation
  # if index is set to approx log(100)
  rhoC = 8 + numpy.log(2000/lambdaF)/index
  return rhoC

def FAN_from_FAP(pF):
  # pF is the FAP
  return -1*numpy.log(1-pF)

def FAP_from_FAN(nF):
  return 1 - numpy.exp(-nF)

def volume_above_rhoC(rhoC, sensemon=20):
  radius = sensemon*(8*numpy.sqrt(2)/rhoC)
  vol = 4*math.pi*radius**3/3
  return vol

def exact_Lambda(rhoC, FAN, index):
  # Steve's formula using volume \propto rhoC^-3
  return (3./index) / (rhoC * FAN)


###################################
############ MAIN #################
###################################

opts, args = parse_command_line()

errorflag =  (opts.lambda_error or (opts.lambda_bias-1))
# tells us whether we are investigating errors in Lambda estimation

Ttot = 1./2  # half a year
Ti = Ttot/opts.ntimes
Rsense = 20  # sensemon range in Mpc

if opts.verbose:
  print "Sensemon range is", Rsense, "Mpc"
  print "Splitting", Ttot, "years science run into", opts.ntimes, "analysis times"
  print "Expected loudest event FAR from noise in each time is", 1./Ti, "/yr"
  print " "
  print "Expected volume / upper limit from single loudest event FAP=0.5 is"
  pF = 0.5
  lambdaF = FAN_from_FAP(pF)/Ttot
  rhoC = rhoC_from_FAR(lambdaF, opts.far_index)
  vol = volume_above_rhoC(rhoC, sensemon=Rsense)
  Lambda = exact_Lambda(rhoC, lambdaF*Ttot, opts.far_index)   #(3./opts.far_index)/(rhoC*Ttot*lambdaF)
  mu = numpy.logspace(-numpy.log10(vol*Ttot)-2, -numpy.log10(vol*Ttot)+2, 10**5)
  likely = ululs.margLikelihood([vol*Ttot], [Lambda], mu, calerr=0, mcerrs=None)
  upperLim = ululs.compute_upper_limit(mu, likely, alpha=0.9)
  print "\t %.0d, %.3e" % (vol, upperLim)

# set up outputs
singleFAP = []
singleVT = []
UL90_exactLambda = []
UL90_Lambdaerrs = []
UL90_single = []

for nmc in range(opts.n_montecarlo):

  Lambdas = []
  estLambdas = []
  Volumes = []
  rhoCs = []
  FARs = []
  for i in range(opts.ntimes):
    # FAP of loudest event in time i
    pF = random.random()
    lambdaF = FAN_from_FAP(pF)/Ti
    FARs.append(lambdaF)
    # convert FAR to a background combined SNR
    rhoC = rhoC_from_FAR(lambdaF, opts.far_index)
    rhoCs.append(rhoC)
    vol = volume_above_rhoC(rhoC, sensemon=Rsense)
    Volumes.append(vol)
    # use analytic Lambda formula
    # given that background drops off with exp(-100*rhoC)
    exactLambda = exact_Lambda(rhoC, Ti*lambdaF, opts.far_index)   #(3./opts.far_index)/(rhoC*Ti*lambdaF)
    Lambdas.append(exactLambda)
    # scatter and bias the estimated Lambdas
    estLambda = exactLambda*opts.lambda_bias*random.lognormvariate(0, opts.lambda_error)
    estLambdas.append(estLambda)

  # sample values of rate to cover the smallest and largest possible values of 1/VT
  mu = numpy.logspace(-numpy.log10(sum(Volumes)*Ti)-2, -numpy.log10(min(Volumes)*Ti)+2, 10**5)
  likely = ululs.margLikelihood([Vol*Ti for Vol in Volumes], Lambdas, mu, calerr=0, mcerrs=None)
  likely_errs = ululs.margLikelihood([Vol*Ti for Vol in Volumes], estLambdas, mu, calerr=0, mcerrs=None)
  upperLim = ululs.compute_upper_limit(mu, likely, alpha=0.9)
  estupperLim = ululs.compute_upper_limit(mu, likely_errs, alpha=0.9)

  # now do the calculation for the single loudest event with the smallest FAR value
  exactLambda_single = exact_Lambda(max(rhoCs), Ttot*min(FARs), opts. far_index)   #(3./opts.far_index)/(max(rhoCs)*Ttot*min(FARs))
  # appropriate sensitive volume is the minimum among the individual times
  likely_single = ululs.margLikelihood([min(Volumes)*Ttot], [exactLambda_single], mu, calerr=0, mcerrs=None)
  upperLim_single = ululs.compute_upper_limit(mu, likely_single, alpha=0.9)

  # update trial summary info
  singleFAP.append(FAP_from_FAN(min(FARs)*Ttot))
  singleVT.append(min(Volumes)*Ttot)
  UL90_single.append(upperLim_single)
  UL90_exactLambda.append(upperLim)
  UL90_Lambdaerrs.append(estupperLim)

  print "MC trial", nmc+1

  if opts.print_trials or opts.verbose:
    print " "
    if opts.verbose: print "FARs /yr^-1:", [str(FAR)[0:5] for FAR in FARs]
    if opts.verbose: print "Volumes /Mpc^3:", [int(Vol) for Vol in Volumes]
    if opts.verbose: print "Exact Lambda values:", [str(Lam)[0:5] for Lam in Lambdas]
    if opts.verbose and errorflag: print "Lambda values with errors:", [str(eLam)[0:5] for eLam in estLambdas]
    print  r"90%UL from", opts.ntimes, "separate analyses: %.2e" % upperLim
    if opts.verbose: print "\t Effective numerator: %.3g" % (upperLim*sum(Volumes)*Ti)
    if errorflag:
      print r"Factor change in 90%UL due to Lambda errors:", "%.3g" % (estupperLim/upperLim)
      if opts.verbose: print "\t Effective numerator with Lambda errors: %.3g" % (estupperLim*sum(Volumes)*Ti)
    if opts.verbose: print "Lambda value from one single analysis: %.3f" % exactLambda_single
    print  r"90%UL from one single analysis would be", "%.2e" % upperLim_single
    if opts.verbose: print "\t Effective numerator: %.3g" % (upperLim_single*min(Volumes)*Ttot)

  if opts.posterior_plots:
    junk, plotlikely = ululs.normalize_pdf(mu, likely)
    junk, plotsingl = ululs.normalize_pdf(mu, likely_single)
    mpl.pyplot.semilogx(mu, plotlikely, 'r-', label=str(opts.ntimes)+" analysis times, exact Lambda")
    if errorflag:
      junk, plotestl = ululs.normalize_pdf(mu, likely_errs)
      mpl.pyplot.semilogx(mu, plotestl, 'm-', label=str(opts.ntimes)+" analysis times, Lambda errors")
    mpl.pyplot.semilogx(mu, plotsingl, 'g-', label="Single analysis, exact Lambda")
    mpl.pyplot.xlabel("Rate (Mpc^-3 yr^-1)", size="large")
    mpl.pyplot.ylabel(r"Posterior pdf (Mpc^3 yr)", size="large")
    mpl.pyplot.legend(loc="upper right")
    mpl.pyplot.title("MC trial "+str(nmc+1)+", loudest event FAP="+str(FAP_from_FAN(min(FARs)*Ttot))[0:5])
    mpl.pyplot.savefig(opts.output_path+"posterior_trial_"+str(nmc+1)+"_ntimes_"+str(opts.ntimes)+".png")
    mpl.pyplot.close()

if errorflag:
  Lambdaerr_logratios = numpy.log(numpy.array(UL90_Lambdaerrs)/numpy.array(UL90_exactLambda))
  print " "
  print "Geometric mean scaling of 90%UL due to Lambda errors is", numpy.exp(numpy.mean(Lambdaerr_logratios))
  print "MC standard deviation of fractional error (log of scaling factor) is", numpy.std(Lambdaerr_logratios)

if opts.make_plots:
  output_prefix = opts.output_path+"ulmc"+"_ntimes_"+str(opts.ntimes)+"_"+opts.user_tag

  # UL values vs. FAP of loudest event in single analysis
  FAPbins = rate.LinearBins(0.0, 1.0, 20)
  FAPindices = numpy.array([FAPbins[pF] for pF in singleFAP])
  fapplot = []
  meansplot = []
  stdevplot = []
  effnumplot = []
  effnumstdplot = []
  if errorflag:
    meansplot_Lerr = []
    stdevplot_Lerr = []
    effnumplot_Lerr = []
    effnumstdplot_Lerr = []
  for i in range(len(FAPbins)):
    ULs_inBin = numpy.array(UL90_exactLambda)[FAPindices == i]
    ULs_inBin_Lerr = numpy.array(UL90_Lambdaerrs)[FAPindices == i]
    FAPbin = FAPbins.centres()[i]
    VTsingle_inBin = Ttot*volume_above_rhoC(rhoC_from_FAR(FAN_from_FAP(FAPbin)/Ttot,opts.far_index))
    if len(ULs_inBin):
      fapplot.append(FAPbin)
      meansplot.append(numpy.mean(ULs_inBin))
      stdevplot.append(numpy.std(ULs_inBin))
      effnumplot.append(numpy.mean(ULs_inBin)*VTsingle_inBin)
      effnumstdplot.append(numpy.std(ULs_inBin)*VTsingle_inBin)
    if (len(ULs_inBin_Lerr) and errorflag):
      meansplot_Lerr.append(numpy.mean(ULs_inBin_Lerr))
      stdevplot_Lerr.append(numpy.std(ULs_inBin_Lerr))
      effnumplot_Lerr.append(numpy.mean(ULs_inBin_Lerr)*VTsingle_inBin)
      effnumstdplot_Lerr.append(numpy.std(ULs_inBin_Lerr)*VTsingle_inBin)

  mpl.pyplot.plot(singleFAP, UL90_single, "b+", label="Single analysis")
  mpl.pyplot.errorbar(fapplot, meansplot, yerr=stdevplot, marker="o", linestyle="None", color="g", label="Split analysis (MC mean/std)")
  if errorflag:
    mpl.pyplot.errorbar(fapplot, meansplot_Lerr, yerr=stdevplot_Lerr, marker="x", linestyle="None", color="m", label="Split analysis w/ Lambda errors")
  mpl.pyplot.xlabel("Loudest event FAP in single analysis", size="large")
  mpl.pyplot.ylabel(r"90% upper limit", size="large")
  mpl.pyplot.legend()
  mpl.pyplot.title(str(opts.ntimes)+" analysis times")
  mpl.pyplot.savefig(output_prefix+"_ul_vs_FAP.png")
  mpl.pyplot.close()

  # UL*(VT for single analysis) vs. FAP of loudest single event
  # this is "effective numerator" for the single analysis
  # and allows us to factor out the dependence of volume on loudest event SNR
  mpl.pyplot.plot(singleFAP, numpy.array(UL90_single)*numpy.array(singleVT), "b+", label="Single analysis")
  mpl.pyplot.errorbar(fapplot, effnumplot, effnumstdplot, marker="o", linestyle="None", color="g", label="Split analysis (MC mean/std)")
  if errorflag:
    mpl.pyplot.errorbar(fapplot, effnumplot_Lerr, effnumstdplot_Lerr, marker="x", linestyle="None", color="m", label="Split analysis w/ Lambda errors")
  mpl.pyplot.xlabel("Loudest event FAP", size="large")
  mpl.pyplot.ylabel(r"90%UL * VT(single analysis)", size="large") 
  mpl.pyplot.legend()
  mpl.pyplot.title(str(opts.ntimes)+" analysis times")
  mpl.pyplot.savefig(output_prefix+"_effnum_vs_FAP.png")
  mpl.pyplot.close()

  # rhoC vs. FAP of loudest single event for comparison to UL vs. FAP
  mpl.pyplot.plot(singleFAP, [rhoC_from_FAR(FAN_from_FAP(pF)/Ttot,opts.far_index) for pF in singleFAP], "b+")
  mpl.pyplot.xlabel("Loudest event FAP", size="large")
  mpl.pyplot.ylabel("Loudest event combined SNR", size="large")
  mpl.pyplot.savefig(output_prefix+"_rhoC_vs_FAP.png")
  mpl.pyplot.close()

  # UL values vs. SNR of loudest event in single analysis
  mpl.pyplot.plot([rhoC_from_FAR(FAN_from_FAP(pF)/Ttot,opts.far_index) for pF in singleFAP], UL90_single, "b+", label="Single analysis")
  mpl.pyplot.errorbar([rhoC_from_FAR(FAN_from_FAP(pF)/Ttot,opts.far_index) for pF in fapplot], meansplot, yerr=stdevplot, marker="o", linestyle="None", color="g", label="Split analysis times (MC mean/stdev)")
  mpl.pyplot.xlabel("Loudest event combined SNR", size="large")
  mpl.pyplot.ylabel(r"90% upper limit", size="large")
  mpl.pyplot.legend(loc="upper left")
  mpl.pyplot.title(str(opts.ntimes)+" analysis times")
  mpl.pyplot.savefig(output_prefix+"_ul_vs_rhoC.png")
  mpl.pyplot.close()


exit()
