#!/usr/bin/env python

# Copyright (C) 2011 Ian W. Harry
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

# =============================================================================
# Preamble
# =============================================================================

from __future__ import division

import os,sys,glob,re,time,numpy
from optparse import OptionParser
from pylal import SimInspiralUtils,MultiInspiralUtils,git_version
from glue.ligolw import ilwd,ligolw,table,lsctables,utils
from glue import segments,segmentsUtils,lal
from pylal.dq import dqSegmentUtils
from pylal.coh_PTF_pyutils import append_process_params,remove_bad_injections,identify_bad_injections

__author__  = "Ian Harry <ian.harry@astro.cf.ac.uk>"
__version__ = "git id %s" % git_version.id
__date__    = git_version.date

# set up timer
jobstart = int(time.time()*10**6)
elapsed_time = lambda: int(time.time()*10**6-jobstart)

# =============================================================================
# Parse command line
# =============================================================================

def parse_command_line():

  usage = """usage: %prog [options] 
  
coh_PTF_injfinder will find and record found and missed injections for the given injection run

--inj-cache
--cache
--time-window
"""

  parser = OptionParser(usage, version=__version__)

  parser.add_option("-c", "--cache", action="store", type="string",\
                     default=None, help="read trigger files from this cache.")

  parser.add_option("-i", "--inj-cache", action="store", type="string",\
                     default=None,\
                     help="read injection files from this cache.")

  parser.add_option("-o", "--output-dir", action="store", type="string",\
                     default=os.getcwd(), help="output directory, "+\
                                               "default: %default")

  parser.add_option("-W", "--time-window", action="store", type="float",\
                     default=0, help="The found time window")

  parser.add_option("-b", "--exclude-segments", action="append",\
                     type="string", default=[],\
                     help="ignore injections in segments found within "+\
                          "these files, e.g. buffer segments (may be given "+\
                          "more than once")
  parser.add_option("-l", "--log-dir", action="store", type="string",\
                     default=None, help="Location of log files to determine "+\
                     "if any SpinTaylor injections failed and remove them. If"+\
                     " not set, no injections will be removed.")
  parser.add_option("-v", "--verbose", action="store_true", default=False,\
                     help="verbose output, default: %default")

  (opts,args) = parser.parse_args()

  if not opts.cache:
    parser.error("--cache must be given")

  if not opts.inj_cache:
    parser.error("--inj-cache must be given")

  if not opts.time_window > 0:
    parser.error("time window must be given and greater than 0")

  return opts, args

# =============================================================================
# Main function
# =============================================================================

def main(trigCacheFile, injCacheFile, timeWindow, outdir,\
         exclude_segments=[], verbose=False):

  #
  # load cache files
  #

  trigcache = lal.Cache.fromfile(open(trigCacheFile, 'r'))
  injcache  = lal.Cache.fromfile(open(injCacheFile, 'r'))
  injcache.sort(key=lambda e: int(e.description.split('_')[-1]))

  if len(trigcache)<1:
    raise Error("No trigger files found.")
  if len(injcache)<1:
    raise Error("No injection files found.")

  # get ifos
  ifos = lsctables.instrument_set_from_ifos(trigcache[-1].observatory)

  if verbose: sys.stdout.write("Cache files loaded at %d.\n" % elapsed_time())

  #
  # load exclusion segments
  #

  excludes = segments.segmentlist([])
  for filename in exclude_segments:
    if filename.endswith('.txt'):
      excludes.extend(segmentsUtils.fromsegwizard(open(filename, 'r')))
    elif filename.endswith('xml.gz') or filename.endswith('xml'):
      excludes.extend(dqSegmentUtils.fromsegmentxml(open(filename, 'r')))
  excludes = excludes.coalesce()

  if verbose: sys.stdout.write("Exclusion segments loaded at %d.\n"\
                               % elapsed_time())

  #
  # identify columns to load
  #

  # MultiInspiralTable

  # set columns
  ifoAtt = { 'G1':'g', 'H1':'h1', 'H2':'h2', 'L1':'l', 'V1':'v', 'T1':'t' }

  # set up columns
  cols = ['end_time', 'end_time_ns', 'ifos', 'process_id', 'ra', 'dec', 'snr',\
          'null_statistic', 'null_stat_degen','mass1','mass2','eta','mchirp']
  # add snr
  cols.extend(['snr_%s' % ifoAtt[i] for i in ifos])\
  # add chisq
  cols.extend(['chisq', 'chisq_dof'])
  cols.extend(['chisq_%s' % ifoAtt[i] for i in ifos])
  # add bank chisq
  cols.extend(['bank_chisq', 'bank_chisq_dof'])
  #cols.extend(['bank_chisq_%s' % ifoAtt[i] for i in ifos])
  # add auto chisq  
  cols.extend(['cont_chisq', 'cont_chisq_dof'])
  #cols.extend(['cont_chisq_%s' % ifoAtt[i] for i in ifos])
  # add sigmasq
  cols.extend(['sigmasq_%s' % ifoAtt[i] for i in ifos])
  # add amplitude terms
  cols.extend(['amp_term_%d' % i for i in xrange(1,11)])
  # set columns
  cols = [c for c in cols if c\
          in lsctables.MultiInspiralTable.validcolumns.keys()]
  lsctables.MultiInspiralTable.loadcolumns = cols

  # SimInspiralTable
  cols = ['geocent_end_time', 'geocent_end_time_ns', 'mass1', 'mass2',\
          'mchirp', 'f_lower', 'inclination', 'spin1x', 'spin1y', 'spin1z',\
          'spin2x', 'spin2y', 'spin2z', 'longitude', 'latitude', 'distance',\
          'eff_dist_h','eff_dist_l','eff_dist_v']
  cols = [c for c in cols if c\
          in lsctables.SimInspiralTable.validcolumns.keys()]
  lsctables.SimInspiralTable.loadcolumns = cols

  #
  # Identify bad injections
  #

  if opts.log_dir:
    badInjs = identify_bad_injections(opts.log_dir)

  if verbose: sys.stdout.write("Bad injections identified at %d.\n"\
                               % elapsed_time())

  #
  # test found/missed
  #

  if verbose: sys.stdout.write("\nTesting missed/found...\n")

  objectList = []
  exampleFile = None

  for e in injcache:

    injFile = e.path
    num = int(e.description.split('_')[-1])

    numtrigcache = trigcache.sieve(description='*_%d' % num, exact_match=True)

    if not numtrigcache:
      sys.stdout.write("WARNING: cannot find any files matching "+\
                       "injection %d.\n" % num)

    trigFiles = numtrigcache.pfnlist()
  
    if os.path.isfile(injFile) and trigFiles:

      if not exampleFile:  exampleFile = numtrigcache[0]

      # get trigs (only columns we need)
      currInjs  = lsctables.New(lsctables.SimInspiralTable,\
                                columns=lsctables.SimInspiralTable.loadcolumns)
      currInjs.extend(i for i in\
                      SimInspiralUtils.ReadSimInspiralFromFiles([injFile])\
                      if i.get_time_geocent() not in excludes)
      currTrigs  = lsctables.New(lsctables.MultiInspiralTable,\
                               columns=lsctables.MultiInspiralTable.loadcolumns)
      tmp = MultiInspiralUtils.ReadMultiInspiralFromFiles(trigFiles)
      if tmp is not None:
        currTrigs.extend(tmp)
      if len(currTrigs):
        currTrigs.sort(key=lambda t: t.snr, reverse=True)

      for trig in currTrigs:
        # Temporary hack to allow the code to work with tables that dont have
        # the new time slide ID column.
        try:
          tmp = trig.time_slide_id
        except AttributeError:
          trig.time_slide_id=ilwd.ilwdchar("multi_inspiral:time_slide_id:0")

      if opts.log_dir:
        currInjs = remove_bad_injections(currInjs,badInjs)

      # loop over injections
      for currInj in currInjs:
        # construct injection dict
        currObject = {}
        currObject['inj'] = currInj
        currObject['trig'] = None
        currObject['found'] = False
        injTime  = currInj.get_time_geocent()

        for t in currTrigs:
          if abs(t.get_end()-injTime) < timeWindow:
            currObject['trig'] = t
            currObject['found'] = True
            break
        objectList.append(currObject)

    if verbose: sys.stdout.write("Injection %d processed at %d.\n"\
                                 % (num, elapsed_time()))

  #
  # construct new tables
  #
  
  foundInjs  = lsctables.New(lsctables.SimInspiralTable,\
                             columns=lsctables.SimInspiralTable.loadcolumns)
  missedInjs = lsctables.New(lsctables.SimInspiralTable,\
                             columns=lsctables.SimInspiralTable.loadcolumns)
  foundTrigs = lsctables.New(lsctables.MultiInspiralTable,\
                             columns=lsctables.MultiInspiralTable.loadcolumns)
  eventCount = 0
  
  for object in objectList:
    object['inj'].simulation_id =\
        ilwd.ilwdchar("sim_inspiral:simulation_id:%d" % eventCount)

    if not object['found']:
      missedInjs.append(object['inj'])
    else:
      object['trig'].event_id =\
          ilwd.ilwdchar("multi_inspiral:event_id:%s" % eventCount)
      foundInjs.append(object['inj'])
      foundTrigs.append(object['trig'])

    eventCount += 1
    
  if verbose:
    sys.stdout.write("\nNumber of found injections: %d\n"\
                     "Number of triggers associated to found injections: %d\n"\
                     "Number of missed injections: %d\n"\
                     % (len(foundInjs), len(foundTrigs), len(missedInjs)))
 
  #
  # write combined injections to file
  #

  if verbose: sys.stdout.write("\nWriting new xml files...\n")

  # prepare xmldocument 
  xmldoc = ligolw.Document()
  xmldoc.appendChild(ligolw.LIGO_LW())

  # append process params table
  xmldoc = append_process_params(xmldoc, sys.argv, __version__, __date__)

  # get search summary table from old file
  oldxml   = utils.load_filename(exampleFile.path,\
                                  gz = exampleFile.path.endswith("gz"), contenthandler = lsctables.use_in(ligolw.LIGOLWContentHandler))
  oldSearchSummTable = table.get_table(oldxml, "search_summary")
  xmldoc.childNodes[-1].appendChild(oldSearchSummTable)

  # construct output filename
  ifoTag = exampleFile.observatory
  start, end = trigcache.to_segmentlistdict().extent_all()

  userTag = injcache[0].description.rsplit('_', 1)[0]
  xmlFile = '%s/%s-%s_FOUND-%d-%d.xml'\
            % (outdir, ifoTag, userTag, start, end-start)

  start,end = map(int, exampleFile.segment)

  # write found injections
  xmldoc.childNodes[-1].appendChild(foundTrigs)
  xmldoc.childNodes[-1].appendChild(foundInjs)
  utils.write_filename(xmldoc, xmlFile, gz=xmlFile.endswith('gz'))
  if verbose: sys.stdout.write("%s written at %d\n" % (xmlFile, elapsed_time()))

  xmldoc.childNodes[0].removeChild(foundTrigs)
  xmldoc.childNodes[0].removeChild(foundInjs)

  # write missed injections
  xmldoc.childNodes[0].appendChild(missedInjs)
  xmlFile = xmlFile.replace('FOUND', 'MISSED')
  utils.write_filename(xmldoc, xmlFile)
  if verbose: sys.stdout.write("%s written at %d\n" % (xmlFile, elapsed_time()))

if __name__=='__main__':
  
  opts, args = parse_command_line()
  
  trigcache  = os.path.abspath(opts.cache)
  injcache   = os.path.abspath(opts.inj_cache)
  excludes   = map(os.path.abspath, opts.exclude_segments)
  timeWindow = opts.time_window
  outdir     = os.path.abspath(opts.output_dir)
  verbose    = opts.verbose 

  main(trigcache, injcache, timeWindow, outdir, exclude_segments=excludes,\
       verbose=verbose)
  if verbose: sys.stdout.write("Done at %d.\n" % (elapsed_time()))
