#!/usr/bin/env python

# pylal_recolor: Recolor noise in frame files to a given ASD/PSD.
# Copyright (C) 2013 Will M. Farr <will.farr@ligo.org>

# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
# 02110-1301, USA.

"""Example usage:

pylal_recolor --cache H.cache --channel H1:LSC-STRAIN --start 873739823 --seglen 32 --len 4096 --outdir ./early --prefix H-H1_RDS_C03_L2- --asdfile early_aligo.dat --psdout recolored-early-psd-H.dat --smoothing 10

This will recolor the data in the frame files listed in H.cache
starting at 873739823 running for 4096 seconds, estimating the PSD in
a 32 second window (this happens to be a science segment of S5), to
match the ASD provided in early_aligo.dat.  The recolored frame files
will have filenames of the form H-H1_RDS_C03_L2-TIME-LEN.gwf and
appear in the ./early directory.  The file recolored-early-psd-H.dat
will contain the PSD (not ASD!) of the output files.  Before
recoloring the estimated PSD of the original data will be smoothed on
a 10 Hz window so that the lines appearing in the data are not
recolored out of it."""

from argparse import ArgumentParser
import matplotlib.pyplot as pp
import numpy as np
import numpy.fft as npf
import os
import os.path as op
import pylal.Fr as fr
import pylal.frutils as fu
import scipy.signal as ss

__author__ = 'Will M. Farr <will.farr@ligo.org>'
__date__ = '2013/02/18'
__version__ = '0.1'

class QuasiWelchPSDEstimator(object):
    """Estimate the PSD using (something like) Welch's method.  Thsi
    differs from Welch's method because it does not overlap data
    segments, but is otherwise identical."""

    def __init__(self, data_length, fsample):
        """Initialize for a data_length of data_length."""

        self._data_length = data_length
        self._fsample = fsample
        self.set_up_window()
        self.reset()

    @property
    def fs(self):
        """Return the frequencies at which the PSD is estimated."""
        fmax = self.fsample / 2.0

        return np.linspace(0, fmax, self.psd.shape[0])

    @property 
    def psd(self):
        """At any time, the current estimate of the PSD."""
        return self._psd

    @property 
    def data_length(self):
        """Length of data (PSD will be N/2+1)."""
        return self._data_length

    @property
    def fsample(self):
        """Returns the sampling frequency."""
        return self._fsample

    @property
    def window(self):
        """Returns the window used to compute the PSD.  The default is
        Hann window."""
        return self._window

    @property
    def n(self):
        """Returns the number of data segments averaged together to
        get the PSD."""
        return self._n

    def reset(self):
        """Reset the accumulated PSD."""
        self._psd = np.zeros(self._data_length/2+1)
        self._n = 0

    def set_up_window(self):
        self._window = 0.5*(1.0 - np.cos(2.0*np.pi*np.arange(self._data_length)/(self._data_length-1)))
 
        # Unit power per sample in the window
        self._window *= np.sqrt(self._data_length/np.sum(self._window*self._window))

    def add_data(self, data):
        """Accumulate the given stretch of data into the PSD estimate."""
        data = data.copy()
        data.resize(self.data_length) # possibly zero-pad

        data *= self.window

        fdata = np.abs(npf.rfft(data)) 
        fdata *= fdata
        fdata *= 1.0/(self.fsample*float(self.data_length))

        # Now average new sample into accumulated samples.
        self._n += 1
        self._psd *= 1.0 - 1.0/float(self.n)
        self._psd += 1.0/float(self.n)*fdata

    def smoothed_psd(self, df=1.0):
        """Returns a median-filter-smoothed PSD, with smoothing width
        in frequency space of (approximately) ``df``."""
        
        fmax = self.fsample / 2.0

        delta_f = fmax / (self.psd.shape[0]-1)

        width = int(df/delta_f + 1.0)

        # Median filter requires odd lengths
        if width % 2 == 0:
            width += 1

        return ss.medfilt(self.psd, width)

def psd_estimator_from_cache(cachefilename, channel, start, length, seglen):
    """Returns a :class:`QuasiWelchPSDEstimator` that contains a PSD
    estimate for the specified data.

    :param cachefilename: A cahce file name.
    
    :param channel: The channel name.

    :param start: GPS start time.

    :param length: Length (seconds) of the data to consider.

    :param seglen: Length (seconds) of each segment used for PSD
      estimation.

    :returns: :class:`QuasiWelchPSDEstimator` containing a PSD
      estimate corresponding to ``seglen`` seconds of data."""

    with open(cachefilename, 'r') as inp:
        cache = fu.Cache.fromfile(inp)
        fcache = fu.FrameCache(cache)
        
        est=None
        for time in range(start, start+length, seglen):
            data = fcache.fetch(channel, time, time + seglen)

            if est is None:
                est=QuasiWelchPSDEstimator(data.shape[0], 1.0/data.metadata.dt)

            est.add_data(data)

    return est    

if __name__ == '__main__':
    parser=ArgumentParser(description='recolor frame files')

    parser.add_argument('--cache', metavar='FILE', required=True, help='cache file')
    parser.add_argument('--channel', metavar='CHNAME', required=True, help='channel name')
    parser.add_argument('--asdfile', metavar='FILE', required=True, help='file for recolored ASD')
    parser.add_argument('--start', metavar='GPST', type=int, required=True, help='start GPS time')
    parser.add_argument('--seglen', metavar='DT', type=int, required=True, 
                        help='length of segments used to estimate data PSD')
    parser.add_argument('--len', metavar='DELTAT', type=int, required=True, 
                        help='length of data to recolor')
    parser.add_argument('--outdir', metavar='FILE', default='.', help='directory for output frames')
    parser.add_argument('--prefix', metavar='PRE', default='', help='filename prefix')
    parser.add_argument('--smoothing', metavar='DF', default=1.0, type=float, 
                        help='width of PSD smoothing in frequency space')
    parser.add_argument('--psdout', metavar='FILE', help='output the recolored PSD')

    args = parser.parse_args()

    est = psd_estimator_from_cache(args.cache, args.channel, args.start, args.len, args.seglen)

    psd = est.smoothed_psd(df=args.smoothing)
    fs = est.fs

    pp.loglog(fs, psd)
    pp.show()

    new_asd_data = np.loadtxt(args.asdfile)
    new_asd = np.interp(fs, new_asd_data[:,0], new_asd_data[:,1])

    amp_ratio = new_asd / np.sqrt(psd)
    
    # If any zeros in PSD (maybe from smoothing?) just set amp_ratio
    # to 1.
    amp_ratio[psd == 0.0] = 1.0

    # Output the actual PSD that will obtain, not the desired one.
    if args.psdout is not None:
        np.savetxt(args.psdout, np.column_stack((fs, est.psd * amp_ratio * amp_ratio)))

    # have to catch the 'directory already exists' error
    try:
        os.makedirs(args.outdir)
    except:
        pass

    with open(args.cache, 'r') as inp:
        cache = fu.Cache.fromfile(inp)
        fcache = fu.FrameCache(cache)
        
        for time in range(args.start, args.start+args.len, args.seglen):
            
            outname = op.join(args.outdir, '%s%d-%d.gwf'%(args.prefix, time, args.seglen))
            data = fcache.fetch(args.channel, time, time+args.seglen)
            fdata = np.fft.rfft(data)
            fdata *= amp_ratio
            
            new_data_array = np.fft.irfft(fdata)

            fr.frputvect(outname, 
                         [{'name' : args.channel,
                           'data' : new_data_array,
                           'start': time,
                           'dx' : data.metadata.dt,
                           'x_unit' : 's',
                           'type' : 1}]) # Time series
