/* -*- mode: c++; c-basic-offset: 3; -*- */
#include "wcondition.hh"
#include "wtile.hh"
#include "matlab_fcs.hh"
#include "medianmeanaveragespectrum.hh"

#include "DVecType.hh"
#include "TSeries.hh"
#include "fSeries/ASD.hh"
#include "fSeries/PSD.hh"
#include "IIRFilter.hh"
#include "IIRdesign.hh"

//---------------------
#include <iostream>
using namespace std;
//---------------------

using namespace wpipe;
using namespace containers;

// $Id: wcondition.m 3070 2010-05-14 21:38:01Z jrollins $

//  WhitenedDft results verified 2011.06.28
//  non-standard? normalization from whitened dft? Maybe an extra factor in 
//  whitening coefficients.

//======================================  forward/backward zero phase filtering
TSeries
sosfiltfilt(const IIRFilter& filt, const TSeries& ts) {
   IIRFilter filter(filt);
   filter.reset();
   TSeries fwd = filter(ts);
   fwd.refDVect()->reverse();
   filter.reset();
   TSeries back = filter(fwd);
   back.refDVect()->reverse();
   return back;
}

//======================================  Condition constructor
wcondition::wcondition(const tser_vect& rawData, const wtile& tiling, 
		       bool doubleWhiten) 
   : highPassFilter(0)
{
   // determine number of channels - only one for now.
   int numberOfChannels = rawData.size();

   // validate tiling structure
   if (tiling.id() != "Discrete Q-transform tile structure") {
      error("input argument is not a discrete Q transform tiling structure");
   }

   // determine required data lengths
   size_t dataLength = size_t(tiling.sampleFrequency() * tiling.duration());
   //**notused** int halfDataLength = dataLength / 2 + 1;

   // validate data length and force row vectors
   for (int channelNumber=0; channelNumber<numberOfChannels; channelNumber++){
      if (rawData[channelNumber].getNSample() != dataLength) {
	 cout << "wcondition: Number of raw samples: " 
	      << rawData[channelNumber].getNSample()
	      << " tiling sample rate: " << tiling.sampleFrequency()
	      << " tiling duration: " << tiling.duration() << endl;
	 error("data length not consistent with tiling");
      }
   }

   //----------------------------------  Design high-pass filter if requested
   if (tiling.highPassCutoff() > 0) {

      // high pass filter order
      int hpfOrder = 12;

      // design high pass filter
      IIRFilter hpf = butter(kHighPass, hpfOrder, tiling.sampleFrequency(),
			     tiling.highPassCutoff());
      highPassFilter = hpf.clone();

      //hpfArgument = (frequencies / tiling.highPassCutoff).^(2 * hpfOrder);
      //hpfResponse = hpfArgument ./ (1 + hpfArgument);
   }  // end test for high pass filtering

   /////////////////////////////////////////////////////////////////////////////
   //                            initialize cell arrays                       //
   /////////////////////////////////////////////////////////////////////////////

   raw.resize(numberOfChannels);
 
   // initialize vector of high pass filtered data vectors
   highPassData.resize(numberOfChannels);
  
   // initialize vector of high pass filtered data vectors
   highPassedDft.resize(numberOfChannels);

   // initialize vector of whitened data vectors
   whitenedDft.resize(numberOfChannels);

   // initialize  of conditioning filter coefficients
   coefficients.resize(numberOfChannels);

   //-----------------------------------  begin loop over channels
   for (int channelNumber=0; channelNumber < numberOfChannels; channelNumber++){
      raw[channelNumber] = rawData[channelNumber];

      single_chan(raw[channelNumber], tiling, highPassData[channelNumber], 
		  highPassedDft[channelNumber], whitenedDft[channelNumber],
		  coefficients[channelNumber], doubleWhiten);
   }
}

//======================================  Condition destructor
wcondition::~wcondition(void) {
   delete highPassFilter;
}

//======================================  Condition a single channel
void 
wcondition::single_chan(const TSeries& rawData, const wtile& tiling, 
			TSeries& hpData, DFT& hpDft, DFT& whiteDft,
			DFT& coeffs, bool doubleWhiten) {
  // Sample rate and nyquist frequency
  double fSample = tiling.sampleFrequency();
  double nyquistFrequency = fSample / 2;

  // linear predictor error filter order
  int dataLength = rawData.getNSample();


  ///////////////////////////////////////////////////////////////////////////
  //                             high pass filter                          //
  ///////////////////////////////////////////////////////////////////////////
  //int highPassCutoffIndex  = 0;
  if (tiling.highPassCutoff() > 0) {
     if (tiling.debug()) cout << "    double high-pass filter data, fmin=" 
			      << tiling.highPassCutoff() << endl;
    //highPassCutoffIndex = int(ceil(tiling.highPassCutoff()*tiling.duration()));
      
    // apply high pass filter forward and backward
    hpData = sosfiltfilt(*highPassFilter, rawData);

    // include high pass filter in conditioning coefficients
    static_cast<fSeries&>(coeffs) = hpfResponse;
  }

  // if (high pass filtering is not requested, do nothing
  else {
    hpData = rawData;
  }  // end test for high pass filtering
    
  // suppress high pass filter transients
  int lpefOrder(int(ceil(fSample * tiling.whiteningDuration())));
  if (lpefOrder != 0) {
    DVector& hpdvref = *(hpData.refDVect());
    hpdvref.replace_with_zeros(0, lpefOrder, lpefOrder);
    hpdvref.replace_with_zeros(dataLength-lpefOrder, lpefOrder, lpefOrder);
  }

  ///////////////////////////////////////////////////////////////////////////
  //                          fast fourier transform                       //
  ///////////////////////////////////////////////////////////////////////////
  
  // fourier transform high pass filtered data
  hpDft = DFT(hpData);


  ///////////////////////////////////////////////////////////////////////////
  //                      whitening filter                                 //
  ///////////////////////////////////////////////////////////////////////////
  
  // if (whitening is requested,
  if (tiling.whiteningDuration() > 0) {
     wlog(tiling.debug(), 1, "    whiten data");

    ///////////////////////////////////////////////////////////////////////////
    //                 compute accurate power spectrum                       //
    ///////////////////////////////////////////////////////////////////////////
    // Use the median mean average algorithm (as detailed in the FINDCHIRP 
    // paper) to calculate an initial PSD. This reduces the effect of large 
    // glitches and/or injections.

    // Generate PSD with 1 Hz resolution.
    // TODO: the timescale should be take from the tiling information
    double binHz  = 1.0;
    long nSegment = long(fSample / binHz + 0.5);
#ifndef OMEGA_MATLAB_BUGS
    long nSkip  = lpefOrder / nSegment;
    long nValid = (dataLength - 2*lpefOrder) / nSegment;
    Time t0 = hpData.getStartTime() + double(nSkip*nSegment)*hpData.getTStep();
    Interval dT = hpData.getTStep() * double(nSegment * nValid);
    PSD PSDintermediate = medianmeanaveragespectrum(hpData.extract(t0, dT),
						    fSample, nSegment);
#else
    PSD PSDintermediate = medianmeanaveragespectrum(hpData, fSample, nSegment);
#endif
    // -- try to correct normalization:
    //    Note the factor of 2 because highPassDft is already folded.
    //+_+_+_+_+_+_+_+_+_+_ standard normalization in medavgspec
    //+_+_+_+_+_+_+_+_+_+_ Should this be 2*size-1) instead?
    //PSDintermediate *= tiling.sampleFrequency() * 
    //                   2 * hpDft.size();
    PSDintermediate *= fSample * 2 * (hpDft.size() - 1);

    // -- interpolate to finer resolution
    // TODO: choose a higher-order interpolation
    static_cast<fSeries&>(PSDintermediate) 
      = PSDintermediate.interpolate(0, nyquistFrequency, 1/tiling.duration());

    // make theCoefficients directly from the power spectrum
    fSeries theCoefficients(PSDintermediate);
    size_t N = theCoefficients.size();
    DVectD& dvd = dynamic_cast<DVectD&>(theCoefficients.refDVect());
    for (size_t i=0; i<N; ++i) {
      if (dvd[i] != 0) dvd[i] = sqrt(2.0/dvd[i]);
    }

    if (tiling.highPassCutoff() > 0) {
      int hpBin= theCoefficients.getBin(tiling.highPassCutoff());
      theCoefficients.refDVect().replace_with_zeros(0, hpBin, hpBin);
    }

    // include whitening filter in conditioning coefficients
    if (!coeffs.empty()) {
      coeffs *= theCoefficients;
    } else {	
      static_cast<fSeries&>(coeffs) = theCoefficients;
    }

    // apply whitening filter
    whiteDft  = hpDft;
    whiteDft *= theCoefficients;

    // reapply whitening filter if double whitening requested
    if (doubleWhiten) {
      wlog(tiling.debug(), 1, "    double whiten data");
      whiteDft *= theCoefficients;
    }

    // if (whitening is not requested,
  } else {

    // do nothing
    whiteDft = hpDft;
  
    // end test for whitening
  }

#ifdef UNNORMALIZED_DFTS
  //====================================  Emulate matlab unnormalized DFTs
  whiteDft *= fSample;
#endif

  /////////////////////////////////////////////////////////////////////////////
  //                                    return                               //
  /////////////////////////////////////////////////////////////////////////////
}

//======================================  Get coefficients
void
wcondition::coefficientDFT(dft_vect& dftvec) const {
  dftvec = coefficients;
}

//======================================  Get a raw DFT
const containers::DFT&
wcondition::rawDFT(void) {
  if (rawDft.empty()) rawDft.push_back(DFT());
  if (rawDft[0].empty()) {
    rawDft[0] = DFT(raw[0]);
#ifdef UNNORMALIZED_DFTS
    //fix for matlab unnormalized dfts
    rawDft[0] *= 1.0/raw[0].getTStep();
#endif
  }
  return rawDft[0];
}

//=======================================  Get high-passed data
TSeries 
wcondition::highPassedData(void) const {
  return highPassData[0];
}

//=======================================  Get whitened data
TSeries 
wcondition::whitenedData(void) const {
  TSeries temp_ts = whitenedDft[0].iFFT();
#ifdef UNNORMALIZED_DFTS
  temp_ts *= whitenedDft[0].getSampleTime();
#endif
  temp_ts.setUnits("whitened");
  return temp_ts;
}

void 
wcondition::whitenedData(tser_vect& tsvec) const {
  size_t N = whitenedDft.size();
  tsvec.resize(N);
  for (size_t i=0; i<N; i++) {
    tsvec[i] = whitenedDft[i].iFFT();
#ifdef UNNORMALIZED_DFTS
    tsvec[i] *= whitenedDft[i].getSampleTime();
#endif
    tsvec[i].setUnits("whitened");
  }
}

void
wcondition::whitenedDFT(dft_vect& dftvec) const {
  dftvec = whitenedDft;
}
