#!/usr/bin/python
#
# Copyright (C) 2009 Tomoki Isogai
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
%prog --ifo=IFO --trigger_files=File --segment_file=File --channel=Channel --positive_window=Second --negative_window=Second [options]

Tomoki Isogai (isogait@carleton.edu)

For a given channel, this program calculates the following veto statistics at each KW significance threshold specified:
Used percentage, # of coincident KW trigs above threshold, # of total KW trigs above threshold, # of vetoed GW triggers, veto efficiency at threshold, dead time, dead time percentage, used percentage over random used percentage, efficiency over dead time percentage.
Random used percentage is calculated as:
(total GW trigger number) x (dead time percentage) / (total KW triggers above the threshold)
Thresholds at which the code calculates those values are specified by --min_thresh, --max_thresh, --resolution. For example, with --min_thresh=50, --max_thresh=500 and --resolution=50, the code will calculate at each of 50, 100, 150, 200, 250, 300, 350, 400, 450, 500 KW significance threshold.
The supported ifo is H1, H2, L1 and V1. 
Unless you specify --KW_location, the code assumes that KW triggers are in the default location specified in KW_veto_getKW.py.
You can supply multiple GW trigger files and the code will take a union.
Both KW and GW triggers will be cut by --segment_file and --min_snr/--min_thresh. GW triggers are considered to be coincident if it falls into +(--positive_window), -(--negative_window) of KW triggers.
Supported output format is txt, txt.gz, pickle, pickle.gz, mat, db (sqlite database). The sqlite database containing all the data will be created in addition.
"""

# =============================================================================
#
#                              PREANBLE
#
# =============================================================================


from __future__ import division
import sys
import os
import optparse
from math import floor, ceil

try:
    import sqlite3
except ImportError:
    # pre 2.5.x
    from pysqlite2 import dbapi2 as sqlite3

from glue.segments import segment, segmentlist
from glue import segmentsUtils

from pylal import git_version
from pylal import KW_veto_getTriggers
from pylal import KW_veto_getKW
from pylal import KW_veto_utils

__author__ = "Tomoki Isogai <isogait@carleton.edu>"
__date__ = "7/10/2009"
__version__ = "2.0"

def parse_commandline():
    """
    Parse the options given in the command-line and do error checking
    """
    parser = optparse.OptionParser(usage=__doc__,version=git_version.verbose_msg)
    parser.add_option("-i", "--ifo", help="IFO to be analyze. Required.")
    parser.add_option("-t", "--trigger_files", action="append", default=[],
                      help="Files containing GW triggers. Can be provided multiple times to specify more than one files. A union will be taken. Required.")
    parser.add_option("-K", "--KW_location",default=None,
                      help="Location of KW trigger folder if you are not using the folders specified in --help in KW_veto_getKW.py.")
    parser.add_option("-S", "--segment_file",
                      help="File which contains segments to be analyzed. Required.")
    parser.add_option("-c", "--channel", help="Channel to be analyzed. Required.")
    parser.add_option("-C", "--critical_usedPer", type="int", default=50,
                      help="Code recognizes channels above this used percentage as veto candidate channel. (Default: 50)")
    parser.add_option("-s","--min_snr", default=8, type="float",
                      help="Minimum SNR to filter triggers. (Default: 8)")
    parser.add_option("-p","--positive_window",type="float",
                      help="Positive veto-window size in second. Required.")
    parser.add_option("-n","--negative_window",type="float",
                      help="Negative veto-window size in second. Required.")
    parser.add_option("-m", "--min_thresh",default=50,type="int",
                      help="Minimum threshold to be analyzed. (Default: 50)")
    parser.add_option("-M", "--max_thresh",default=5000,type="int",
                      help="Maximum threshold to be analyzed. (Default: 5000)")
    parser.add_option("-r", "--resolution",default=50,type="int",
                      help="The code raise threshold at this interval. (Default: 50)")
    parser.add_option("-I", "--injection_file", default=None,
                      help="File that contains HW injection times. If given, code will check safety against those HW injections.")
    parser.add_option("-x", "--safety_thresh",type="int",default=-5,
                      help="Threshold 10**x will be applied for safety check if given. (Default: -5)")
    parser.add_option("-o","--out_dir",default="result",
                      help="Output directory. (Default: result)")
    parser.add_option("-l","--scratch_dir",default=".",
                      help="Scratch directory to be used for database engine. Specify local scratch directory for better performance and less fileserver load. (Default: current directory)")
    parser.add_option("-N","--name_tag",default="",
                      help="Name tag for output file. (Default: )")
    parser.add_option("-e","--extension",default=".txt",
                      help="Extension for output file. See --help for supported formats. (Default: .txt)")
    parser.add_option("-v", "--verbose", action="store_true",
                      default=False, help="Run verbosely. (Default: False)")
    
    opts, args = parser.parse_args()
   
    ########################## sanity check ####################################
    
    # check if necessary input exists
    for o in ("ifo","segment_file","channel","positive_window","negative_window"):
      if getattr(opts,o) is None:
        parser.error("Error: --%s is a required parameter"%o)
            
    if opts.trigger_files == []:
      parser.error("Error: --trigger_files is a required parameter")

    # check if ifo is supported
    opts.ifo = opts.ifo.upper()
    if opts.ifo not in ['H1','H2','L1','V1']:
      parser.error("Error: ifo has to be H1, H2, L1 or V1")
        
    # create output directory if not exist yet
    if not os.path.exists(opts.out_dir): os.mkdir(opts.out_dir)
    
    # check if scratch directory exists
    if not os.path.exists(opts.scratch_dir):
      parser.error("Error: %s does not exist"%opts.scratch_dir)
        
    # check if necessary files exist
    for f in opts.trigger_files + [opts.segment_file]:
      if not os.path.isfile(f):
        parser.error("Error: %s not found"%f)
            
    # check the output extension
    if not opts.extension.startswith("."): opts.extension="."+opts.extension
    if opts.extension not in (".txt",".txt.gz",".pickle",".pickle.gz",".mat",".db"):
      parser.error("Error: unsupported file extension %s. See -help for supported formats."%opts.extension)
        
    # veto_window 0 is not interesting and might cause trouble in the code
    if opts.positive_window == 0 and opts.negative_window == 0:
      parser.error("Error: veto window size cannot be 0.")

    if opts.injection_file != None:
      if not os.path.isfile(opts.injection_file):
        print >> sys.stderr, "--injection_file %s is not found."%injection_file
        sys.exit(1)
        
    #
    # Convert paths to absolute paths so that it can later be referenced by
    # other codes without trouble
    #

    opts.trigger_files = map(os.path.abspath,opts.trigger_files)
    opts.segment_file = os.path.abspath(opts.segment_file)
    opts.out_dir = os.path.abspath(opts.out_dir)
    opts.scratch_dir = os.path.abspath(opts.scratch_dir)
    if opts.injection_file != None:
      opts.injection_file = os.path.abspath(opts.injection_file)
 
    ######################### show parameters ##################################

    if opts.verbose:
        print >> sys.stderr, "Running KW_veto_calc..."
        print >> sys.stderr, git_version.verbose_msg
        print >> sys.stderr, ""
        print >> sys.stderr, "***************** PARAMETERS ********************"
        for o in opts.__dict__.items():
          print >> sys.stderr, o[0]+":"
          print >> sys.stderr, o[1]
        print >> sys.stderr, "" 
        
    return opts
    
def find_coincidence(master_cur, KWcur, GWcur):
    """
    This function find coincident triggers, and make tables 'KWtrigs' and
    'GWtrigs' for KW and GW triggers respectively.

    Both tables contain columns:
    ID: integer ID of each event
    GPSTime: GPS time of each event

   'KWtrigs' table also contain columns:
    KWSignificance: KW significance of each event 
    CoincidenceGWTrigID: ID number of GW events that the trigger is coincident
                         to for coincident event, and 'No Coincidence' for 
                         non-coincident event
   
    'GWtrigs' table also contain columns:
    SNR: SNR of each event
    coincidence: 1 for coincident event and 0 for non-coincident event
    """
    
    def CoincidenceGWTrigID(IDs):
      """
      return comma separated IDs of coincident GW trigs to the coincident
      KW trig;
      return a string 'No Coincidence' for non-coincident KW trig.
      """
      return len(IDs) > 0 and ",".join(map(str, IDs)) or 'No Coincidence'
          
    def get_next(index, trigList, trigs):
      """
      Get the next trigger only when necessary so that we can maintain the 
      KWlist short.
      fetchone() returns None when no more trigger is available - don't add
      None to the list.
      """
      if index >= len(trigList):
        next = trigs.fetchone()
        if next is not None:
          trigList.append(KWtrig(next))
      return trigList
 
    class KWtrig:
      """
      Class that represents KW trigger.
      IDs is IDs of coincident GW triggers.
      """
      def __init__(self, trig):
        """
        trig is assumed to be a tuple (GPSTime, KWSignificance, frequency) from fetchone()
        """
        self.GPSTime = trig[0]
        self.KWSignificance = trig[1]
        self.frequency = trig[2]
        self.IDs = []

      def add_ID(self, ID):
         self.IDs.append(ID)
   
    ## initialize tables to put data in
    master_cur.execute("create table KWtrigs (ID integer primary key, GPSTime double, KWSignificance double, frequency double, CoincidenceGWTrigID text)")
    master_cur.execute("create table GWtrigs (ID integer primary key, GPSTime double, SNR double, coincidence int)")

    ## get KW triggers
    KWtrigs = KWcur.execute("select * from KWtrigs order by GPSTime asc")
 
    # KWlist is a list that contains KWtrig instances and used to hold
    # the necessary KW trigs to revisit.
    # get_next() function ensures that this list would stay in the minimum size
    # possible; (the number of KW coincidences to one GW trig) + 1
    first =  KWtrigs.fetchone()
    if first != None:
      KWlist = [KWtrig(first)]
    else: # case where we don't have any KW
      KWlist = []

    # Strategy is that going through each triggers miminum number possilbe.
    # It visits each GW trigger exactly once, and it revisits only coincident
    # KW triggers and the one right next to them (needed because one trigger
    # can be coincident to multiple triggers.)
    # GW[0] is GPS time of the GW trigger, GW[1] is SNR of the GW trigger

    for GW_id, GW in enumerate(GWcur.execute("select * from GWtrigs order by GPSTime asc")):
      i = 0 # track index of current trigger in KWlist
      GWcoinc = False # track if this GW trigger is in coincidence
      while True: # loop over KW triggers
        ## End Case: when KW triggers finish first
        if i == len(KWlist):
          break

        KW = KWlist[i]
        if ceil(KW.GPSTime + opts.positive_window) >= GW[0]:
          if floor(KW.GPSTime - opts.negative_window) <= GW[0]:
            ## Case 1: Coincidence
            GWcoinc = True
            # add 1 to GW_id because actual ID starts from 1 and not 0
            KWlist[i].add_ID(GW_id + 1)
            i += 1
            KWlist = get_next(i, KWlist, KWtrigs)

          else:
            # Case 2: No coincidence and  GW < KW in GPS time
            # GW trigger is not coincident with this particular KW trigger
            break

        else:
          # Case 3: No coincidence and  KW < GW in GPS time
          t = KWlist.pop(0)
          # add GW trig in the table
          master_cur.execute("insert into KWtrigs (GPSTime, KWSignificance, frequency, CoincidenceGWTrigID) values (?,?,?,?)",(t.GPSTime, t.KWSignificance, t.frequency, CoincidenceGWTrigID(t.IDs)))
          KWlist = get_next(0, KWlist, KWtrigs)

      ######################### end of while loop #############################
      
      # add KW trig in the table
      master_cur.execute("insert into GWtrigs (GPSTime, SNR, Coincidence) values (?,?,?)",(GW[0], GW[1], GWcoinc))

    ## End Case:  Dump the rest of KW triggers when GW triggers finish first
    for KW in KWlist + map(lambda x: KWtrig(x), KWtrigs.fetchall()):
      master_cur.execute("insert into KWtrigs (GPSTime, KWSignificance, frequency, CoincidenceGWTrigID) values (?,?,?,?)", (KW.GPSTime, KW.KWSignificance, KW.frequency, CoincidenceGWTrigID(KW.IDs)))
    
    return master_cur


def calc_stats():
  """
  This function calculates statistics of the veto at each threshold and store
  the result in database.
  In addition, this generates veto segments at the threshold at which the used
   percentage exceeds predifined value.
  """

  ## initialize
  totalTime = abs(analyzed_segs)
  totalKW_counter = 0
  coinKW_counter =  0
  vetoedGW = set()
  vetoed_segs = segmentlist()
  candidate = False
  result = []

  ## get KW triggers
  KWtrigs = cursor.execute("select GPSTime, KWSignificance, CoincidenceGWTrigID from KWtrigs order by KWSignificance desc")
  KW = KWtrigs.fetchone()
 
  # t is the threshold
  for t in sorted(range(opts.min_thresh, opts.max_thresh, opts.resolution),reverse=True):
    while KW is not None and KW[1] >= t:
      if KW[2] != 'No Coincidence': 
        # Coincident Event
        coinKW_counter += 1
        vetoedGW |= set(KW[2].split(","))
      # All Event
      totalKW_counter += 1
      vetoed_segs.append(KW_veto_utils.get_segment(KW[0],
                           opts.positive_window,opts.negative_window))

      # get the next trigger
      KW = KWtrigs.fetchone()

    ## store the result
    deadTime = abs(vetoed_segs.coalesce()) # avoid calculate it again and again 
    # round to 2 decimal points for float values
    result.append((
    t, # threshold
    "%.2f"%(100 * coinKW_counter / (totalKW_counter or 1)), # used percentage
    coinKW_counter, # number of coincident KW events above the threshold
    totalKW_counter, # number of KW events above the threshold
    len(vetoedGW), # number of vetoed GW triggers
    "%.2f"%(100 * len(vetoedGW) / totalGWNum), # efficiency
    deadTime, # dead time
    "%.4f"%(100 * deadTime / totalTime), # dead time percentage
    # used percentage / random used percentage
    "%.2f"%(coinKW_counter * totalTime / totalGWNum / (deadTime or 1)),
    # efficiency / dead time percentage
    "%.2f"%(len(vetoedGW) * totalTime / totalGWNum / (deadTime or 1))
    ))
  
  ######################### End of for Loop  #################################
 
  ## set up the table for result
  cursor.execute("create table result (threshold int, usedPercentage double, coincidentKW int, totalKW int, vetoedGW int, efficiency double, deadTime int, deadTimePercentage double, usedPer_over_randomUsedPer double, efficiency_over_deadTimePer double)")

  ## insert the result
  for r in sorted(result): # sort by threshold from low to high
    cursor.execute("insert into result (threshold, usedPercentage, coincidentKW, totalKW, vetoedGW, efficiency, deadTime, deadTimePercentage, usedPer_over_randomUsedPer, efficiency_over_deadTimePer) values (?,?,?,?,?,?,?,?,?,?)", r)
    if r[4] == 0: 
      # r[4] is total KW number
      # it's not interesting once we hit 0 for total KW number, so stop
      break
       
def get_veto_segs(veto_thresh):
  """
  Create a veto segment list with a given KW significance threshold.
  Take an intersection with analyzed segments so that all the veto segments
  stays in analyzed segs.
  """
  veto_times = cursor.execute("select GPSTime from KWtrigs where KWSignificance >= ?",(veto_thresh,))
  veto_segs = segmentlist([KW_veto_utils.get_segment(t[0], opts.positive_window, opts.negative_window) for t in veto_times])
  return veto_segs.coalesce().__iand__(analyzed_segs)

def injection_check(cur, injection_file, veto_segments, totalTime):
  """
  """
  all_injection_times = [float(time) for time in open(injection_file).readlines() if not time.startswith("#") and time != "\n"]

  injection_times = [t for t in all_injection_times if t in analyzed_segs]
  injection_times.sort()

  cur.execute("create table injections (GPSTime float, vetoed txt)")
  Nvetoed = 0
  for i in injection_times:
    try:
      i = float(i)
    except(ValueError):
      print >> sys.stderr, ("--injection_file contains non-number.")
      raise
    if i in veto_segments:
      Nvetoed += 1
      cur.execute("insert into injections (GPSTime, vetoed) values (?,?)", (i,"Vetoed"))
    else:
      cur.execute("insert into injections (GPSTime, vetoed) values (?,?)", (i,"Not Vetoed"))
  
  totalInjNum = len(injection_times)
  Nexp = totalInjNum * abs(veto_segments) / totalTime
 
  import scipy.stats
  prob = 1 - scipy.stats.poisson.cdf(Nvetoed - 1, Nexp)
  if prob > 10**opts.safety_thresh or totalInjNum == 0:
    safety = "Pass"
  else:
    safety = "Unsafe"

  return cur, Nvetoed, "%.2f"%Nexp, "%.2e"%prob, safety, totalInjNum

#==============================================================================
#
#                                MAIN
#
#==============================================================================
    
## parse command line
opts = parse_commandline()

# segments to be analyzed  
analyzed_segs = KW_veto_utils.read_segfile(opts.segment_file)

"""
1. read in GW triggers
2. read in KW triggers
3. check coincident events
4. calculate statistical values (used percentage etc.)
5. Save the result
"""

try: # include in finally cause so that we can erase tmp files in any case

  ########################################################################
  # 1. Read in the inspiral triggers
  ########################################################################

  if opts.verbose: print >> sys.stderr, "reading in GW triggers..."

  # triggers are filtered by seg file and minimum threshold
  GWdbname, GW_working_filename, GWconnection, GWcursor, totalGWNum = \
     KW_veto_getTriggers.get_trigs(opts.trigger_files,analyzed_segs,
     opts.min_snr,opts.name_tag,opts.scratch_dir,verbose=opts.verbose)

  ########################################################################
  # 2. Get KW triggers for the channel
  ########################################################################
  
  if opts.verbose: print >> sys.stderr, "reading in KW triggers..."
  # KW trigs are filtered by seg file and minimum threshold
  KWdbname, KW_working_filename, KWconnection, KWcursor, opts.channel = \
          KW_veto_getKW.get_trigs(opts.channel, analyzed_segs, opts.min_thresh,
           trigs_loc=opts.KW_location,name_tag=opts.name_tag,
            scratch_dir=opts.scratch_dir,verbose=opts.verbose)
  
  ########################################################################
  # 3. Check coincident events
  ########################################################################

  ## initialize master database
  ## define other file names as well for later use
  # common prefix:
  # (tag)-(IFO)-(channel name)-(start time)-(duration)

  start_time = int(analyzed_segs[0][0])
  end_time = int(analyzed_segs[-1][1])
  duration = end_time - start_time
  filePrefix = os.path.join(opts.out_dir, "-".join([opts.name_tag, 
                         opts.ifo.upper(), opts.channel,
                         str(start_time), str(duration)]))
  dbFile = filePrefix + "-data.db"
  resultFile  = '%s-result%s'%(filePrefix,opts.extension)  
  GWtrigsFile = '%s-GWtrigs%s'%(filePrefix,opts.extension) 
  KWtrigsFile = '%s-KWtrigs%s'%(filePrefix,opts.extension)
  vetoSegFile = '%s-vetoSegments%s'%(filePrefix,opts.extension)

  # if the file already exists, rename the old one
  KW_veto_utils.rename(dbFile)

  global working_filename # so that it can be erased when error occurs
  working_filename = KW_veto_utils.get_connection_filename(\
                     dbFile, tmp_path=opts.scratch_dir, verbose=opts.verbose)

  connection = sqlite3.connect(working_filename)
  cursor = connection.cursor()
  
  if opts.verbose: 
      print >> sys.stderr, "checking KW triggers for coincidence..."
  # mark coincident triggers
  cursor = find_coincidence(cursor,KWcursor,GWcursor)
  
  connection.commit()

  KWconnection.close()
  del KWcursor
  del KWconnection
  GWconnection.close()
  del GWcursor
  del GWconnection
  
  ########################################################################
  # 4. calculate statistics of veto
  # - for candidate channel, make veto segments.
  ########################################################################
  
  if opts.verbose: print >> sys.stderr, "calculating veto statistics..."
  
  # calc_stats is the function that calculates all the values
  calc_stats()
  connection.commit()

  ## for candidate channel, get veto segs and check injection....   
  candidate_data = KW_veto_utils.get_candidate(cursor,opts.critical_usedPer)
  if candidate_data != None:
      if opts.verbose:
        print >> sys.stderr, "candidate channel!"
        print >> sys.stderr, "creating veto segments..."
      veto_segments = get_veto_segs(candidate_data[0])
      # store in the database
      KW_veto_utils.write_segs_db(cursor,veto_segments,"veto_segments")
      connection.commit()
      # save in a file
      segmentsUtils.tosegwizard(open(vetoSegFile,'w'),veto_segments)
      if opts.verbose:
        print >> sys.stderr, "checking safety with HW injections..."
      if opts.injection_file != None:
        cursor, Nvetoed, Nexp, probability, safety, totalInjNum = \
                         injection_check(cursor, opts.injection_file,\
                         veto_segments, abs(analyzed_segs))
  
  if candidate_data == None or opts.injection_file == None:

    Nvetoed = "n/a"
    Nexp = "n/a"
    probability = "n/a"
    safety = "n/a"
    totalInjNum = "n/a"

  ############################################################################
  # 5. save the result
  ############################################################################

  # for the sake of keeping track of varialbes for later use, store them in the database
  opts.trigger_files = ",".join(opts.trigger_files)
  params = opts.__dict__.items()
  params.append(("filePrefix",os.path.basename(filePrefix)))
  params.append(("totalGWtrigsNum",totalGWNum))
  params.append(("totalTime",int(abs(analyzed_segs))))
  params.append(("start_time",start_time))
  params.append(("end_time",end_time))
  params.append(("resultFile",resultFile))
  params.append(("GWtrigsFile",GWtrigsFile))
  params.append(("KWtrigsFile",KWtrigsFile))
  params.append(("vetoSegFile",vetoSegFile))
  params.append(("HWInjNvetoed",Nvetoed))
  params.append(("HWInjNexp",Nexp))
  params.append(("HWInjProbability",probability))
  params.append(("safety",safety))
  params.append(("totalInjNum",totalInjNum))
  cursor.execute("create table params (var_name txt, value)")
  cursor.executemany("insert into params values (?,?)", params)
  connection.commit()

  # also save the analyzed segment into the database
  cursor = KW_veto_utils.write_segs_db(cursor,analyzed_segs,"analyzed_segs")  
  connection.commit()

  if opts.verbose: 
    print >> sys.stderr, "saving the result in %s, %s, %s..."\
                         %(resultFile, GWtrigsFile, KWtrigsFile)

  ## write in files
  KW_veto_utils.save_db(cursor, 'ALL', dbFile, working_filename, order_by = None, verbose = opts.verbose)
  if opts.extension != '.db':
    KW_veto_utils.save_db(cursor, 'result', resultFile, working_filename, order_by = 'threshold asc', verbose = opts.verbose)
    KW_veto_utils.save_db(cursor, 'GWtrigs', GWtrigsFile, working_filename, order_by = 'GPSTime asc', verbose = opts.verbose)
    KW_veto_utils.save_db(cursor, 'KWtrigs', KWtrigsFile, working_filename, order_by = 'GPSTime asc', verbose = opts.verbose)

finally:
  # clean up all the tmp databases
  for db in ("KW_working_filename","GW_working_filename","working_filename"):
    if globals().has_key(db):
      db = globals()[db]
      if opts.verbose:
        print >> sys.stderr, "removing temporary workspace '%s'..."%db
      os.remove(db)
  
if opts.verbose: 
  print >> sys.stderr, "KW_veto_calc for %s triggers * %s done!"\
                                                %(opts.ifo,opts.channel)    
