// Jubatus: Online machine learning framework for distributed environment
// Copyright (C) 2016 Preferred Networks and Nippon Telegraph and Telephone Corporation.
//
// This library is free software; you can redistribute it and/or
// modify it under the terms of the GNU Lesser General Public
// License version 2.1 as published by the Free Software Foundation.
//
// This library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
// Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public
// License along with this library; if not, write to the Free Software
// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA

#include <utility>
#include <map>
#include <string>
#include <vector>
#include <cfloat>

#include "cosine_similarity_regression.hpp"
#include "nearest_neighbor_regression_util.hpp"

using jubatus::util::lang::shared_ptr;
using jubatus::util::concurrent::scoped_lock;

namespace jubatus {
namespace core {
namespace regression {

cosine_similarity_regression::cosine_similarity_regression(
    const config& conf): inverted_index_regression(conf) {
}

double cosine_similarity_regression::estimate(
    const common::sfv_t& fv) const {
  std::vector<std::pair<std::string, double> > ids;
  {
    util::concurrent::scoped_rlock lk(storage_mutex_);
    mixable_storage_->get_model()->calc_scores(
        fv, ids, config_.nearest_neighbor_num);
  }
  if (ids.size() > 0) {
    double sum = 0.0;
    if (config_.weight && *config_.weight == "distance") {
      double sum_w = 0.0;
      if (1.0 - ids[0].second <= DBL_EPSILON) {
        // in case same points exists, return mean value of their target values.
        for (std::vector<std::pair<std::string, double> >:: const_iterator
               it = ids.begin(); it != ids.end(); ++it) {
          if (1.0 - it->second > DBL_EPSILON) {
            break;
          }
          const std::pair<bool, uint64_t> index =
              values_->get_model()->exact_match(it->first);
          sum += values_->get_model()->get_double_column(0)[index.second];
          sum_w += 1.0;
        }
      } else {
        for (std::vector<std::pair<std::string, double> >:: const_iterator
               it = ids.begin(); it != ids.end(); ++it) {
          double w = 1.0 / (std::abs(1.0 - it->second));
          const std::pair<bool, uint64_t> index =
              values_->get_model()->exact_match(it->first);
          sum += w * values_->get_model()->get_double_column(0)[index.second];
          sum_w += w;
        }
      }
      return sum / sum_w;
    } else {
      for (std::vector<std::pair<std::string, double> >:: const_iterator
               it = ids.begin(); it != ids.end(); ++it) {
        const std::pair<bool, uint64_t> index =
            values_->get_model()->exact_match(it->first);
        sum += values_->get_model()->get_double_column(0)[index.second];
    }
      return sum / ids.size();
    }
  } else {
    return 0.0;
  }
}

std::string cosine_similarity_regression::name() const {
    return "cosine similarity regression";
}

}  // namespace regression
}  // namespace core
}  // namespace jubatus
