#!/usr/bin/env python

__prog__ = "plotsnrchisq_pipe"
__title__ = "SNR/CHISQ time series plots"

import matplotlib
matplotlib.use('Agg')
import copy
import math
import sys
import os
import socket, time
from optparse import *
import re, string
import exceptions
import glob
import tempfile
import ConfigParser
import urlparse
from types import *
from pylab import *
import numpy
# sys.path.append('/archive/home/channa/opt/pylal/lib64/python2.4/site-packages')
# sys.path.append('/archive/home/channa/opt/glue/lib64/python2.4/site-packages')
from pylal import viz
from pylal import Fr
from glue import segments
from glue import segmentsUtils
from glue.ligolw import ligolw
from glue.ligolw import table
from glue.ligolw import lsctables
from glue.ligolw import utils
from pylal import CoincInspiralUtils
from pylal import InspiralUtils
from pylal import git_version
#from pylal import webUtils
from glue import pipeline
from glue import lal


sys.path.append('@PYTHONLIBDIR@')

rc('text', usetex=False)

def whiteTmp(newtemplate, ASD_vector, ASD_freq):
   # FIXME Hardcoded to a 256 second segment length
   # Assumes sample rate is < = 4096 
   segLen = 256.0
   deltaT = segLen / len(newtemplate) 
   fsamp = 1.0 / deltaT
   deltaF = ASD_freq[1]-  ASD_freq[0]
   # "Whiten" the template
   fft_template = numpy.fft.fft(newtemplate)
   whiteffttemplate = []
   maxFreqIndex = int(len(ASD_freq) * (fsamp/2) / ( max(ASD_freq) + 1./segLen))- 1
   nb2 = len(fft_template)/2
   maxASD = max(ASD_vector)
   
   tmpFreq = []
   stix = 0
   stpix = 0
   for n in range(len(fft_template)):
     if n < maxFreqIndex:
        denom_fac = ASD_vector[n]
        f = ASD_freq[n]
     else:
        denom_fac = ASD_vector[maxFreqIndex - n]
        f = (0.0 - ASD_freq[maxFreqIndex - n])
     tmpFreq.append(f)
     if f >= 0 and f < 30 or f > 1024: denom_fac = maxASD
     if f < -1024 or f > -30 and f < 0: denom_fac = maxASD 
     if f > 40 and not stix: stix = n
     if f > 1024 and not stpix: stpix = n
     whiteffttemplate.append(fft_template[n] / denom_fac)

   if not stpix: stpix = nb2
   whiteffttemplate = numpy.array(whiteffttemplate)
   whitetemplate = numpy.fft.ifft(whiteffttemplate)

   for n in range(len(whitetemplate)):
     whitetemplate[n] = whitetemplate[n].real

   return fft_template, whiteffttemplate, tmpFreq, stix, stpix, whitetemplate

# FUNCTION TO RETURN THE TEMPLATE TIME SERIES
def getTemplate(template_tuple):
   template_vector = template_tuple[0]
   newtemplate = []
   for n in range(len(template_vector)):
     newtemplate.append(template_vector[n-len(template_vector)+20000])
   stix = 0
   stpix = 0
   cl = 200
   newtemplate = numpy.array(newtemplate)
   # don't consider parts of the template with amplitudes of 1/20th of the peak for the plot
   maxTV = max(abs(newtemplate))/20.00
   for n in range(len(newtemplate)):
      if n > cl and not stix and max(abs(newtemplate[n-cl:n])) > maxTV: stix = n
      if stix and not stpix and max(abs(newtemplate[n-cl:n])) < maxTV: stpix = n - cl
   return newtemplate, stix, stpix

def plotsnrchisq(opts,args,gpsTime,frameFile,outputPath,inspProcParams,user_tag=None):
 
    rsqThreshold = 0;

    for row in inspProcParams:
      if row.param == "--channel-name":
        chanStringBase = row.value
        ifoName = str(row.value).split(":",1)
      if row.param == "--segment-length":
        segLen = eval(row.value)
      if row.param == "--sample-rate":
        sampleRate = eval(row.value)
      if row.param == "--segment-overlap":
        segOverlap = eval(row.value)
      if (row.param == "--chisq-delta"):
        chisqDelta = eval(row.value)
      if row.param =="--chisq-bins":
        chisqBins = eval(row.value)
      if row.param == "--snr-threshold":
        snrThreshold = eval(row.value)
      if row.param == "--chisq-threshold":
        chisqThreshold = eval(row.value)
      if row.param == "--rsq-veto-threshold":
        rsqThreshold = eval(row.value)
      if row.param == "--trig-start-time":
        trigStart = eval(row.value)
      if row.param == "--gps-start-time":
        gpsStart = eval(row.value)
      if row.param == "--low-frequency-cutoff":
        flow = eval(row.value)
      if row.param == "--dynamic-range-exponent":
        dynRange = eval(row.value)

    if user_tag:
      ifo_string = user_tag
    else:
      ifo_string = ifoName[0]

    segLenSec = segLen / sampleRate
    segOverlapSec = segOverlap / sampleRate

    trigPosition = 0.0
    if (trigStart):
      trig_position = float(trigStart - gpsStart - segOverlapSec ) / float(segLenSec - segOverlapSec)
      #ceil usage: if trig_position = 3.0, then number must be 2.0
      #if trig_position = 3.1, then number must be 3.0
      trigPosition = math.ceil(trig_position) - 1.0 

    if (trigPosition < 0): trigPosition = 0.0

    gpsPosition = int((eval(gpsTime) - gpsStart - segOverlapSec/2. ) / (segLenSec -segOverlapSec))

    position = gpsPosition - int(trigPosition)
    chanNumber = str(position)
    chanNamePSD = chanStringBase + "_PSD"
 
    # now, read the data !!
    # The window width should be an input argument maybe ?
    duration = opts.plot_width
    
    chanNameSnr = chanStringBase + "_SNRSQ_" + chanNumber
    chanNameChisq = chanStringBase + "_CHISQ_" + chanNumber
    chanNameTemplate = chanStringBase + "_TEMPLATE"
    squareSnr_tuple = Fr.frgetvect(frameFile,chanNameSnr,-1,segLenSec,0)
    squareChisq_tuple = Fr.frgetvect(frameFile,chanNameChisq,-1,segLenSec,0)
    
    # Try to extract the template   
    try:
     template_tuple = Fr.frgetvect(frameFile,chanNameTemplate,-1,segLenSec,0)
    except:
      print "WARNING: couldn't extract TEMPLATE channel, continuing"
      template_tuple = None
    PSD_tuple = Fr.frgetvect(frameFile,chanNamePSD,-1,segLenSec*8,0)
    #print PSD_tuple
    snr_position = eval(gpsTime) - (gpsStart + gpsPosition* (segLenSec - segOverlapSec) )
    chisq_position = snr_position
        
    # compute the snr vector
    snr_vector = sqrt(squareSnr_tuple[0])
    # print squareSnr_tuple
    snr_time = array(range(0, segLen)) * squareSnr_tuple[3][0] - snr_position
#    print dynRange
    # compute PSD freq vector
#    if dynRange != 0:
#      print dynRange
#      print 'doing division'
#      ASD_vector = PSD_tuple[0] / (float(pow(2,int(dynRange))))
#    else: 
    ASD_vector = PSD_tuple[0]
    ASD_freq = array(range(0, len(ASD_vector))) * PSD_tuple[3][0] 
    #print len(ASD_vector)
    #print len(ASD_freq)
     
    # compute the r^2
    rsq_vector = squareChisq_tuple[0]
    
    chisq_time = array(range(0, segLen)) * squareChisq_tuple[3][0] - chisq_position
    # compute the normalized chisq
    if(chisqBins > 0):
        chisqNorm_vector = rsq_vector/(1 + chisqDelta/chisqBins*squareSnr_tuple[0])
    else:
        chisqNorm_vector = rsq_vector
        
    # define the time boundaries of the plot
    snr_start =  (snr_position - duration/2.) * 1/squareSnr_tuple[3][0]
    snr_stop = (snr_position + duration/2.) * 1/squareSnr_tuple[3][0]
    lengthSnr = int(snr_stop) - int(snr_start)
	
    chisq_start =  (chisq_position - duration/2.) * 1/squareChisq_tuple[3][0]
    chisq_stop = (chisq_position + duration/2.) * 1/squareChisq_tuple[3][0]	
    lengthChisq = int(chisq_stop) - int(chisq_start) 

    #prepare the thresholds to be plotted 
    
    snrThreshVect = lengthSnr * [snrThreshold]
    if(chisqThreshold < 100.):	
    	chisqThreshVect = lengthChisq * [chisqThreshold]
    if((rsqThreshold > 0) and (rsqThreshold < 100.)):
	rsqThreshVect = lengthChisq * [rsqThreshold]

    fnameList = []
    tagList = []
  
    # Now plot the snr time serie !!
    
    figure(1)
    plot(snr_time[int(snr_start):int(snr_stop)],snr_vector[int(snr_start):int(snr_stop)])    
    hold(1)
    plot(snr_time[int(snr_start):int(snr_stop)],snrThreshVect)
    hold(0)
    xlim(-duration/2., duration/2.)
    xlabel('time (s)',size='x-large')
    ylabel(r'snr',size='x-large')
    grid()
    title(ifoName[0] + ' trigger: ' + gpsTime)
    figName = InspiralUtils.set_figure_name(opts,'snr')
    InspiralUtils.savefig_pylal(figName)
    fnameList.append(figName)
    tagList.append("Plot of snr")

    figure(11)
    plot(snr_time[int(snr_start):int(snr_stop)],snr_vector[int(snr_start):int(snr_stop)])
    hold(1)
    plot(snr_time[int(snr_start):int(snr_stop)],snrThreshVect)
    hold(0)
    xlim(-duration/20., duration/20.)
    xlabel('time (s)',size='x-large')
    ylabel(r'snr',size='x-large')
    grid()
    title(ifoName[0] + ' trigger: ' + gpsTime + ' Zoom')
    figName = InspiralUtils.set_figure_name(opts,'snr_zoom')
    InspiralUtils.savefig_pylal(figName)
    fnameList.append(figName)
    tagList.append("Plot of snr (zoom)")

    ### END CHADS CHANGES ###

    
    # Now plot the r^2 time serie !!
    figure(2)
    plot(chisq_time[int(chisq_start):int(chisq_stop)],rsq_vector[int(chisq_start):int(chisq_stop)])
    if((rsqThreshold > 0) and (rsqThreshold < 100.)):    
	hold(1)
    	plot(chisq_time[int(chisq_start):int(chisq_stop)],rsqThreshVect)
    	hold(0)
    xlim(-duration/2., duration/2.)
    xlabel('time (s)',size='x-large')
    ylabel(r'chisq/p',size='x-large')
    grid()
    title(ifoName[0] + ' trigger: ' + gpsTime)
    figName = InspiralUtils.set_figure_name(opts,'rsq')
    InspiralUtils.savefig_pylal(figName)
    fnameList.append(figName)
    tagList.append("Plot of rsq")

    
    figure(22)
    plot(chisq_time[int(chisq_start):int(chisq_stop)],rsq_vector[int(chisq_start):int(chisq_stop)])
    if((rsqThreshold > 0) and (rsqThreshold < 100.)):
    	hold(1)
    	plot(chisq_time[int(chisq_start):int(chisq_stop)],rsqThreshVect)
        hold(0)
    xlim(-duration/20., duration/20.)
    xlabel('time (s)',size='x-large')
    ylabel(r'chisq/p',size='x-large')
    grid()
    title(ifoName[0] + ' trigger: ' + gpsTime + ' Zoom')
    figName = InspiralUtils.set_figure_name(opts,'rsq_zoom')
    InspiralUtils.savefig_pylal(figName)
    fnameList.append(figName)
    tagList.append("Plot of rsq (zoom)")

    # Now plot the normalized chisq time serie !!
    figure(3)
    plot(chisq_time[int(chisq_start):int(chisq_stop)],chisqNorm_vector[int(chisq_start):int(chisq_stop)])
    if chisqThreshold < 100.:
    	hold(1)
    	plot(chisq_time[int(chisq_start):int(chisq_stop)],chisqThreshVect)
        hold(0)
    xlim(-duration/2., duration/2.)
    xlabel('time (s)',size='x-large')
    ylabel(r'chisq / (p + Delta * snrsq)',size='x-large')
    grid(1)
    title(ifoName[0] + ' trigger: ' + gpsTime)
    figName = InspiralUtils.set_figure_name(opts,'chisq')
    InspiralUtils.savefig_pylal(figName)
    fnameList.append(figName)
    tagList.append("Plot of chisq")
 
    figure(33)
    plot(chisq_time[int(chisq_start):int(chisq_stop)],chisqNorm_vector[int(chisq_start):int(chisq_stop)])
    if chisqThreshold < 100.:
    	hold(1)
    	plot(chisq_time[int(chisq_start):int(chisq_stop)],chisqThreshVect)
        hold(0)
    xlim(-duration/20., duration/20.)
    xlabel('time (s)',size='x-large')
    ylabel(r'chisq / (p + Delta * snrsq)',size='x-large')
    grid(1)
    title(ifoName[0] + ' trigger: ' + gpsTime + ' Zoom')
    figName = InspiralUtils.set_figure_name(opts,'chisq_zoom')
    InspiralUtils.savefig_pylal(figName)
    fnameList.append(figName)
    tagList.append("Plot of chisq (zoom)")

    # plot the PSD
    figure(4)
    loglog(ASD_freq,ASD_vector)
    xlabel('freq (Hz)', size='x-large')
    ylabel(r'PSD',size='x-large')
    xlim(flow, ASD_freq[-1])
    title(ifoName[0] + ' trigger: ' + gpsTime)
    figName = InspiralUtils.set_figure_name(opts,'PSD')
    InspiralUtils.savefig_pylal(figName)
    fnameList.append(figName)
    tagList.append("Plot of PSD")


    # find the relevant portion of the template if available
    if template_tuple:
      #extract the relevant template piece
      newtemplate, tstart, tend = getTemplate(template_tuple)

      figure(5)
      plot(newtemplate[tstart:tend])
      xlim(0, len(newtemplate[tstart:tend]))
      xlabel('time (samples)',size='x-large')
      ylabel(r'amplitude',size='x-large')
      grid()
      title(ifoName[0] + ' Template: ' + gpsTime)
      figName = InspiralUtils.set_figure_name(opts,'template')
      InspiralUtils.savefig_pylal(figName)
      fnameList.append(figName)
      tagList.append("Plot of template waveform")


      # "Whiten" the template
      fft_template, whiteffttemplate, tmpFreq, fstart, fend, whitetemplate = whiteTmp(newtemplate, ASD_vector, ASD_freq)
      
      # plot the "whitened" template
      figure(6)
      plot(whitetemplate[tstart:tend])
      xlim(0, len(whitetemplate[tstart:tend]))
      xlabel('time (samples)',size='x-large')
      ylabel(r'amplitude',size='x-large')
      grid()
      title(ifoName[0] + ' Template: ' + gpsTime)
      figName = InspiralUtils.set_figure_name(opts,'white_template')
      InspiralUtils.savefig_pylal(figName)
      fnameList.append(figName)
      tagList.append("Plot of the whitened template waveform")

      figure(7)
      wfft = whiteffttemplate[fstart:fend]
      loglog(tmpFreq[fstart:fend], abs(wfft)/max(wfft),'g')
      hold(1)
      loglog(tmpFreq, abs(fft_template)/max(abs(fft_template)),'r', linewidth=2)
      loglog(ASD_freq,ASD_vector/max(ASD_vector))
      astart = int(40  / 1 / (ASD_freq[1] - ASD_freq[0]) )
      astop = int(1024 / 1 / (ASD_freq[1] - ASD_freq[0]) )
      hold(0)
      xlim([40, 1024])
      ymin = min(ASD_vector[astart:astop]/max(ASD_vector))
      ylim([ymin, 100])
      xlabel('Freq (Hz)',size='x-large')
      ylabel(r'amplitude (normalized)',size='x-large')
      grid()
      legend(['ASD','Template','ASD weighted Template'],'upper right')
      title(ifoName[0] + ' Template and ASD: ' + gpsTime)
      figName = InspiralUtils.set_figure_name(opts,'fft_of_template_and_asd')
      InspiralUtils.savefig_pylal(figName)
      fnameList.append(figName)
      tagList.append("Plot of the template and ASD")
      # END if template_tuple
    

    if opts.enable_output is True:
      html_filename = InspiralUtils.write_html_output(opts, args, fnameList, \
        tagList)
      InspiralUtils.write_cache_output(opts, html_filename, fnameList)

##############################################################################
#
#  MAIN PROGRAM
#
##############################################################################
usage = """ %prog [options]
"""

parser = OptionParser(usage, version=git_version.verbose_msg)

parser.add_option("-f","--frame-file",action="store",type="string",\
    metavar=" FILE",help="use frame files list FILE")

parser.add_option("-t","--gps",action="store",type="string",\
    metavar=" GPS",help="use gps time GPS")

parser.add_option("","--plot-width",action="store",type="float",\
    default=2.0, metavar= "FLOAT",help="specify time width of the plots")

parser.add_option("-o","--output-path",action="store",type="string",\
    default="", metavar=" PATH",\
    help="path where the figures would be stored")

parser.add_option("-O","--enable-output",action="store_true",\
    default="false",  metavar="OUTPUT",\
    help="enable the generation of the html and cache documents")

parser.add_option("-x","--inspiral-xml-file", action="store",type="string", \
    metavar=" XML",help="use inspiral-file")

parser.add_option("-T","--user-tag", action="store",type="string", \
    default=None, metavar=" USERTAG",help="user tag for the output file name")

parser.add_option("","--ifo-times",action="store",\
    type="string", default=None, metavar=" IFOTIMES",\
    help="provide ifo times for naming figure")

parser.add_option("","--ifo-tag",action="store",\
    type="string",  metavar=" IFOTAG",\
    help="ifo tag gives the information about ifo times and stage")

parser.add_option("","--gps-start-time", action="store",type="float", \
    metavar=" GPSSTARTTIME",help="gps start time (for naming figure and \
    output files)")

parser.add_option("","--gps-end-time", action="store",type="float", \
    metavar=" GPSENDTIME",help="gps end time (for naming figure and \
    output files)")

command_line = sys.argv[1:]
(opts,args) = parser.parse_args()

#################################
# Sanity check of input arguments

if not opts.frame_file:
  print >> sys.stderr, "No frame file specified."
  print >> sys.stderr, "Use --frame-file FILE to specify location."
  sys.exit(1)

if not opts.gps:
  print >> sys.stderr, "No gps time specified."
  print >> sys.stderr, "Use --gps GPS to specify location."
  sys.exit(1)

if not opts.output_path:
  print >> sys.stderr, "No output path specified."
  print >> sys.stderr, "Use --output-path PATH to specify location."
  sys.exit(1)

if not opts.inspiral_xml_file:
  print >> sys.stderr, "No inspiral xml file specified."
  print >> sys.stderr, "Use --inspiral-xml-file XML to specify a file"
  sys.exit(1)

opts = InspiralUtils.initialise(opts, __prog__, git_version.verbose_msg)

doc = utils.load_filename(opts.inspiral_xml_file)
proc = table.get_table(doc, lsctables.ProcessParamsTable.tableName)

plotsnrchisq(opts,args,str(opts.gps),opts.frame_file,opts.output_path,proc,opts.user_tag)

