#!/usr/bin/python

#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#

__prog__ = "lalapps_cbc_plothzdist"
__author__ = "Collin Capano <collin.capano@ligo.org>"
description = \
''' 
Plots horizon distance versus total mass for specified masses and SNR. Alternatively,
SNR vs total mass can be plotted for specified distance. Multiple waveforms (with the
same physical parameters) may be plotted.
'''
usage = "%s --mtotal-range RANGE --mass-ratio Q --waveform-fmin WFMIN --overlap-fmin OFMIN --output-file FILE [--psd-model MODEL | --asd-file FILE] [--snr SNR | --distance DIST] [additional options]"

from optparse import OptionParser
import os, sys
import numpy
from optparse import OptionParser

import lal
import lalsimulation as lalsim

import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot
pyplot.rcParams.update({
    "text.usetex": True,
    "text.verticalalignment": "center",
    "font.family": "serif",
    "font.serif": ["Computer Modern Roman"],
    "font.weight": "bold",
    "font.size": 12,
    "axes.titlesize": 16,
    "axes.labelsize": 16,
    "xtick.labelsize": 12,
    "ytick.labelsize": 12,
    "legend.fontsize": 12,
    })


#
# =============================================================================
#
#                             Useful Functions
#
# =============================================================================
#

# default hubble constant comes from arXiv:1212.5225
default_H0 = 69.32 # km/s/Mpc

def get_redshift(distance, hubble_const = None):
    """
    Gives the redshift corresponding
    to the distance. Distance should be in Mpc.
    The hubble_const should be in km/s/Mpc, if None
    the default_H0 will be used.
    """
    if hubble_const is None:
        hubble_const = default_H0
    return distance * hubble_const * 1e3 / lal.C_SI

def get_source_mass(redshifted_mass, distance, hubble_const = None):
    """
    Distance should be in Mpc.
    """
    return redshifted_mass/(1 + get_redshift(distance, hubble_const))

def get_proper_distance(distance, hubble_const = None):
    """
    Distance should be in Mpc.
    """
    return distance / (1 + get_redshift(distance, hubble_const))

def m1_m2_from_M_q(mtotal, q):
    m2 = mtotal / float(1+q)
    m1 = mtotal - m2
    return m1, m2

def get_psd(N, df, fmin, psd_model):
    psd = lal.CreateREAL8FrequencySeries('psd', lal.LIGOTimeGPS(0,0), fmin, df, lal.HertzUnit, N)
    lalsim.SimNoisePSD(psd, fmin, getattr(lalsim, 'SimNoisePSD'+psd_model))
    return psd

def get_psd_from_file(N, df, fmin, asd_file):
    psd = lal.CreateREAL8FrequencySeries('psd', lal.LIGOTimeGPS(0,0), fmin, df, lal.HertzUnit, N)
    lalsim.SimNoisePSDFromFile(psd, fmin, os.path.abspath(asd_file))
    return psd

def zero_pad_h(h, N, position, overwrite = False):
    # check that desired position is in bounds
    if position > N:
        raise ValueError, "position must be <= N"
    # create a copy so as not to overwrite original
    if overwrite:
        newh = h
    else:
        newh = lal.CreateREAL8TimeSeries(h.name, h.epoch, h.f0, h.deltaT, lal.lalSecondUnit, h.data.length)
        newh.data.data[:] = h.data.data
    lal.ResizeREAL8TimeSeries(newh, 0, N-position)
    lal.ResizeREAL8TimeSeries(newh, -1*position, N)
    return newh

def get_htilde(h, N, df, fftplan = None):
    # do the fft
    htilde = lal.CreateCOMPLEX16FrequencySeries("htilde", h.epoch, h.f0, df, lal.HertzUnit, int(N/2 + 1))
    if fftplan is None:
        fftplan = lal.CreateForwardREAL8FFTPlan(N, 0)
    lal.REAL8TimeFreqFFT(htilde, h, fftplan)
    return htilde

def get_overlap(htilde, stilde, psd, fmin, fmax = None):
    df = htilde.deltaF
    if not(df == stilde.deltaF == psd.deltaF):
        raise ValueError, "df mismatch"
    if not(htilde.data.length == stilde.data.length == psd.data.length):
        raise ValueError, "vector length mismatch"
    kstart = fmin/df < 1 and 1 or int(fmin/df)
    if fmax is None or fmax/df > htilde.data.length - 1:
        kstop = -1
    else:
        kstop = int(fmax/df)
    # note: numpy always does the half-open interval, i.e.
    # the sum will go from kstart to kstop-1
    return 4*df*sum((htilde.data.data.conj() * stilde.data.data)[kstart:kstop]/psd.data.data[kstart:kstop])

def get_sigmasq(htilde, psd, fmin, fmax = None):
    return get_overlap(htilde, htilde, psd, fmin, fmax).real

def get_psd_models():
    psd_prefix = 'SimNoisePSD'
    psd_models = [x.replace(psd_prefix, '') for x in dir(lalsim) if x.startswith(psd_prefix) and x != psd_prefix and x != psd_prefix+'FromFile']
    return psd_models

# =============================================================================
#
#                                   Parse Options
#
# =============================================================================

parser = OptionParser(version = "", usage = usage % __prog__, description = description)

parser.add_option("-o", "--output-file", help = "File to save plot to. The suffix on the file name will determine the plot type, e.g., '.png' or '.pdf'")
parser.add_option("-p", "--psd-model", help = "PSD model to use. Options are %s. If not specified, must provide an asd file." % ', '.join(get_psd_models()))
parser.add_option("-a", "--asd-file", help = "Get the PSD from the ASD specified in the given file. The file must be an ASCII file consisting of two columns: the first column should give the frequency, the second column the ASD. If not using an ASD file, must specify a psd-model.")
parser.add_option("-m", "--mtotal-range", metavar = "min,max", help = "Required. Set the minimum and maximum total masses to plot (in solar masses).")
parser.add_option("-q", "--mass-ratio", type = 'float', help = "Required. What mass ratio to use.")
parser.add_option("-S", "--snr", type = "float", help = "If specified, will plot horizon distance vs total mass for the given SNR. If de-redshifting, proper distance will be plotted.")
parser.add_option("-D", "--distance", type = "float", help = "If specified, will plot SNR vs total mass at the given horizon distance. If de-redshifting, the given distance is assumed to be the proper distance; otherwise, the given distance is assumed to be the luminosity distance.")
parser.add_option("-z", "--de-redshift", action = "store_true", default = False, help = "Plot de-redshifted source masses and proper distance. Default is to plot red-shifted results. Note: redshift is approximated as z = H0*D/c. Values for z > 1 are therefore inaccurate.")
parser.add_option("", "--hubble-constant", type = 'float', default = default_H0, help = "What Hubble constant to use for de-redshifting, in units of km/s/Mpc. Default is %.2f." %(default_H0))
parser.add_option("-f", "--waveform-fmin", type = "float", help = "Frequency (Hz) at which to start waveform generation.")
parser.add_option("-F", "--overlap-fmin", type = "float", help = "Frequency (Hz) at which to start overlap calculation.")
parser.add_option("--overlap-fmax", type = "float", help = "Set the maximum overlap frequency. If not specified, Nyquist will be used.")
parser.add_option("-w", "--waveform-model", metavar = "WAVEFORM[:TAPER]", action = "append", help = "What waveform model to plot. Can specify multiple times. At least one is required. For time-domain waveforms you can additionally specify a taper by appending the taper type to the name. Taper options are: %s. Time-domain (TD) waveforms will be generated once at the lowest mass, then transformed to the frequency domain (FD). All higher mass waveforms are then rescaled from the initial mass. FD waveforms are regenerated for each total mass.  Note: if a waveform model has both a TD and FD implementation, the FD version is used. IMRPhenomB, for example, is treated as an FD waveform. This does not apply to Taylor(T|F) models; e.g, TaylorF2 is frequency domain, TaylorT2 is time domain." % ', '.join(['start', 'end', 'startend']))
# waveform parameters
parser.add_option("--sample-rate", type = "int", help = "Required for time-domain waveforms. What sample rate (in Hz) to use.")
parser.add_option("--seg-length", type = "int", help = "Required for frequency-domain waveforms. The inverse of this gives the deltaF used for waveform generation.")
parser.add_option("--f-max", type = "float", default = 0., help = "Max frequency to generate FD waveforms to. If not specified the default value for each waveform family will be used.")
parser.add_option("--f-ref", type = "float", help = "Set reference frequency for time-domain waveforms. Default is 0.")
parser.add_option("--spin1", metavar = "x,y,z", default = "0,0,0", help = "Specify x, y, and z components of the larger mass. Default is 0,0,0.")
parser.add_option("--spin2", metavar = "x,y,z", default = "0,0,0", help = "Specify x, y, and z components of the smaller mass. Default is 0,0,0.")
parser.add_option("--lambda1", type = "float", default = 0., help = "Specify tidal-deformation parameter of the larger mass. Default is 0.")
parser.add_option("--lambda2", type = "float", default = 0., help = "Specify tidal-deformation parameter of the smaller mass. Default is 0.")
parser.add_option("--amp-order", type = "int", default = 0, help = "Specify the amplitude order. Default is 0.")
parser.add_option("--phase-order", type = "int", default = 7, help = "Specify the phase order. Default is 7.")
parser.add_option("--inclination", type = "float", default = 0., help = "What inclination to use for all waveforms. Default is 0.")
parser.add_option("--initial-phase", type = "float", default = 0., help = "What inital phase to use for all waveforms. Default is 0.")
# plot parameters
parser.add_option("", "--x-min", type = "float", help = "Set minimum x value.")
parser.add_option("", "--x-max", type = "float", help = "Set maximum x value.")
parser.add_option("", "--y-min", type = "float", help = "Set minimum y value.")
parser.add_option("", "--y-max", type = "float", help = "Set maximum y value.")
parser.add_option("", "--linear-y", action = "store_true", default = False, help = "Make y-axis linear. Default is for y-axis to be log scale.")
parser.add_option("", "--title", type = "string", help = "Specify the plot title (in quotation marks). If none provided and an SNR is given, the title will print that SNR; if a distance is given, the title will print that distance.")
parser.add_option("", "--no-title", action = "store_true", default = False, help = "Do not add a title to the plot. Default is to add a title.")
parser.add_option("", "--legend-loc", default = "lower right", help = "Where to put the legend. Options are 'upper right', 'upper left', 'lower right', 'lower right'. Default is lower right.")
parser.add_option("", "--num-points", type = "int", default = 40, help = "number of points to use to generate plot; default is 40")
parser.add_option("", "--dpi", type = "int", default = 200, help = "dots-per-inch to use for plot; default is 200")
parser.add_option("", "--save-data", action = "store_true", default = False, help = "Save the plot data to a text file. There will be one file for each waveform generated, with filename data-$WAVEFORM.dat. The files will be saved in the same directory as the plots.")
parser.add_option("-v", "--verbose", action = "store_true", default = False, help = "Be verbose.")

opts, args = parser.parse_args()

# parse for errors
required_args = ['output-file', 'mtotal-range', 'mass-ratio', 'waveform-fmin', 'overlap-fmin', 'waveform-model']
forgotten_args = [arg for arg in required_args if not getattr(opts, arg.replace('-', '_'))]
if any(forgotten_args):
    parser.error("option(s) %s is/are required" % ', '.join(forgotten_args))
output_file = opts.output_file
psd_model = opts.psd_model
asd_file = opts.asd_file
snr = opts.snr
distance = opts.distance
q = opts.mass_ratio
wFmin = opts.waveform_fmin
oFmin = opts.overlap_fmin
waveforms = []
tapers = {}
for waveform in opts.waveform_model:
    if len(waveform.split(':')) == 2:
        waveform, taper = waveform.split(':')
        tapers[waveform] = lalsim.GetTaperFromString('TAPER_'+taper.upper())
    waveforms.append(waveform)
sample_rate = opts.sample_rate
seg_length = opts.seg_length
f_ref = opts.f_ref
f_max = opts.f_max
oFmax = opts.overlap_fmax
lmbda1 = opts.lambda1
lmbda2 = opts.lambda2
ampO = opts.amp_order
phaseO = opts.phase_order
inc = opts.inclination
phi0 = opts.initial_phase
npoints = opts.num_points
dpi = opts.dpi

if not (psd_model or asd_file):
    parser.error("psd-model or asd-file required")
if psd_model is not None and asd_file is not None:
    parser.error("please specify an asd-file or a psd-model (not both)")
if not (opts.snr or opts.distance):
    parser.error("snr or distance required")
if snr is not None and distance is not None:
    parser.error("please specify snr or distance (not both)")
if oFmax is None and sample_rate is not None:
    oFmax = sample_rate / 2.
elif sample_rate is not None and oFmax > sample_rate /2.:
    parser.error("overlap-fmax (%i) cannot be larger than Nyquist (%i)" %(f_max, sample_rate/2.))
fd_waveforms = [waveform for waveform in waveforms if lalsim.SimInspiralImplementedFDApproximants(lalsim.GetApproximantFromString(waveform))]
td_waveforms = [waveform for waveform in waveforms if lalsim.SimInspiralImplementedTDApproximants(lalsim.GetApproximantFromString(waveform)) and waveform not in fd_waveforms]
if any(fd_waveforms) and seg_length is None:
    parser.error("one or more FD waveforms desired (%s), but no seg. length given" % ', '.join(fd_waveforms))
if any(td_waveforms) and sample_rate is None:
    parser.error("one or more TD waveforms desired (%s), but no sample rate given" % ', '.join(td_waveforms))

# parse extra args
min_mtotal, max_mtotal = map(float, opts.mtotal_range.split(','))
if snr is not None:
    plot_distance = True
    plot_snr = False
else:
    plot_snr = True
    plot_distance = False
if f_ref is None:
    f_ref = 0.
s1x, s1y, s1z = map(float, opts.spin1.split(','))
s2x, s2y, s2z = map(float, opts.spin2.split(','))

# =============================================================================
#
#                                Main
#
# =============================================================================

if sample_rate is not None:
    dt = 1./sample_rate
if seg_length is not None:
    df = 1./seg_length

input_masses = numpy.linspace(min_mtotal, max_mtotal, num = npoints)
xvals = {}
yvals = {}
if opts.verbose:
    print >> sys.stdout, "Calculating sensitivity:"
for waveform in waveforms:
    xvals[waveform] = numpy.zeros(npoints)
    yvals[waveform] = numpy.zeros(npoints)
    for ii,M in enumerate(input_masses):
        if opts.verbose:
            print >> sys.stdout, "\t%s: %.2f%%\r" %(waveform, 100.*ii/float(npoints)),
            sys.stdout.flush()
        m1, m2 = m1_m2_from_M_q(M, q)

        # generate the waveform: we only do this once for the lowest mass
        # the rest are rescaled
        apprx = lalsim.GetApproximantFromString(waveform)
        if lalsim.SimInspiralImplementedFDApproximants(apprx):
            htilde, _ = lalsim.SimInspiralChooseFDWaveform(phi0, df, m1*lal.MSUN_SI, m2*lal.MSUN_SI, s1x, s1y, s1z, s2x, s2y, s2z, wFmin, f_max, 0.0, 1e6*lal.PC_SI, inc, lmbda1, lmbda2, None, None, ampO, phaseO, apprx)
            N = 2*(htilde.data.length-1)

        elif lalsim.SimInspiralImplementedTDApproximants(apprx):
            if ii == 0:
                hplus, hcross = lalsim.SimInspiralChooseTDWaveform(phi0, dt, m1*lal.MSUN_SI, m2*lal.MSUN_SI, s1x, s1y, s1z, s2x, s2y, s2z, wFmin, f_ref, 1e6*lal.PC_SI, inc, lmbda1, lmbda2, None, None, ampO, phaseO, apprx)
                if waveform in tapers:
                    lalsim.SimInspiralREAL8WaveTaper(hplus.data, tapers[waveform])

                # get length
                N = int(2*2**numpy.ceil(numpy.log2(hplus.data.length)))
                this_df = float(sample_rate)/N

                h = zero_pad_h(hplus, N, 0, overwrite = True)
                htilde = get_htilde(h, N, this_df)
                htilde_ref = htilde

            else:
                # the frequency is scaled by the mass of the ref waveform / this mass
                # the amplitude is scaled by the inverse of that times a correction
                # factor for the equivalent difference in dt due to the rescaling of 
                # the frequency
                scale_fac = M/input_masses[0]
                htilde = lal.CreateCOMPLEX16FrequencySeries("htilde", htilde_ref.epoch, htilde_ref.f0, htilde_ref.deltaF/scale_fac, lal.HertzUnit, htilde_ref.data.length)
                htilde.data.data[:] = htilde_ref.data.data * scale_fac / (2 * htilde.deltaF * (htilde.data.length - 1) * dt)

        else:
            raise ValueError, "unrecognized waveform approximant %s" % waveform
            
        # get psd
        if asd_file is not None:
            psd = get_psd_from_file(htilde.data.length, htilde.deltaF, wFmin, asd_file)
        else:
            psd = get_psd(htilde.data.length, htilde.deltaF, wFmin, psd_model)

        # calculate sensitivity
        sigmasq = get_sigmasq(htilde, psd, oFmin, oFmax)
        if plot_distance:
            distance = numpy.sqrt(sigmasq) / snr
            if opts.de_redshift:
                distance = get_proper_distance(distance, opts.hubble_constant)
            yvals[waveform][ii] = distance
        else:
            snr = numpy.sqrt(sigmasq) / distance
            yvals[waveform][ii] = snr

        if opts.de_redshift:
            xvals[waveform][ii] = get_source_mass(M, distance, opts.hubble_constant)
        else:
            xvals[waveform][ii] = M

    if opts.save_data:
        numpy.savetxt('%s/%s_data-%s.dat' %(os.path.dirname(os.path.abspath(output_file)), waveform, os.path.basename(output_file)[:-4]), numpy.array([xvals[waveform], yvals[waveform]]).T)
        
    if opts.verbose:
        print >> sys.stdout, "\t%s: %.2f%%" %(waveform, 100.)

if opts.verbose:
    print >> sys.stdout, "Plotting..."
fig = pyplot.figure()
ax = fig.add_subplot(111)
for waveform in waveforms:
    if opts.linear_y:
        ax.plot(xvals[waveform], yvals[waveform], lw = 2, label = waveform.replace('_','\\_'))
    else:
        ax.semilogy(xvals[waveform], yvals[waveform], lw = 2, label = waveform.replace('_','\\_'))
ax.set_xlabel('%s total mass ($\mathrm{M}_\odot$)' %(opts.de_redshift and 'source rest-frame' or 'red-shifted'))
if plot_distance:
    ax.set_ylabel('%s distance (Mpc)' %(opts.de_redshift and 'proper' or 'luminosity'))
    if not opts.no_title:
        if opts.title is None:
            ax.set_title(r'$\rho = %.2f,~ q = %.2f %s$' % (snr, q, opts.de_redshift and r',~H_0 = %.2f\,\mathrm{km/s/Mpc}' % opts.hubble_constant or ''))
        else:
            ax.set_title(opts.title)
else:
    ax.set_ylabel(r'$\rho$')
    if not opts.no_title:
        if opts.title is None:
            ax.set_title('$D%s = %i\,\mathrm{Mpc},~ q = %.2f %s$' %(not opts.de_redshift and '_L' or '', distance, q, opts.de_redshift and r',~H_0 = %.2f\,\mathrm{km/s/Mpc}' % opts.hubble_constant or ''))
        else:
            ax.set_title(opts.title)
plt_xmin, plt_xmax = ax.get_xlim()
plt_ymin, plt_ymax = ax.get_ylim()
if opts.x_min is not None:
    plt_xmin = opts.x_min
if opts.x_max is not None:
    plt_xmax = opts.x_max
if opts.y_min is not None:
    plt_ymin = opts.y_min
if opts.y_max is not None:
    plt_ymax = opts.y_max
ax.set_xlim(plt_xmin, plt_xmax)
ax.set_ylim(plt_ymin, plt_ymax)
ax.legend(loc = opts.legend_loc)
if int(matplotlib.__version__.split('.')[0]) < 1:
    ax.grid(which = 'majorminor')
else:
    ax.grid(which = 'both')

fig.savefig(output_file, dpi = dpi)

if opts.verbose:
    print >> sys.stdout, "Finished!"

sys.exit(0)
