#!/usr/bin/python

import numpy

import sqlite3

import matplotlib
matplotlib.use("agg")
from matplotlib import pyplot as plt

import os
import sys
from optparse import OptionParser

from glue import segments, segmentsUtils, lal, iterutils
from glue.ligolw import ligolw, lsctables, dbtables, utils, array

from pylal import ligolw_cbc_compute_durations as compute_dur
from pylal import rate
from pylal import InspiralUtils
from pylal.xlal.datatypes.ligotimegps import LIGOTimeGPS

import lal as constants
from pylal import upper_limit_utils, imr_utils

from pylal import git_version
__author__ = "Stephen Privitera <sprivite@caltech.edu>, Chad Hanna <channa@perimeterinstitute.ca>, Kipp Cannon <kipp.cannon@ligo.org>"
__version__ = "git id %s" % git_version.id
__date__ = git_version.date

matplotlib.rcParams.update({
        "font.size": 14.0,
        "font.family":"serif",
        "font.serif":"Computer Modern Roman",
        "axes.titlesize": 18.0,
        "axes.labelsize": 16.0,
        "xtick.labelsize": 14.0,
        "ytick.labelsize": 14.0,
        "legend.fontsize": 16.0,
        "figure.dpi": 200,
        "savefig.dpi": 200,
        "text.usetex": True,
        "path.simplify": True
})


lsctables.LIGOTimeGPS = LIGOTimeGPS


def chirp_mass(m1, m2):
  m1 = numpy.array(m1)
  m2 = numpy.array(m2)
  mu = (m1 * m2) / (m1 + m2)
  mtotal = m1 + m2
  return mu ** (3./5.) * mtotal ** (2./5.)


class upper_limit(object):
  """
  The upper_limit class organizes the calculation of the sensitive search volume
  for a search described by the input database.
  """
  def __init__(self, database, opts):

    self.opts = opts

    if opts.verbose:
        print >> sys.stdout, "Gathering stats from: %s...." % (database,)

    # open a connection to the input database
    working_filename = dbtables.get_connection_filename(database, tmp_path=opts.tmp_space, verbose=opts.verbose)
    connection = sqlite3.connect(working_filename)

    # find out which instruments were on and when during search
    self.set_instruments(connection)
    self.get_inspinj_mass_ranges(connection)
    self.get_inspinj_mtotal_range(connection)
    self.get_injected_mchirp_range(connection)
    self.get_segments(connection)
    self.get_zero_lag_segments()
    self.get_livetime()
    self.get_far_thresholds(connection)
    self.get_gps_times_duration(connection)
    self.get_all_injections(connection)
    self.get_distance_bins()

    # done with db, close connection
    connection.commit()
    dbtables.discard_connection_filename(database, working_filename, verbose=opts.verbose)

    # Set up mass bins
    self.mass_bins = {}
    if self.opts.bin_by_total_mass:
      self.mass_bins["Total_Mass"] = self.set_total_mass_bins()
    if self.opts.bin_by_component_mass:
      self.mass_bins["Component_Mass"] = self.set_component_mass_bins()
    if self.opts.bin_by_chirp_mass:
      self.mass_bins["Chirp_Mass"] = self.set_chirp_mass_bins()
    if self.opts.bin_by_bns_bbh:
      self.mass_bins["BNS_BBH"] = self.set_bns_bbh_bins()
    if self.opts.bin_by_mass1_mass2:
      self.mass_bins["Mass1_Mass2"] = self.set_m1m2_bins()


  def get_livetime(self):
      '''
      Compute live times from zero lag segments.
      '''
      self.livetime = {}
      for inst in self.instruments:
          # compute live time in years from zero lag segments
          self.livetime[inst] = abs(self.zero_lag_segments[inst].coalesce()).seconds / float(constants.YRJUL_SI)

      return


  def get_inspinj_mass_ranges(self, connection):
      '''
      Find the component mass ranges for injections made using inspinj.
      '''
      self.minmass = None
      self.maxmass = None

      minmass1 = connection.cursor().execute('SELECT MIN(CAST(value as REAL)) FROM process_params JOIN process on process_params.process_id = process.process_id WHERE process.program = "inspinj" AND param == "--min-mass1"').fetchone()[0]
      minmass2 = connection.cursor().execute('SELECT MIN(CAST(value as REAL)) FROM process_params JOIN process on process_params.process_id = process.process_id WHERE process.program = "inspinj" AND param == "--min-mass2"').fetchone()[0]
      maxmass1 = connection.cursor().execute('SELECT MAX(CAST(value as REAL)) FROM process_params JOIN process on process_params.process_id = process.process_id WHERE process.program = "inspinj" AND param == "--max-mass1"').fetchone()[0]
      maxmass2 = connection.cursor().execute('SELECT MAX(CAST(value as REAL)) FROM process_params JOIN process on process_params.process_id = process.process_id WHERE process.program = "inspinj" AND param == "--max-mass2"').fetchone()[0]

      # if inspinj did not have min / max-mass1,2 arguments the attributes remain None
      if minmass1 is not None and minmass2 is not None:
          self.minmass = min(float(minmass1), float(minmass2))
      if maxmass1 is not None and maxmass2 is not None:
          self.maxmass = max(float(maxmass1), float(maxmass2))

      return True


  def get_inspinj_mtotal_range(self, connection):
      '''
      Find the mtotal range for injections made using inspinj.
      '''
      self.mintotal = None
      self.maxtotal = None

      mintotal = connection.cursor().execute('SELECT MIN(CAST(value as REAL)) FROM process_params JOIN process on process_params.process_id = process.process_id WHERE process.program = "inspinj" AND param == "--min-mtotal"').fetchone()[0]
      maxtotal = connection.cursor().execute('SELECT MAX(CAST(value as REAL)) FROM process_params JOIN process on process_params.process_id = process.process_id WHERE process.program = "inspinj" AND param == "--max-mtotal"').fetchone()[0]

      # if inspinj did not have min / max-mtotal arguments the attributes remain None
      if mintotal is not None:
          self.mintotal = float(mintotal)
      if maxtotal is not None:
          self.maxtotal = float(maxtotal)

      return True


  def get_injected_mchirp_range(self, connection):
      '''
      Find the mchirp range actually injected.
      '''
      minchirp = connection.cursor().execute('SELECT MIN(mchirp) FROM sim_inspiral').fetchone()[0]
      maxchirp = connection.cursor().execute('SELECT MAX(mchirp) FROM sim_inspiral').fetchone()[0]

      # Injections must have chirp masses!
      if minchirp is None or maxchirp is None:
          raise RuntimeError("Can't find any injection mchirp values in database!")
      else:
          self.minchirp = float(minchirp)
          self.maxchirp = float(maxchirp)

      return True


  def get_segments(self, connection):
    '''
    Retrieve raw single IFO segments from the database and apply vetoes.
    '''
    if self.opts.verbose:
      print >>sys.stdout,"\nQuerying database for single IFO segments..."
    self.segments = compute_dur.get_single_ifo_segments(
          connection,
          program_name=self.opts.livetime_program,
          usertag="FULL_DATA")
    if self.opts.verbose:
      for ifo in self.segments.keys():
        print "\t%s was on for %d seconds" % (ifo, abs(self.segments[ifo]))

    # getting vetoes
    xmldoc = dbtables.get_xml(connection)
    veto_segments = compute_dur.get_veto_segments(xmldoc, self.opts.verbose)
    if self.opts.veto_segments_name:
        veto_segments = veto_segments[self.opts.veto_segments_name]
    else:
        veto_segments = veto_segments.values()[0]

    # applying vetoes
    if self.opts.verbose:
      print >>sys.stdout,"\nApplying vetoes to single IFO zero lag segments..."
    self.segments -= veto_segments
    if self.opts.verbose:
      for ifo in self.segments.keys():
        print "\t%s was on for %d seconds (after vetoes)" % (ifo, abs(self.segments[ifo]))

    return self.segments


  def get_zero_lag_segments(self):
    '''Compute multi-ifo (coincident) segment list from single ifo segments.'''
    if self.opts.verbose:
      print >>sys.stdout,"\nForming coincident segments from single IFO segments..."

    self.zero_lag_segments = segments.segmentlistdict()
    on_ifos_dict, excluded_ifos_dict = compute_dur.get_allifo_combos(self.segments, 2)

    for on_ifos_key, combo in on_ifos_dict.items():
      ifos = frozenset(lsctables.instrument_set_from_ifos(on_ifos_key))
      self.zero_lag_segments[ifos] = self.segments.intersection( combo )

      excluded_ifos = excluded_ifos_dict[on_ifos_key]
      self.zero_lag_segments[ifos] -= self.segments.union( excluded_ifos )

      if not self.opts.include_play: # subtract playground segments 
        self.zero_lag_segments[ifos] -= segmentsUtils.S2playground(self.segments.extent_all())
        play_status = 'excludes playground'
      else:
        play_status = 'includes playground'

      if self.opts.verbose:
        print >> sys.stdout, "\t%s were on for %d seconds (%s)" % (on_ifos_key, float(abs(self.zero_lag_segments[ifos])), play_status)

    return self.zero_lag_segments


  def get_distance_bins(self):
    '''
    Determine distance bins to use for the calculation.
    '''
    self.dbins = {}
    for instr in self.instruments[:]:
      f, m = self.get_injections(instr)
      found_dist = numpy.array([l.distance for l in f])
      if len(f) == 0:
        print >>sys.stderr, "No injections found in %s time... skipping." % (instr,)
        self.instruments.remove(instr)
      else:
        dmin = numpy.min(found_dist)
        dmax = numpy.max(found_dist)
        self.dbins[instr] = numpy.linspace(0, dmax,self.opts.dist_bins)

    return self.dbins


  def set_total_mass_bins(self):
    '''
    Set the total mass bins to use based on what mass ranges were injected or user options.
    '''
    # find bin edges
    if self.opts.total_mass_bins is not None: # use user-specified bins if given
      binEdges = numpy.array([float(k) for k in self.opts.total_mass_bins.split(',')])
      self.mintotal = numpy.min(binEdges)
      self.maxtotal = numpy.max(binEdges)
    elif self.mintotal is not None and self.maxtotal is not None: # otherwise, make bins based on inspinj arguments
      binEdges = numpy.linspace(self.mintotal, self.maxtotal, self.opts.mass_bins+1)
    else:
      raise RuntimeError("Can't find injection minimum and maximum total masses! Please specify --total-mass-bins")

    self.total_mass_bins = rate.NDBins((rate.IrregularBins(binEdges),))
    return self.total_mass_bins


  def set_component_mass_bins(self):
    '''
    Set the component mass bins to use based on what mass ranges were injected or user options.
    '''
    # find bin edges
    if self.opts.component_mass1_bins is not None: # use user-specified bins if given
      binEdges1 = numpy.array([float(k) for k in self.opts.component_mass1_bins.split(',')])
      self.minmass = numpy.min(binEdges1)
      self.maxmass = numpy.max(binEdges1)
    elif self.minmass is not None and self.maxmass is not None: # otherwise, make bins based on inspinj arguments
      binEdges1 = numpy.linspace(self.minmass, self.maxmass, self.opts.mass_bins+1)
    else:
      raise RuntimeError("Can't find injection minimum and maximum component masses! Please specify --component-mass-bins")

    self.component_mass_bins = rate.NDBins((rate.IrregularBins(binEdges1),))
    return self.component_mass_bins


  def set_chirp_mass_bins(self):
    '''
    Set the chirp mass bins to use based on the injections present or user options.
    '''
    # find bin edges
    if self.opts.chirp_mass_bins is not None: # use user-specified bins if given
      binEdges = numpy.array([float(k) for k in self.opts.chirp_mass_bins.split(',')])
    else: # otherwise, make bins based on actual injected range with some small padding
      binEdges = numpy.linspace(self.minchirp-0.01, self.maxchirp+0.01, self.opts.mass_bins+1)
    self.chirp_mass_bins = rate.NDBins((rate.IrregularBins(binEdges),))

    return self.chirp_mass_bins


  def set_bns_bbh_bins(self):
    binEdges = [self.opts.min_nsmass, self.opts.max_nsmass, self.opts.min_bhmass, self.opts.max_bhmass]
    self.bns_bbh_bins = rate.NDBins((rate.IrregularBins(binEdges),))

    return self.bns_bbh_bins


  def set_m1m2_bins(self):
    if self.minmass is not None and self.maxmass is not None:
      numedges = self.opts.mass_bins + 1
      bin1Edges = numpy.linspace(self.minmass, self.maxmass, numedges)
      m2bins = int(numedges / 2.) + numpy.mod(numedges, 2) # ensures equal bin width in m2 as in m1
      bin2Edges = bin1Edges[:m2bins+1]
      mass1Bin = rate.IrregularBins(bin1Edges)
      mass2Bin = rate.IrregularBins(bin2Edges)
      self.m1m2_bins = rate.NDBins((mass1Bin, mass2Bin))
    else:
      raise RuntimeError("Can't find injection minimum and maximum component masses! Without these I can't do --bin-by-mass1-mass2")

    return self.m1m2_bins


  def get_far_thresholds(self, connection):
      """
      Returns the false alarm rate to use for computing the search volume (in the typical case, this will
      be the FAR of the most rare zero-lag coinc).
      """
      if self.opts.verbose:
          print >>sys.stdout, "\nGetting FAR thresholds for finding injections..."

      self.far = {}

      # far option takes precedence for closed box
      if self.opts.far and not self.opts.open_box:
        for inst in self.instruments:
          self.far[inst] = self.opts.far
      # if no far specified, use expected-loudest-event for closed box
      elif not self.opts.open_box:
        for inst in self.instruments:
          if inst in self.livetime.keys():
            self.far[inst] = 1. / self.livetime[inst]
      # for open box use true loudest event
      if self.opts.open_box:
        if self.opts.include_play:
          datatype="all_data"
        else:

          if self.opts.include_play:
              datatype="all_data"
          else:
              datatype="exclude_play"

          for inst, far in upper_limit_utils.get_loudest_event(connection, opts.coinc_table, datatype):

              if inst not in self.instruments:
                  continue

              # cannot determine efficiency above the loudest event
              # when the loudest event has FAR=0
              if far == 0:
                  print >> sys.stderr, "Could not determine FAR threshold for %s time. Skipping..." % ','.join(sorted(list(inst)))
                  self.instruments.remove(inst)
                  continue

              else:
                  self.far[inst] = far

      # check that all instrument sets have an associated loudest event
      for inst in self.instruments[:]:
          if inst not in self.far.keys():
              print >> sys.stderr, "Could not determine FAR threshold for %s time. Skipping..." % ','.join(sorted(list(inst)))
              self.instruments.remove(inst)

      # report the far threshold being used
      if self.opts.verbose:
          for inst in self.instruments:
              print >>sys.stdout,"\tFAR threshold used in %s time is %g/yr" % (','.join(sorted(list(inst))), self.far[inst])

      return


  def get_gps_times_duration(self, connection):
    self.start_time = int( connection.cursor().execute('SELECT MIN(gps_start_time) FROM experiment').fetchone()[0] )
    self.end_time = int( connection.cursor().execute('SELECT MAX(gps_end_time) FROM experiment').fetchone()[0] )
    return True


  def set_instruments(self, connection):
      # looking for instrument sets of two or more detectors
      self.instruments = [inst for inst in imr_utils.get_instruments_from_coinc_event_table(connection) if len(inst) > 1]

      if self.opts.instruments is None:
          return

      if self.opts.instruments in self.instruments:
          self.instruments = [opts.instruments]
      else:
          print >> sys.stderr, "Instruments %s do not exist in DB, nothing will be calculated" % (str(self.opts.instruments))
          self.instruments = []

      return


  def get_volume_derivative(self, instruments, bin_type, FAR, fnameList=None, tagList=None, InspiralUtilsOpts=None):
    """
    Compute the derivative of the search volume at the FAR of the loudest event
    """
    # check for infinite FAR. if no found injections there, return
    f, m = self.get_injections(instruments,float('inf'), bin_type=bin_type)
    if len(f) == 0: return rate.BinnedArray(self.mass_bins[bin_type])

    # guess an initial range of fars at which to evaluate the volume
    FARh = FAR * 2.
    FARl = FAR / 4.

    # get the volume at the extreme FARs
    flo, mlo = self.get_injections(instruments, FARl, bin_type=bin_type)
    fhi, mhi = self.get_injections(instruments, FARh, bin_type=bin_type)

    # expand the FAR range as needed
    while len(flo) == len(fhi):
      FARh *= 2.
      FARl /= 2.
      # get the volume at the extreme FARs
      flo, mlo = self.get_injections(instruments, FARl, bin_type=bin_type)
      fhi, mhi = self.get_injections(instruments, FARh, bin_type=bin_type)

    # compute volume for the chosen FAR values
    nbins = 10
    fars = numpy.logspace(numpy.log10(FARl), numpy.log10(FARh), nbins)
    if self.opts.verbose:
      print >>sys.stdout, "\nComputing Lambda at FAR=%f binning by %s" % (FAR,bin_type)
      print >>sys.stdout, "FARs = ", fars

    # compute the volume in each mass bin for a few fars
    vA = []
    vA2 = []
    for far in fars:
      found, missed = self.get_injections(instruments, far, bin_type=bin_type)
      vAt, vA2t, f, m, junk, junk = upper_limit_utils.compute_volume_vs_mass(found, missed, self.mass_bins[bin_type], bin_type, dbins = self.dbins[instruments])
      vA.append(vAt)
      vA2.append(vA2t)

    # compute the volume derivative in each mass bin
    volDeriv = rate.BinnedArray(self.mass_bins[bin_type])
    fans = fars*self.livetime[instruments]
    FAN = FAR*self.livetime[instruments]

    # compute volume derivatives
    for j, mbin in enumerate(iterutils.MultiIter(*UL.mass_bins[bin_type].centres())):

        vtmp = self.livetime[instruments] * numpy.array([v[mbin] for v in vA])
        v2tmp = self.livetime[instruments] * numpy.array([v[mbin] for v in vA2])
        coeffs = upper_limit_utils.log_volume_derivative_fit(numpy.log(fans), vtmp)
        volDeriv[mbin] = coeffs[0] / FAN

        if self.opts.verbose:
            print >> sys.stdout, "mass bin %s" %str(mbin)
            print >> sys.stdout, "vols = ", vtmp
            print >> sys.stdout, "\tLambda = d logV/d FAN = %f" % (volDeriv[mbin],)

        if self.opts.plot_lambda_fits:

            tag = "%s-VOLUME_DERIVATIVE_FIT" % ("".join(sorted(list(instruments))),)
            if vtmp.max() == 0:
                print >> sys.stdout, "All estimated volumes were zero for", bin_type, "bin", str(mbin), "in", tag
                print >> sys.stdout, "- not making a plot!"
                continue

            if bin_type == "BNS_BBH": label = ["NS-NS", "NS-BH", "BH-BH"][j]
            else: label = '-'.join(["%.2f" % ma for ma in mbin]) + "M$_\odot$"
            plt.errorbar(fans, vtmp, yerr=v2tmp, color='k')
            plt.plot(fans, numpy.exp(coeffs[0] * numpy.log(fans) + coeffs[1]), lw=2, label="%s\nLambda=%.3f" % (label,volDeriv[mbin],))
            plt.legend(loc="lower left")
            plt.xlabel("FAN")
            plt.ylabel("VT")
            plt.title("Volume derivative at FAN = %.3f" % FAN)
            plt.gca().set_xscale("log", nonposx='clip')
            plt.gca().set_yscale("log", nonposy='clip')
            name = InspiralUtils.set_figure_tag(tag+"_BIN_" + str(j), open_box=opts.open_box)
            fname = InspiralUtils.set_figure_name(InspiralUtilsOpts, name)
            fname_thumb = InspiralUtils.savefig_pylal(filename=fname)
            fnameList.append(fname)
            tagList.append(name)
            plt.close()

    return volDeriv


  def filter_injection(self,sim,bin_type):
    if sim.waveform in self.opts.exclude_sim_type:
      return True

    if self.opts.disable_spin:
      # throw out if any spin component is nonzero
      if not max(abs(numpy.array([sim.spin1x, sim.spin1y, sim.spin1z, sim.spin2x, sim.spin2y, sim.spin2z]))) == 0: return True
    elif self.opts.spin_only:
      # throw out injections having all zero spin components
      if max(abs(numpy.array([sim.spin1x, sim.spin1y, sim.spin1z, sim.spin2x, sim.spin2y, sim.spin2z]))) == 0: return True

    chi = sim.mass1 / (sim.mass1 + sim.mass2) * sim.spin1z + sim.mass2 / (sim.mass1 + sim.mass2) * sim.spin2z
    if not self.opts.min_aligned_spin < chi < self.opts.max_aligned_spin: return True

    if self.opts.min_mtotal is not None and sim.mass1+sim.mass2 < self.opts.min_mtotal:
      return True
    if self.opts.max_mtotal is not None and sim.mass1+sim.mass2 > self.opts.max_mtotal:
      return True

    mratio = max(sim.mass1 / sim.mass2, sim.mass2 / sim.mass1)
    if (mratio < self.opts.min_mass_ratio):
      return True
    if self.opts.max_mass_ratio and (self.opts.max_mass_ratio < mratio):
      return True

    if bin_type == "Component_Mass":
      if not (self.opts.min_mass2 < sim.mass1 < self.opts.max_mass2 or self.opts.min_mass2 < sim.mass2 < self.opts.max_mass2):
        return True # neither component in the right range, throw out injection
      elif not self.opts.min_mass2 < sim.mass2 < self.opts.max_mass2:
        mass1 = sim.mass1 # the right mass component was there but in an inconvenient order
        sim.mass1 = sim.mass2
        sim.mass2 = mass1

    if bin_type == "BNS_BBH":
      bhs = numpy.sum([self.opts.min_bhmass < sim.mass1 < self.opts.max_bhmass,
                       self.opts.min_bhmass < sim.mass2 < self.opts.max_bhmass])
      nss = numpy.sum([self.opts.min_nsmass < sim.mass1 < self.opts.max_nsmass,
                       self.opts.min_nsmass < sim.mass2 < self.opts.max_nsmass])
      if bhs + nss < 2: return True # components not in the right range, throw out injection

    return False

  def get_injections(self, instruments, FAR=None, bin_type=None):
      if FAR is None: FAR = self.far[instruments]

      found = []
      missed = []

      for far, sim in self.injections[instruments]:

          if self.filter_injection(sim, bin_type): continue
          if far < FAR:
              found.append(sim)
          else:
              missed.append(sim)

      return found, missed


  def get_all_injections(self, connection):
    """
    This method separates injections into found and missed categories. An injection which
    is coincident with two or more single inspiral triggers (sngl1<-->sim and sngl2<-->sim)
    which are themselves coincident (sngl1<-->sngl2) triggers are considered "found". An
    injection in triple time therefore is considered found if it matches any double-IFO
    coincident event. All other injections are considered "missed".
    """

    if self.opts.verbose:
        print >>sys.stdout, "\nFinding parameters of the injections performed ..."

    self.injections = {}

    for inst in self.instruments:
        found, total, missed = imr_utils.get_min_far_inspiral_injections(connection, segments = self.zero_lag_segments[inst], table_name = "%s:table"%self.opts.coinc_table)
        self.injections[inst] = found + [(float('inf'), sim) for sim in missed]

    # report how many injections were performs
    if self.opts.verbose:
        for inst in self.instruments:
            print >>sys.stdout,"\tFound %d injections performed in %s time" % (len(self.injections[inst]), ','.join(sorted(list(inst))))

    return True


  def live_time_array(self,instruments,mass_bins):
    """
    return an array of live times, note every bin will be the same :) it is just a
    convenience.
    """
    live_time = rate.BinnedArray(mass_bins)
    live_time.array += self.livetime[instruments]
    return live_time


def parse_command_line():

  description = '''
   description:

The program lalapps_cbc_svim computes the sensitive volume of a CBC search from a database containing triggers from simulation experiments. These triggers need to be ranked by false alarm rate, the detection statistic used in S6 searches. Then injections which register a trigger louder than the loudest event, by false alarm rate, are considered found. All others are considered missed. The efficiency of detecting an event depends on the source parameters, such as its component masses, distance, spin, inclination, sky position, etc. However, lalapps_cbc_svim only considers the dependency of the efficiency on distance and mass, marginalizing over the other parameters. Injections are binned in distance and mass and the estimated efficiency is integrated over distance to convert the efficiency into a physical volume.

The output of lalapps_cbc_svim is input for lalapps_cbc_sink, which uses the loudest event statistic method (Biswas, Creighton, Fairhurst, 2007) to compute a posterior on the rate, given the observed volumes.

implementation and workflow:

1. Extract information about the particular search that was performed.
  a. which instruments were on in the search
  b. the segments of time for which there is coincident data
  c. the significance of the loudest event (for blinded searches, use a dummy FAR value instead)
  d. set up mass and distance bins
2. Compute search volume above loudest event
  a. separate injections by found and missed based on the loudest event
  b. ensure that the injection was made within the recorded segments
  c. compute the efficiency of the search as a function of distance and mass
  d. integrate the efficiency over distance and time to compute the search volume.
3. Compute volume derivative (lambda) at the cFAR of the loudest event
  a. Choose a number of FARs linearly-spaced around the FAR of the loudest event
  b. Compute search volume below each FAR (calling the same routines as in 2)
  c. Linear least-squares fit to volume vs FAR points to estimate the volume derivative
4. Write the computed search volume and search volume derivative to disk
'''

  usage = '''lalapps_cbc_svim [options] database

The input to svim is a single database containing simulation experiments with triggers ranked by false alarm rate. The bare-bones command

  lalapps_cbc_svim H1L1V1-FULL_DATA_CAT_4_VETO_CLUSTERED_CBC_RESULTS-968803143-1209744.sqlite

will compute from the results database the search volume, volume derivative, and volume uncertainty as a function of total mass for each IFO set found in the database. You can override the default mass binning with, for instance,

  lalapps_cbc_svim --bin-by-component-mass --component-mass1-bins=\'1,3,8,13,18,23\' --min-mass2 1 --max-mass2 3 H1L1V1-FULL_DATA_CAT_4_VETO_CLUSTERED_CBC_RESULTS-968803143-1209744.sqlite

will compute the search volume using all injections with at least one component in the mass range 1-3Msun as a function of the other component mass, using the specified boundaries for the mass of the first object.
'''


  parser = OptionParser(version = git_version.verbose_msg, usage = usage + description)

  #
  # tell svim where to find injections, segments, and triggers
  parser.add_option("--sim-table", default="sim_inspiral", metavar="TBL", help="Set the name of the table containing information regarding the injected signal parameters.", choices = ["sim_inspiral", "sim_ringdown"])
  parser.add_option("--coinc-table", default="coinc_inspiral", metavar="TBL", help="Set the name of the table containing information regarding coincident triggers.", choices = ["coinc_inspiral", "coinc_ringdown"])
  parser.add_option("--livetime-program", default=None, metavar="name", type="string", help="Set the name of the program whose segments or rings will be extracted from the search_summary table: usually thinca or inspiral.")
  parser.add_option("--veto-segments-name", default=None, metavar="name", type="string", help="Set the name of the veto segments to use from the XML document.")

  #
  # The code supports five mass binning modes.
  #
  parser.add_option("--bin-by-mass1-mass2", default=False, action="store_true", help="Bin injections by component mass in two dimensions when estimating the search efficiency.")

  # Bin injections by total mass
  parser.add_option("--bin-by-total-mass", default=False, action="store_true", help="Bin injections by total mass when estimating the search efficiency.")
  parser.add_option("--total-mass-bins", default=None, metavar="\'m0,m1,...,mk''", help="Specify the boundaries of the total mass bins. Input should be a comma separated list of masses. Injections outside the range of the bins will be ignored. If given, --bin-by-total-mass will be set to True.")

  # Bin injections by total mass
  parser.add_option("--bin-by-chirp-mass", default=False, action="store_true", help="Bin injections by chirp mass when estimating the search efficiency.")
  parser.add_option("--chirp-mass-bins", default=None, metavar="\'m0,m1,...,mk''", help="Specify the boundaries of the chirp mass bins. Input should be a comma separated list of masses. Injections outside the range of the bins will be ignored. If given, --bin-by-chirp-mass will be set to True.")

  # Bin injections by component mass, with one mass fixed
  parser.add_option("--bin-by-component-mass", default=False, action="store_true", help="Bin injections by the first component's mass with the second component restricted to a single bin specified by --min/max-mass2.")
  parser.add_option("--component-mass1-bins", default=None, metavar="\'m0,m1,...,mk''", help="Specify the boundaries of the first component's mass bins. Input should be a comma separated list of masses. Injections outside the range of the bins will be ignored. If given, --bin-by-component-mass will be set to True.")
  parser.add_option("--min-mass2", metavar="m", type='float', default=1.0, help="Minimum second component mass in Msun when using --bin-by-component-mass. Default 1.")
  parser.add_option("--max-mass2", metavar="m", type='float', default=3.0, help="Maximum second component mass in Msun when using --bin-by-component-mass. Default 3.")

  # Bin injections into three
  parser.add_option("--bin-by-bns-bbh", default=False, action="store_true", help="Compute the sensitive volume for three mass bins: BNS, NSBH, and BBH.")
  parser.add_option("--min-nsmass", metavar="m", type='float', default=1.0, help="Specify the minimum mass for a neutron star (default is 1Msun)")
  parser.add_option("--max-nsmass", metavar="m", type='float', default=3.0, help="Specify the maximum mass for a neutron star (default is 3Msun)")
  parser.add_option("--min-bhmass", metavar="m", type='float', default=4.0, help="Specify the minimum mass for a black hole (default is 4Msun)")
  parser.add_option("--max-bhmass", metavar="m", type='float', default=6.0, help="Specify the maximum mass for a black hole (default is 6Msun)")

  # Specify which instrument set to look for
  parser.add_option("--min-mtotal", metavar="m", type='float', default=None, help="Specify the minimum total mass to consider among the injections found in the DB. This filters all injections outside this total mass range, even if binning by another method.")
  parser.add_option("--max-mtotal", metavar="m", type='float', default=None, help="Specify the maximum total mass to consider among the injections found in the DB. This filters all injections outside this total mass range, even if binning by another method.")
  parser.add_option("--instruments", default=None, metavar="IFO[,IFO,...]", help="Specify the on-instruments sets for computing the search volume.  Example \"H1,L1,V1\"")

  # Control the injection population used to measure the efficiency
  parser.add_option("--min-mass-ratio", metavar="q", type='float', default=1, help="Specify the minimum allowed mass ratio. Must be >= 1.")
  parser.add_option("--max-mass-ratio", metavar="q", type='float', default=None, help="Specify the maximum allowed mass ratio. Should be >= min-mass-ratio >= 1.")
  parser.add_option("--min-aligned-spin", metavar="chi", type='float', default=-1.0, help="Specify the minimum allowed value for the aligned spin parameter chi.")
  parser.add_option("--max-aligned-spin", metavar="chi", type='float', default=1.0, help="Specify the maximum allowed value for the aligned spin parameter chi.")
  parser.add_option("--exclude-sim-type", default=[], action="append", metavar="SIM", help="When computing the search volume, exclude injections made using the SIM waveform family. Example: SpinTaylorthreePointFivePN. Use this option multiple times to exclude more than one injection type.")
  parser.add_option("--disable-spin", default=False, action="store_true", help="When computing the search volume, exclude injections with any nonzero spin components.")
  parser.add_option("--spin-only", default=False, action="store_true", help="When computing the search volume, exclude injections with all spin components equal to zero. --disable-spin takes precedence.")

  # Binning options. A large number of distance bins is needed for accurate numerical integration.
  parser.add_option("--dist-bins", default=15, metavar="integer", type="int", help="Space distance bins evenly and specify the number of distance bins to use.")
  parser.add_option("--mass-bins", default=5, metavar="integer", type="int", help="If mass bin boundaries are not explicitly set, specify the number of mass bins to use along a single dimension. Default 5 bins")

  # Options to control the FAR used for computing the search volume. The default behavior is to use
  # the FAR of the loudest non-playground event.
  parser.add_option("--far", type="float", help="Specify a cFAR threshold for classifying injections into found or missed, rather than using cFAR of the actual or expected loudest event. Overridden by the --open-box option.")
  parser.add_option("--open-box", default=False, action="store_true", help="Use the observed cFAR for the loudest event: only appropriate when the full search result is known. If neither --far nor --open-box is given, default action is to use cFAR=1/livetime.")
  parser.add_option("--include-play", default=False, action="store_true", help="Include playground data in computing the livetime and volume.")

  # options to control output
  parser.add_option("--user-tag", default="", metavar="name", help="Add a descriptive tag to the names of output files.")
  parser.add_option("--output-cache", default=None, help="Name of output cache file. If not specified, then no cache file will be written.")
  parser.add_option("--output-path", type="string", default="./", action="store", help="Choose directory to save output files.")
  parser.add_option("--verbose", action="store_true", help="Be verbose.")
  parser.add_option("--plot-efficiency", action="store_true", help="Plot efficiency as a function of distance.")
  parser.add_option("--plot-volume-vs-far", action="store_true", help="Compute and plot the 4-volume searched as a function of FAR.  Overridden by the --open-box option.")
  parser.add_option("--plot-lambda-fits", action="store_true", help="Plot the data points used to compute lambda, and the fit obtained from these points.")
  parser.add_option("-t", "--tmp-space", metavar="path", help="Path to a directory used as work area for the database file.  The database file will be worked on in this directory then moved to the final location when complete.  Intended to improve performance in a networked environment: eg local disk with higher available bandwidth than the filesystem where the database lives.")

  opts, filenames = parser.parse_args()

  if opts.livetime_program is None:
    print >>sys.stderr, "Error: no livetime-program option was given, exiting!"
    sys.exit(1)

  if opts.open_box and opts.plot_volume_vs_far:
    print >>sys.stderr, "I will not plot V*T vs FAR if the --open-box option is given! Proceeding without V*T plots..."
    opts.plot_volume_vs_far = False

  if opts.instruments is not None:
    opts.instruments = frozenset(lsctables.instrument_set_from_ifos(opts.instruments))

  if len(filenames) != 1:
    print >>sys.stderr, "Error: must specify exactly one database file"
    sys.exit(1)

  opts.enable_output = True

  if opts.chirp_mass_bins:
    opts.bin_by_chirp_mass = True
  if opts.component_mass1_bins:
    opts.bin_by_component_mass = True
  if opts.total_mass_bins:
    opts.bin_by_total_mass = True

  if (not opts.bin_by_total_mass) and (not opts.bin_by_chirp_mass) and (not opts.bin_by_component_mass) and (not opts.bin_by_bns_bbh) and (not opts.bin_by_mass1_mass2):
    raise RuntimeError("Must specify at least one mass binning method!")

  if opts.min_mass_ratio < 1:
    raise ValueError, "The maximum mass ratio must be >=1!"
  if opts.max_mass_ratio and not opts.min_mass_ratio <= opts.max_mass_ratio:
    raise ValueError, "The maximum mass ratio must be >= minimum mass ratio!"

  opts.output_path = opts.output_path if opts.output_path[-1] == "/" else opts.output_path + "/"

  return opts, filenames[0]


############################ MAIN PROGRAM #####################################
###############################################################################
###############################################################################

#create an empty cache which will store the output files/metadata
cache_list = []


def write_cache(cache_list, fileout):
    # write cache file
    if opts.output_cache is not None:
        fd = open(fileout, "w")
        for l in cache_list:
            fd.write(str(l) + "\n")
        fd.close()
    return

#
# MAIN
#


# read in command line opts
opts, database = parse_command_line()

# read in data from input database, store in upper limit object
UL = upper_limit(database, opts)

# Set up opts structure for use with InspiralUtils
opts.gps_start_time = UL.start_time
opts.gps_end_time = UL.end_time

#
# plot the sensitive volume and range as a function of FAR threshold
#
if opts.plot_volume_vs_far:

    # determine the range of FARs applicable to this search.
    # do it once and for all to give plots the same x-limits
    fars = []
    for instr in UL.instruments:
        fars.extend([f for f, sim in UL.injections[instr] if 0 < f < float('inf')])
    fars = numpy.sort(fars)
    far_min = min(fars)
    far_max = max(fars)

    # make an array of fars to loop over
    nfars = 25
    fars = numpy.logspace(numpy.log10(far_min), numpy.log10(far_max), nfars)

    for instr in UL.instruments:

        lt = UL.livetime[instr] # preferred unit is Mpc^3*yr

        for bin_type in UL.mass_bins:

            minvol = 1e20  # dummy initialization for V*T plot axis limit
            maxvol = 0
            vAs, vA2s = [], []

            for far in fars:

                # compute volume for given far binned in mass
                found, missed = UL.get_injections(instr, far, bin_type)
                vA, vA2, _, _, _, _  = upper_limit_utils.compute_volume_vs_mass(found, missed, UL.mass_bins[bin_type], bin_type, dbins=UL.dbins[instr])
                vAs.append(vA)
                vA2s.append(vA2)

            for j, mbin in enumerate(iterutils.MultiIter(*UL.mass_bins[bin_type].centres())):

                if bin_type == "BNS_BBH":
                    label = ["NS-NS", "NS-BH", "BH-BH"][j]
                elif bin_type == "Mass1_Mass2":
                    label = '-'.join(["%.2f" % ma for ma in mbin]) + "M$_\odot$"
                else:
                    lo = UL.mass_bins[bin_type].lower()
                    hi = UL.mass_bins[bin_type].upper()
                    label = ','.join( "%.1f-%.1f M$_\odot$" % (lo[k][j], hi[k][j]) for k in range(len(lo)) )

                vols = numpy.array([v[mbin] for v in vAs])
                vols_errs = numpy.array([v2[mbin] for v2 in vA2s])

                plt.figure(1)
                print bin_type, vols
                line, = plt.plot(fars, lt * vols, label=label)
                plt.fill_between(fars, lt * (vols - vols_errs), lt * (vols + vols_errs), alpha=0.5, color=line.get_color())
                minvol = min(minvol, min(lt * (vols - vols_errs))) # update minimum plotted value
                maxvol = max(maxvol, max(lt * (vols + vols_errs)))

                dists = (3. * vols / (4. * numpy.pi)) ** (1. / 3.)
                dists_errs = dists * vols_errs / (3. * vols)

                plt.figure(2)
                line, = plt.plot(fars, dists, label=label)
                plt.fill_between(fars, dists - dists_errs, dists + dists_errs, alpha=0.5, color=line.get_color())

            plt.figure(1)
            plt.axvline(UL.far[instr], ls='--', c='k')
            tag = "%s-VOLUME_VS_FAR_BINNED_BY_%s_%s" % ("".join(sorted(list(instr))), bin_type.upper(), opts.user_tag)
            plt.xlabel("False Alarm Rate (yr$^{-1}$)")
            plt.ylabel("Mean Sensitive 4-Volume (Mpc$^3\,$yr)")
            # prefer log y axis for V*T plot but if one or more volumes is =0 then loglog() will fail
            if minvol > 0 and maxvol > 0:
                plt.loglog()
                plt.ylim(ymin=0.7 * minvol)
            elif maxvol > 0:
                plt.semilogx()
            else:
                plt.ylim(0,1)
            plt.grid()
            plt.xlim(min(fars), max(fars))
            plt.legend(loc="lower right")
            plt.savefig(opts.output_path + tag + ".png")
            plt.close()

            plt.figure(2)
            plt.axvline(UL.far[instr], ls='--', c='k')
            tag = "%s-RANGE_VS_FAR_BINNED_BY_%s_%s" % ("".join(sorted(list(instr))), bin_type.upper(), opts.user_tag)
            plt.xlabel("False Alarm Rate (yr$^{-1}$)")
            plt.ylabel("Mean Sensitive Distance (Mpc)")
            plt.semilogx()
            plt.grid()
            plt.xlim(min(fars), max(fars))
            plt.ylim(ymin=0)
            plt.legend(loc="lower right")
            plt.savefig(opts.output_path + tag + ".png")
            plt.close()


# loop over the requested instruments and mass bin types,
# compute the search volume and volume derivative at the specified
# FAR for the given mass bin
#
for bin_type in UL.mass_bins:
  for instruments in UL.instruments:
    # initialize plotting util helper
    __prog__ = "lalapps_cbc_svim_by_" + bin_type.lower()
    opts.ifo_times = "".join(sorted(list(instruments)))
    InspiralUtilsOpts = InspiralUtils.initialise( opts, __prog__, git_version.verbose_msg )
    fnameList = []; tagList = []

    #compute volume first and second moments above the loudest event
    found, missed = UL.get_injections(instruments, UL.far[instruments],bin_type)
    vA, vA2, f, m, eff, err = upper_limit_utils.compute_volume_vs_mass(found, missed, UL.mass_bins[bin_type], bin_type, dbins=UL.dbins[instruments])

    # plot efficiencies
    if opts.plot_efficiency:
      tag = "%s-SEARCH_EFFICIENCY_BINNED_BY_%s" % ("".join(sorted(list(instruments))), bin_type.upper())
      for j, mbin in enumerate(iterutils.MultiIter(*UL.mass_bins[bin_type].centres())):
        if bin_type == "BNS_BBH": label = ["NS-NS", "NS-BH", "BH-BH"][j]
        else: label = '-'.join(["%.2f"%ma for ma in mbin]) + "M$_\odot$"
        plt.errorbar((UL.dbins[instruments][:-1] + UL.dbins[instruments][1:]) / 2, eff[j], yerr=err[j], label=label)
        plt.legend(loc='upper right')
        plt.ylim([0, 1])
        plt.xlim([0, max(UL.dbins[instruments])])
        plt.xlabel("Distance (Mpc)")
        plt.ylabel("Efficiency")
        name = InspiralUtils.set_figure_tag(tag + "_BIN_" + str(j+1), open_box=opts.open_box)
        fname = InspiralUtils.set_figure_name(InspiralUtilsOpts, name)
        fname_thumb = InspiralUtils.savefig_pylal(filename=fname)
        fnameList.append(fname)
        tagList.append(name)
        plt.close()

    #compute volume derivative at loudest event
    dvA = UL.get_volume_derivative(instruments, bin_type, UL.far[instruments], fnameList, tagList, InspiralUtilsOpts)

    # make a live time table
    ltA = UL.live_time_array(instruments, UL.mass_bins[bin_type])

    vA.array *= UL.livetime[instruments]  #preferred unit is Mpc^3*yr
    vA2.array *= UL.livetime[instruments]

    # write out the results to the xml file
    xmldoc = ligolw.Document()
    child = xmldoc.appendChild(ligolw.LIGO_LW())

    # write out mass bins
    dim = len(UL.mass_bins[bin_type])
    for j in range(dim):
      xml = ligolw.LIGO_LW({u"Name": u"mass%d_bins:%s" % (j + 1, bin_type)})
      edges = tuple( numpy.concatenate((l, u[-1:])) for l, u in zip(UL.mass_bins[bin_type].lower(), UL.mass_bins[bin_type].upper()) )
      xml.appendChild(array.from_array(u"array", edges[j]))
      child.appendChild(xml)

    output_arrays = {"SearchVolumeFirstMoment":vA.array,
                     "SearchVolumeSecondMoment":vA2.array,
                     "SearchVolumeDerivative":dvA.array,
                     "SearchVolumeFoundInjections":f.array,
                     "SearchVolumeMissedInjections":m.array,
                     "SearchVolumeLiveTime":ltA.array}
    for arr in output_arrays:
      xml = ligolw.LIGO_LW({u"Name": u"binned_array:%s" % arr})
      xml.appendChild(array.from_array(u"array", output_arrays[arr]))
      child.appendChild(xml)

    output_filename = opts.output_path + "%s-SEARCH_VOLUME_BINNED_BY_%s_%s-%d-%d.xml" % \
        ("".join(sorted(list(instruments))), bin_type.upper(), opts.user_tag, UL.start_time, UL.end_time-UL.start_time)
    utils.write_filename(xmldoc, output_filename)
    cache_entry =  lal.CacheEntry( "".join(sorted(list(instruments))),
                                   bin_type,
                                   segments.segment(UL.start_time, UL.end_time),
                                   "file://localhost%s/%s" % (os.getcwd(),output_filename) )

    cache_list.append(cache_entry)

    plothtml = InspiralUtils.write_html_output( InspiralUtilsOpts, [database], fnameList,
                                                tagList, add_box_flag=False )
    InspiralUtils.write_cache_output( InspiralUtilsOpts, plothtml, fnameList )


# write a cache file describing the files generated during by this program
if opts.output_cache:
    write_cache(cache_list, opts.output_cache)
