#!/usr/bin/python

__author__ = "Alexander Dietz <Alexander.Dietz@astro.cf.ac.uk>"
__prog__ = "plotinsppop"
__title__ = "Injection Population Plots"


import sys
import os
from optparse import *
import re
import exceptions
import glob
from types import *
import itertools
from glue import lal, segmentsUtils
from pylal import SimInspiralUtils, InspiralUtils
from pylal import git_version
from pylal import plotutils
from lal import PI as LAL_PI
from lal import MTSUN_SI as LAL_MTSUN_SI

def create_hist(data, xlabel, ylabel, title):

  plot = plotutils.VerticalBarHistogram(xlabel, ylabel, title)
  plot.add_content(data, color='r', label=xlabel)
  plot.finalize(normed=False)

def get_total_angmom(sim):
  
  # conversion factor for the angular momentum
  angmomfac = sim.mass1 * sim.mass2 * \
              np.power(LAL_PI * LAL_MTSUN_SI * (sim.mass1 + sim.mass2) * \
                       sim.f_lower, -1.0/3.0)
  m1sq = sim.mass1 * sim.mass1
  m2sq = sim.mass2 * sim.mass2

  # compute the orbital angular momentum
  L = np.zeros(3)
  L[0] = angmomfac * np.sin(sim.inclination)
  L[1] = 0                               
  L[2] = angmomfac * np.cos(sim.inclination)

  # compute the spins
  S = np.zeros(3)
  S[0] =  m1sq * sim.spin1x + m2sq * sim.spin2x
  S[1] =  m1sq * sim.spin1y + m2sq * sim.spin2y
  S[2] =  m1sq * sim.spin1z + m2sq * sim.spin2z

  # and finally the total angular momentum
  J = L + S

  return J


##############################################################################
usage = """usage: %prog [options] file1 (file2 file3)

Inspiral Injection Plotting Functions

Generate a set of summary plots from an injection file to understand
which regions of parameter space are sampled.

Available plots:

1) Chirp mass accuracy can be plotted as:
  a) a function of the injected chirp mass, using --plot-mchirp
  b) a histogram, using --hist-mchirp
  c) a function of the injected effective distance, using --chirp-dist
  d) a function of the detected SNR, using --chirp-snr
The width of all these plots is set using --chirp-axis

2) Any two-dimensional plot is also available for each of the value that is stored
in the sim_inspiral table, including 'total_mass','spin1' and 'spin2'. Just specify the values you want to have plotted with '--x-value' and '--y-value' (repetitive arguments possible).

Example:
plotinsppop --glob-injfile \"../Injection/HL/HL*.xml\" --fig-name s5_test --logdist --mass-v-mass

plotinsppop --glob-injfile \"../Injection/HL/HL*.xml\" --fig-name s5_test --x-value total_mass --y-value distanceLOG --x-value mass1 --y-value mass2 --x-value inclination --y-value distanceLOG

3) Any one-dimensional histogram is available with --hist <column name>;
   specify multiple times to get histograms of multiple columns.  All
   available numeric columns can be histogrammed with '--hist all'.

"""
parser = OptionParser(usage, version=git_version.verbose_msg)
parser.add_option("-C","--cache-file",action="store",type="string",\
    default=None, metavar=" CACHE",\
    help="setting cache-filename" )
parser.add_option("", "--inj-pattern", default="INJECTION",
    help="sieve pattern for the injection files of interest")
parser.add_option("-F","--glob-injfile",action="store",type="string",\
    default=None, metavar=" INJFILE",\
    help="glob of the injection file" )
parser.add_option("-I","--ifo-times",action="store",type="string",\
    default=None, metavar=" IFO-TIMES",\
    help="setting IFO-times" )
parser.add_option("-B","--gps-start-time",action="store",type="int",\
    default=None, metavar=" TEXT",\
    help="setting GPS start time" )
parser.add_option("-E","--gps-end-time",action="store",type="int",\
    default=None, metavar=" TEXT",\
    help="setting GPS end time" )
parser.add_option("-s","--show-plot",action="store_true",default=False,\
    help="display the figures on the terminal" )
parser.add_option("-f","--enable-output",action="store_true",\
    default=False, metavar=" OUTPUT",\
    help="enable the generation of the html and cache documents" )
parser.add_option("-u","--user-tag",action="store",type="string",\
    default=None, metavar=" USER-TAG",\
    help="setting the user-tag" )
parser.add_option("","--output-path",action="store",type="string",\
    default=None, metavar=" output",\
    help="setting the output-path" )
parser.add_option("-X","--x-value",action="append",type="string",\
    default=None, metavar=" TEXT",\
    help="variable to plot on the x-axis. Add 'LOG' at the end to make it a log  axis" )
parser.add_option("-Y","--y-value",action="append",type="string",\
    default=None, metavar=" TEXT",\
    help="variable to plot on the y-axis. Add 'LOG' at the end to make it a log  axis" )
parser.add_option("-H", "--hist", action="append", default=[],
    help="column to histogram; specify multiple times for multiple "\
    "histograms. Add 'LOG' at the end to make it a log "\
    "plot.  Make 'all' or 'allLOG' to histogram all columns of numeric data." )
parser.add_option("--plot-skymap", action="store_true", default=False,
    help="plot the sky positions on a Mollweide-projection map; "\
    "requires the basemap package")
parser.add_option("","--plot-spins",action="store_true",\
    default=False, help="create plots related to spin." )
parser.add_option("","--verbose",action="store_true",\
    default=False, metavar=" VERBOSE",\
    help="setting verbosity" )



######################################################
## basic preparations
######################################################
(opts,args) = parser.parse_args()
args  = sys.argv[1:]
opts = InspiralUtils.initialise( opts, __prog__, git_version.verbose_msg)

colors = InspiralUtils.colors
fnameList = []
tagList = []
plot_num = itertools.count(1)

# loading the graphics libs
if not opts.show_plot:
  import matplotlib
  matplotlib.use('Agg')
from pylab import *
from pylal import viz

  


######################################################
### reading data
######################################################

if opts.cache_file:
  injfiles = lal.Cache.fromfile(open(opts.cache_file)).sieve(ifos=opts.ifo_times, description = 'INJECTION').\
             checkfilesexist()[0].pfnlist()

else:
  # glob the list of files to read in
  injfiles = glob.glob(opts.glob_injfile);

# an instance of the simInspiralTable
binaries = SimInspiralUtils.ReadSimInspiralFromFiles(injfiles)

# resolve --hist=all, --hist=allLOG, and --hist=allLOGLOG
if len(opts.hist) == 1 and opts.hist[0].startswith("all"):
  log_suffix = opts.hist[0][3:]
  opts.hist = [str(attr) + log_suffix for attr, attr_type in \
               zip(binaries.columnnames, binaries.columntypes) if \
               attr_type.startswith("real") or attr_type.startswith("int")]

######################################################
### plotting routines
######################################################


## create the sky map
if opts.plot_skymap:
    from matplotlib.figure import Figure
    from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
    from mpl_toolkits.basemap import Basemap

    fig = Figure()
    canvas = FigureCanvas(fig)
    ax = fig.add_axes([0.05, 0.05, 0.9, 0.9])
    m = Basemap(projection='moll', lon_0=0.0, lat_0=0.0, ax=ax)
    x, y = m([math.degrees(b.longitude) for b in binaries],
             [math.degrees(b.latitude) for b in binaries])  # put in map coords
    ax.plot(x, y, "b.")

    m.drawmapboundary()
    m.drawparallels(np.linspace(-90, 90, 7), labels=[1,0,0,0], labelstyle='+/-')
    m.drawmeridians(np.linspace(0, 360, 12), labels=[0,0,0,1], labelstyle='+/-')
    ax.set_title("Injection sky locations")

    if opts.enable_output:
      fname = InspiralUtils.set_figure_name(opts, "skymap")
      InspiralUtils.savefig_pylal(fname, fig=fig)
      fnameList.append(fname)
      tagList.append("Sky map")
    del fig, canvas, m

## make the a vs b plots
if opts.x_value and opts.y_value:

  # loop over the given values
  numberPlot = len(opts.x_value)
  for i in range(numberPlot):

    # restore the corresponding value
    xval=opts.x_value[i]
    yval=opts.y_value[i]

    # test for log-plot
    xlog=False
    ylog=False
    if xval[len(xval)-3:]=='LOG':
      xlog=True
      xval=xval[0:len(xval)-3]
    if yval[len(yval)-3:]=='LOG':
      ylog=True
      yval=yval[0:len(yval)-3]
 
    figure(plot_num.next())
    px=binaries.get_column(xval)
    py=binaries.get_column(yval)
    if "end_time" in xval:
      px=[px[i]-opts.start_time for i in range(len(px))]

    if xlog and not ylog:
      semilogx(px, py, 'bo')
    elif not xlog and ylog:
      semilogy(px, py, 'bo')
    else:
      plot(px, py, 'bo')
   
    if "end_time" in xval:
      v=axis()
      print v
      axis([0, opts.end_time-opts.start_time, v[2], v[3]])
      print opts.start_time
    xlabel(xval.replace('_','-'))
    ylabel(yval.replace('_','-'))
    titleString=yval+ ' versus '+ xval
    title(titleString.replace('_','-'))
    grid(True)
    if opts.enable_output:
      fname = InspiralUtils.set_figure_name( opts, xval+'_'+yval)
      InspiralUtils.savefig_pylal(fname)
      fnameList.append( fname )
      tagList.append(xval +' vs. '+yval)
    if not opts.show_plot:
      close()


## create all the histograms
if opts.hist:
  for hist in opts.hist:
    nbins=30
    
    xlog=False
    ylog=False
    plot_type='normal'
    if hist[len(hist)-3:]=='LOG':
      hist=hist[0:len(hist)-3]
      plot_type='logy'

      if hist[len(hist)-3:]=='LOG':
        hist=hist[0:len(hist)-3]
        plot_type='loglog'  
    

    figure(plot_num.next())
    viz.histcol( binaries, hist, nbins = nbins, plot_type=plot_type)
    if opts.enable_output:
      fname = InspiralUtils.set_figure_name( opts, 'hist_'+hist)
      InspiralUtils.savefig_pylal(fname)
      fnameList.append( fname )
      tagList.append('histogram '+hist)
    if not opts.show_plot:
      close()




if opts.plot_spins:
  # create histogram of spin1 magnitude  
  data = [np.sqrt(b.spin1x*b.spin1x + b.spin1y*b.spin1y + b.spin1z*b.spin1z) for b in binaries]
  create_hist(data, "magnitude spin1","number", "histogram spin1 magnitude")
  if opts.enable_output:
    fname = InspiralUtils.set_figure_name( opts, 'hist_spin1')
    InspiralUtils.savefig_pylal(fname)
    fnameList.append( fname )
    tagList.append('histogram spin1')
  if not opts.show_plot:
    close()

  data = [np.sqrt(b.spin2x*b.spin2x + b.spin2y*b.spin2y + b.spin2z*b.spin2z) for b in binaries]
  create_hist(data, "magnitude spin2","number","histogram spin2 magnitude")
  if opts.enable_output:
    fname = InspiralUtils.set_figure_name( opts, 'hist_spin2')
    InspiralUtils.savefig_pylal(fname)
    fnameList.append( fname )
    tagList.append('histogram spin2')
  if not opts.show_plot:
    close()

  # create final total spin plot (see G110113)
  datadict = {'Jmag':[],'Jx':[],'Jy':[],'Jz':[],'Jincl':[]}  
  for sim in binaries:

    # get the total angular momentum
    J = get_total_angmom(sim)

    Jabs = np.sqrt(np.dot(J, J))
    Jincl =  np.arctan2(np.sqrt(J[0] * J[0] + J[1] * J[1]),J[2])
    datadict['Jmag'].append(Jabs)
    datadict['Jx'].append(J[0])
    datadict['Jy'].append(J[1])
    datadict['Jz'].append(J[2])
    datadict['Jincl'].append(Jincl)
    

  for key, data in datadict.iteritems():

    create_hist(data, key, "number", key+" histogram")

    if opts.enable_output:
      fname = InspiralUtils.set_figure_name( opts, 'hist_'+key)
      InspiralUtils.savefig_pylal(fname)
      fnameList.append( fname )
      tagList.append('histogram '+key)
    if not opts.show_plot:
      close()



# generate html and cache
if opts.enable_output:
  html_filename = InspiralUtils.write_html_output( opts, args, fnameList, tagList)
  InspiralUtils.write_cache_output( opts, html_filename, fnameList)
  

if opts.show_plot:
  show()

