/* -*- mode: c++; c-basic-offset: 3; -*- */
#include "wresample.hh"
#include "matlab_fcs.hh"
#include "lcl_array.hh"
#include "Bits.hh"
//#include "DecimateBy2.hh"
#include "Hanning.hh"
#include "DVecType.hh"
#include "FIRdft.hh"
#include "FIRdesign.hh"
#include <iostream>

//--------------------------------------------------------------------------
using namespace std;

namespace wpipe {

   size_t                    single_resample::use_count = 0;
   single_resample::filt_map single_resample::rate_map;

   //===================================  Construct an empty resampler
   single_resample::single_resample(void) {
      use_count++;
   }

   //===================================  Construct and initialize a resampler
   single_resample::single_resample(long from, long to) {
      use_count++;
      init(from, to);
   }

   //===================================  Destroy a single resampler
   single_resample::~single_resample(void) {
      if (--use_count == 0) {
	 for (filt_iter i=rate_map.begin(); i != rate_map.end(); ++i) {
	    delete i->second;
	    i->second = 0;
	 }
	 rate_map.clear();
      }
   }

   //===================================  Initialize a single resampler
   void
   single_resample::init(long from, long to) {
      _filter.set(filter(from, to));
      _outData.Clear();
   }

   //===================================  Design a single resampler filter
   Pipe*
   single_resample::filter(long fRawSample, long fSample) {
      rate_pair rp(fRawSample, fSample);
      if (rate_map.find(rp) == rate_map.end()) {
	 long denom      = gcd(fSample, fRawSample);
	 long upSampleFactor   = fSample /denom;
	 long downSampleFactor = fRawSample / denom; 
	 
	 // design anti-alias filter
	 const int nBand = 2;
	 int filterOrder = 2 * 256 * max(upSampleFactor, downSampleFactor);
	 if (filterOrder > 16384) filterOrder = 16384;
	 double filterCutoff = 0.99 / max(upSampleFactor, downSampleFactor);
	 double filterFrequencies[2*nBand] = {0.0, filterCutoff, filterCutoff, 1.0};
	 double filterMagnitudes[2*nBand]  = {1.0, 1.0, 0.0, 0.0};
	 double wband[nBand] = {1.0, 1.0};
	 DVectD filter_coefs(filterOrder+1);
	 firls(filterOrder, nBand, filterFrequencies, filterMagnitudes, wband, 
	       filter_coefs.refTData());
	 Hanning filter1(filterOrder + 1);
	 filter_coefs *= filter1.refDVect();
	 filter_coefs *= double(upSampleFactor)
	    /filter1.refDVect().getDouble(filterOrder/2);
	 FIRdft* filt = new FIRdft(filterOrder, double(fRawSample));
	 filt->setCoefs(filterOrder+1, filter_coefs.refTData());
	 filt->setMode(FIRdft::fm_zero_phase);
	 rate_map.insert(filt_node(rp, filt));
      }
      return rate_map.find(rp)->second->clone();
   }

   //===================================  Resample a time-series
   void
   single_resample::resample(const TSeries& ts, double fSample) {
      TSeries tsin(ts);
      FIRdft& filt = dynamic_cast<FIRdft&>(*_filter);

      //--------------------------------  Check for input gaps.
      Time tFilt  = filt.getCurrentTime();
      Time tInput = tsin.getStartTime();
      if (!tFilt) {
	 filt.reset();
	 _outData.Clear();
	 tFilt = Time(0);
      }
      else if (tFilt < tInput) {
	 cerr << "wresample: Input data gap, filter current-time: "
	      << tFilt << " data start: " << tInput << endl;
	 filt.reset();
	 _outData.Clear();
	 tFilt = Time(0);
      }

      //--------------------------------  Erase overlapping input data
      //                                  that have already been filtered.
      if (tFilt > tInput) tsin.eraseStart(tFilt - tInput);

      //--------------------------------  Filter and decimate
      long downSampleFactor = 1.0 / (tsin.getTStep() * fSample);
      _outData.Append(filt(tsin).decimate(downSampleFactor));

      //--------------------------------  Remove decimated data from before 
      //                                  nominal start time.
      Time tOutput = _outData.getStartTime();
      if (tInput>tOutput) _outData.eraseStart(tInput-tOutput);
   }

   //===================================  Reset a single resampler
   void
   single_resample::reset(void) {
      _filter.set(0);
      _outData.Clear();
   }

   //===================================  Resampler constructor
   resampler::resampler(void) {
   }

   //===================================  Resampler destructor
   resampler::~resampler(void) {
      reset();
   }

   //===================================  Resample function
   tser_vect
   resampler::wresample(const tser_vect& data, double sampleFrequency) {
      size_t nChans = data.size();
      bool_vect select(nChans, true);
      return wresample(data, sampleFrequency, select);
   }

   //===================================  Resample function
   tser_vect
   resampler::wresample(const tser_vect& data, double sampleFrequency,
			const bool_vect& select) {
      size_t nChans = data.size();
      if (pipe_vect.size() != nChans) reset();
      if (pipe_vect.empty()) pipe_vect.resize(nChans);

#if 0
      //--------------------------------  Print
      cout << "Resampler input data: " << endl;
      cout << "channel names: ";
      for (size_t i=0; i<nChans; i++) cout << " " << data[i].getName();
      cout << endl;
      cout << "durations: ";
      for (size_t i=0; i<nChans; i++) cout << " " << data[i].getInterval();
      cout << endl;
#endif

     //---------------------------------  Validate equal data durations
      size_t nValid=0;
     Interval dataDuration = 0;
     for (size_t chanNumber=0; chanNumber < nChans; chanNumber++) {
	if (select[chanNumber]) {
	   if (!dataDuration) {
	      dataDuration = data[chanNumber].getInterval();
	   } else if (dataDuration != data[chanNumber].getInterval()) {
	      error("data durations are not equal");
	   }
	   nValid++;
	}
     }

     //---------------------------------  Begin loop over channels
     tser_vect ts(nChans);
     long fSample = long(sampleFrequency);
     size_t iValid=0;
     for (size_t chanNumber=0; chanNumber<nChans; chanNumber++){
	if (!select[chanNumber]) continue;
	TSeries tsin(data[iValid]);
	tsin.Convert(DVector::t_double);

	//------------------------------  Demand reduction by 2^N
	double decimFact = sampleFrequency*double(tsin.getTStep());
	if (decimFact == 0) {
	   error("Decimation factor is zero");
	}
	decimFact = 1.0/decimFact;
	if (!is_power_of_2(int(decimFact))) {
	   cerr << "Target rate: " << sampleFrequency << " input t-step: " 
		<< tsin.getTStep() << endl;
	   error("Sample rate reduction not a power of 2");
	}

	long fRawSample = long(1./double(tsin.getTStep()));
	if (fSample != fRawSample) {
	   //------------------------------  Get an anti-aliasing filter 
	   //                                Note: assume integer sample rates
	   if (pipe_vect[chanNumber]._filter.null()) {
	      pipe_vect[chanNumber].init(fRawSample, fSample);
	   }

	   //---------------------------  Resample and pad data to nominal end.
	   pipe_vect[chanNumber].resample(tsin, fSample);
	   tsin = pipe_vect[chanNumber]._outData;
	   tsin.extend(data[chanNumber].getEndTime());
	}
	ts[iValid] = tsin;
	iValid++;
     }      // end loop over channels

     //---------------------------------  Return to calling function
     return ts;
  }

   //====================================  Reset all filters.
   void
   resampler::reset(void) {
      pipe_vect.clear();
   }

}  // namespace wpipe
