#!/usr/bin/python

import os, sys, exceptions
from optparse import *

import matplotlib
matplotlib.use('Agg')
from pylab import *

from pylal import SnglInspiralUtils
from pylal import SimInspiralUtils
from pylal import CoincInspiralUtils
from glue import lal
from glue import segments
from glue.ligolw import table
from glue.ligolw import lsctables
from glue.ligolw import utils

from pylal import webUtils      
from glue import markup
from glue.markup import oneliner as extra


##############################################################################
deltaTime = 20.0
deltaTimeInspiral = 20.0
ifoList = {'H1':'r','H2':'b','L1':'g','V1':'m','G1':'c'}
verbose = False

##############################################################################
def print_inj( inj, number, injID, opts ):


  if opts.ext_trig:
    print "\nInjection details for injection %d with injID %s: " % (number, injID)
  else:
    print "\nInjection details for injection %d:" % (number)    
  print "m1: %.2f  m2:%.2f  | end_time: %d.%d | "\
        "distance: %.2f  eff_dist_h: %.2f eff_dist_l: %.2f" % \
        ( inj.mass1, inj.mass2, inj.geocent_end_time, inj.geocent_end_time_ns,\
          inj.distance, inj.eff_dist_h, inj.eff_dist_l )



##############################################################################
def findInjections( missedInjections, injections ):
  """
  Find the injection and the file-number that corresponds to this particular
  missed injection
  """

  vectorID = []
  vectorInj = []

  for missedInj in missedInjections:

    for injID, groupInj in injections.iteritems():
      for inj in groupInj:
        if missedInj.geocent_end_time==inj.geocent_end_time and \
               missedInj.geocent_end_time_ns==inj.geocent_end_time_ns:
          vectorID.append( injID )
          vectorInj.append( inj )

  return vectorID, vectorInj

##############################################################################
def getTimeTrigger( trig ):

  return float(trig.end_time) + float(trig.end_time_ns) * 1.0e-9

##############################################################################
def getTimeSim( sim, ifo=None ):
  time=0
  nano=0

  if not ifo:
    time = sim.geocent_end_time
    nano = sim.geocent_end_time_ns
  if ifo:
    time = sim.get(ifo[0].lower()+'_end_time' )
    nano = sim.get(ifo[0].lower()+'_end_time_ns' )    

  return  float(time) + float(nano) * 1.0e-9

##############################################################################
def investigateInspiral( triggerFiles, inj, injTime, label, number ):
  """
  Investigate inspiral triggers
  """

  # read the inspiral file(s)
  if verbose: print "Processing INSPIRAL triggers from files ", triggerFiles
  snglTriggers = SnglInspiralUtils.ReadSnglInspiralFromFiles( triggerFiles )

  # selection segment
  timeInjection = getTimeSim( inj )
  segSmall =  segments.segment( timeInjection-injTime, timeInjection+injTime )
  segLarge =  segments.segment( timeInjection-deltaTimeInspiral, timeInjection+deltaTimeInspiral )

  foundAny = False
  fig=figure()
  for ifo in ifoList.keys():

    # get the singles 
    snglInspiral = snglTriggers.ifocut(ifo)

    # select a range of triggers
    selectedLarge = snglInspiral.vetoed( segLarge )
    timeLarge = [ getTimeTrigger( sel )-timeInjection for sel in selectedLarge ]
    
    selectedSmall = snglInspiral.vetoed( segSmall )
    timeSmall = [ getTimeTrigger( sel )-timeInjection for sel in selectedSmall ]
    
    if len(timeLarge)==0:
      continue

    if len(timeSmall)>0:
      foundAny=True


    # plot the triggers
    plot( timeLarge, selectedLarge.get_column('snr'), ifoList[ifo]+'o')
    plot( timeSmall, selectedSmall.get_column('snr'), ifoList[ifo]+'s', \
          label=ifo)
    


  # draw the injection times and other stuff
  ylims=axes().get_ylim()
  plot( [0,0], ylims, 'g--')
  plot( [-injTime, -injTime], ylims, 'c:')
  plot( [+injTime, +injTime], ylims, 'c:')

  # save the plot
  grid(True)
  legend()
  axis([-deltaTimeInspiral, +deltaTimeInspiral, ylims[0], ylims[1]])
  xlabel('time [s]')
  ylabel('SNR')
  savefig('Pictures/plot'+label.upper()+'_'+str(number)+'.png')

  close(fig)
  
  return foundAny

##############################################################################
def investigateThinca( triggerFiles, inj, injTime, label, number ):
  """
  Investigate inspiral triggers
  """

  # read and construct the coincident events
  if verbose: print "Processing THINCA triggers from files ", triggerFiles
  snglInspiral = SnglInspiralUtils.ReadSnglInspiralFromFiles( triggerFiles, mangle_event_id=True )
  coincInspiral = CoincInspiralUtils.coincInspiralTable( snglInspiral, \
                                                         CoincInspiralUtils.coincStatistic("snr") )
          
  # selection segment
  timeInjection = getTimeSim( inj )
  segSmall =  segments.segment( timeInjection-injTime, timeInjection+injTime )
  segLarge =  segments.segment( timeInjection-deltaTime, timeInjection+deltaTime )

  plotAny = False
  foundAny= False
  fig=figure()
  selected = dict()  
  for ifo in ifoList.keys():

    # get the singles 
    snglInspiral = coincInspiral.getsngls(ifo)

    # select a range of triggers
    selectedLarge = snglInspiral.vetoed( segLarge )
    timeLarge = [ getTimeTrigger( sel )-timeInjection for sel in selectedLarge ]
    
    selectedSmall = snglInspiral.vetoed( segSmall )
    timeSmall = [ getTimeTrigger( sel )-timeInjection for sel in selectedSmall ]
    
    if len(timeLarge)==0:
      continue

    if len(timeSmall)>0:
      foundAny=True

    # plot the triggers
    plotAny = True
    plot( timeLarge, selectedLarge.get_column('snr'), ifoList[ifo]+'o')
    plot( timeSmall, selectedSmall.get_column('snr'), ifoList[ifo]+'s', \
          label=ifo)
    

  # draw the injection times and other stuff
  if plotAny:
    ylims=axes().get_ylim()
  else:
    ylims=[0,1]
  plot( [0,0], ylims, 'g--')
  plot( [-injTime, -injTime], ylims, 'c:')
  plot( [+injTime, +injTime], ylims, 'c:')

  # save the plot
  grid(True)
  if plotAny:
    legend()
  axis([-deltaTime, +deltaTime, ylims[0], ylims[1]])
  xlabel('time [s]')
  ylabel('SNR')
  savefig('Pictures/plot'+label.upper()+'_'+str(number)+'.png')
  close(fig)

  return foundAny

##############################################################################
def investigateTrigbank( triggerFiles, inj, injMass, label, number ):
  """
  Investigate inspiral triggers
  """
  
  # read the trigbank
  if verbose: print "Processing TRIGBANK triggers from files ", triggerFiles
  snglTriggers = SnglInspiralUtils.ReadSnglInspiralFromFiles( triggerFiles )
          

  fig=figure()
  selected = dict()  
  for ifo in ifoList.keys():

    # get the singles 
    snglInspiral = snglTriggers.ifocut(ifo)

    # plot the templatebank
    plot( snglInspiral.get_column('mass1'), snglInspiral.get_column('mass2'), ifoList[ifo]+'x', \
          label=ifo)

  # plot the missed trigger and save the plot
  xm = injMass[0]
  ym = injMass[1]
  plot( [xm], [ym],  'ko', ms=10.0, mfc='None', mec='k', lw=3)
  grid(True)
  legend()
  xlabel('mass1')
  ylabel('mass2')
  savefig('Pictures/plot'+label.upper()+'_'+str(number)+'.png')
  
  axis([ max(xm-5,0), xm+5.0, max(ym-5.0, 0.0), ym+5])
  savefig('Pictures/plot'+label.upper()+'-zoom_'+str(number)+'.png')

  close(fig)


def fillTable( page, contents ):

  page.add('<tr>')
  for content in contents:
    page.add('<td>')
    page.add( str(content) )
    page.add('</td>')
  page.add('</tr>')
    

##############################################################################
usage = """The code will examine why injections were missed.  
Provide the output and intermediate cache from an injection run of 
lalapps_inspiral_hipe.  The code will then look to see if any of the 
injections which are missed at the end of the pipeline were found at
the inspiral, coire or inspiral veto stages and print out the relevant
parameters for the triggers.
"""

parser = OptionParser( usage )

parser.add_option("-c", "--cache-file", action = "store", \
    type = "string", default = None, metavar = " CACHE", \
    help = "cache of data from the inspiral hipe injections")
parser.add_option("-a", "--min-eff-dist-l", action = "store", \
    type = "float", default = None, metavar = " MINEFFDISTL", \
    help = "cut on this minimum effective distance for Livingston")
parser.add_option("-b", "--min-eff-dist-h", action = "store", \
    type = "float", default = None, metavar = " MINEFFDISTL", \
    help = "cut on this minimum effective distance for Hanford")
parser.add_option("-o", "--output-path", action = "store", \
    type = "string", default = './', metavar = " OUTPUT", \
    help = "output path for putting plots etc. ")
parser.add_option("-e", "--enable-output", action="store_true", default=False,
       help="enable output of html files")
parser.add_option("-x", "--ext-trig", action="store_true", default=False,
       help="sets exttrigger flag")
parser.add_option("-v", "--verbose", action="store_true", default=False,
       help="sets verbosity")


(opts,args) = parser.parse_args()
verbose = opts.verbose

if not opts.cache_file:
  print >>sys.stderr, "Must specify input cache file!"
  sys.exit(1)

##############################################################################

###########################################################
# get the intermediate found injections
triggerCache = dict()
interCache = lal.Cache(map(lal.CacheEntry, open( opts.cache_file ) ))
triggerCache['firstInspiral'] = lal.Cache([c for c in interCache if "INSPIRAL_FIRST" in c.description])
triggerCache['secondInspiral'] = lal.Cache([c for c in interCache if "INSPIRAL_SECOND" in c.description])
triggerCache['firstThinca'] = lal.Cache([c for c in interCache if "THINCA_FIRST" in c.description])
triggerCache['secondThinca'] = lal.Cache([c for c in interCache if "THINCA_SECOND" in c.description])
triggerCache['Trigbank'] = lal.Cache([c for c in interCache if "TRIGBANK" in c.description])
injectionCache = lal.Cache([c for c in interCache if "INJECTION" in c.description])

coireVetoMissedCache =  lal.Cache([c for c in interCache if 
  ("COIRE_INJECTION" in c.description) and ("MISSED" in c.description) ])

orderDict = {'firstInspiral':0,'firstThinca':1,'secondInspiral':2, 'secondThinca':3  }
typeDict = {0:'firstInspiral',1:'firstThinca',2:'secondInspiral', 3:'secondThinca'  }

# need a dictionary if injection files for the ext-trig analysis
# because they are splitted
if opts.ext_trig:
  injections=dict()
  for file in injectionCache.pfnlist():
    basename = os.path.basename(file)
    injectionID =  basename.split('_')[1].split('-')[0]
    injections[injectionID] = SimInspiralUtils.ReadSimInspiralFromFiles( [file] )

if verbose:
  print "parsing of cache file done..."

###########################################################
# retrieve the time window used from on of the COIRE files
coireFile = coireVetoMissedCache.pfnlist()[0]
doc = utils.load_filename( coireFile, gz=(coireFile).endswith(".gz") )
try:
  processParams = table.get_table(doc, lsctables.ProcessParamsTable.tableName)
except:
  print >>sys.stderr, "Error while reading process_params table from file ", coireFils
  print >>sys.stderr, "Probably files does not exist"
  sys.exit(1)

for table in processParams:
  if table.param=='--injection-window':
    injectionWindow = float(table.value)/1000.0

if verbose:
  print "Injection-window set to %.0f ms" % (1000*injectionWindow)

###########################################################
## start the main loop
number = 0
for missed in coireVetoMissedCache:

  ## read the file containing the missed injections
  try: missedInj = SimInspiralUtils.ReadSimInspiralFromFiles( [missed.path] )
  except:
    print >>sys.stderr, "Error while reading sim_inspiral table from file ", missed.path
    print >>sys.stderr, "Probably files does not exist"
    sys.exit(1)

  # get the number of missed injections in this object
  if missedInj: numInj = len(missedInj)
  else: numInj=0
  if verbose:
    print "Found %d missed injections in file %s" % ( numInj, missed.path)

  # see if we have any missed injection in this round
  if numInj>0:

    # get the injections thar were missed
    if opts.ext_trig:
      injIDs, injList = findInjections( missedInj , injections )
    else:
      injList = missedInj
      injIDs = ones(len(injList))

    # loop over the single missed injections
    for injID, inj in zip( injIDs, injList):

      # select the range in effective distance
      if opts.min_eff_dist_l and inj.eff_dist_l>opts.min_eff_dist_l:
        continue
      if opts.min_eff_dist_h and inj.eff_dist_h>opts.min_eff_dist_h:
        continue

      # increase number:
      number+=1
      
      ## create the web-page
      page = markup.page()
      page.h1("Followup missed injection #"+str(number))
      page.hr()

      # create a table
      page.add('<table border="2" >')
      fillTable( page, ['<b>parameter','<b>value'] )
      fillTable( page, ['Number',str(number)] )
      fillTable( page, ['inj ID', injID] )
      fillTable( page, ['mass1', inj.mass1] )
      fillTable( page, ['mass2', inj.mass2] )
      fillTable( page, ['end_time', inj.geocent_end_time_ns] )
      fillTable( page, ['distance', inj.distance] )
      fillTable( page, ['eff_dist_h', inj.eff_dist_h] )
      fillTable( page, ['eff_dist_l', inj.eff_dist_l] )
      page.add('</table>')
        
      print_inj( inj, number, injID, opts )

      
      # now we can retrieve all the other files for this particular
      # missed injection
      foundList = zeros(4)
      for type, cache in triggerCache.iteritems():

        if opts.ext_trig:
          trigCache = lal.Cache(
            [c for c in cache if ('_'+str(injID)+'-' in os.path.basename(c.url)\
                                  and inj.geocent_end_time in c.segment)])
        else:
          trigCache = lal.Cache(
            [c for c in cache if inj.geocent_end_time in c.segment])

        # check if the pfnlist is empty. 
        if trigCache.pfnlist()==[]:
          print >>sys.stderr, "### WARNING ### No files found for %s in the cache for ID %s and time %d " % \
                ( type, injID, inj.geocent_end_time)
          continue
       
        # now create the several plots
        if 'Inspiral' in type:
          found = investigateInspiral( trigCache.pfnlist(), inj, injectionWindow, type, number )
          foundList[orderDict[type]]=found
          
        elif 'Thinca' in type:
            
          found = investigateThinca( trigCache.pfnlist(), inj, injectionWindow , type, number )
          foundList[orderDict[type]]=found
          
        elif 'Trigbank' in type:

          investigateTrigbank( trigCache.pfnlist(), inj, [inj.mass1, inj.mass2 ], type, number )
          
        else:
          print >>sys.stderr, "Error: Unknown type ", type
          sys.exit(1)


      # print out the result for this particular injection
      page.add('<table border="2" >')
      fillTable( page, ['<b>step','<b>F/M'] )
      for index in range(4):
        text="Injection %s in step %s was " % \
             (number, typeDict[index])
        if foundList[index]:
          fillTable( page, [typeDict[index],  'FOUND'])          
          text+="found"
        else:
          fillTable( page, [typeDict[index],  '<font color="red">MISSED'])          
          text+="MISSED"
        print text
      page.add('</table>')            

      ##################################
      ## add the pictures to the webpage
      for plotName in ['FIRSTINSPIRAL','FIRSTTHINCA','SECONDINSPIRAL','SECONDTHINCA','TRIGBANK','TRIGBANK-zoom']:
        fname = opts.output_path+'/plot'+plotName+'_'+str(number)+'.png'
        page.a(extra.img(src=[fname], width=400, \
                         alt=fname, border="2"), title=fname, href=[ fname])

      # and write the html file
      if opts.enable_output:
        file = open('missed'+str(number)+'.html','w')      
        file.write(page(False))
        file.close()         




