#!/usr/bin/env python

# Copyright (C) 2011 Ian W. Harry
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

# =============================================================================
# Preamble
# =============================================================================

from __future__ import division
import os,sys,time
from optparse import OptionParser
import numpy
from pylal import MultiInspiralUtils,git_version
from glue.ligolw import ligolw,table,lsctables,utils
from pylal.coh_PTF_pyutils import append_process_params
from glue import lal
      
__author__  = "Ian Harry <ian.harry@astro.cf.ac.uk>"
__version__ = "git id %s" % git_version.id
__date__    = git_version.date

# set up timer
start = int(time.time()*10**6)
elapsed_time = lambda: int(time.time()*10**6-start)

# =============================================================================
# Parse command line
# =============================================================================

def parse_command_line():

  usage = """usage: %prog [options] 
  
coh_PTF_trig_cluster is designed to cluster a MultiInspiralTable in time

--trig-file
--time-window
"""

  parser = OptionParser(usage, version=__version__)

  parser.add_option("-o", "--output-dir", action="store", type="string",\
                     default=os.getcwd(), help="output directory, "+\
                                               "default: %default")

  parser.add_option("-v", "--verbose", action="store_true", default=False,\
                     help="verbose output, default: %default")

  parser.add_option("-t", "--trig-file", action="store", type="string",\
                     default=None, help="The location of the trigger file")

  parser.add_option("-W", "--time-window", action="store", type="float",\
                     default=0, help="The cluster time window")

  (opts,args) = parser.parse_args()

  if not opts.trig_file:
    parser.error("must provide trig file")

  if not opts.time_window > 0:
    parser.error("time window must be given and greater than 0")

  return opts, args

# =============================================================================
# Main function
# =============================================================================

def main(trigFile, timeWindow, outdir, verbose=False):

  # read file
  cacheFile = lal.CacheEntry.from_T050017(trigFile)
  start,end = map(int,cacheFile.segment)

  # get triggers
  if verbose:
    sys.stdout.write("Loading triggers...\n")

  oldxml   = utils.load_filename(trigFile,\
                                 gz = trigFile.endswith("gz"), contenthandler = lsctables.use_in(ligolw.LIGOLWContentHandler))
  tmp      = table.get_table(oldxml, "multi_inspiral")
  if len(tmp) >= 1:
    lsctables.MultiInspiralTable.loadcolumns =\
        [slot for slot in tmp[0].__slots__ if hasattr(tmp[0], slot)]
  else:
    lsctables.MultiInspiralTable.loadcolumns =\
        lsctables.MultiInspiralTable.validcolumns.keys()
  currTrigs = lsctables.New(lsctables.MultiInspiralTable,\
                            columns=lsctables.MultiInspiralTable.loadcolumns)
  currTrigs.extend(tmp)
  clstTrigs = lsctables.New(lsctables.MultiInspiralTable,\
                            columns=lsctables.MultiInspiralTable.loadcolumns)

  # Allow for some backwards compatibility, these tables should be present in
  # all output files going forward
  try:
    timeSlideList = table.get_table(oldxml, "time_slide")
  except:
    timeSlideList = None

  try:
    segmentTable = table.get_table(oldxml, "segment")
    timeSlideMapTable = table.get_table(oldxml, "time_slide_segment_map")
  except:
    raise

  if verbose:
    sys.stdout.write("%d total unclustered triggers read at %d.\n"\
                     %(len(currTrigs),elapsed_time()))

  # How many time slides are there?
  if timeSlideList:
    maxID = 0
    for slides in timeSlideList:
      if int(slides.time_slide_id) > maxID:
        maxID = int(slides.time_slide_id)
    numSlides = maxID + 1

    if verbose:
      sys.stdout.write("Identified %d distinct time slides.\n" %(numSlides))
  else:
    numSlides = 1

  slideIDList = numpy.array([int(trig.time_slide_id) for trig in currTrigs])
  trigSlideDict = {}

  # Sort the triggers into time slides
  for i,trig in enumerate(currTrigs):
    currID = slideIDList[i]
    if currID not in trigSlideDict.keys():
      trigSlideDict[currID] = lsctables.New(lsctables.MultiInspiralTable,\
                            columns=lsctables.MultiInspiralTable.loadcolumns)
    trigSlideDict[currID].append(trig)

  if verbose:
    sys.stdout.write("Separated triggers by slide ID at %d." %(elapsed_time()))

    

  for slideID in range(numSlides):
    # Create list of slides matching slideID
    if not trigSlideDict.has_key(slideID):
      continue
    slideTrigs = trigSlideDict[slideID]

    if numSlides == 1 and (len(slideTrigs) != len(currTrigs)):
      raise BrokenError

    if verbose:
      sys.stdout.write("%d triggers identified in slide %d at %d.\n"\
                       % (len(slideTrigs),slideID,elapsed_time()))
      sys.stdout.write("Binning triggers in time...\n")

    # bin all triggers in time
    numBins  = int((end-start)//timeWindow + 1) 
    timeBins = []
    loudestTrigSNR = []
    loudestTrigTime = []

    for n in range(numBins):
      timeBins.append([])
      loudestTrigSNR.append(None)
      loudestTrigTime.append(None)

    for trig in slideTrigs:
      t = trig.get_end()
      t = trig.end_time + trig.end_time_ns * 1E-9
      bin = int(float(t-start)//timeWindow)
      timeBins[bin].append(trig)
      if not loudestTrigSNR[bin]:
        loudestTrigSNR[bin] = trig.snr
        loudestTrigTime[bin] = trig.end_time + trig.end_time_ns * 1E-9
      else:
        if loudestTrigSNR[bin] < trig.snr:
          loudestTrigSNR[bin] = trig.snr
          loudestTrigTime[bin] = trig.end_time + trig.end_time_ns * 1E-9

    if verbose:
      sys.stdout.write("%d bins generated at %d\nClustering triggers...\n"\
                       % (numBins, elapsed_time())) 

    # loop over all bins
    numAdded = 0
    for i,bin in enumerate(timeBins):

      if len(bin)<1:  continue

      first = False
      last = False
      p = i-1
      n = i+1

      if i==0:
        first = True
      elif i==numBins-1:
        last = True
 
      # loop all triggers in bin
      for trig in bin:
        # search this trigger's own bin
        if trig.snr < loudestTrigSNR[i]:
          continue

        t = trig.end_time + trig.end_time_ns * 1E-9
  
        # if trigger was loudest in its bin, check loudest event in previous bin
        if not first:
          if loudestTrigTime[p]:
            if (loudestTrigTime[p]-t).__abs__() < timeWindow:
              if trig.snr < loudestTrigSNR[p]:
                continue

        # Same for the next bin
        if not last:
          if loudestTrigTime[n]:
            if (loudestTrigTime[n]-t).__abs__() < timeWindow:
              if trig.snr < loudestTrigSNR[n]:
                continue
 
        loudest=True
        # if trigger was loudest in it's bin, search previous bin
        if loudest and loudestTrigTime[p] and not first:
          if not (loudestTrigTime[p]-t).__abs__() < timeWindow:
            for trig2 in timeBins[p]:
              t2 = trig2.end_time + trig2.end_time_ns * 1E-9
              if (t2-t).__abs__() < timeWindow:
                if trig.snr < trig2.snr:
                  loudest = False
                  break

        # if still loudest, check the next bin
        if loudest and loudestTrigTime[n] and not last:
          if not (loudestTrigTime[n]-t).__abs__() < timeWindow:
            for trig2 in timeBins[n]:
              t2 = trig2.end_time + trig2.end_time_ns * 1E-9
              if (t2-t).__abs__() < timeWindow:
                if trig.snr < trig2.snr:
                  loudest = False
                  break

        # if this was the loudest trigger in its vicinity, keep it and move to
        # the next bin
        if loudest:
          clstTrigs.append(trig)
          numAdded = numAdded + 1
          break

    if verbose:
      sys.stdout.write("%d Triggers added from slide %d at %d.\n\n"\
                       % (numAdded, slideID, elapsed_time()))

  #
  # write clustered xml file
  # 

  if verbose:
    sys.stdout.write("%d triggers remaining at %d.\n"\
                     "\nWriting triggers to new xml file...\n"\
                     % (len(clstTrigs), elapsed_time()))

  # prepare xmldocument 
  xmldoc = ligolw.Document()
  xmldoc.appendChild(ligolw.LIGO_LW())

  # append process params table
  xmldoc = append_process_params(xmldoc, sys.argv, __version__, __date__)

  # get search summary table from old file
  oldSearchSummTable = table.get_table(oldxml, "search_summary")
  xmldoc.childNodes[-1].appendChild(oldSearchSummTable)
  xmldoc.childNodes[-1].appendChild(clstTrigs)
  if timeSlideList:
    xmldoc.childNodes[-1].appendChild(timeSlideList)
  if segmentTable:
    xmldoc.childNodes[-1].appendChild(segmentTable)
  if timeSlideMapTable:
    xmldoc.childNodes[-1].appendChild(timeSlideMapTable)

  # generate filename and write
  outdesc = '%s_%s' % (cacheFile.description, 'CLUSTERED')
  xmlFile = cacheFile.path.replace(cacheFile.description, outdesc)

  utils.write_filename(xmldoc, xmlFile, gz = xmlFile.endswith("gz"))
  if verbose:
    sys.stdout.write("%s written at %d.\n" % (xmlFile, elapsed_time()))

if __name__=='__main__':

  opts, args = parse_command_line()

  outdir     = os.path.abspath(opts.output_dir)
  verbose    = opts.verbose
  trigFile   = os.path.abspath(opts.trig_file)
  timeWindow = opts.time_window

  main(trigFile, timeWindow, outdir, verbose=verbose)
  if verbose: sys.stdout.write("Done at %d.\n" % (elapsed_time()))
