#!/usr/bin/python

import numpy

import sqlite3

import matplotlib
matplotlib.use("agg")
from matplotlib import pyplot as plt

import os
import sys
from optparse import OptionParser

from glue import segments, segmentsUtils
from glue.ligolw import ligolw
from glue.ligolw import lsctables
from glue.ligolw import dbtables
from glue.ligolw import utils
from glue.ligolw import array
from glue import lal
from glue import iterutils

from pylal import ligolw_cbc_compute_durations as compute_dur
from pylal import rate
from pylal import InspiralUtils
from pylal.xlal.datatypes.ligotimegps import LIGOTimeGPS

import lal as constants
from pylal import upper_limit_utils
from pylal import imr_utils

from pylal import git_version
__author__ = "Stephen Privitera <sprivite@caltech.edu>, Chad Hanna <channa@perimeterinstitute.ca>, Kipp Cannon <kipp.cannon@ligo.org>. Thomas Dent <thomas.dent@ligo.org>"
__version__ = "git id %s" % git_version.id
__date__ = git_version.date

matplotlib.rcParams.update({
        "font.size": 14.0,
        "font.family":"serif",
        "font.serif":"Computer Modern Roman",
        "axes.titlesize": 18.0,
        "axes.labelsize": 16.0,
        "xtick.labelsize": 14.0,
        "ytick.labelsize": 14.0,
        "legend.fontsize": 16.0,
        "figure.dpi": 200,
        "savefig.dpi": 200,
        "text.usetex": True,
        "path.simplify": True
})


lsctables.LIGOTimeGPS = LIGOTimeGPS


def chirp_mass(m1,m2):
  m1 = numpy.array(m1)
  m2 = numpy.array(m2)
  mu = (m1*m2)/(m1+m2)
  mtotal = m1+m2
  return mu**(3./5) *mtotal**(2./5)


class upperLimit(object):
  """
  The upperLimit class organizes the calculation of the sensitive search volume
  for a search described by the input database.
  """
  def __init__(self, database, opts):

    self.opts = opts

    if opts.verbose:
        print >> sys.stdout, "Gathering stats from: %s...." % (database,)

    # open a connection to the input database
    working_filename = dbtables.get_connection_filename(database, tmp_path=opts.tmp_space, verbose = opts.verbose)
    connection = sqlite3.connect(working_filename)

    # find out which instruments were on and when during search
    self.set_instruments(connection)
    self.get_injected_mass_ranges(connection)
    self.get_segments(connection)
    self.get_zero_lag_segments()
    self.get_livetime()
    self.get_far_thresholds(connection)
    self.get_gps_times_duration(connection)
    self.get_all_injections(connection)
    self.get_distance_bins()

    # done with db, close connection
    connection.commit()
    dbtables.discard_connection_filename(database, working_filename, verbose = opts.verbose)

    # Set up mass bins
    self.mass_bins = {}
    if self.opts.bin_by_total_mass:
      self.mass_bins["Total_Mass"] = self.set_total_mass_bins()
    if self.opts.bin_by_component_mass:
      self.mass_bins["Component_Mass"] = self.set_component_mass_bins()
    if self.opts.bin_by_chirp_mass:
      self.mass_bins["Chirp_Mass"] = self.set_chirp_mass_bins()
    if self.opts.bin_by_bns_bbh:
      self.mass_bins["BNS_BBH"] = self.set_bns_bbh_bins()
    if self.opts.bin_by_mass1_mass2:
      self.mass_bins["Mass1_Mass2"] = self.set_m1m2_bins()


  def get_livetime(self):
      '''
      Compute live times from zero lag segments.
      '''
      self.livetime = {}
      for inst in self.instruments:
          # compute live time in years from zero lag segments
          self.livetime[inst] = abs(self.zero_lag_segments[inst].coalesce()).seconds/float(constants.YRJUL_SI)

      return


  def get_injected_mass_ranges(self, connection):
      '''
      Find the component and total mass ranges for the search.
      '''
      self.mintotal = float(connection.cursor().execute('SELECT MIN(CAST(value as REAL)) FROM process_params JOIN process on process_params.process_id = process.process_id WHERE process.program = "inspinj" AND param == "--min-mtotal"').fetchone()[0])
      self.maxtotal = float(connection.cursor().execute('SELECT MAX(CAST(value as REAL)) FROM process_params JOIN process on process_params.process_id = process.process_id WHERE process.program = "inspinj" AND param == "--max-mtotal"').fetchone()[0])
      minmass1 = float(connection.cursor().execute('SELECT MIN(CAST(value as REAL)) FROM process_params JOIN process on process_params.process_id = process.process_id WHERE process.program = "inspinj" AND param == "--min-mass1"').fetchone()[0])
      minmass2 = float(connection.cursor().execute('SELECT MIN(CAST(value as REAL)) FROM process_params JOIN process on process_params.process_id = process.process_id WHERE process.program = "inspinj" AND param == "--min-mass2"').fetchone()[0])
      maxmass1 = float(connection.cursor().execute('SELECT MAX(CAST(value as REAL)) FROM process_params JOIN process on process_params.process_id = process.process_id WHERE process.program = "inspinj" AND param == "--max-mass1"').fetchone()[0])
      maxmass2 = float(connection.cursor().execute('SELECT MAX(CAST(value as REAL)) FROM process_params JOIN process on process_params.process_id = process.process_id WHERE process.program = "inspinj" AND param == "--max-mass2"').fetchone()[0])

      self.minmass = min(minmass1,minmass2)
      self.maxmass = max(maxmass1,maxmass2)

      return True


  def get_segments(self, connection):
    '''
    Retrieve raw single IFO segments from the database and apply vetoes.
    '''
    if self.opts.verbose:
      print >>sys.stdout,"\nQuerying database for single IFO segments..."
    self.segments = compute_dur.get_single_ifo_segments(
      connection,
      program_name = self.opts.livetime_program,
      usertag = "FULL_DATA")
    if self.opts.verbose:
      for ifo in self.segments.keys():
        print "\t%s was on for %d seconds" % (ifo,abs(self.segments[ifo]))

    # getting vetoes
    xmldoc = dbtables.get_xml(connection)
    veto_segments = compute_dur.get_veto_segments(xmldoc, self.opts.verbose)
    if self.opts.veto_segments_name:
        veto_segments = veto_segments[self.opts.veto_segments_name]
    else:
        veto_segments = veto_segments.values()[0]

    # applying vetoes
    if self.opts.verbose:
      print >>sys.stdout,"\nApplying vetoes to single IFO zero lag segments..."
    self.segments -= veto_segments
    if self.opts.verbose:
      for ifo in self.segments.keys():
        print "\t%s was on for %d seconds (after vetoes)" % (ifo,abs(self.segments[ifo]))

    return self.segments


  def get_zero_lag_segments(self):
    '''Compute multi-ifo (coincident) segment list from single ifo segments.'''
    if self.opts.verbose:
      print >>sys.stdout,"\nForming coincident segments from single IFO segments..."
  
    self.zero_lag_segments = segments.segmentlistdict()
    on_ifos_dict, excluded_ifos_dict = compute_dur.get_allifo_combos(self.segments, 2)

    for on_ifos_key, combo in on_ifos_dict.items():
      ifos = frozenset(lsctables.instrument_set_from_ifos(on_ifos_key))
      self.zero_lag_segments[ifos] = self.segments.intersection( combo )
  
      excluded_ifos = excluded_ifos_dict[on_ifos_key]
      self.zero_lag_segments[ifos] -= self.segments.union( excluded_ifos )
  
      if not self.opts.include_play: # subtract playground segments 
        self.zero_lag_segments[ifos] -= segmentsUtils.S2playground(self.segments.extent_all())
        play_status = 'excludes playground'
      else:
        play_status = 'includes playground'
  
      if self.opts.verbose:
        print >> sys.stdout,"\t%s were on for %d seconds (%s)" % ( on_ifos_key, float(abs(self.zero_lag_segments[ifos])), play_status )

    return self.zero_lag_segments


  def get_distance_bins(self):
    '''Determine distance bins to use for the calculation.'''
    self.dbins = {}
    for instr in self.instruments[:]:
      f,m = self.get_injections(instr)
      found_dist = numpy.array([l.distance for l in f])
      if len(f) == 0:
        print >>sys.stderr,"No injections found in %s time... skipping."%(instr,)
        self.instruments.remove(instr)
      else:
        dmin = numpy.min(found_dist)
        dmax = numpy.max(found_dist)
        self.dbins[instr] = numpy.linspace(0, dmax, self.opts.dist_bins)

    return self.dbins


  def set_total_mass_bins(self):
    '''
    Set the total mass bins to use based on what mass ranges were injected and specified user options.
    '''
    # construct m1 bins
    if self.opts.total_mass_bins is not None: # use user-specified bins if given
      binEdges = numpy.array([float(k) for k in self.opts.total_mass_bins.split(',')])
      self.mintotal = max(numpy.min(binEdges),self.mintotal)
      self.maxtotal = min(numpy.max(binEdges),self.maxtotal)
    else: # otherwise, make bins based on what's in DB
      binEdges = numpy.linspace(self.mintotal,self.maxtotal,self.opts.mass_bins)

    self.total_mass_bins = rate.NDBins((rate.IrregularBins(binEdges),))

    return self.total_mass_bins


  def set_component_mass_bins(self):
    '''
    Set the component mass bins to use based what mass ranges were injected.
    '''
    # construct m1 bins
    if self.opts.component_mass1_bins is not None: # use user-specified bins if given
      binEdges1 = numpy.array([float(k) for k in self.opts.component_mass1_bins.split(',')])
      self.minmass = max(numpy.min(binEdges1),self.minmass)
      self.maxmass = min(numpy.max(binEdges1),self.maxmass)
    else: # otherwise, make bins based on what's in DB
      binEdges1 = numpy.linspace(self.minmass,self.maxmass,self.opts.mass_bins+1)

    self.component_mass_bins = rate.NDBins((rate.IrregularBins(binEdges1),))

    return self.component_mass_bins


  def set_chirp_mass_bins(self):

    if self.opts.chirp_mass_bins is not None: # use user-specified bins if given
      binEdges = numpy.array([float(k) for k in self.opts.chirp_mass_bins.split(',')])
    else: # otherwise, make bins based on what we know about total mass ranges
      minchirp = numpy.min(chirp_mass(self.minmass,self.mintotal-self.minmass))
      maxchirp = numpy.max(chirp_mass(self.maxtotal/2,self.maxtotal/2))
      binEdges = numpy.linspace(minchirp,maxchirp,self.opts.mass_bins)

    self.chirp_mass_bins = rate.NDBins((rate.IrregularBins(binEdges),))

    return self.chirp_mass_bins


  def set_bns_bbh_bins(self):

    binEdges = [self.opts.min_nsmass,self.opts.max_nsmass,self.opts.min_bhmass,self.opts.max_bhmass]
    self.bns_bbh_bins = rate.NDBins((rate.IrregularBins(binEdges),))

    return self.bns_bbh_bins


  def set_m1m2_bins(self):

    bin1Edges = numpy.linspace(self.minmass,self.maxmass,self.opts.mass_bins)
    m2bins = int(self.opts.mass_bins/2) + numpy.mod(self.opts.mass_bins,2) # ensures equal bin width in m2 as in m1
    bin2Edges = bin1Edges[:m2bins+1]
    mass1Bin = rate.IrregularBins(bin1Edges)
    mass2Bin = rate.IrregularBins(bin2Edges)
    self.m1m2_bins = rate.NDBins( (mass1Bin,mass2Bin) )

    return self.m1m2_bins


  def get_far_thresholds(self, connection):
      """
      Returns the false alarm rate to use for computing the search volume (in the typical case, this will
      be the FAR of the most rare zero-lag coinc).
      """
      if self.opts.verbose:
          print >>sys.stdout, "\nGetting FAR thresholds for finding injections..."

      self.far = {}

      # far option takes precedence for closed box
      if self.opts.far and not self.opts.open_box:
        for inst in self.instruments:
          self.far[inst] = self.opts.far
      # if no far specified, use expected-loudest-event for closed box
      elif not self.opts.open_box:
        for inst in self.instruments:
          if inst in self.livetime.keys():
            self.far[inst] = 1./self.livetime[inst]
      # for open box use true loudest event
      if self.opts.open_box:
        if self.opts.include_play:
          datatype="all_data"
        else:

          if self.opts.include_play:
              datatype="all_data"
          else:
              datatype="exclude_play"

          for inst, far in upper_limit_utils.get_loudest_event(connection, opts.coinc_table, datatype):

              if inst not in self.instruments:
                  continue

              # cannot determine efficiency above the loudest event
              # when the loudest event has FAR=0
              if far == 0:
                  print >> sys.stderr, "Could not determine FAR threshold for %s time. Skipping..."%','.join(sorted(list(inst)))
                  self.instruments.remove(inst)
                  continue

              else:
                  self.far[inst] = far

      # check that all instrument sets have an associated loudest event
      for inst in self.instruments[:]:
          if inst not in self.far.keys():
              print >> sys.stderr, "Could not determine FAR threshold for %s time. Skipping..."%','.join(sorted(list(inst)))
              self.instruments.remove(inst)

      # report the far threshold being used
      if self.opts.verbose:
          for inst in self.instruments:
              print >>sys.stdout,"\tFAR threshold used in %s time is %g/yr" % (','.join(sorted(list(inst))),self.far[inst])

      return


  def get_gps_times_duration(self, connection):

    self.start_time = int( connection.cursor().execute('SELECT MIN(gps_start_time) FROM experiment').fetchone()[0] )
    self.end_time = int( connection.cursor().execute('SELECT MAX(gps_end_time) FROM experiment').fetchone()[0] )

    return True


  def set_instruments(self, connection):

      # looking for instrument sets of two or more detectors
      self.instruments = [inst for inst in imr_utils.get_instruments_from_coinc_event_table(connection) if len(inst) > 1]

      if self.opts.instruments is None:
          return

      if self.opts.instruments in self.instruments:
          self.instruments = [opts.instruments]
      else:
          print >> sys.stderr, "Instruments %s do not exist in DB, nothing will be calculated" % (str(self.opts.instruments))
          self.instruments = []

      return


  def get_volume_derivative(self, instruments, bin_type, FAR, fnameList=None, tagList=None, InspiralUtilsOpts=None):
    """
    Compute the derivative of the search volume at the FAR of the loudest event
    """
    # check for infinite FAR. if no found injections there, return
    f,m = self.get_injections(instruments,float('inf'),bin_type=bin_type)
    if len(f) == 0: return rate.BinnedArray(self.mass_bins[bin_type])

    # guess an initial range of fars at which to evaluate the volume
    FARh = FAR*2.0
    FARl = FAR/4.0

    # get the volume at the extreme FARs
    flo,mlo = self.get_injections(instruments,FARl,bin_type=bin_type)
    fhi,mhi = self.get_injections(instruments,FARh,bin_type=bin_type)

    # expand the FAR range as needed
    while len(flo) == len(fhi):
      FARh *= 2
      FARl /= 2
      # get the volume at the extreme FARs
      flo,mlo = self.get_injections(instruments,FARl,bin_type=bin_type)
      fhi,mhi = self.get_injections(instruments,FARh,bin_type=bin_type)

    # compute volume for the chosen FAR values
    nbins = 10
    fars = numpy.logspace( numpy.log10(FARl), numpy.log10(FARh), nbins)
    if self.opts.verbose:
      print >>sys.stdout, "\nComputing Lambda at FAR=%f binning by %s" % (FAR,bin_type)
      print >>sys.stdout, "FARs = ",fars

    # compute the volume in each mass bin for a few fars
    vA = []
    vA2 = []
    for far in fars:
      found, missed = self.get_injections(instruments,far,bin_type=bin_type)
      vAt, vA2t, f, m, junk, junk = upper_limit_utils.compute_volume_vs_mass(found, missed,
          self.mass_bins[bin_type], bin_type, dbins=self.dbins[instruments],
          distribution_param=self.opts.distr_param, distribution=self.opts.distribution,
          limits_param=self.opts.limits_param,
          max_param=self.opts.max_param, min_param=self.opts.min_param)
      vA.append(vAt)
      vA2.append(vA2t)

    # compute the volume derivative in each mass bin
    volDeriv = rate.BinnedArray(self.mass_bins[bin_type])
    fans = fars*self.livetime[instruments]
    FAN = FAR*self.livetime[instruments]

    # compute volume derivatives
    for j,mbin in enumerate(iterutils.MultiIter(*UL.mass_bins[bin_type].centres())):

        vtmp = self.livetime[instruments]*numpy.array([v[mbin] for v in vA])
        v2tmp = self.livetime[instruments]*numpy.array([v[mbin] for v in vA2])
        coeffs = upper_limit_utils.log_volume_derivative_fit(numpy.log(fans), vtmp)
        volDeriv[mbin] = coeffs[0]/FAN

        if self.opts.verbose:
            print >> sys.stdout, "mass bin %s" %str(mbin)
            print >> sys.stdout, "vols = ", vtmp
            print >> sys.stdout, "\tLambda = d logV/d FAN = %f" % (volDeriv[mbin],)

        if self.opts.plot_lambda_fits:

            tag = "%s-VOLUME_DERIVATIVE_FIT" % ("".join(sorted(list(instruments))),)
            if vtmp.max() == 0:
                print >> sys.stdout, "All estimated volumes were zero for", bin_type, "bin", str(mbin), "in", tag
                print >> sys.stdout, "- not making a plot!"
                continue

            if bin_type == "BNS_BBH": label = ["NS-NS","NS-BH","BH-BH"][j]
            else: label = '-'.join(["%.2f"%ma for ma in mbin])+"M$_\odot$"
            plt.errorbar(fans,vtmp,yerr=v2tmp,color='k')
            plt.plot(fans,numpy.exp(coeffs[0]*numpy.log(fans)+coeffs[1]),lw=2,label="%s\nLambda=%.3f"%(label,volDeriv[mbin],))
            plt.legend(loc="lower left")
            plt.xlabel("FAN")
            plt.ylabel("VT")
            plt.title("Volume derivative at FAN = %.3f" % FAN)
            plt.gca().set_xscale("log", nonposx='clip')
            plt.gca().set_yscale("log", nonposy='clip')
            name = InspiralUtils.set_figure_tag(tag+"_BIN_"+str(j), open_box = opts.open_box)
            fname = InspiralUtils.set_figure_name(InspiralUtilsOpts, name)
            fname_thumb = InspiralUtils.savefig_pylal( filename=fname )
            fnameList.append(fname)
            tagList.append(name)
            plt.close()

    return volDeriv


  def filter_injection(self,sim,bin_type):
    if sim.waveform in self.opts.exclude_sim_type:
      return True

    if self.opts.disable_spin:
      # throw out if any spin component is nonzero
      if not max(abs(numpy.array([sim.spin1x, sim.spin1y, sim.spin1z, sim.spin2x, sim.spin2y, sim.spin2z]))) == 0: return True
    elif self.opts.spin_only:
      # throw out injections having all zero spin components
      if max(abs(numpy.array([sim.spin1x, sim.spin1y, sim.spin1z, sim.spin2x, sim.spin2y, sim.spin2z]))) == 0: return True

    chi = sim.mass1/(sim.mass1+sim.mass2) *sim.spin1z + sim.mass2/(sim.mass1+sim.mass2) *sim.spin2z
    if not self.opts.min_aligned_spin < chi < self.opts.max_aligned_spin: return True

    if self.opts.min_mtotal is not None and sim.mass1+sim.mass2 < self.opts.min_mtotal:
      return True
    if self.opts.max_mtotal is not None and sim.mass1+sim.mass2 > self.opts.max_mtotal:
      return True

    mratio = max(sim.mass1/sim.mass2, sim.mass2/sim.mass1)
    if (mratio < self.opts.min_mass_ratio):
      return True
    if self.opts.max_mass_ratio and (self.opts.max_mass_ratio < mratio):
      return True

    if bin_type == "Component_Mass":
      if not (self.opts.min_mass2 < sim.mass1 < self.opts.max_mass2 or self.opts.min_mass2 < sim.mass2 < self.opts.max_mass2):
        return True # neither component in the right range, throw out injection
      elif not self.opts.min_mass2 < sim.mass2 < self.opts.max_mass2:
        mass1 = sim.mass1 # the right mass component was there but in an inconvenient order
        sim.mass1 = sim.mass2
        sim.mass2 = mass1

    if bin_type == "BNS_BBH":
      bhs = numpy.sum([self.opts.min_bhmass < sim.mass1 < self.opts.max_bhmass,
                       self.opts.min_bhmass < sim.mass2 < self.opts.max_bhmass])
      nss = numpy.sum([self.opts.min_nsmass < sim.mass1 < self.opts.max_nsmass,
                       self.opts.min_nsmass < sim.mass2 < self.opts.max_nsmass])
      if bhs + nss < 2: return True # components not in the right range, throw out injection

    return False

  def get_injections(self, instruments, FAR=None, bin_type=None):
      if FAR is None: FAR = self.far[instruments]

      found = []
      missed = []

      for far, sim in self.injections[instruments]:

          if self.filter_injection(sim, bin_type): continue
          if far < FAR:
              found.append(sim)
          else:
              missed.append(sim)

      return found, missed


  def get_all_injections(self, connection):
    """
    This method separates injections into found and missed categories. An injection which
    is coincident with two or more single inspiral triggers (sngl1<-->sim and sngl2<-->sim)
    which are themselves coincident (sngl1<-->sngl2) triggers are considered "found". An
    injection in triple time therefore is considered found if it matches any double-IFO
    coincident event. All other injections are considered "missed".
    """

    if self.opts.verbose:
        print >>sys.stdout, "\nFinding parameters of the injections performed ..."

    self.injections = {}

    for inst in self.instruments:
        found, total, missed = imr_utils.get_min_far_inspiral_injections(connection, segments = self.zero_lag_segments[inst], table_name = "%s:table"%self.opts.coinc_table)
        self.injections[inst] = found + [(float('inf'), sim) for sim in missed]

    # report how many injections were performs
    if self.opts.verbose:
        for inst in self.instruments:
            print >>sys.stdout,"\tFound %d injections performed in %s time" % (len(self.injections[inst]),','.join(sorted(list(inst))))

    return True


  def live_time_array(self,instruments,mass_bins):
    """
    return an array of live times, note every bin will be the same :) it is just a
    convenience.
    """
    live_time = rate.BinnedArray(mass_bins)
    live_time.array += self.livetime[instruments]
    return live_time


def parse_command_line():

  description = '''
   description:

The program lalapps_cbc_svim computes the sensitive volume of a CBC search from a database containing triggers from simulation experiments. These triggers need to be ranked by false alarm rate, the detection statistic used in S6 searches. Then injections which register a trigger louder than the loudest event, by false alarm rate, are considered found. All others are considered missed. The efficiency of detecting an event depends on the source parameters, such as its component masses, distance, spin, inclination, sky position, etc. However, lalapps_cbc_svim only considers the dependency of the efficiency on distance and mass, marginalizing over the other parameters. Injections are binned in distance and mass and the estimated efficiency is integrated over distance to convert the efficiency into a physical volume.

The output of lalapps_cbc_svim is input for lalapps_cbc_sink, which uses the loudest event statistic method (Biswas, Creighton, Fairhurst, 2007) to compute a posterior on the rate, given the observed volumes.

implementation and workflow:

1. Extract information about the particular search that was performed.
  a. which instruments were on in the search
  b. the segments of time for which there is coincident data
  c. the significance of the loudest event (for blinded searches, use a dummy FAR value instead)
  d. set up mass and distance bins
2. Compute search volume above loudest event
  a. separate injections by found and missed based on the loudest event
  b. ensure that the injection was made within the recorded segments
  c. compute the efficiency of the search as a function of distance and mass
  d. integrate the efficiency over distance and time to compute the search volume.
3. Compute volume derivative (lambda) at the cFAR of the loudest event
  a. Choose a number of FARs linearly-spaced around the FAR of the loudest event
  b. Compute search volume below each FAR (calling the same routines as in 2)
  c. Linear least-squares fit to volume vs FAR points to estimate the volume derivative
4. Write the computed search volume and search volume derivative to disk
'''

  usage = '''lalapps_cbc_svim [options] database

The input to svim is a single database containing simulation experiments with triggers ranked by false alarm rate. The bare-bones command

  lalapps_cbc_svim H1L1V1-FULL_DATA_CAT_4_VETO_CLUSTERED_CBC_RESULTS-968803143-1209744.sqlite

will compute from the results database the search volume, volume derivative, and volume uncertainty as a function of total mass for each IFO set found in the database. You can override the default mass binning with, for instance,

  lalapps_cbc_svim --bin-by-component-mass --component-mass1-bins=\'1,3,8,13,18,23\' --min-mass2 1 --max-mass2 3 H1L1V1-FULL_DATA_CAT_4_VETO_CLUSTERED_CBC_RESULTS-968803143-1209744.sqlite

will compute the search volume using all injections with at least one component in the mass range 1-3Msun as a function of the other component mass, using the specified boundaries for the mass of the first object.
'''


  parser = OptionParser(version = git_version.verbose_msg, usage = usage + description)

  #
  # tell svim where to find injections, segments, and triggers
  parser.add_option("--sim-table", default="sim_inspiral", metavar="TBL", help="Set the name of the table containing information regarding the injected signal parameters.", choices = ["sim_inspiral", "sim_ringdown"])
  parser.add_option("--coinc-table", default="coinc_inspiral", metavar="TBL", help="Set the name of the table containing information regarding coincident triggers.", choices = ["coinc_inspiral", "coinc_ringdown"])
  parser.add_option("--livetime-program", default=None, metavar="name", help="Set the name of the program whose segments or rings will be extracted from the search_summary table: usually thinca or inspiral.")
  parser.add_option("--veto-segments-name", default=None, metavar="name", help="Set the name of the veto segments to use from the XML document.")

  #
  # The code supports five mass binning modes.
  #
  parser.add_option("--bin-by-mass1-mass2", default=False, action="store_true", help="Bin injections by component mass in two dimensions when estimating the search efficiency.")

  # Bin injections by total mass
  parser.add_option("--bin-by-total-mass", default=False, action="store_true", help="Bin injections by total mass when estimating the search efficiency.")
  parser.add_option("--total-mass-bins", default=None, metavar="\'m0,m1,...,mk''", help="Specify the boundaries of the total mass bins. Input should be a comma separated list of masses. Injections outside the range of the bins will be ignored. This option implies --bin-by-total-mass.")

  # Bin injections by total mass
  parser.add_option("--bin-by-chirp-mass", default=False, action="store_true", help="Bin injections by chirp mass when estimating the search efficiency.")
  parser.add_option("--chirp-mass-bins", default=None, metavar="\'m0,m1,...,mk''", help="Specify the boundaries of the chirp mass bins. Input should be a comma separated list of masses. Injections outside the range of the bins will be ignored. This option implies --bin-by-chirp-mass.")

  # Bin injections by component mass, with one mass fixed
  parser.add_option("--bin-by-component-mass", default=False, action="store_true", help="Bin injections by the first component's mass with the second component restricted to a small bin (specified by --min/max-mass2).")
  parser.add_option("--component-mass1-bins", default=None, metavar="\'m0,m1,...,mk''", help="Specify the boundaries of the first component's mass bins. Input should be a comma separated list of masses. Injections outside the range of the bins will be ignored. This option implies --bin-by-component-mass.")
  parser.add_option("--min-mass2", metavar="m", type='float', default=1.0, help="Specify the minimum of the second component's mass bins.")
  parser.add_option("--max-mass2", metavar="m", type='float', default=3.0, help="Specify the maximum of the second component's mass bins.")

  # Bin injections into three
  parser.add_option("--bin-by-bns-bbh", default=False, action="store_true", help="Compute the sensitive volume for three mass bins: BNS, NSBH, and BBH.")
  parser.add_option("--min-nsmass", metavar="m", type='float', default=1.0, help="Specify the minimum mass for a neutron star (default is 1Msun)")
  parser.add_option("--max-nsmass", metavar="m", type='float', default=3.0, help="Specify the maximum mass for a neutron star (default is 3Msun)")
  parser.add_option("--min-bhmass", metavar="m", type='float', default=4.0, help="Specify the minimum mass for a black hole (default is 4Msun)")
  parser.add_option("--max-bhmass", metavar="m", type='float', default=6.0, help="Specify the maximum mass for a black hole (default is 6Msun)")

  # Specify which instrument set to look for
  parser.add_option("--min-mtotal", metavar="m", type='float', default=None, help="Specify the minimum total mass to consider among the injections found in the DB. This filters all injections outside this total mass range, even if binning by another method.")
  parser.add_option("--max-mtotal", metavar="m", type='float', default=None, help="Specify the maximum total mass to consider among the injections found in the DB. This filters all injections outside this total mass range, even if binning by another method.")
  parser.add_option("--instruments", default=None, metavar="IFO[,IFO,...]", help="Specify the on-instruments sets for computing the search volume.  Example \"H1,L1,V1\"")

  # Control the injection population used to measure the efficiency
  parser.add_option("--min-mass-ratio", metavar="q", type='float', default=1, help="Specify the minimum allowed mass ratio. Must be >= 1.")
  parser.add_option("--max-mass-ratio", metavar="q", type='float', default=None, help="Specify the maximum allowed mass ratio. Should be >= min-mass-ratio >= 1.")
  parser.add_option("--min-aligned-spin", metavar="chi", type='float', default=-1.0, help="Specify the minimum allowed value for the aligned spin parameter chi.")
  parser.add_option("--max-aligned-spin", metavar="chi", type='float', default=1.0, help="Specify the maximum allowed value for the aligned spin parameter chi.")
  parser.add_option("--exclude-sim-type", default=[], action="append", metavar="SIM", help="When computing the search volume, exclude injections made using the SIM waveform family. Example: SpinTaylorthreePointFivePN. Use this option multiple times to exclude more than one injection type.")
  parser.add_option("--disable-spin", default=False, action="store_true", help="When computing the search volume, exclude injections with any nonzero spin components.")
  parser.add_option("--spin-only", default=False, action="store_true", help="When computing the search volume, exclude injections with all spin components equal to zero. --disable-spin takes precedence.")

  # Binning options. Distance bins are used to display efficiency vs. distance.
  parser.add_option("--dist-bins", default=15, metavar="integer", type="int", help="Space distance bins evenly and specify the number of distance bins to use.")
  parser.add_option("--mass-bins", default=6, metavar="integer", type="int", help="If mass bin boundaries are not explicitly set, you can specify here the number of mass bins to use along a single dimension. The number of bins is currently 1 fewer than the value you specify.")

  # Unbinned MC options. Needed to perform volume integral without distance binning
  parser.add_option("--distr-param", default=None, metavar="D", help="Parameter that injections are distributed over: may be 'distance' or 'chirp_distance'.")
  parser.add_option("--distribution", default=None, help="Distribution of injections over D: possible values 'log', 'uniform', 'distancesquared', 'volume'.")
  parser.add_option("--limits-param", default=None, help="Parameter specifying distance limits: may be 'distance' or 'chirp_distance'.")
  parser.add_option("--max-param", default=None, type="float", help="Maximum value of D. If not given, code will use the maximum value of D over injections performed")
  parser.add_option("--min-param", default=None, type="float", help="Minimum value of D. Only affects output for the 'log' distribution. If not given, code will use the minimum value of D over injections performed.")

  # Options to control the FAR used for computing the search volume. The default behavior is to use
  # the FAR of the loudest non-playground event.
  parser.add_option("--far", type="float", help="Specify a cFAR threshold for classifying injections into found or missed, rather than using cFAR of the actual or expected loudest event. Overridden by the --open-box option.")
  parser.add_option("--open-box", default=False, action="store_true", help="Use the observed cFAR for the loudest event: only appropriate when the full search result is known. If neither --far nor --open-box is given, default action is to use cFAR=1/livetime.")
  parser.add_option("--include-play", default=False, action="store_true", help="Include playground data in computing the livetime and volume.")

  # options to control output
  parser.add_option("--user-tag", default="", metavar="name", help="Add a descriptive tag to the names of output files.")
  parser.add_option("--output-cache", default=None, help="Name of output cache file. If not specified, then no cache file will be written.")
  parser.add_option("--output-path", default="./", action="store", help="Choose directory to save output files.")
  parser.add_option("--verbose", action="store_true", help="Be verbose.")
  parser.add_option("--plot-efficiency", action="store_true", help="Plot efficiency as a function of distance.")
  parser.add_option("--plot-volume-vs-far", action="store_true", help="Compute and plot the volume as a function of FAR.")
  parser.add_option("--plot-lambda-fits", action="store_true", help="Plot the data points used to compute lambda, and the fit obtained from these points.")
  parser.add_option("-t", "--tmp-space", metavar="path", help="Path to a directory used as work area for the database file.  The database file will be worked on in this directory then moved to the final location when complete.  Intended to improve performance in a networked environment: eg local disk with higher available bandwidth than the filesystem where the database lives.")

  opts, filenames = parser.parse_args()

  if opts.livetime_program is None:
    print >>sys.stderr, "Error: no livetime-program option was given, exiting!"
    sys.exit(1)

  if opts.instruments is not None:
    opts.instruments = frozenset(lsctables.instrument_set_from_ifos(opts.instruments))

  if len(filenames) != 1:
    print >>sys.stderr, "Error: must specify exactly one database file"
    sys.exit(1)

  opts.enable_output = True

  if opts.chirp_mass_bins:
    opts.bin_by_chirp_mass = True
  if opts.component_mass1_bins:
    opts.bin_by_component_mass = True
  if opts.total_mass_bins:
    opts.bin_by_total_mass = True

  if (not opts.bin_by_chirp_mass) and (not opts.bin_by_component_mass) and (not opts.bin_by_bns_bbh) and (not opts.bin_by_mass1_mass2):
    opts.bin_by_total_mass = True # choose a default binning method

  if opts.min_mass_ratio < 1:
      raise ValueError, "The maximum mass ratio must be >=1!"
  if opts.max_mass_ratio and not opts.min_mass_ratio <= opts.max_mass_ratio:
    raise ValueError, "The maximum mass ratio must be >= minimum mass ratio!"

  opts.output_path = opts.output_path if opts.output_path[-1]=="/" else opts.output_path+"/"

  return opts, filenames[0]


############################ MAIN PROGRAM #####################################
###############################################################################
###############################################################################

#create an empty cache which will store the output files/metadata
cache_list = []


def write_cache(cache_list, fileout):
    # write cache file
    if opts.output_cache is not None:
        fd = open( fileout, "w" )
        for l in cache_list:
            fd.write( str(l) + "\n")
        fd.close()
    return

#
# MAIN
#


# read in command line opts
opts, database = parse_command_line()

# read in data from input database, store in upper limit object
UL = upperLimit(database, opts)

# Set up opts structure for use with InspiralUtils
opts.gps_start_time = UL.start_time
opts.gps_end_time = UL.end_time

#
# plot the sensitive volume and range as a function of FAR threshold
#
if opts.plot_volume_vs_far:

    # determine the range of FARs applicable to this search.
    # do it once and for all to give plots the same x-limits
    fars = []
    for instr in UL.instruments:
        fars.extend([f for f, sim in UL.injections[instr] if 0 < f < float('inf')])
    fars = numpy.sort(fars)
    far_min = min( fars )
    far_max = max( fars )

    # make an array of fars to loop over
    nfars = 25
    fars = numpy.logspace( numpy.log10(far_min), numpy.log10(far_max), nfars )

    for instr in UL.instruments:

        lt = UL.livetime[instr] #preferred unit is Mpc^3*yr

        minvol = 1e20  # dummy initialization for minimum V*T plot boundary
        for bin_type in UL.mass_bins:
            vAs, vA2s = [], []

            for far in fars:

                # compute volume for given far binned in mass
                found, missed = UL.get_injections(instr, far, bin_type)
                vA, vA2, _, _, _, _  = upper_limit_utils.compute_volume_vs_mass(found, missed,
                    UL.mass_bins[bin_type], bin_type, dbins=UL.dbins[instr],
                    distribution_param=opts.distr_param,
                    distribution=opts.distribution, limits_param=opts.limits_param,
                    max_param=opts.max_param, min_param=opts.min_param)
                vAs.append(vA)
                vA2s.append(vA2)

            for j, mbin in enumerate(iterutils.MultiIter(*UL.mass_bins[bin_type].centres())):

                if bin_type == "BNS_BBH":
                    label = ["NS-NS","NS-BH","BH-BH"][j]
                elif bin_type == "Mass1_Mass2":
                    label = '-'.join(["%.2f"%ma for ma in mbin])+"M$_\odot$"
                else:
                    lo = UL.mass_bins[bin_type].lower()
                    hi = UL.mass_bins[bin_type].upper()
                    label = ','.join("%.1f-%.1f M$_\odot$"%(lo[k][j],hi[k][j]) for k in range(len(lo)))

                vols = numpy.array([ v[mbin] for v in vAs ])
                vols_errs = numpy.array([v2[mbin] for v2 in vA2s ])

                plt.figure(1)
                line, = plt.plot(fars, lt * vols, label=label)
                plt.fill_between(fars, lt * (vols - vols_errs), lt * (vols + vols_errs), alpha=0.5, color=line.get_color())
                minvol = min(minvol, min(lt * (vols - vols_errs))) # update minimum plotted value

                dists = (3. * vols / (4. * numpy.pi)) ** (1./3.)
                dists_errs = dists * vols_errs / (3. * vols)

                plt.figure(2)
                line, = plt.plot(fars, dists, label=label)
                plt.fill_between(fars, dists - dists_errs, dists + dists_errs, alpha=0.5, color=line.get_color())

            plt.figure(1)
            plt.axvline(UL.far[instr],ls='--',c='k')
            tag = "%s-VOLUME_VS_FAR_BINNED_BY_%s_%s" % ("".join(sorted(list(instr))),bin_type.upper(),opts.user_tag)
            plt.xlabel("False Alarm Rate (yr$^{-1}$)")
            plt.ylabel("Mean Sensitive 4-Volume (Mpc$^3\,$yr)")
            plt.loglog()
            plt.grid()
            plt.xlim(min(fars), max(fars))
            plt.ylim(ymin=0.7 * minvol)
            plt.legend(loc="lower right")
            plt.savefig(opts.output_path+tag+".png")
            plt.close()

            plt.figure(2)
            plt.axvline(UL.far[instr],ls='--',c='k')
            tag = "%s-RANGE_VS_FAR_BINNED_BY_%s_%s" % ("".join(sorted(list(instr))),bin_type.upper(),opts.user_tag)
            plt.xlabel("False Alarm Rate (yr$^{-1}$)")
            plt.ylabel("Mean Sensitive Distance (Mpc)")
            plt.semilogx()
            plt.grid()
            plt.xlim(min(fars), max(fars))
            plt.ylim(ymin=0)
            plt.legend(loc="lower right")
            plt.savefig(opts.output_path+tag+".png")
            plt.close()


# loop over the requested instruments and mass bin types,
# compute the search volume and volume derivative at the specified
# FAR for the given mass bin
#
for bin_type in UL.mass_bins:
  for instr in UL.instruments:
    # initialize plotting util helper
    __prog__ = "lalapps_cbc_svim_by_" + bin_type.lower()
    opts.ifo_times = "".join(sorted(list(instr)))
    InspiralUtilsOpts = InspiralUtils.initialise( opts, __prog__, git_version.verbose_msg )
    fnameList = []; tagList = []

    #compute volume first and second moments above the loudest event
    found, missed = UL.get_injections(instr, UL.far[instr],bin_type)
    vA, vA2, f, m, eff, err = upper_limit_utils.compute_volume_vs_mass(found, missed,
        UL.mass_bins[bin_type], bin_type, dbins=UL.dbins[instr], distribution_param=opts.distr_param,
        distribution=opts.distribution, limits_param=opts.limits_param,
        max_param=opts.max_param, min_param=opts.min_param)

    # plot efficiencies
    if opts.plot_efficiency:
      tag = "%s-SEARCH_EFFICIENCY_BINNED_BY_%s" % ("".join(sorted(list(instr))),bin_type.upper())
      for j,mbin in enumerate(iterutils.MultiIter(*UL.mass_bins[bin_type].centres())):
        if bin_type == "BNS_BBH": label = ["NS-NS","NS-BH","BH-BH"][j]
        else: label = '-'.join(["%.2f"%ma for ma in mbin])+"M$_\odot$"
        plt.errorbar((UL.dbins[instr][:-1]+UL.dbins[instr][1:])/2, eff[j], yerr=err[j],label=label)
        plt.legend(loc='upper right')
        plt.ylim([0,1])
        plt.xlim([0,max(UL.dbins[instr])])
        plt.xlabel("Distance (Mpc)")
        plt.ylabel("Efficiency")
        name = InspiralUtils.set_figure_tag(tag + "_BIN_"+str(j+1), open_box = opts.open_box)
        fname = InspiralUtils.set_figure_name(InspiralUtilsOpts, name)
        fname_thumb = InspiralUtils.savefig_pylal( filename=fname )
        fnameList.append(fname)
        tagList.append(name)
        plt.close()

    #compute volume derivative at loudest event
    dvA = UL.get_volume_derivative(instr,bin_type,UL.far[instr],fnameList,tagList,InspiralUtilsOpts)

    # make a live time table
    ltA = UL.live_time_array(instr,UL.mass_bins[bin_type])

    vA.array *= UL.livetime[instr] #preferred unit is Mpc^3*yr
    vA2.array *= UL.livetime[instr] #preferred unit is Mpc^3*yr

    #write out the results to the xml file
    xmldoc = ligolw.Document()
    child = xmldoc.appendChild(ligolw.LIGO_LW())

    # write out mass bins
    dim = len(UL.mass_bins[bin_type])
    for j in range(dim):
      xml = ligolw.LIGO_LW({u"Name": u"mass%d_bins:%s" % (j+1,bin_type)})
      edges = tuple( numpy.concatenate((l,u[-1:])) for l,u in zip(UL.mass_bins[bin_type].lower(),UL.mass_bins[bin_type].upper()) )
      xml.appendChild(array.from_array(u"array", edges[j]))
      child.appendChild(xml)

    output_arrays = {"SearchVolumeFirstMoment":vA.array,
                     "SearchVolumeSecondMoment":vA2.array,
                     "SearchVolumeDerivative":dvA.array,
                     "SearchVolumeFoundInjections":f.array,
                     "SearchVolumeMissedInjections":m.array,
                     "SearchVolumeLiveTime":ltA.array}
    for arr in output_arrays:
      xml = ligolw.LIGO_LW({u"Name": u"binned_array:%s" % arr})
      xml.appendChild(array.from_array(u"array",output_arrays[arr]))
      child.appendChild(xml)

    output_filename = opts.output_path+"%s-SEARCH_VOLUME_BINNED_BY_%s_%s-%d-%d.xml" % \
        ("".join(sorted(list(instr))),bin_type.upper(),opts.user_tag,UL.start_time,UL.end_time-UL.start_time)
    utils.write_filename(xmldoc, output_filename)
    cache_entry =  lal.CacheEntry( "".join(sorted(list(instr))),
                                   bin_type,
                                   segments.segment(UL.start_time, UL.end_time),
                                   "file://localhost%s/%s" % (os.getcwd(),output_filename) )

    cache_list.append(cache_entry)

    plothtml = InspiralUtils.write_html_output( InspiralUtilsOpts, [database], fnameList,
                                                tagList, add_box_flag = False )
    InspiralUtils.write_cache_output( InspiralUtilsOpts, plothtml, fnameList )


# write a cache file describing the files generated during by this program
if opts.output_cache:
    write_cache(cache_list, opts.output_cache)
