/*
  Copyright(C) 2007-2012 National Institute of Information and Communications Technology
*/

/*
  A Latent Variable Model for Sentiment Classification
*/


#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <float.h>
#include "lsm_common.h"
#include "exception.h"
#include "model.h"
#include "lbfgs.h"


#define BUF_MAX (32 * 1024)	/* $BF~NO%P%C%U%!$N%5%$%:(B */
#define SGM 1.0		/* Gaussian prior$B$NJ,;6(B */
#define EPS 1.0e-5	/* epsilon */
#define M 5	/* L-BFGS$B$N%Q%i%a!<%?(B */
#define DLT 0.001	/* $B<}B+>r7o(B */
#define ITER 10		/* L-BFGS$B$N<}B+H=Dj(Biteration$B?t(B */
#define PAR_MIN -0.1	/* $B%Q%i%a!<%?=i4|CM(B($BMp?t(B)$B$N:G>.CM(B */
#define PAR_MAX +0.1	/* $B%Q%i%a!<%?=i4|CM(B($BMp?t(B)$B$N:GBgCM(B */


static int readtrain(FILE *fp, XPR **xpr);
static MDL *train(int xpr_num, XPR *xpr);


/*
  main
*/
int main(int argc, char **argv) {
  MDL *mdl;
  XPR *xpr;
  int xpr_num;

  srand(1);

  /* $B%*%W%7%g%s$N=hM}(B */
  if (argc != 1) {
    fprintf(stderr, "usage: %s < <training *.xpr file> > <model file>\n", argv[0]);
    return 1;
  }

  /* $B71N}%G!<%?$NFI$_9~$_(B */
  fprintf(stderr, "# Reading the training data... ");
  xpr_num = readtrain(stdin, &xpr);
  exception(xpr_num < 0, "cannot read the training data");
  fprintf(stderr, "done.\n");

  /* $B3X=,(B */
  mdl = train(xpr_num, xpr);

  /* $B%b%G%k$N=PNO(B */
  fprintf(stderr, "# Writing the model data... ");
  mdl_write(mdl, stdout);
  fprintf(stderr, "done.\n");

  return 0;
}


/*
  $B71N}%G!<%?$NFI$_9~$_(B
*/
static int readtrain(FILE *fp, XPR **xpr) {
  int xpr_num, xpr_size;

  xpr_size = 0;
  *xpr = NULL;

  for (xpr_num = 0; ; xpr_num++) {
    if (xpr_num + 1 > xpr_size) {
      xpr_size = 2 * xpr_size + 1;
      *xpr = srealloc(*xpr, sizeof(XPR) * xpr_size);
    }

    (*xpr)[xpr_num].buf_size = -1;
    if (readxpr(&(*xpr)[xpr_num], fp)) break;
  }

  if (xpr_num > 0) {
    xpr_size = xpr_num;
    *xpr = srealloc(*xpr, sizeof(XPR) * xpr_size);
  }

  return xpr_num;
}


/*
  $B3X=,(B
*/
static MDL *train(int xpr_num, XPR *xpr) {
  int i, n, c, d, k;
  MDL *mdl;
  int iter;
  int iprint[2], iflag;
  double *diag, *wrk;
  double l, *ll;
  double l_old = -DBL_MAX;
  int l_cnt = 0;
  int *f1, *f2, f1_num, f2_num, f1_size, f2_size;
  int xpr_sum;
  double *prb1u, *prb2u, *prb1c, *prb2c, p;
  double *pot1, *pot2;
  VAR *v;
  FAC *f;

  /* $B%b%G%k$N=i4|2=(B */
  mdl = mdl_new();
  exception(mdl == NULL, "mdl_new() failed");

  /* $BAG@-%Y%/%H%k$N:n@.(B */
  xpr_sum = 0;
  for (n = 0; n < xpr_num; n++) xpr_sum += xpr[n].num;
  f1_size = 0;
  f1 = NULL;
  f2_size = 0;
  f2 = NULL;
  f1_num = 0;
  f2_num = 0;
  for (n = 0; n < xpr_num; n++) {
    for (i = 1; i < xpr[n].num; i++) {
      for (c = 0; c < CLS_MAX; c++) {
	while (f1_num + BUF_MAX > f1_size) {
	  f1_size = 2 * f1_size + 1;
	  f1 = srealloc(f1, sizeof(int) * f1_size);
	}
	makeftr1(mdl->ft, 1, &xpr[n], i, c, f1_size, &f1_num, f1);	/* unigram$BAG@-(B */
	for (d = 0; d < CLS_MAX; d++) {
	  while (f2_num + BUF_MAX > f2_size) {
	    f2_size = 2 * f2_size + 1;
	    f2 = srealloc(f2, sizeof(int) * f2_size);
	  }
	  makeftr2(mdl->ft, 1, &xpr[n], i, xpr[n].head[i], c, d, f2_size, &f2_num, f2);
	}
      }
    }
  }

  /* $B%Q%i%a!<%?$N=i4|2=(B */
  mdl->lmd = smalloc(sizeof(double) * sdb_size(mdl->ft));
  for (k = 0; k < sdb_size(mdl->ft); k++) {
    double r;
    char *name;
    int p, q;

    name = sdb_str(mdl->ft, k);
    if (name[0] == 'A') {
      exception(sscanf(name, "A:%d:%d", &p, &q) != 2, "invalid features");
      if (p == q) {
	mdl->lmd[k] = 1.0;
      } else {
	mdl->lmd[k] = 0.0;
      }
    } else {
      mdl->lmd[k] = 0.0;
    }
    r = rand() / (double)RAND_MAX;
    r = r * (PAR_MAX - PAR_MIN) + PAR_MIN;
    mdl->lmd[k] += r;
  }

  /* L-BFGS$BMQ$N=i4|2=(B */
  iprint[0] = -1;
  iprint[1] = 0;
  iflag = 0;
  diag = smalloc(sizeof(double) * sdb_size(mdl->ft));
  wrk = smalloc(sizeof(double) * (sdb_size(mdl->ft) * (2 * M + 1) + 2 * M));
  ll = smalloc(sizeof(double) * sdb_size(mdl->ft));

  /* $B$=$NB>(BBP$BMQJQ?t$N=i4|2=(B */
  prb1u = smalloc(sizeof(double) * TKN_MAX * CLS_MAX);
  prb2u = smalloc(sizeof(double) * TKN_MAX * CLS_MAX * CLS_MAX);
  prb1c = smalloc(sizeof(double) * TKN_MAX * CLS_MAX);
  prb2c = smalloc(sizeof(double) * TKN_MAX * CLS_MAX * CLS_MAX);
  v = smalloc(sizeof(VAR) * TKN_MAX);
  f = smalloc(sizeof(FAC) * 2 * TKN_MAX);
  pot1 = smalloc(sizeof(double) * TKN_MAX * CLS_MAX);
  pot2 = smalloc(sizeof(double) * TKN_MAX * CLS_MAX * CLS_MAX);

  /* $B%a%C%;!<%8(B */
  fprintf(stderr, "# Training (%d features, delta=%g, iter=%d, sigma=%g, eps=%g)...\n", (int)sdb_size(mdl->ft), DLT, ITER, SGM, EPS);

  /* $B:GE,2=7W;;(B($B:G>.2=(B) */
  for (iter = 1; ; iter++) {
    f1_num = 0;
    f2_num = 0;

    /* $BL\E*4X?t$H8{G[$N7W;;(B */
    l = 0.0;
    for (k = 0; k < sdb_size(mdl->ft); k++) {
      l += mdl->lmd[k] * mdl->lmd[k];
      ll[k] = mdl->lmd[k] / (SGM * SGM);
    }
    l /= 2.0 * SGM * SGM;
    for (n = 0; n < xpr_num; n++) {
      int tmp1, tmp2;

      /* $B%]%F%s%7%c%k$N=i4|2=(B */
      tmp1 = f1_num;
      tmp2 = f2_num;
      for (c = 0; c < CLS_MAX; c++) POT1(pot1, 0, c) = 0.0;
      for (i = 1; i < xpr[n].num; i++) {
	double tmp;
	for (c = 0; c < CLS_MAX; c++) {
	  tmp = 0.0;
	  while (f1[f1_num] != -2) {
	    tmp += mdl->lmd[f1[f1_num++]];
	  }
	  f1_num++;
	  POT1(pot1, i, c) = tmp;
	  for (d = 0; d < CLS_MAX; d++) {
	    tmp = 0.0;
	    while (f2[f2_num] != -2) {
	      tmp += mdl->lmd[f2[f2_num++]];
	    }
	    f2_num++;
	    POT2(pot2, i, c, d) = tmp;
	  }
	}
      }
      f1_num = tmp1;
      f2_num = tmp2;

      /* BP($B@)LsL5$7(B) */
      bp(&xpr[n], pot1, pot2, v, f, prb1u, prb2u);
      /* $B7k2L$NI=<((B */
#if 0
      {
	fprintf(stderr, "BP(unconstrained) %d\n", xpr[n].lbl);
	for (i = 0; i < xpr[n].num; i++) {
	  fprintf(stderr, "%d:%d:%s:%d:%d : ", i, xpr[n].head[i], xpr[n].surf[i], xpr[n].pol[i], xpr[n].rev[i]);
	  for (c = 0; c < CLS_MAX; c++) {
	    fprintf(stderr, "%f, ", PRB1(prb1u, i, c));
	  }
	  fprintf(stderr, "\n");
	}
	fprintf(stderr, "\n");
      }
#endif

      /* $B@)Ls$rF~$l$?%]%F%s%7%c%k(B */
      if (xpr[n].lbl == +1) {
	for (c = CLS_MAX / 2; c < CLS_MAX; c++) {
	  POT1(pot1, 0, c) = -DBL_MAX;
	}
      } else {
	for (c = 0; c < CLS_MAX / 2; c++) {
	  POT1(pot1, 0, c) = -DBL_MAX;
	}
      }

      /* BP($B@)LsM-$j(B) */
      bp(&xpr[n], pot1, pot2, v, f, prb1c, prb2c);
#if 0
      {
	fprintf(stderr, "BP(constrained) %d\n", xpr[n].lbl);
	for (i = 0; i < xpr[n].num; i++) {
	  fprintf(stderr, "%d:%d:%s:%d:%d : ", i, xpr[n].head[i], xpr[n].surf[i], xpr[n].pol[i], xpr[n].rev[i]);
	  for (c = 0; c < CLS_MAX; c++) {
	    fprintf(stderr, "%f, ", PRB1(prb1c, i, c));
	  }
	  fprintf(stderr, "\n");
	}
	fprintf(stderr, "\n");
      }
#endif

      p = 0.0;
      for (c = 0; c < CLS_MAX / 2; c++) p += PRB1(prb1u, 0, c);
      if (xpr[n].lbl == -1) p = 1.0 - p;
      //fprintf(stderr, "@ %f\n", l);
      l -= log(p);
      //fprintf(stderr, "@@ %f\n", l);

      /* unigram$BAG@-(B */
      for (i = 1; i < xpr[n].num; i++) {
	for (c = 0; c < CLS_MAX; c++) {
	  while (f1[f1_num] != -2) {
	    ll[f1[f1_num++]] += PRB1(prb1u, i, c) - PRB1(prb1c, i, c);
	  }
	  f1_num++;
	}
      }
      /* bigram$BAG@-(B */
      for (i = 1; i < xpr[n].num; i++) {
	for (c = 0; c < CLS_MAX; c++) {
	  for (d = 0; d < CLS_MAX; d++) {
	    while (f2[f2_num] != -2) {
	      ll[f2[f2_num++]] += PRB2(prb2u, i, c, d) - PRB2(prb2c, i, c, d);
	    }
	    f2_num++;
	  }
	}
      }
    }

    /* $B%Q%i%a!<%?$NI=<((B */
#if 0
    fprintf(stderr, "### Current parameter\n");
    for (k = 0; k < sdb_size(mdl->ft); k++) {
      fprintf(stderr, " %d %s -- %f\n", k, sdb_str(mdl->ft, k), mdl->lmd[k]);
    }
#endif

    /* $B8{G[$NI=<((B */
#if 0
    fprintf(stderr, "### Gradient (l=%f)\n", l);
    for (k = 0; k < sdb_size(mdl->ft); k++) {
      fprintf(stderr, " %d %s -- %f\n", k, sdb_str(mdl->ft, k), ll[k]);
    }
#endif

    /* $B%m%0I=<((B */
    fprintf(stderr, "#  %5d\t%f\n", iter, l);

    /* L-BFGS$B$K$h$k:GE,2=(B */
    lbfgs(sdb_size(mdl->ft), M, mdl->lmd, l, ll, 0, diag, iprint, EPS, DBL_EPSILON, wrk, &iflag);

    if (iflag < 0) {
      fprintf(stderr, "********** ERROR in LBFGS **********\n");
      fprintf(stderr, "#  the optimal solution is not solved\n");
      break;
    }
    if (iflag == 0) break;

    if (l < l_old && (l_old - l) / l_old < DLT) l_cnt++; else l_cnt = 0;
    if (l_cnt >= ITER) break;
    l_old = l;
  }
  fprintf(stderr, "#  completed.\n");

  return mdl;
}
