#!/usr/bin/env python

# Copyright (C) 2011 Ian W. Harry
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

# =============================================================================
# Preamble
# =============================================================================

from __future__ import division

import os,sys,glob,re
from optparse import OptionParser
from pylal import SimInspiralUtils,MultiInspiralUtils,git_version
from glue.ligolw import ilwd,ligolw,table,lsctables,utils
from glue import segments,segmentsUtils,lal
from pylal.dq import dqSegmentUtils
from pylal.coh_PTF_pyutils import append_process_params,remove_bad_injections,identify_bad_injections,sim_inspiral_get_theta

__author__  = "Ian Harry <ian.harry@astro.cf.ac.uk>"
__version__ = "git id %s" % git_version.id
__date__    = git_version.date

# =============================================================================
# Parse command line
# =============================================================================

def parse_command_line():

  usage = """usage: %prog [options] 
  
coh_PTF_injfinder will find and record found and missed injections for the given injection run

--inj-cache
--cache
--time-window
"""

  parser = OptionParser( usage, version=__version__ )

  parser.add_option( "-i", "--inj-cache", action="store", type="string",\
                     default=None,\
                     help="read injection files from this cache." )
  parser.add_option( "-o", "--output-dir", action="store", type="string",\
                     default=os.getcwd(), help="output directory, "+\
                                               "default: %default" )
  parser.add_option( "-s", "--inj-string", action="store", type="string",\
                     default=None, help="Injection type to parse, e.g. NSNS." )
  parser.add_option( "-I", "--max-inclination", action="store", type="float",\
                     default=0, help="Create an injection set with injections"+\
                     " uniformly distributed out to this inc !!in degrees!!")
  parser.add_option( "-v", "--verbose", action="store_true", default=False,\
                     help="verbose output, default: %default" )

  (opts,args) = parser.parse_args()

  if not opts.inj_cache:
    parser.error( "--inj-cache must be given" )

  if not opts.max_inclination:
    parser.error( "Maximum inclination must be given." )

  if not opts.inj_string:
    parser.error( "Injection string must be given." )
  return opts, args

# =============================================================================
# Main function
# =============================================================================

def main( injCacheFile, outdir, maxIncl , injString , verbose=False ):

  # get maxIncl in radians
  maxInclRad = maxIncl*0.0174532925

  lsctables.SimInspiral.get_theta = sim_inspiral_get_theta

  if verbose:
    print >>sys.stdout
    print >>sys.stdout, 'Loading cache file...'

  injcache  = lal.Cache.fromfile(open( injCacheFile, 'r' ) )
  sieved = injcache.sieve(description=injString, exact_match=False)
  foundcache = sieved.sieve(description='FOUND', exact_match=False)
  missedcache = sieved.sieve(description='MISSED', exact_match=False)
  ifos = lsctables.instrument_set_from_ifos(foundcache[0].observatory)

  if len(foundcache)<1 or len(missedcache)<1:
    raise Error, 'No injection files found.'

  if len(foundcache) != len(missedcache):
    raise Error, "Length of found and missed not equal!"

  # set columns
  ifoAtt = { 'G1':'g', 'H1':'h1', 'H2':'h2', 'L1':'l', 'V1':'v', 'T1':'t' }

  # set up columns
  cols = ['end_time', 'end_time_ns', 'ifos', 'process_id', 'ra', 'dec', 'snr',\
          'null_statistic', 'null_stat_degen','mass1','mass2','eta','mchirp']
  # add snr
  cols.extend(['snr_%s' % ifoAtt[i] for i in ifos])\
  # add chisq
  cols.extend(['chisq', 'chisq_dof'])
  cols.extend(['chisq_%s' % ifoAtt[i] for i in ifos])
  # add bank chisq
  cols.extend(['bank_chisq', 'bank_chisq_dof'])
  #cols.extend(['bank_chisq_%s' % ifoAtt[i] for i in ifos])
  # add auto chisq  
  cols.extend(['cont_chisq', 'cont_chisq_dof'])
  #cols.extend(['cont_chisq_%s' % ifoAtt[i] for i in ifos])
  # add sigmasq
  cols.extend(['sigmasq_%s' % ifoAtt[i] for i in ifos])
  # add amplitude terms
  cols.extend(['amp_term_%d' % i for i in xrange(1,11)])
  # set columns
  cols = [c for c in cols if c\
          in lsctables.MultiInspiralTable.validcolumns.keys()]
  lsctables.MultiInspiralTable.loadcolumns = cols

  # SimInspiralTable
  cols = ['geocent_end_time', 'geocent_end_time_ns', 'mass1', 'mass2',\
          'mchirp', 'f_lower', 'inclination', 'spin1x', 'spin1y', 'spin1z',\
          'spin2x', 'spin2y', 'spin2z', 'longitude', 'latitude', 'distance',\
          'eff_dist_h','eff_dist_l','eff_dist_v']
  cols = [c for c in cols if c\
          in lsctables.SimInspiralTable.validcolumns.keys()]
  lsctables.SimInspiralTable.loadcolumns = cols



  cache_dict = {}

  for found in foundcache:
    angle = int(found.description.split('_')[2].replace('injectionsAstro' + injString,''))
    miss = None
    for missed in missedcache:
      angle2 = int(missed.description.split('_')[2].replace('injectionsAstro' + injString,''))
      if angle == angle2:
        miss = missed
    if not miss:
      raise Error, "Cannot match up found and missed files"
    cache_dict[angle] = [found,miss]

  foundInjs = lsctables.New(lsctables.SimInspiralTable,columns=lsctables.SimInspiralTable.loadcolumns)
  missedInjs = lsctables.New(lsctables.SimInspiralTable,columns=lsctables.SimInspiralTable.loadcolumns)
  foundTrigs = lsctables.New(lsctables.MultiInspiralTable,columns=lsctables.MultiInspiralTable.loadcolumns)
 
  for angle,[found,missed] in cache_dict.items():
    if angle >= maxIncl:
      currFInjs = SimInspiralUtils.ReadSimInspiralFromFiles([found.path])
      currMInjs = SimInspiralUtils.ReadSimInspiralFromFiles([missed.path])
      currFTrigs = MultiInspiralUtils.ReadMultiInspiralFromFiles([found.path])

      for inj in currMInjs:
        if inj.get_theta() < maxInclRad:
          missedInjs.append(inj)

      for inj,trig in zip(currFInjs,currFTrigs):
        if inj.get_theta() < maxInclRad:
          foundInjs.append(inj)
          foundTrigs.append(trig)


  # prepare xmldocument 
  xmldoc = ligolw.Document()
  xmldoc.appendChild(ligolw.LIGO_LW())

  # append process params table
  xmldoc = append_process_params( xmldoc, sys.argv, __version__, __date__ )

  # get search summary table from old file
  oldxml   = utils.load_filename( foundcache[0].path,\
                                  gz = foundcache[0].path.endswith("gz"), contenthandler = lsctables.use_in(ligolw.LIGOLWContentHandler))
  oldSearchSummTable = table.get_table( oldxml, "search_summary" )
  xmldoc.childNodes[-1].appendChild( oldSearchSummTable )

  # construct output filename
  ifoTag = foundcache[0].observatory
  start, end = foundcache.to_segmentlistdict().extent_all()

  userTag = injcache[0].description.rsplit( '_', 2 )[0]
  xmlFile = '%s/%s-%s_injectionsAstro%s_FILTERED_%d_FOUND-%d-%d.xml'\
            % ( outdir, ifoTag, userTag, injString, maxIncl, start, end-start )

  start,end = map( int, foundcache[0].segment )

  # write found injections
  xmldoc.childNodes[-1].appendChild(foundTrigs)
  xmldoc.childNodes[-1].appendChild(foundInjs)
  utils.write_filename( xmldoc, xmlFile, gz=xmlFile.endswith('gz') )
  if verbose:
    print >>sys.stdout, xmlFile
  xmldoc.childNodes[0].removeChild(foundTrigs)
  xmldoc.childNodes[0].removeChild(foundInjs)

  # write missed injections
  xmldoc.childNodes[0].appendChild(missedInjs)
  xmlFile = xmlFile.replace( 'FOUND', 'MISSED' )
  utils.write_filename( xmldoc, xmlFile )
  if verbose:
    print >>sys.stdout, xmlFile  

if __name__=='__main__':
  
  opts, args = parse_command_line()
  
  injcache   = os.path.abspath( opts.inj_cache )
  outdir     = os.path.abspath( opts.output_dir )
  maxIncl    = opts.max_inclination
  injString  = opts.inj_string
  verbose    = opts.verbose 

  main( injcache, outdir, maxIncl , injString, verbose=verbose )
