#! /usr/bin/env python

# Copyright (C) 2011 Duncan Macleod
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.    See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA    02110-1301, USA.

"""
Interferometer (system) frequency noise projection generator.
Display the noise projection for an interferometer or interferometer
system using a configuration file at the given GPS time.
"""

# =============================================================================
# Preamble
# =============================================================================

from __future__ import division

import numpy
import ConfigParser
import optparse
import os
import sys
import re
import datetime
import lal

from scipy import signal

# set plotting backed
from matplotlib import use
use("Agg")

from pylal import seriesutils
from pylal import htmlutils
from pylal import git_version
from pylal import plotutils
from pylal.dq import noisebudget

from glue import segments
from glue.lal import Cache

# set metadata
__author__  = 'Duncan M. Macleod <duncan.macleod@ligo.org>'
__version__ = git_version.id
__date__    = git_version.date

# userful regex, lambdas, and global variables
_cchar = re.compile('[\W_]+')
_ifos = [d.frDetector.prefix for d in lal.lalCachedDetectors]+["C1"]

plotutils.set_rcParams()

# =============================================================================
# Verbose
# =============================================================================

def print_verbose(message, verbose=False, stream=sys.stdout, profile=True):
    """
    Print verbose messages to stdout
    """
    if verbose:
        stream.write(message)
        stream.flush()

# =============================================================================
# Construct ZPK filter
# =============================================================================

def freqresp(zeros, poles, gain, f0=0, deltaF=1, length=100):
    """
    Compute frequency response for ZPK filter for frequency array with the
    given parameters.
    """
    lti = signal.lti(numpy.asarray(zeros), numpy.asarray(poles), gain)
    f     = (numpy.arange(length)*deltaF + f0)
    fresp = map(lambda w: numpy.polyval(lti.num, w*1j)/\
                          numpy.polyval(lti.den, w*1j), f)
    fresp = numpy.asarray(fresp)
    mag   = abs(fresp)
    return mag

# =============================================================================
# Parse command line
# =============================================================================

def parse_command_line():
    """
    Take arguments from the command line and format them appropriately.
    """
    prog   = os.path.basename(sys.argv[0])
    usage  = "%s --config-file ${CONFIG_FILE} --gps-start-time ${GPSSTART} "\
             "--gps-end-time ${GPSEND} --ifo ${IFO} [OPTIONS]" % prog
    epilog = "If you're having trouble, e-mail detchar+code@ligo.org. "+\
             "To report a bug, please visit "+\
             "https://bugs.ligo.org/redmine/projects/detchar and submit "+\
             "an Issue."
    parser = optparse.OptionParser(usage=usage, description=__doc__[1:],\
                                   formatter=optparse.IndentedHelpFormatter(4),\
                                   epilog=epilog)
    parser.add_option("-p", "--profile", action="store_true", default=False,\
                      help="show second timer with verbose statements, "+\
                           "default: %default")
    parser.add_option("-v", "--verbose", action="store_true", default=False,\
                      help="show verbose output, default: %default")
    parser.add_option("-V", "--version", action="version",\
                      help="show program's version number and exit")
    parser.version = __version__

    # add interferometer options
    nbopts = optparse.OptionGroup(parser, "Noise projection options")
    nbopts.add_option("-i", "--ifo", action="store", default=None,\
                      metavar="IFO", help="interferometer")
    nbopts.add_option("-o", "--output-dir", action="store", type="string",\
                      default=os.getcwd(),\
                      help="output directory, default: %default")
    nbopts.add_option("-f", "--config-file", action="store", type="string",\
                      metavar="FILE", default='config.ini',\
                      help="ini file for analysis, default: %default")
    nbopts.add_option("-s", "--gps-start-time", action="store", type="int",\
                      metavar="GPSSTART", help="GPS start time")
    nbopts.add_option("-e", "--gps-end-time", action="store", type="int",\
                      metavar="GPSEND", help="GPS end time")
    nbopts.add_option("--online", action="store_true", default=False,\
                      help="run in online monitor mode, default: %default")

    # add data options
    dopts = optparse.OptionGroup(parser, "Data access options")
    dopts.add_option("-c", "--data-cache", action="store", type="string",\
                     metavar="FILE", default=None,\
                     help="read GWF frame locations from FILE, "+\
                          "default: use ligo_data_find")
    dopts.add_option("-n", "--nds", action="store_true", default=False,\
                      help="use NDS(2) to access data, default: %default")

    parser.add_option_group(nbopts)
    parser.add_option_group(dopts)

    options, args = sanity_check_command_line(parser.parse_args())

    return options, args

def sanity_check_command_line((opts, args)):
    """
    Sanity check the command line arguments given.
    """
    # assert all required options
    req_opts = ['ifo', 'config_file', 'gps_start_time', 'gps_end_time']
    for opt in req_opts:
        assert getattr(opts, opt),\
               '--%s must be given.' % re.sub('_', '-', opt)

    # verify ifo
    assert opts.ifo in _ifos, "--ifo=%s invalid. --ifo must be one of \n%s"\
                              % (opts.ifo, ', '.join(_ifos))

    # verify configuration file
    assert os.path.isfile(opts.config_file),\
           "--config-file=%s invalid. %s cannot be found." % (opts.config_file)
    opts.config_file = os.path.abspath(opts.config_file)

    # verify data cache
    if opts.data_cache:
        assert os.path.isfile(opts.data_cache),\
               "--data-cache=%s invalid. %s cannot be found."% (opts.data_cache)
        opts.data_cache = os.path.abspath(opts.data_cache)

    # verify output directory
    opts.output_dir = os.path.abspath(opts.output_dir)

    # verify GPS start time
    if opts.online and (not opts.gps_start_time or not opts.gps_end_time):
        opts.gps_end_time = int(lal.GPSTimeNow())
        opts.gps_start_time = opts.gps_end_time - 60

    # verify GPS end time or set to 60 seconds after start
    if not opts.gps_end_time:
        opts.gps_end_time = opts.gps_start_time + 60

    return opts,args

# =============================================================================
# Main function
# =============================================================================

def build_noise_budget(cp, start, end, ifo, outdir, cache=None, online=False,\
                       usends=False, verbose=False):
    """
    Generate a noise projection based on the content of the
    ConfigParser.ConfigParser object cp, between the given start and end times.
    """

    #
    # setup output
    #

    if not os.path.isdir(outdir):
        os.makedirs(outdir)
    print_verbose("Will output to %s.\n" % outdir, verbose)

    #
    # setup parameters
    #

    print_verbose("Misc config info understood.\n", verbose)

    # default median-mean spectrum
    spectrum = cp.has_option("spectrum", "type") and cp.get("spectrum", "type")\
               or "medianmean"
    # default 1 second FFT
    NFFT     = cp.has_option("spectrum", "segment-length") and\
               cp.getfloat("spectrum", "segment-length") or 2
    # default 0.5 second FFT overlap for median-mean
    overlap  = cp.has_option("spectrum", "segment-overlap") and\
               cp.getfloat("spectrum", "segment-overlap") or NFFT/2

    print_verbose("Spectrum config info understood.\n", verbose)

    #
    # get cache from datafind
    #

    if usends:
        ndsserver = cp.get("nds", "server")
        ndsport   = cp.getint("nds", "port")

    if not usends:
        frametypes = []
        for sec in cp.sections():
            for opt in [opt for opt in cp.options(sec)\
                        if opt.endswith('frame-cache')]:
                frametypes.append(cp.get(sec, opt))

        if not cache:
            cache = Cache()
            for ftype in frametypes:
                cache.extend(dqFrameUtils.get_cache(start, end, ifo, ftype))

    #
    # construct NoiseBudget
    #

    name = cp.has_option("input", "budget-name") and\
           cp.get("input", "budget-name") or ifo
    budget = noisebudget.NoiseBudget(name=name)

    # find sources
    misc_config_sections = ['input', 'spectrum', "nds"]
    for sec in cp.sections():
        if sec.lower() in misc_config_sections: continue
        if sec.lower().startswith('plot'):      continue
        budget.append(noisebudget.NoiseTerm(name=sec))
            
    # pull h(t) to the front
    budget.sort()
    budget.sort(key=lambda t: t.name.lower() == 'hoft' and 0 or 1)

    for i,noise_term in enumerate(budget):

        print_verbose("\n--------------------------------------------\n"+\
                      "Processing %s...\n" % noise_term.name, verbose)

        # get colour
        if cp.has_option(noise_term.name, 'color'):
            noise_term.color = cp.get(noise_term.name, 'color')
        else:
            noise_term.color = colors.next()

        power_spectrum = cp.has_option("input", "power-spectrum") and\
                         cp.getboolean("input", "power-spectrum") or False
        noise_term.plot_noise_curve = not cp.has_option(noise_term.name,\
                                                        "skip-plot")

        # load reference data
        if cp.has_option(noise_term.name, 'reference-file'):
            reffile = cp.get(noise_term.name, 'reference-file')
            noise_term.loadreference(reffile)
            if cp.has_option(noise_term.name, 'reference-type')\
            and re.match('(psd|power)',\
                         cp.get(noise_term, 'reference-type'), re.I):
                self.ref_spectrum = self.ref_spectrum ** (1/2)

        # load theoretical noise curve
        if cp.has_option(noise_term, 'noise-curve')\
        or cp.has_option(noise_term, 'design-curve'):
            # load design curve from file
            noise_term_file = cp.has_option(noise_term, 'noise-curve')\
                              and cp.get(noise_term, 'noise-curve')\
                              or cp.get(noise_term, 'design-curve')
            f, data = numpy.loadtxt(noise_term_file, usecols=[0,1], unpack=True)
            if not power_spectrum:
                data = data**(1/2)
            deltaF = f[1]-f[0]
            f0     = f[0]
            self.frequencyseries = seriesutils.fromarray(data,\
                                                         str(noise_term.name),\
                                                         epoch=noise_term.epoch,
                                                         deltaT=deltaF, f0=f0,\
                                                         frequencyseries=True)
            self.spectrum = data
            self.f0 = f0
            self.deltaF = deltaF
            print_verbose("Noise curve loaded from file.\n", verbose)
            print_verbose("%s processed.\n" % (noise_term), verbose)
            continue

        # load squeezing parameters for shot_noise
        if re.match('shot[-_ ]noise', noise_term.name.lower(), re.I)\
        and cp.has_option(noise_term, 'dc-power-channel'):

            # make sure we've already processed hoft
            assert data.has_key('hoft'),\
                   "To calculate squeezed shot noise, please "+\
                   "include [hoft] in the configuration."

            # load squeezing parameters
            dc     = dict((key.split('-',2)[2], val)\
                          for key,val in cp.items(noise_term)\
                          if key.startswith('dc-power-'))
            opgain = dict((key.split('-',2)[2], val)\
                          for key,val in cp.items(noise_term)\
                          if key.startswith('optical-gain-'))
            cavpf  = dict((key.split('-',3)[3], val)\
                          for key,val in cp.items(noise_term)\
                          if key.startswith('cavity-pole-frequency-'))
            calfac  = cp.getfloat(noise_term, 'calibration-factor')
            shotlim = map(float, cp.get(noise_term,\
                                        'shot-noise-region').split(','))
            shotregion = (freq['hoft'] >= shotlim[0]) &\
                         (freq['hoft'] < shotlim[1])

            # get data
            sqzdata = dict()
            for sqz_param,n in zip([dc, opgain, cavpf], ['dc', 'gain', 'pole']):
                if usends:
                    series = seriesutils.fromNDS(sqz_param['channel'],
                                                 start, end-start,\
                                                 server=ndsserver,\
                                                 post=ndsport)
                else:
                    subcache = cache.sieve(description=sqz_param['frame-type'],\
                                           exact_match=True)
                    series = seriesutils.fromlalcache(subcache,\
                                                      sqz_param['channel'],\
                                                      start, end-start)
                sqzdata[n] = numpy.median(series.data.data)
            shot_model = calfac * abs(1 + (1j*freq['hoft'])/\
                         sqzdata['pole']) * \
                         (numpy.sqrt(-sqzdata['dc'] + float(dc['offset'])) /\
                                     sqzdata['gain'])
            sqzratio    = numpy.median(shot_model[shotregion]) /\
                          numpy.median(data['hoft'][shotregion])
            if sqzratio > 0.8: shot_model /= sqzratio
            data[noise_term] = shot_model
            freq[noise_term] = freq['hoft']
            print_verbose("%s processed.\n" % (noise_term), verbose)
            continue

        # get data params
        noise_term.channel = cp.get(noise_term.name, 'channel')
        if not usends:
            noise_term.frame_type = cp.get(noise_term.name, 'frame-type')
        slope     = cp.has_option(noise_term.name, 'slope') and\
                    cp.getfloat(noise_term.name, 'slope') or 1.0
        dc_offset = cp.has_option(noise_term.name, 'dc-offset') and\
                    cp.getfloat(noise_term.name, 'dc-offset') or 0.0
        calfunc   = cp.has_option(noise_term.name, 'transform') and\
                    eval(cp.get(noise_term.name, 'transform')) or\
                    (lambda d: slope*(d-dc_offset))
        flim      = cp.has_option(noise_term.name, 'frequency-band') and\
                    map(float, cp.get(noise_term.name,\
                                      'frequency-band').split(',')) or None

        # load data for channel
        if usends:
            series = seriesutils.fromNDS(noise_term.channel, start, end-start,\
                                         server=ndsserver, port=ndsport)
        else:
            subcache = cache.sieve(description=noise_term.frame_type,\
                                   exact_match=True)
            series   = seriesutils.fromlalcache(subcache, noise_term.channel,\
                                                start, end-start)
        print_verbose("Data loaded from frames.\n", verbose)

        # calibrate data
        series.data.data = calfunc(series.data.data)

        # generate spectrum
        noise_term.compute_average_spectrum(series, NFFT/series.deltaT,\
                                            overlap/series.deltaT)
        f = numpy.arange(noise_term.spectrum.size) * noise_term.deltaF +\
            noise_term.f0
        print_verbose("Median-mean spectrum calculated.\n", verbose)
        if not power_spectrum:
            noise_term.spectrum = noise_term.spectrum ** (1/2)

        # get transfer function
        if cp.has_option(noise_term.name, 'transfer-function'):
            noise_term_file = cp.get(noise_term.name, 'transfer-function')

            # load transfer function from file
            if noise_term_file.startswith('lambda'):
                transfer_function = eval(noise_term_file)
                noise_term.apply_spectrum_calibration(transfer_function)
            elif os.path.isfile(noise_term_file):
                transfer_function = numpy.loadtxt(noise_term_file,\
                                                  dtype=complex,\
                                                  unpack=True, usecols=[1])
                noise_term.apply_spectrum_calibration(transfer_function)
            else:
                zeros, poles, gain = eval(noise_term_file)
                mag = freqresp(zeros, poles, gain, noise_term.f0,\
                               noise_term.deltaF, noise_term.spectrum.size)
                noise_term.apply_spectrum_calibration(mag)

        # restrict to frequency limits
        if flim:
            noise_term.apply_frequency_band(*flim)
                                                                        
        # set traget
        if cp.has_option(noise_term, 'target'):
            budget.target = noise_term
            print_verbose("%s identified as noise projection target.\n"\
                          % noise_term.name, verbose)

        print_verbose("%s processed.\n" % (noise_term.name), verbose)
 
    print_verbose("\n--------------------------------------------\n"+\
                  "Noise components complete.\n"+\
                  "--------------------------------------------\n", verbose)

    #
    # plot budget
    #
 
    plotlist = []
    name = budget.name
    if re.match("%s[-_\s]" % ifo, name): name = name[3:]
    if cp.has_section('plot noise-budget'):
        _,params = plotutils.parse_plot_config(cp, 'plot noise-budget')
        if online:
            outfile = '%s/%s-%s_NOISE_BUDGET-0-0.png'\
                                % (outdir, ifo, _cchar.sub('_', name.upper()))
            params.setdefault("subtitle",\
                              datetime.datetime(*lal.GPSToUTC(int(end))[:6])\
                                  .strftime("%B %d %Y, %H:%M:%S %ZUTC"))
        else:
            outfile = '%s/%s-%s_NOISE_BUDGET-%d-%d.png'\
                      % (outdir, ifo, _cchar.sub('_', name.upper()),\
                         start, end-start)
        budget.plot(outfile, **params) 
        print_verbose("%s written.\n" % outfile, verbose)
        plotlist.append(os.path.basename(outfile))

    #
    # plot deficit
    #
    
    if cp.has_section('plot ratio-deficit'):
        _,params = plotutils.parse_plot_config(cp, 'plot ratio-deficit')
        if online:
            outfile = '%s/%s-%s_NOISE_BUDGET_RATIO_DEFICIT-0-0.png'\
                      % (outdir, ifo, _cchar.sub('_', name.upper()))
            params.setdefault("subtitle",\
                              datetime.datetime(*lal.GPSToUTC(int(end))[:6])\
                                  .strftime("%B %d %Y, %H:%M:%S %ZUTC"))
        else:
            outfile = '%s/%s-%s_NOISE_BUDGET_RATIO_DEFICIT-%d-%d.png'\
                      % (outdir, ifo, _cchar.sub('_', name.upper()),\
                        start, end-start)
        budget.plot_ratio_deficit(outfile, **params)
        print_verbose("%s written.\n" % outfile, verbose)
        plotlist.append(os.path.basename(outfile))
    
    if cp.has_section('plot deficit'):
        _,params = plotutils.parse_plot_config(cp, 'plot deficit')
        if online:
            outfile = '%s/%s-%s_NOISE_BUDGET_DEFICIT-0-0.png'\
                                % (outdir, ifo, _cchar.sub('_', name.upper()))
            params.setdefault("subtitle",\
                              datetime.datetime(*lal.GPSToUTC(int(end))[:6])\
                                 .strftime("%B %d %Y, %H:%M:%S %ZUTC"))
        else:
            outfile = '%s/%s-%s_NOISE_BUDGET_DEFICIT-%d-%d.png'\
                      % (outdir, ifo, _cchar.sub('_', name.upper()),\
                         start, end-start)
        budget.plot_deficit(outfile, **params)
        print_verbose("%s written.\n" % outfile, verbose)
        plotlist.append(os.path.basename(outfile))

    # Plot noise components
    subplotlist = []
    if cp.has_section('plot noise-terms'):
        _,params = plotutils.parse_plot_config(cp, 'plot noise-terms')
        if online:
            params.setdefault("subtitle",\
                              datetime.datetime(*lal.GPSToUTC(int(end))[:6])\
                                 .strftime("%B %d %Y, %H:%M:%S %ZUTC"))
        for term in budget:
            if not term.plot_noise_curve: continue
            name = term.name
            if not re.match("%s?", name):
                name = "%s-%s" % (term.channel[:2], name)
            if online:
                outfile = '%s/%s_NOISE_CURVE-0-0.png'\
                          % (outdir, _cchar.sub('-', _cchar.sub('_', name), 1))
            else:
                outfile = '%s/%s_NOISE_CURVE-%d-%d.png'\
                          % (outdir, _cchar.sub('-', _cchar.sub('_', name), 1),\
                             start, end-start)
            term.plot(outfile, **params)
            print_verbose("%s written.\n" % outfile, verbose)
            subplotlist.append(os.path.basename(outfile))

    # print an HTML overview page
    html = '%s/index.html' % outdir
    text = htmlutils.write_table(["Noise terms", "Channel"],\
                                 map(list, zip([term.name for term in budget],\
                                            [term.channel for term in budget])))
    cmd    = map(lambda p: os.path.isfile(p) and os.path.abspath(p) or p,\
                  sys.argv)
    cmd    = ' '.join(cmd)
    bodystyle = 'min-width: 400; max-width:800; width: auto; margin: 0 auto;'
    page = htmlutils.summary_page(header=budget.name,\
                                  htag='h1', plotlist=plotlist, text=text,\
                                  subplotlist=subplotlist, info=cmd, init=True,\
                                  title=budget.name,\
                                  bodyattrs={'style': bodystyle})
    open(html, 'w').write(page())
    print_verbose("%s written.\n" % html, verbose)

# =============================================================================
# Run main function if command line
# ============================================================================= 

if __name__=='__main__':

    opts,args = parse_command_line()

    # print opener
    print_verbose("Imports complete, command line read.\n", opts.verbose)

    # read inifile
    cp = ConfigParser.ConfigParser()
    cp.optionxform = str
    cp.read(opts.config_file)
    cp.filename = opts.config_file
    print_verbose("Config file read.\n", opts.verbose)

    # read data cache
    if opts.data_cache:
        cache = Cache.fromfile(open(opts.data_cache,'r'))
        cache = cache.sieve(segment=segments.segment(opts.gps_start_time,\
                                                     opts.gps_end_time))
        cache,_ = cache.checkfilesexist(on_missing='warn')
        print_verbose("%d frames located from data cache.\n" % len(cache),\
                      opts.verbose)
        if len(cache)==0:
            raise RuntimeError("No frames in cache.")
    else:
        cache = None

    # generate noise projection
    build_noise_budget(cp, opts.gps_start_time, opts.gps_end_time, opts.ifo,\
                       opts.output_dir, cache=cache, online=opts.online,\
                       usends=opts.nds, verbose=opts.verbose)
    print_verbose("Finished.\n", opts.verbose)
