/* -*- mode: c++; c-basic-offset: 3; -*- */
//  $Id$
#include "weventlist.hh"
#include "matlab_fcs.hh"
#include "AgglomClusterTree.hh"

#include "constant.hh"
#include "lcl_array.hh"
#include <algorithm>
#include <stdexcept>
#include <iostream>

using namespace wpipe;
using namespace std;

const size_t null_cluster_id = ~0;
const size_t idle_cluster_id = ~1;

struct tile_group {
   struct neighbor_id {
      neighbor_id(size_t inx, double erg);
      bool operator>(const neighbor_id& x) const;
      size_t index;
      double normE;
   };
   typedef std::vector<neighbor_id> nbor_list;
   size_t    clusterNumber;
   nbor_list neighbor;
   tile_group(void) : clusterNumber(null_cluster_id) {}
   void sort(void);
   void add(const weventlist& el, size_t inx);
};

inline
tile_group::neighbor_id::neighbor_id(size_t inx, double erg)
   : index(inx), normE(erg)
{}

inline bool 
tile_group::neighbor_id::operator>(const neighbor_id& x) const {
   return normE > x.normE;
}

inline void 
tile_group::sort(void) {
   std::sort(neighbor.begin(), neighbor.end(), 
			 std::greater<tile_group::neighbor_id>());
}

inline void 
tile_group::add(const weventlist& el, size_t inx) {
   neighbor.push_back(neighbor_id(inx, el[inx].normalizedEnergy));
}

//=======================================  Density recursion function.
void 
recurse(std::vector<tile_group>& tiles, size_t tileNumber, 
	int maximumRecursions, int recursionNumber=0);

//////////////////////////////////////////////////////////////////////////////
//                      Density clustering                                  //
//////////////////////////////////////////////////////////////////////////////
void
weventlist::wcluster(weventlist& significants, double clusterRadius, 
		     double clusterDensity, bool clusterSingles, 
		     const std::string& distanceMetric,
		     double durationInflation, double bandwidthInflation, 
		     int debugLevel) {

  // maximum allowed number of recursions for density based clustering
  int maximumRecursions = 100;

  // number of tiles
  size_t numberOfTiles = significants.size();
  if (numberOfTiles < 2) return;

  // sort significant tiles by decreasing normalized energy
  significants.sort();

  ////////////////////////////////////////////////////////////////////////////
  //                              compute distances                         //
  ////////////////////////////////////////////////////////////////////////////

  size_t distMatrixSize = numberOfTiles * (numberOfTiles - 1) / 2;
  lcl_array<double> distances(distMatrixSize);
  significants.wdistance(distances.get(), distanceMetric, durationInflation, 
	    bandwidthInflation);
  
  ////////////////////////////////////////////////////////////////////////////
  //                          density clustering                            //
  ////////////////////////////////////////////////////////////////////////////

  // convert distance structure to

  // begin loop over tiles
  std::vector<tile_group> tiles(numberOfTiles);
  size_t distInx = 0;
  for (size_t tileNumber=0; tileNumber < numberOfTiles-1; tileNumber++) {      
    for (size_t testID=tileNumber+1; testID < numberOfTiles; testID++) {
      if (distances[distInx++] <= clusterRadius) {
	tiles[tileNumber].add(significants, testID);
	tiles[testID].add(significants, tileNumber);
      }
    }
  }

  for (size_t tileNumber=0; tileNumber < numberOfTiles; tileNumber++) {

    // if (number of tiles within cluster radius exceeds critical density,
    if (tiles[tileNumber].neighbor.size() + 1 >= clusterDensity) {

      // identify tile as potential seed
      tiles[tileNumber].clusterNumber = idle_cluster_id;
 
    } // end test for critical density
  }     // end loop over tiles

  // initialize cluster number counter
  size_t numberOfClusters = 0;

  // begin loop over sorted tiles
  for (size_t tileNumber=0; tileNumber < numberOfTiles; tileNumber++) {

    // if (current tile has not been processed,
    if (tiles[tileNumber].clusterNumber == idle_cluster_id) {

      // assign current tile to new cluster
      tiles[tileNumber].clusterNumber = numberOfClusters;

      //  Recursive tile construction.
      recurse(tiles, tileNumber, maximumRecursions);

      // check for merge with existing cluster
      if (tiles[tileNumber].clusterNumber == numberOfClusters) {
	numberOfClusters++;
      }

      // otherwise, skip to the next tile
    }

    // end loop over sorted tiles
  }

  // collect cluster numbers
  for (size_t tileNumber=0; tileNumber < numberOfTiles; tileNumber++) {
    significants[tileNumber].cluster_id = tiles[tileNumber].clusterNumber;
  }

  if (clusterSingles) {
    for (size_t tileNumber=0; tileNumber < numberOfTiles; tileNumber++) {
      if (significants[tileNumber].cluster_id == null_cluster_id) {
	significants[tileNumber].cluster_id = numberOfClusters++;
      }
    }
  }

  /////////////////////////////////////////////////////////////////////////
  //                         cluster properties                          //
  /////////////////////////////////////////////////////////////////////////
  clusterFill(significants, durationInflation, bandwidthInflation);

  for (size_t i=0; i<numberOfTiles; i++) {
     size_t clusterID = significants[i].cluster_id;
     if (clusterID != null_cluster_id) {
	significants[i].cluster_size = _events[clusterID].cluster_size;
     }
  }

 }

//////////////////////////////////////////////////////////////////////////////
//                      Hierarchical clustering                             //
//////////////////////////////////////////////////////////////////////////////
void
weventlist::wcluster(weventlist& significants,  
		     const std::string& clusterLinkage, 
		     const std::string& clusterCriterion, 
		     double clusterThreshold,
		     const std::string& distanceMetric,
		     double durationInflation, double bandwidthInflation, 
		     int debugLevel) {

  // number of tiles
  size_t numberOfTiles = significants.size();
  if (numberOfTiles < 2) return;

  ////////////////////////////////////////////////////////////////////////////
  //                              compute distances                         //
  ////////////////////////////////////////////////////////////////////////////

  size_t distMatrixSize = numberOfTiles * (numberOfTiles - 1) / 2;
  lcl_array<double> distances(distMatrixSize);
  significants.wdistance(distances, distanceMetric, durationInflation, 
			 bandwidthInflation);

  //////////////////////////////////////////////////////////////////////
  //                       hierarchical clustering                    //
  //////////////////////////////////////////////////////////////////////
  size_t numberOfClusters = 0;

  // switch on number of significant tiles
  switch (numberOfTiles) {

    // handle case of zero significant tiles
  case 0:
    break;

    // handle case of one significant tile
  case 1:

     // assign single tile to a single cluster
     significants._events[0].cluster_id = numberOfClusters++;
     break;

    // handle case of more than one significant tile
  default: {
     // produce heirarchical cluster tree
     AgglomClusterTree linktree(numberOfTiles, distances, clusterLinkage);

     // construct clusters from tree
     std::vector<size_t> clustIDs;
     numberOfClusters = linktree.cluster(clusterCriterion, clusterThreshold,
					 clustIDs);
     // copy the cluster 
     for (size_t i=0; i<numberOfTiles; i++) {
	significants._events[i].cluster_id = clustIDs[i];
     }
  }
  } // end switch on number of significant tiles

  /////////////////////////////////////////////////////////////////////////
  //                         cluster properties                          //
  /////////////////////////////////////////////////////////////////////////

  clusterFill(significants, durationInflation, bandwidthInflation);


  for (size_t i=0; i<numberOfTiles; i++) {
     size_t clusterID = significants[i].cluster_id;
     if (clusterID != null_cluster_id) {
	significants[i].cluster_size = _events[clusterID].cluster_size;
     }
  }

}  

void
weventlist::clusterFill(const weventlist& significants, 
			double durationInflation, 
			double bandwidthInflation) {
   size_t numberOfTiles = significants.size();
   size_t numberOfClusters = 0;
   for (size_t i=0; i<numberOfTiles; i++) {
      size_t clusterID = significants[i].cluster_id;
      if (clusterID != null_cluster_id && clusterID >= numberOfClusters) {
	 numberOfClusters = clusterID + 1;
      }
   }

   // Initialize the cluster bank
   wevent zero;
   zero.t_offset  = 0;
   zero.frequency = 0;
   zero.q         = 0;
   zero.duration  = 0;
   zero.bandwidth = 0;
   zero.amplitude = 0;
   zero.normalizedEnergy = 0;
   zero.incoherentEnergy = 0;
   zero.cluster_size = 0;
   zero.cluster_id   = 0;
   _events.clear();
   _events.resize(numberOfClusters, zero);

   _channelName = significants._channelName;
   _refTime = significants._refTime;
   
  //
  for (size_t tileIndex=0; tileIndex<numberOfTiles; tileIndex++) {
    //---------------------------  Get the cluster number for this tile.
    size_t clustNumber = significants[tileIndex].cluster_id;
    if (clustNumber == null_cluster_id) continue;

    // cluster size
    _events[clustNumber].cluster_size++;
    
    // cluster normalized energy
    double normE = significants[tileIndex].normalizedEnergy;
    _events[clustNumber].normalizedEnergy += normE;

    // cluster incoherent energy
    _events[clustNumber].incoherentEnergy +=
      significants[tileIndex].incoherentEnergy;

    // cluster time
    _events[clustNumber].t_offset += 
      significants[tileIndex].t_offset * normE;
    
    // cluster frequency
    _events[clustNumber].frequency += 
      significants[tileIndex].frequency * normE;
    
    // cluster Q
    _events[clustNumber].q += significants[tileIndex].q * normE;
    
    // cluster duration *********  Does this need to be fixed --- e.g.
    //                             need to accumulate duration^2 & time^2
    //                             independently?
    _events[clustNumber].duration += 
      (pow(significants[tileIndex].duration, 2) +
       pow(significants[tileIndex].t_offset, 2)) * normE;
 
    // cluster bandwidth
    _events[clustNumber].bandwidth += 
      (pow(significants[tileIndex].bandwidth, 2) +
       pow(significants[tileIndex].frequency, 2)) * normE;    
  }	// end loop over tiles

  double infFactor = durationInflation * bandwidthInflation;
  for (size_t clusNumber=0; clusNumber<numberOfClusters; clusNumber++) {
     _events[clusNumber].cluster_id = clusNumber;
     double sumNormE = _events[clusNumber].normalizedEnergy;
     _events[clusNumber].normalizedEnergy *= infFactor;
     _events[clusNumber].incoherentEnergy *= infFactor;
     _events[clusNumber].t_offset  /= sumNormE;
     _events[clusNumber].frequency /= sumNormE;
     _events[clusNumber].q         /= sumNormE;
     _events[clusNumber].duration  = infFactor / sumNormE *
	sqrt(_events[clusNumber].duration-pow(_events[clusNumber].t_offset,2));
     _events[clusNumber].bandwidth  = infFactor / sumNormE *
      sqrt(_events[clusNumber].bandwidth-pow(_events[clusNumber].frequency,2));
  } // end of loop over clusters

  //sort();
}

// WDISTANCE Compute distances between significant Q transform tiles
//
// WDISTANCE computes the distance between significant Q transform
// tiles produced by WTHREHSOLD or WSELECT.
//
// wdistance(distances, significants, distanceMetric, 
//           durationInflation, bandwidthInflation);
//
//   distances            cell array of significant tiles distances
//
//   significants         cell array of significant tiles properties
//   distanceMetric       choice of metric for computing distance
//   durationInflation    multiplicative scale factor for duration
//   bandwidthInflation   multiplicative scale factor for bandwidth
//
// WDISTANCE expects a cell array of Q transform event structures with
// one cell per channel.  The event structures must contain at least
// the following fields, which describe the properties of statistically
// significant tiles used to compute distance.
// 
//   time                 center time of tile [gps seconds]
//   frequency            center frequency of tile [Hz]
//   q                    quality factor of tile []
//   normalizedEnergy     normalized energy of tile []
//
// WDISTANCE returns a cell array of Q transform distance structures
// with one cell per cahnnel.  In addition to a structure identifier,
// the distance structures contains the following single field.
//
//   distance             pairwise distance between tiles
//
// Distances are returned in the same format as the PDIST function.
// In particular, distances are reported as row vectors of length
// N * (N - 1) / 2, where N is the number of significant tiles for
// a given channel.  This row vector is arranged in the order of
// (1,2), (1,3), , (1,N), (2,3), , (2,N), , (N-1, N).  Use
// the SQUAREFORM function to convert distances into a matrix format.
//
// The following choices of distance metric are provided
//
//   'pointMismatch'
//
//     The second order expansions of the mismatch metric used for
//     tiling the signal space, evaluated at the center point between
//     two tiles.
//
//   'integratedMismatch'
//
//     The integrated second order expansion of the mismatch metric
//     between two tiles.  The center point between two tiles is used
//     to determine the metric coefficients.  Only the diagonal metric
//     terms are integrated, and these are added in quadrature.
//
//   'logMismatch'
//
//     The exact mismatch between two tiles.  The result is returned
//     as the negative natural logarithm of the overlap between two
//     tiles.
//
//   'euclidean'
// 
//     The dimensionless euclidean distance between tiles in the
//     time-frequency plane after normalizing by the mean duration and
//     bandwidth of each tile pair.
//
//   'modifiedEuclidean'
// 
//     A modified calculation of euclidean distance based on emperical
//     studies that gives more weight to differences in frequency than
//     in time.
//
// The optional durationInflation and bandwidthInflation arguments are
// multiplicative scale factors that are applied to the duration and
// bandwidth of significant tiles prior to determining their distance.
// If not specified, these parameters both default to unity such that
// the resulting tiles have unity time-frequency area.  They are only
// used in the calculation of the euclidean metric distance.
//
// See also WTHRESHOLD, WSELECT, WCLUSTER, SQUAREFORM, and PDIST.

// Rubab Khan <rmk2109@columbia.edu>
// Shourov Chatterji <shourov@ligo.caltech.edu>
void
weventlist::wdistance(double distances[],
		      const std::string& distanceMetric, 
		      double durationInflation, 
		      double bandwidthInflation) {

  ////////////////////////////////////////////////////////////////////////////
  //                      create pairwise list of tiles                     //
  ////////////////////////////////////////////////////////////////////////////
  
  // number of significant tiles
  size_t nSignif = _events.size();
  if (nSignif < 2) return;

  // number of unique significant tile pairs
  //size_t numberOfPairs = nSignif * (nSignif - 1)/2;

  ////////////////////////////////////////////////////////////////////////////
  //                         determine tile distances                       //
  ////////////////////////////////////////////////////////////////////////////
  
  string metric = tolower(distanceMetric);
  // handle case of point mismatch metric distance
  if (metric == "pointmismatch") {

    size_t dist_index = 0;
    for (size_t index1=0; index1<nSignif-1; index1++) {
      const wevent& sig1 = _events[index1];
      for (size_t index2=index1+1; index2<nSignif; index2++) {
	const wevent& sig2 = _events[index2];
	// mean properties of pairs of tiles
	double meanFrequency = sqrt(sig1.frequency * sig2.frequency);
	double meanQ = sqrt(sig1.q * sig2.q);

	// parameter distances between pairs of tiles
	double timeDistance = double(sig2.t_offset - sig1.t_offset);
	double frequencyDistance = sig2.frequency - sig1.frequency;
	double qDistance = sig2.q - sig1.q;

	// mismatch metric distances between pairs of tiles
	distances[dist_index++] = 
	  pow(timeDistance * 2 * pi * meanFrequency / meanQ, 2)  + 
	  + pow(frequencyDistance, 2) * (2 + meanQ*meanQ) / 
	    (4 * meanFrequency * meanFrequency)
	  + qDistance*qDistance / (2 * meanQ * meanQ) 
	  - frequencyDistance * qDistance / (meanFrequency * meanQ);
      }
    }
  }

  //------------------------------------  Integrated mismatch metric distance
  else if (metric == "integratedmismatch") {

    size_t dist_index = 0;
    for (size_t index1=0; index1<nSignif-1; index1++) {
      const wevent& sig1 = _events[index1];
      for (size_t index2=index1+1; index2<nSignif; index2++) {
	const wevent& sig2 = _events[index2];

	// mean properties of pairs of tiles
	double meanFrequency = sqrt(sig1.frequency * sig2.frequency);
	double meanQ = sqrt(sig1.q * sig2.q);

	// parameter distances between pairs of tiles
	double timeDistance = double(sig2.t_offset - sig1.t_offset) *
	                      twopi*meanFrequency/meanQ;
	double frequencyDistance = 0.5 * sqrt(2 + meanQ*meanQ) 
	                         * log(sig2.frequency / sig1.frequency);

	double qDistance = log(sig2.q / sig1.q) / sqrt(2);

	//cout << "t1 = " << sig1.t_offset << " t2 = " << sig2.t_offset
	//     << " <f> = " << meanFrequency  << " <Q> = " << meanQ << endl;
	// mismatch metric distances between pairs of tiles
	distances[dist_index++] = sqrt(timeDistance*timeDistance
				     + frequencyDistance*frequencyDistance
				     + qDistance * qDistance);
	//cout << "distance [" << index1 << "," << index2 << "] = " 
	//     << distances[dist_index-1] << " dt = " << timeDistance 
	//     << " dF = " << frequencyDistance << " dQ = " << qDistance
	//     << endl;
      }
    }
  }

  // handle case of log mismatch distance
  else if (metric ==  "logmismatch") {  
    error("logMismatch metric not yet implemented");
  }

  // handle case of euclidean distance
  else if (metric ==  "euclidean") {
    
    // tile dimensions
    //bandwidth1 = 2 * sqrt(pi) * frequency1 ./ q1;
    //bandwidth2 = 2 * sqrt(pi) * frequency2 ./ q2;
    //duration1 = 1 ./ bandwidth1;
    //duration2 = 1 ./ bandwidth2;

    // apply tile inflation factors
    //duration1 = duration1 * durationInflation;
    //duration2 = duration2 * durationInflation;
    //bandwidth1 = bandwidth1 * bandwidthInflation;
    //bandwidth2 = bandwidth2 * bandwidthInflation;

    // time and frequency distance scales
    // timeScale = (duration1 .* normalizedEnergy1 + 
    //	 duration2 .* normalizedEnergy2) / 
    //  (normalizedEnergy1 + normalizedEnergy2);
    //frequencyScale = (bandwidth1 .* normalizedEnergy1 + 
    //		      bandwidth2 .* normalizedEnergy2) / 
    //(normalizedEnergy1 + normalizedEnergy2);

    // normalized time and frequency distance between tiles
    //timeDistance = abs(time2 - time1) / timeScale;
    //frequencyDistance = abs(frequency2 - frequency1) / frequencyScale;

    // normalized euclidean distance between tiles
    //    distances{channelNumber}.distance = 
    //			       sqrt(timeDistance.^2 + frequencyDistance.^2);
    throw runtime_error("Not implemented");
  }

  // handle case of modified euclidean distance
  else if (metric == "modifiedeuclidean") {
    
    //  // tile dimensions
    //bandwidth1 = 2 * sqrt(pi) * frequency1 ./ q1;
    //bandwidth2 = 2 * sqrt(pi) * frequency2 ./ q2;
    //duration1 = 1 ./ bandwidth1;
    //duration2 = 1 ./ bandwidth2;

    // apply tile inflation factors
    //duration1 = duration1 * durationInflation;
    //duration2 = duration2 * durationInflation;
    //bandwidth1 = bandwidth1 * bandwidthInflation;
    //bandwidth2 = bandwidth2 * bandwidthInflation;

    // time and frequency distance scales
    //timeScale = (duration1 .* normalizedEnergy1 + 
    //	 duration2 .* normalizedEnergy2) / 
    // (normalizedEnergy1 + normalizedEnergy2);
    //frequencyScale = (bandwidth1 .* normalizedEnergy1 + 
    //	      bandwidth2 .* normalizedEnergy2) / 
    //(normalizedEnergy1 + normalizedEnergy2);

    // normalized time and frequency distance between tiles
    //double timeDistance = abs(time2 - time1) / timeScale;
    //double frequencyDistance = abs(frequency2 - frequency1) / frequencyScale;

    // modified normalized euclidean distance between tiles
    //distances[dist_index] = sqrt(timeDistance*timeDistance 
    //				 + 30.0 * frequencyDistance*frequencyDistance);
    throw runtime_error("Not implemented");
  }

  // handle unknown distance metric
  else {

    // report error
    error(string("unknown distance metric '") + distanceMetric + "'");
  }  // end switch on distance metric

  // for (int i=0; i<100; i++) cout << i << " " << distances[i] << endl;
}

//////////////////////////////////////////////////////////////////////////////
//                            recursion subfunction                        //
//////////////////////////////////////////////////////////////////////////////

// RECURSE Auxiliary function for density based clustering algorithm
//
// RECURSE is an auxiliary function for the WCLUSTER density based clustering
// algorithm.  It implements recursive clustering of tiles and should only be
// called by WCLUSTER.
//
// usage: tiles = recurse(std::vector<tile_group>& tiles, size_t tileNumber, 
//                        int maximumRecursions=100, int recursionNumber=1);
//
//   tiles               input tile neighbor and cluster structure
//   tileNumber          current tile under test
//   maximumRecursions   limit on allowed recursion depth
//   recursionNumber     current recursion depth
//
//   tiles               updated tile neighbor and cluster structure
void 
recurse(std::vector<tile_group>& tiles, size_t tileNumber, 
	int maximumRecursions, int recursionNumber) {

  tile_group& tileN = tiles[tileNumber];

  // return if (recursion limit is exceeded
  if (recursionNumber >= maximumRecursions) {
    tileN.clusterNumber = 0;
    return;
  }

  // increment recursion number counter
  recursionNumber++;

  // determine number of significant tiles
  size_t numberOfTiles = tiles.size();

  // begin loop over neighboring tiles
  for (size_t neighborInx=0; neighborInx<tileN.neighbor.size(); neighborInx++){
    
    // if (neighbor tile has less than critical density,
    size_t neighborNum   = tileN.neighbor[neighborInx].index;
    size_t neighborClust = tiles[neighborNum].clusterNumber;

    if (neighborClust == null_cluster_id) {

      // assign neighbor tile as border tile of current cluster
      tiles[neighborNum].clusterNumber = tileN.clusterNumber;
    }

    // if neighbor tile has critical density and has not been processed
    else if (neighborClust == idle_cluster_id) {

      // assign neighbor tile to current cluster
      tiles[neighborNum].clusterNumber = tileN.clusterNumber;

      // continue to recursively build the cluster
      recurse(tiles, neighborNum, maximumRecursions, recursionNumber);
    }

    // if (neighbor tile is already in a different cluster
    else if (neighborClust != tileN.clusterNumber) {
    
      // merge current cluster into the other cluster
      for (size_t mergeTileNumber=0; mergeTileNumber<numberOfTiles;
	   mergeTileNumber++) {
	if (tiles[mergeTileNumber].clusterNumber == tileN.clusterNumber) {
	  tiles[mergeTileNumber].clusterNumber = neighborClust;
	}
      } // end merge loop.
    }
  }   // end loop over neighboring tiles
}

//======================================  Multi-channel density clustering
void
weventstack::wcluster(weventstack& significants, double clusterRadius, 
		      double clusterDensity, bool clusterSingles, 
		      const std::string& distanceMetric,
		      double durationInflation, double bandwidthInflation, 
		      int debugLevel) {
   size_t N = significants.numberOfChannels();
   _lists.resize(N, weventlist("cluster"));
   for (size_t i=0; i<N; i++) {
      _lists[i].wcluster(significants._lists[i], clusterRadius, clusterDensity,
			 clusterSingles, distanceMetric, durationInflation,
			 bandwidthInflation, debugLevel);
   }
}

//====================================== Multi-channel hierarchical clustering
void
weventstack::wcluster(weventstack& significants,  
		     const std::string& clusterLinkage, 
		     const std::string& clusterCriterion, 
		     double clusterThreshold,
		     const std::string& distanceMetric,
		     double durationInflation, double bandwidthInflation, 
		     int debugLevel) {
   size_t N = _lists.size();
   _lists.resize(N, weventlist("cluster"));
   for (size_t i=0; i<N; i++) {
      _lists[i].wcluster(significants._lists[i], clusterLinkage, 
			 clusterCriterion, clusterThreshold, distanceMetric,
			 durationInflation, bandwidthInflation, debugLevel);
   }
}
