/***************************************************************************
    File        : NotchFilter.cpp
    Description : Implements class NotchFilter
 ---------------------------------------------------------------------------
    Begin       : Mon Sep 3 2001
    Author(s)   : Roberto Grosso
 ***************************************************************************/

#include "NotchFilter.h"



// Method: timeDomainFilter
//   Computes the Notch Filter in Time
//   Domain
void
gwd::NotchFilter::ComputeFilter(const double SamplingRate,const unsigned int order,
                                const Vector& frq,
                                const Vector& wdt,
                                Vector& filter)
{
  typedef Vector::size_type SizeType;
  // Set size of frequency arrays:
  // work with filter size instead of filter order
  // Size of complex data
  unsigned int lOrder = order;
  if (IsOdd(lOrder)) lOrder++;
  // filter length = order + 1
  lOrder++;
  unsigned int cSize = lOrder/2 + 1;

  // Create filter in frequency domain
  std::vector<Complex> data(cSize);
  
  // Frequency Response
  FrequencyResponse(SamplingRate,frq,wdt,data);

  // Time shift: causality
  double factor = 0.5 * PI * static_cast<double>(lOrder - 1) / static_cast<double>(cSize - 1);
  for (unsigned int nn = 0; nn < cSize; nn++)
  {
    double val = factor*static_cast<double>(nn);
    data[nn] = std::complex<double>(cos(val),-sin(val))*data[nn];
  }

  // Transform to time domain
  Fourier dft;
  dft.dft(data,filter);
    
  // Normalize inverse fourier transform
  const double fSize  = static_cast<double>(filter.size());
  gwd::Multiply<double>(filter,1./fSize);

  // Create window
  Vector wnd;
  Window window;
  window.Hamming(filter.size(),wnd);

  // Apply window
  gwd::Multiply<double>(filter,wnd);
}


// Method: freqDomainFilter
//   Computes the Notch filter in Frequency
//   Domain
void
gwd::NotchFilter::FrequencyResponse(const double SamplingRate,
                                    const Vector& frq,
                                    const Vector& wdt,
                                    std::vector<Complex>& filter)
{
  // Set filter to identity
  //for (std::vector<Complex>::size_type nn = 0; nn < filter.size(); nn++) filter[nn] = double(1);
  std::fill(filter.begin(),filter.end(),Complex(1));
  
  // For each frequency in file
  // remove frequencies from filter
  // Remark: DFT of complex to real -> NN = Nyquist frequency
  //         and nn = 2*(NN-1) is the real (or time domain) dimension
  int NN = static_cast<int>(filter.size());
  double factor = static_cast<double>(2*(NN-1))/SamplingRate;
  for (Vector::size_type nn = 0; nn < frq.size(); nn++)
  {
    double frequency = frq[nn];
    double bandwidth = wdt[nn]/2;
    int maxIndex = static_cast<int>(factor * (frequency + bandwidth));
    int minIndex = static_cast<int>(factor * (frequency - bandwidth));

    // Check index boundaries
    if (minIndex < 0)
      minIndex = 0;
    if ((maxIndex) >= NN)
      maxIndex = NN-1;
    // Eliminate frequencies
    for (int index = minIndex; index < maxIndex; index++)
    {
      filter[index] = 0.;
    }
    // Smmoth out
    if (minIndex > 0)
    {
      int index = minIndex - 1;
      if ( filter[index].real() > 0. ) filter[index] = 0.5;
    }
    if (maxIndex < (NN-1))
    {
      int index = maxIndex + 1;
      if ( filter[index].real() > 0. ) filter[index] = 0.5;
    }
  } // for loop: frw, nn
} // freqDomainFilter()


