/*
  Copyright(C) 2007-2012 National Institute of Information and Communications Technology
*/

/*
  svmtools
  Multiclass classification module
*/


#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "exception.h"
#include "svm_common.h"
#include "svm_smo.h"
#include "svm_mc.h"


#define BUF_SIZE (32 * 1024)


static int svm_train(SVM_TPRM *tprm, SVM_EXM *exm, FILE *fp);


/*
  Multiclass-SVM$B$N3X=,(B
*/
int svm_mc_makemdl(int method, int class, SVM_TPRM *tprm, SVM_EXM *mcexm, FILE *fp) {
  int i, j, k, l;
  SVM_EXM exm;
  int *rank;
  int *count;

  /* $B=i4|2=(B */

  /* $B71N}MQ;vNc(B */
  exm.label = smalloc(sizeof(int) * mcexm->num);
  exm.sv = smalloc(sizeof(SVM_SV *) * mcexm->num);
  /* $B3NN($r5a$a$F%=!<%H(B */
  rank = smalloc(sizeof(int) * class);
  count = smalloc(sizeof(int) * class);
  for (i = 0; i < class; i++) {
    rank[i] = i;
    count[i] = 0;
  }
  for (i = 0; i < mcexm->num; i++) count[mcexm->label[i]]++;
  for (i = 0; i < class; i++) {
    int max, tmp;
    max = i;
    for (j = i + 1; j < class; j++) {
      if (count[rank[j]] > count[rank[max]]) max = j;
    }
    if (i != max) {
      tmp = rank[i];
      rank[i] = rank[max];
      rank[max] = tmp;
    }
  }
  free(count);

  /* $B%b%G%k$N6&DLItJ,=q$-9~$_(B */

  fprintf(fp, "svmtools multiclass\n");
  fprintf(fp, "%d\n", method);
  fprintf(fp, "%d\n", class);
  for (i = 0; i < class; i++) {
    fprintf(fp, "%d\n", rank[i]);
  }

  /* $B3X=,(B */

  switch (method) {
  case SVM_MC_METHOD_OVR:
    /* $B;vNc$N%Y%/%H%k$N$_%3%T!<(B */
    for (i = 0; i < mcexm->num; i++) {
      exm.sv[i] = mcexm->sv[i];
    }
    exm.num = mcexm->num;

    /* k$B8D$N(Bclassifier$B$r:n@.(B */
    for (i = 0; i < class; i++) {
      if (svm_verbose) fprintf(stderr, "class=%d: ", i);

      /* $B;vNc$N:n@.(B(One-Versus-Rest) */
      for (j = 0; j < mcexm->num; j++) {
	if (mcexm->label[j] == i) {
	  exm.label[j] = +1;
	} else {
	  exm.label[j] = -1;
	}
      }

      /* $B3X=,(B */
      if (svm_train(tprm, &exm, fp)) return 1;
    }
    break;
  case SVM_MC_METHOD_PAIRWISE:
    /* class * (class - 1) / 2 $B8D$N(Bclassifier$B$r:n@.(B */
    for (i = 0; i < class; i++) {
      /* $B%/%i%9(Bi$B$r3X=,;vNc$K%;%C%H(B */
      l = 0;
      for (k = 0; k < mcexm->num; k++) {
	if (mcexm->label[k] == i) {
	  exm.label[l] = +1;
	  exm.sv[l] = mcexm->sv[k];
	  l++;
	}
      }
      for (j = i + 1; j < class; j++) {
	/* $B%/%i%9(Bk$B$r3X=,;vNc$KDI2C(B */
	exm.num = l;
	for (k = 0; k < mcexm->num; k++) {
	  if (mcexm->label[k] == j) {
	    exm.label[exm.num] = -1;
	    exm.sv[exm.num] = mcexm->sv[k];
	    (exm.num)++;
	  }
	}

	if (svm_verbose) fprintf(stderr, "class=%d/%d: ", i, j);

	if (svm_train(tprm, &exm, fp)) return 1;
      }
    }

    break;
  case SVM_MC_METHOD_SBT:
    /* (class - 1)$B8D$N(Bclassifier$B$r:n@.(B */

    /* $B$O$8$a$K0lHV%i%s%/$NDc$$%/%i%9$rIiNc$H$7$F%;%C%H(B */
    exm.num = 0;
    for (i = 0; i < mcexm->num; i++) {
      if (mcexm->label[i] == rank[class - 1]) {
	exm.label[exm.num] = -1;
	exm.sv[exm.num] = mcexm->sv[i];
	(exm.num)++;
      }
    }
    /* $B%i%s%/$NDc$$$b$N$+$i:n$C$F$$$/(B */
    for (i = class - 2; i >= 0; i--) {
      int tmpnum;

      if (svm_verbose) fprintf(stderr, "class=%d: ", i);

      /* $B@5Nc$rDI2C(B */
      tmpnum = exm.num;
      for (j = 0; j < mcexm->num; j++) {
	if (mcexm->label[j] == rank[i]) {
	  exm.label[exm.num] = +1;
	  exm.sv[exm.num] = mcexm->sv[j];
	  (exm.num)++;
	}
      }

      /* $B3X=,(B */
      if (svm_train(tprm, &exm, fp)) return 1;

      /* $B@5Nc$rIiNc$KJQ$($k(B */
      for (j = tmpnum; j < exm.num; j++) exm.label[j] = -1;
    }

    break;
  }

  free(rank);
  free(exm.sv);
  free(exm.label);

  return 0;
}


/*
  SVM$B$N3X=,(B
*/
static int svm_train(SVM_TPRM *tprm, SVM_EXM *exm, FILE *fp) {
  double *alpha;
  double b;

  /* $B%a%b%j3d$jEv$F(B */
  if (exm->num > 0) {
    alpha = smalloc(sizeof(double) * exm->num);
  } else {
    alpha = NULL;
  }

  /* $B3X=,(B */
  if (svm_smo(tprm, exm, alpha, &b)) return 1;

  /* $B%b%G%k$KJQ49(B */
  if (svm_createmdl(&tprm->kprm, exm, alpha, b, fp)) return 1;

  if (alpha != NULL) free(alpha);

  return 0;
}


/*
  Multiclass-SVM$B$NFI$_9~$_(B
*/
SVM_MC_MDL *svm_mc_readmdl(FILE *fp) {
  int i, j, k;
  char buf[BUF_SIZE];
  SVM_MC_MDL *mcmdl;

  /* $B%a%b%j3NJ](B */
  mcmdl = smalloc(sizeof(SVM_MC_MDL));

  /* ID */
  if (fgets(buf, BUF_SIZE, fp) == NULL) return NULL;
  buf[strlen(buf) - 1] = '\0';
  if (strcmp("svmtools multiclass", buf) != 0) return NULL;

  /* $BJ}K!(B */
  if (fgets(buf, BUF_SIZE, fp) == NULL) return NULL;
  buf[strlen(buf) - 1] = '\0';
  mcmdl->method = atoi(buf);

  /* $B%/%i%9$N?t(B */
  if (fgets(buf, BUF_SIZE, fp) == NULL) return NULL;
  buf[strlen(buf) - 1] = '\0';
  mcmdl->class = atoi(buf);

  /* $B;vA03NN($N=gHV(B */
  mcmdl->rank = smalloc(sizeof(int) * mcmdl->class);
  for (i = 0; i < mcmdl->class; i++) {
    if (fgets(buf, BUF_SIZE, fp) == NULL) return NULL;
    buf[strlen(buf) - 1] = '\0';
    mcmdl->rank[i] = atoi(buf);
  }

  mcmdl->vote = smalloc(sizeof(int) * mcmdl->class);

  /* $B%b%G%k(B */
  switch (mcmdl->method) {
  case SVM_MC_METHOD_OVR:
    /* $B%b%G%k$N%a%b%j3d$jEv$F(B */
    mcmdl->mdl = smalloc(sizeof(SVM_MDL *) * mcmdl->class);
    /* $B3F%b%G%k$NFI$_9~$_(B */
    for (i = 0; i < mcmdl->class; i++) {
      if ((mcmdl->mdl[i] = svm_readmdl(fp)) == NULL) return NULL;
    }
    break;
  case SVM_MC_METHOD_PAIRWISE:
    /* $B%b%G%k$N%a%b%j3d$jEv$F(B */
    mcmdl->mdl = smalloc(sizeof(SVM_MDL *) * mcmdl->class * (mcmdl->class - 1) / 2);
    /* $B3F%b%G%k$NFI$_9~$_(B */
    k = 0;
    for (i = 0; i < mcmdl->class; i++) {
      for (j = i + 1; j < mcmdl->class; j++) {
	if ((mcmdl->mdl[k++] = svm_readmdl(fp)) == NULL) return NULL;
      }
    }
    break;
  case SVM_MC_METHOD_SBT:
    /* $B%b%G%k$N%a%b%j3d$jEv$F(B */
    mcmdl->mdl = smalloc(sizeof(SVM_MDL *) * (mcmdl->class - 1));
    /* $B3F%b%G%k$NFI$_9~$_(B */
    for (i = 0; i < mcmdl->class - 1; i++) {
      if ((mcmdl->mdl[i] = svm_readmdl(fp)) == NULL) return NULL;
    }
    break;
  }

  return mcmdl;
}


/*
  $BJ,N`(B
  tie$B$N>l9g$O;vA03NN($NBg$-$$%/%i%9$rM%@h(B
*/
int svm_mc_decision(SVM_MC_MDL *mcmdl, SVM_SV *x) {
  int i, j, k;
  int class;
  float s;

  switch (mcmdl->method) {
  case SVM_MC_METHOD_OVR:
    class = mcmdl->rank[0];
    s = svm_decision(mcmdl->mdl[class], x);
    for (i = 1; i < mcmdl->class; i++) {
      float t;
      t = svm_decision(mcmdl->mdl[mcmdl->rank[i]], x);
      if (t > s) {
	class = mcmdl->rank[i];
	s = t;
      }
    }
    break;
  case SVM_MC_METHOD_PAIRWISE:
    for (i = 0; i < mcmdl->class; i++) mcmdl->vote[i] = 0;
    k = 0;
    for (i = 0; i < mcmdl->class; i++) {
      for (j = i + 1; j < mcmdl->class; j++) {
	s = svm_decision(mcmdl->mdl[k++], x);
	if (s > 0.0) {
	  mcmdl->vote[i]++;
	} else if (s < 0.0) {
	  mcmdl->vote[j]++;
	}
      }
    }
    class = mcmdl->rank[0];
    s = mcmdl->vote[class];
    for (i = 1; i < mcmdl->class; i++) {
      if (mcmdl->vote[mcmdl->rank[i]] > s) {
	class = mcmdl->rank[i];
	s = mcmdl->vote[class];
      }
    }
    break;
  case SVM_MC_METHOD_SBT:
    for (i = mcmdl->class - 2; i >= 0; i--) {
      s = svm_decision(mcmdl->mdl[i], x);
      if (s >= 0.0) break;
    }
    class = mcmdl->rank[mcmdl->class - 2 - i];
    break;
  default:
    class = -1;
  }

  return class;
}
