/*
  Copyright(C) 2007-2012 National Institute of Information and Communications Technology
*/

/*
  A Latent Variable Model for Sentiment Classification
  ̥⥸塼
*/


#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <float.h>
#include "lsm_common.h"
#include "exception.h"
#include "split.h"


#define BUF_MAX (16 * 32 * 1024)	/* ϥХåեΥ */
#define LOGSUMEXP(a, b) (((a) > (b)) ? (a) + log(1.0 + exp((b) - (a))) : (b) + log(1.0 + exp((a) - (b))))


/*
  ǡɤ߹
*/
int readxpr(XPR *xpr, FILE *fp) {
  int i, j;
  int mem_num;
  int len;
  int tkn_num;
  char *tkn[1 + 7 * TKN_MAX + 1];
  int adj = 0;

  /* ɬפʤ */
  if (xpr->buf_size < 0) {
    xpr->buf_size = BUF_MAX;
    xpr->buf = smalloc(sizeof(char) * xpr->buf_size);
    adj = 1;
  }

  mem_num = 0;
  if (fgets(xpr->buf, xpr->buf_size, fp) == NULL) return 1;
  len = strlen(xpr->buf);
  exception(xpr->buf[len - 1] != '\n', "buffer overflow");
  xpr->buf[--len] = '\0';

  /* ɬפʤڤͤ */
  if (adj) {
    xpr->buf_size = len + 1;
    xpr->buf = srealloc(xpr->buf, sizeof(char) * xpr->buf_size);
  }

  tkn_num = split(xpr->buf, '\t', tkn, 1 + 7 * TKN_MAX + 1);
  exception(tkn_num > 1 + 7 * TKN_MAX, "too many tokens");
  exception((tkn_num - 1) % 7 != 0, "wrong format");
  tkn_num = (tkn_num - 1) / 7 + 1;

  xpr->lbl = atoi(tkn[0]);
  xpr->num = tkn_num;

  /* root */
  xpr->surf[0] = "";
  xpr->base[0] = "";
  xpr->cpos[0] = "";
  xpr->fpos[0] = "";
  xpr->pol[0] = -1;
  xpr->rev[0] = 0;

  /* ƥȡξ¸ */
  for (i = 1; i < tkn_num; i++) {
    char *s;

    j = 7 * i - 6;

    xpr->head[i] = atoi(tkn[j]);
    xpr->surf[i] = tkn[j + 1];
    for (s = tkn[j + 1]; *s != '\0'; s++) if (*s == ' ') *s = '\0';
    xpr->base[i] = tkn[j + 2];
    for (s = tkn[j + 2]; *s != '\0'; s++) if (*s == ' ') *s = '\0';
    xpr->cpos[i] = tkn[j + 3];
    for (s = tkn[j + 3]; *s != '\0'; s++) if (*s == ' ') *s = '\0';
    xpr->fpos[i] = tkn[j + 4];
    for (s = tkn[j + 4]; *s != '\0'; s++) if (*s == ' ') *s = '\0';
    xpr->pol[i] = atoi(tkn[j + 5]);
    xpr->rev[i] = atoi(tkn[j + 6]);
  }

  return 0;
}


/*
  node
  οFTR1_MAXǸ
*/
void makeftr1(SDB *ft, int add, XPR *xpr, int i, int c, int ftr_size, int *ftr_num, int *ftr) {
  char buf[BUF_MAX];
  char *s, *ss;

  sprintf(buf, "a:%d:", c);
  ftr[(*ftr_num)++] = sdb_id(ft, buf, add);

  sprintf(buf, "b:%d:%d", c, xpr->pol[i]);
  ftr[(*ftr_num)++] = sdb_id(ft, buf, add);

  sprintf(buf, "c:%d:%d_%d", c, xpr->pol[i], xpr->rev[i]);
  ftr[(*ftr_num)++] = sdb_id(ft, buf, add);

#if 0
  for (s = xpr->base[i]; *s != '\0'; s += strlen(s) + 1) {
    sprintf(buf, "d:%d:%s", c, s);
    ftr[(*ftr_num)++] = sdb_id(ft, buf, add);
  }
#endif

  for (s = xpr->surf[i]; *s != '\0'; s += strlen(s) + 1) {
    sprintf(buf, "h:%d:%s", c, s);
    ftr[(*ftr_num)++] = sdb_id(ft, buf, add);
  }

  for (s = xpr->cpos[i]; *s != '\0'; s += strlen(s) + 1) {
    sprintf(buf, "e:%d:%s", c, s);
    ftr[(*ftr_num)++] = sdb_id(ft, buf, add);
  }

  for (s = xpr->fpos[i]; *s != '\0'; s += strlen(s) + 1) {
    sprintf(buf, "f:%d:%s", c, s);
    ftr[(*ftr_num)++] = sdb_id(ft, buf, add);
  }

  for (s = xpr->base[i], ss = NULL; *s != '\0'; s += strlen(s) + 1) {
    if (ss != NULL) {
      sprintf(buf, "g:%d:%s_%s", c, ss, s);
      ftr[(*ftr_num)++] = sdb_id(ft, buf, add);
    }
    ss = s;
  }

  for (s = xpr->surf[i], ss = NULL; *s != '\0'; s += strlen(s) + 1) {
    if (ss != NULL) {
      sprintf(buf, "i:%d:%s_%s", c, ss, s);
      ftr[(*ftr_num)++] = sdb_id(ft, buf, add);
    }
    ss = s;
  }

  ftr[(*ftr_num)++] = -2;
  exception(*ftr_num > ftr_size, "buffer overflow (ftr1)");	/* ȴ */

  return;
}


/*
  edge
  οFTR2_MAXǸ
*/
void makeftr2(SDB *ft, int add, XPR *xpr, int i, int j, int c, int d, int ftr_size, int *ftr_num, int *ftr) {
  char buf[BUF_MAX];
  char *s;

  sprintf(buf, "A:%d:%d", c, d);
  ftr[(*ftr_num)++] = sdb_id(ft, buf, add);

  sprintf(buf, "B:%d:%d_%d", c, d, xpr->rev[j]);
  ftr[(*ftr_num)++] = sdb_id(ft, buf, add);

  sprintf(buf, "C:%d:%d_%d_%d", c, d, xpr->rev[j], xpr->pol[j]);
  ftr[(*ftr_num)++] = sdb_id(ft, buf, add);

  for (s = xpr->base[i]; *s != '\0'; s += strlen(s) + 1) {
    sprintf(buf, "D:%d:%d_%s", c, d, s);
    ftr[(*ftr_num)++] = sdb_id(ft, buf, add);
  }

  for (s = xpr->base[j]; *s != '\0'; s += strlen(s) + 1) {
    sprintf(buf, "E:%d:%d_%s", c, d, s);
    ftr[(*ftr_num)++] = sdb_id(ft, buf, add);
  }

  ftr[(*ftr_num)++] = -2;
  exception(*ftr_num > ftr_size, "buffer overflow (ftr2)");	/* ȴ */

  return;
}


/*
  Belief Propagation
*/
void bp(XPR *xpr, double *pot1, double *pot2, VAR *v, FAC *f, double *prb1, double *prb2) {
  int i;
  int com;
  int m, n;
  double msg;
  int c, d;
  double z, max;

  /* Ρɤν */

  /* variableΡ */
  for (n = 0; n < xpr->num; n++) {
    v[n].con = 0;
    v[n].cnt = 0;
    v[n].flg = 0;
  }
  /* factorΡ(node) */
  for (m = 0; m < xpr->num; m++) {
    f[m].con = 1;
    f[m].cnt = 0;
    f[m].flg = 0;
    f[m].nbr[0] = m;
    v[f[m].nbr[0]].con++;
    f[m].rcv[0] = 0;
    f[m].snd[0] = 0;
    for (c = 0; c < CLS_MAX; c++) {
      f[m].v2f[0][c] = 0.0;
      f[m].f2v[0][c] = 0.0;
    }
  }
  /* factorΡ(edge) */
  for (m = xpr->num + 1; m < 2 * xpr->num; m++) {
    f[m].con = 2;
    f[m].cnt = 0;
    f[m].flg = 0;
    /* 긵 */
    f[m].nbr[0] = m - xpr->num;
    v[f[m].nbr[0]].con++;
    f[m].rcv[0] = 0;
    f[m].snd[0] = 0;
    for (c = 0; c < CLS_MAX; c++) {
      f[m].v2f[0][c] = 0.0;
      f[m].f2v[0][c] = 0.0;
    }
    /*  */
    f[m].nbr[1] = xpr->head[m - xpr->num];
    v[f[m].nbr[1]].con++;
    f[m].rcv[1] = 0;
    f[m].snd[1] = 0;
    for (c = 0; c < CLS_MAX; c++) {
      f[m].v2f[1][c] = 0.0;
      f[m].f2v[1][c] = 0.0;
    }
  }
  /* äå */
  for (i = 0; i < xpr->num; i++) {
    for (c = 0; c < CLS_MAX; c++) {
      PRB1(prb1, i, c) = 0.0;
      for (d = 0; d < CLS_MAX; d++) {
	PRB2(prb2, i, c, d) = POT2(pot2, i, c, d);
      }
    }
  }

  /* inference */

  for (com = 0; com < xpr->num; ) {	/* «ޤǷ֤(;Ϥ) */
    /* factorvariableؤ */
    for (m = 0; m < 2 * xpr->num; m++) {
      if (m == xpr->num) continue;	/* ̵factor */

      /* ⤷ܤƤΥΡɤåäƤ */
      if (f[m].cnt == f[m].con && f[m].flg < 2) {
	for (i = 0; i < f[m].con; i++) {
	  if (f[m].snd[i]) continue;
	  n = f[m].nbr[i];
	  f[m].snd[i] = 1;
	  v[n].cnt++;
	  for (c = 0; c < CLS_MAX; c++) {	/* ƤγΨѿͤ */
	    if (m < xpr->num) {	/* node-feature */
	      msg = POT1(pot1, m, c);
	    } else {	/* edge-feature */
	      msg = -DBL_MAX;
	      for (d = 0; d < CLS_MAX; d++) {
		double tmp;
		if (i == 0) {
		  tmp = POT2(pot2, m - xpr->num, c, d) + f[m].v2f[1][d];
		} else {
		  tmp = POT2(pot2, m - xpr->num, d, c) + f[m].v2f[0][d];
		}
		msg = LOGSUMEXP(msg, tmp);
	      }
	    }
	    f[m].f2v[i][c] = msg;
	    PRB1(prb1, n, c) += msg;
	  }
	}
	f[m].flg = 2;
      }

      /* ⤷1ĤΥΡɰʳåäƤФ1Ĥ */
      if (f[m].cnt == f[m].con - 1 && f[m].flg < 1) {
	for (i = 0; i < f[m].con; i++) {
	  if (f[m].rcv[i]) continue;	/* äƤʤΡɤ */
	  n = f[m].nbr[i];
	  f[m].snd[i] = 1;
	  v[n].cnt++;
	  for (c = 0; c < CLS_MAX; c++) {	/* ƤγΨѿͤ */
	    if (m < xpr->num) {	/* node-feature */
	      msg = POT1(pot1, m, c);
	    } else {	/* edge-feature */
	      msg = -DBL_MAX;
	      for (d = 0; d < CLS_MAX; d++) {
		double tmp;
		if (i == 0) {
		  tmp = POT2(pot2, m - xpr->num, c, d) + f[m].v2f[1][d];
		} else {
		  tmp = POT2(pot2, m - xpr->num, d, c) + f[m].v2f[0][d];
		}
		msg = LOGSUMEXP(msg, tmp);
	      }
	    }
	    f[m].f2v[i][c] = msg;
	    PRB1(prb1, n, c) += msg;
	  }
	}
	f[m].flg = 1;
      }
    }

    /* variablefactorؤ */
    for (n = 0; n < xpr->num; n++) {
      /* ⤷ܤƤΥΡɤåäƤ */
      if (v[n].cnt == v[n].con && v[n].flg < 2) {
	for (m = 0; m < 2 * xpr->num; m++) {	/* ܥΡɤ */
	  if (m == xpr->num) continue;
	  for (i = 0; i < f[m].con; i++) {
	    if (f[m].nbr[i] == n && !f[m].rcv[i]) {
	      f[m].rcv[i] = 1;
	      f[m].cnt++;
	      for (c = 0; c < CLS_MAX; c++) {	/* ƤγΨѿͤ */
		msg = PRB1(prb1, n, c) - f[m].f2v[i][c];
		f[m].v2f[i][c] = msg;
		if (m >= xpr->num) {
		  for (d = 0; d < CLS_MAX; d++) {
		    if (i == 0) {
		      PRB2(prb2, m - xpr->num, c, d) += msg;
		    } else {
		      PRB2(prb2, m - xpr->num, d, c) += msg;
		    }
		  }
		}
	      }
	    }
	  }
	}
	v[n].flg = 2;
	com++;
      }

      /* ⤷1ĤΥΡɰʳåäƤФ1Ĥ */
      if (v[n].cnt == v[n].con - 1 && v[n].flg < 1) {
	for (m = 0; m < 2 * xpr->num; m++) {	/* ܥΡɤ */
	  if (m == xpr->num) continue;
	  for (i = 0; i < f[m].con; i++) {
	    if (f[m].nbr[i] == n && !f[m].snd[i]) {
	      f[m].rcv[i] = 1;
	      f[m].cnt++;
	      for (c = 0; c < CLS_MAX; c++) {	/* ƤγΨѿͤ */
		msg = PRB1(prb1, n, c);
		f[m].v2f[i][c] = msg;
		if (m >= xpr->num) {
		  for (d = 0; d < CLS_MAX; d++) {
		    if (i == 0) {
		      PRB2(prb2, m - xpr->num, c, d) += msg;
		    } else {
		      PRB2(prb2, m - xpr->num, d, c) += msg;
		    }
		  }
		}
	      }
	    }
	  }
	}
	v[n].flg = 1;
      }
    }
  }

  /* nodeΨη׻ */
  for (n = 0; n < xpr->num; n++) {
    /* ʬ۴ؿη׻ */
    max = PRB1(prb1, n, 0);
    for (c = 0; c < CLS_MAX; c++) {
      if (PRB1(prb1, n, c) > max) max = PRB1(prb1, n, c);
    }
    z = 0.0;
    for (c = 0; c < CLS_MAX; c++) {
      z += exp(PRB1(prb1, n, c) - max);
    }
    z = max + log(z);
    /* ճΨη׻ */
    for (c = 0; c < CLS_MAX; c++) {
      PRB1(prb1, n, c) = exp(PRB1(prb1, n, c) - z);
    }
  }

  /* edgeΨη׻ */
  for (m = 1; m < xpr->num; m++) {
    /* ʬ۴ؿη׻ */
    max = PRB2(prb2, m, 0, 0);
    for (c = 0; c < CLS_MAX; c++) {
      for (d = 0; d < CLS_MAX; d++) {
	if (PRB2(prb2, m, c, d) > max) max = PRB2(prb2, m, c, d);
      }
    }
    z = 0.0;
    for (c = 0; c < CLS_MAX; c++) {
      for (d = 0; d < CLS_MAX; d++) {
	z += exp(PRB2(prb2, m, c, d) - max);
      }
    }
    z = max + log(z);
    /* ճΨη׻ */
    for (c = 0; c < CLS_MAX; c++) {
      for (d = 0; d < CLS_MAX; d++) {
	PRB2(prb2, m, c, d) = exp(PRB2(prb2, m, c, d) - z);
      }
    }
  }

  return;
}
