/* PSPP - a program for statistical analysis.
   Copyright (C) 2017, 2019 Free Software Foundation, Inc.

   This program is free software: you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation, either version 3 of the License, or
   (at your option) any later version.

   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with this program.  If not, see <http://www.gnu.org/licenses/>. */

#include <config.h>

#include "matrix-reader.h"

#include <stdbool.h>

#include <libpspp/message.h>
#include <libpspp/str.h>
#include <data/casegrouper.h>
#include <data/casereader.h>
#include <data/dictionary.h>
#include <data/variable.h>
#include <data/data-out.h>
#include <data/format.h>

#include "gettext.h"
#define _(msgid) gettext (msgid)
#define N_(msgid) msgid


/*
This module interprets a "data matrix", typically generated by the command
MATRIX DATA.  The dictionary of such a matrix takes the form:

 s_0, s_1, ... s_m, ROWTYPE_, VARNAME_, v_0, v_1, .... v_n

where s_0, s_1 ... s_m are the variables defining the splits, and
v_0, v_1 ... v_n are the continuous variables.

m >= 0; n >= 0

The ROWTYPE_ variable is of type A8.
The VARNAME_ variable is a string type whose width is not predetermined.
The variables s_x are of type F4.0 (although this reader accepts any type),
and v_x are of any numeric type.

The values of the ROWTYPE_ variable are in the set {MEAN, STDDEV, N, CORR, COV}
and determine the purpose of that case.
The values of the VARNAME_ variable must correspond to the names of the varibles
in {v_0, v_1 ... v_n} and indicate the rows of the correlation or covariance
matrices.



A typical example is as follows:

s_0 ROWTYPE_   VARNAME_   v_0         v_1         v_2

0   MEAN                5.0000       4.0000       3.0000
0   STDDEV              1.0000       2.0000       3.0000
0   N                   9.0000       9.0000       9.0000
0   CORR       V1       1.0000        .6000        .7000
0   CORR       V2        .6000       1.0000        .8000
0   CORR       V3        .7000        .8000       1.0000
1   MEAN                9.0000       8.0000       7.0000
1   STDDEV              5.0000       6.0000       7.0000
1   N                   9.0000       9.0000       9.0000
1   CORR       V1       1.0000        .4000        .3000
1   CORR       V2        .4000       1.0000        .2000
1   CORR       V3        .3000        .2000       1.0000

*/

struct matrix_reader
{
  const struct dictionary *dict;
  const struct variable *varname;
  const struct variable *rowtype;
  struct casegrouper *grouper;

  gsl_matrix *n_vectors;
  gsl_matrix *mean_vectors;
  gsl_matrix *var_vectors;
};

struct matrix_reader *
create_matrix_reader_from_case_reader (const struct dictionary *dict, struct casereader *in_reader,
				       const struct variable ***vars, size_t *n_vars)
{
  struct matrix_reader *mr = xzalloc (sizeof *mr);

  mr->varname = dict_lookup_var (dict, "varname_");
  mr->dict = dict;
  if (mr->varname == NULL)
    {
      msg (ME, _("Matrix dataset lacks a variable called %s."), "VARNAME_");
      free (mr);
      return NULL;
    }

  if (!var_is_alpha (mr->varname))
    {
      msg (ME, _("Matrix dataset variable %s should be of string type."),
	   "VARNAME_");
      free (mr);
      return NULL;
    }

  mr->rowtype = dict_lookup_var (dict, "rowtype_");
  if (mr->rowtype == NULL)
    {
      msg (ME, _("Matrix dataset lacks a variable called %s."), "ROWTYPE_");
      free (mr);
      return NULL;
    }

  if (!var_is_alpha (mr->rowtype))
    {
      msg (ME, _("Matrix dataset variable %s should be of string type."),
	   "ROWTYPE_");
      free (mr);
      return NULL;
    }

  size_t dvarcnt;
  const struct variable **dvars = NULL;
  dict_get_vars (dict, &dvars, &dvarcnt, DC_SCRATCH);

  if (n_vars)
    *n_vars = dvarcnt - var_get_dict_index (mr->varname) - 1;

  if (vars)
    {
      int i;
      *vars = xcalloc (*n_vars, sizeof (struct variable **));

      for (i = 0; i < *n_vars; ++i)
	{
	  (*vars)[i] = dvars[i + var_get_dict_index (mr->varname) + 1];
	}
    }

  /* All the variables before ROWTYPE_ (if any) are split variables */
  mr->grouper = casegrouper_create_vars (in_reader, dvars, var_get_dict_index (mr->rowtype));

  free (dvars);

  return mr;
}

bool
destroy_matrix_reader (struct matrix_reader *mr)
{
  if (mr == NULL)
    return false;
  bool ret = casegrouper_destroy (mr->grouper);
  free (mr);
  return ret;
}


/*
   Allocates MATRIX if necessary,
   and populates row MROW, from the data in C corresponding to
   variables in VARS. N_VARS is the length of VARS.
*/
static void
matrix_fill_row (gsl_matrix **matrix,
      const struct ccase *c, int mrow,
      const struct variable **vars, size_t n_vars)
{
  int col;
  if (*matrix == NULL)
    *matrix = gsl_matrix_alloc (n_vars, n_vars);

  for (col = 0; col < n_vars; ++col)
    {
      const struct variable *cv = vars [col];
      double x = case_data (c, cv)->f;
      assert (col  < (*matrix)->size2);
      assert (mrow < (*matrix)->size1);
      gsl_matrix_set (*matrix, mrow, col, x);
    }
}

bool
next_matrix_from_reader (struct matrix_material *mm,
			 struct matrix_reader *mr,
			 const struct variable **vars, int n_vars)
{
  struct casereader *group;

  assert (vars);

  gsl_matrix_free (mr->n_vectors);
  gsl_matrix_free (mr->mean_vectors);
  gsl_matrix_free (mr->var_vectors);

  if (!casegrouper_get_next_group (mr->grouper, &group))
    return false;

  mr->n_vectors    = gsl_matrix_alloc (n_vars, n_vars);
  mr->mean_vectors = gsl_matrix_alloc (n_vars, n_vars);
  mr->var_vectors  = gsl_matrix_alloc (n_vars, n_vars);

  mm->n = mr->n_vectors;
  mm->mean_matrix = mr->mean_vectors;
  mm->var_matrix = mr->var_vectors;

  struct substring *var_names = XCALLOC (n_vars,  struct substring);
  for (int i = 0; i < n_vars; ++i)
    {
      ss_alloc_substring (var_names + i, ss_cstr (var_get_name (vars[i])));
    }

  struct ccase *c;
  for (; (c = casereader_read (group)); case_unref (c))
    {
      const union value *uv = case_data (c, mr->rowtype);
      const char *row_type = CHAR_CAST (const char *, uv->s);
      int col, row;
      for (col = 0; col < n_vars; ++col)
	{
	  const struct variable *cv = vars[col];
	  double x = case_data (c, cv)->f;
	  if (0 == strncasecmp (row_type, "N       ", 8))
	    for (row = 0; row < n_vars; ++row)
	      gsl_matrix_set (mr->n_vectors, row, col, x);
	  else if (0 == strncasecmp (row_type, "MEAN    ", 8))
	    for (row = 0; row < n_vars; ++row)
	      gsl_matrix_set (mr->mean_vectors, row, col, x);
	  else if (0 == strncasecmp (row_type, "STDDEV  ", 8))
	    for (row = 0; row < n_vars; ++row)
	      gsl_matrix_set (mr->var_vectors, row, col, x * x);
	}

      const char *enc = dict_get_encoding (mr->dict);

      const union value *uvv  = case_data (c, mr->varname);
      int w = var_get_width (mr->varname);

      struct fmt_spec fmt = { .type = FMT_A };
      fmt.w = w;
      char *vname = data_out (uvv, enc, &fmt, settings_get_fmt_settings ());
      struct substring the_name = ss_cstr (vname);

      int mrow = -1;
      for (int i = 0; i < n_vars; ++i)
	{
	  if (ss_equals (var_names[i], the_name))
	    {
	      mrow = i;
	      break;
	    }
	}
      free (vname);

      if (mrow == -1)
	continue;

      if (0 == strncasecmp (row_type, "CORR    ", 8))
	{
	  matrix_fill_row (&mm->corr, c, mrow, vars, n_vars);
	}
      else if (0 == strncasecmp (row_type, "COV     ", 8))
	{
	  matrix_fill_row (&mm->cov, c, mrow, vars, n_vars);
	}
    }

  casereader_destroy (group);

  for (int i = 0; i < n_vars; ++i)
    ss_dealloc (var_names + i);
  free (var_names);

  return true;
}
