"""
generate toy datasets
based on code by Mark Rogers
"""

from math   import sin, pi
from random import gauss
import numpy
import datafunc

##
# Class generators:
##
def sineClass(xlim=[0,1], ylim=[0,1], n=20, sigma = 0.04) :
    """
    Generates a 2-D noisy sine wave
    Parameters:
      xlim     - list of length 2 that delimits the x value range 
      ylim     - list of length 2 that delimits the y value range
      n        - number of data points
    Note: for use with PyML demo2d, only use x and y values
          between -1 and 1
    """
    minx   = min(xlim)
    dx     = float(max(xlim)-minx)/n
    yrange = max(ylim)-min(ylim)
    miny   = min(ylim)
    gamma  = float(yrange)/2.0
    X = []
    for i in xrange(n) :
        xval = i*dx
        newx = minx + xval + gauss(0,sigma)
        newy = miny + gamma*sin(xval*pi*2) + gauss(0,sigma)
        X.append([newx, newy])

    return X

def multivariate_normal(mu, sigma=0.1, n=20) :
    """
    a wrapper around numpy's random.multivariate_normal function
    Generates data from a Gaussian distribution with mean mu
    and standard deviation sigma
    Parameters:
      mu      - mean
      sigma   - variance (either a float, list or square matrix)
      n       - number of points to generate

    Note: for use with PyML demo2d, only use mu1 and mu2
          values that keep populations between -1 and 1
    """

    dim = len(mu)
    if type(sigma) == type(1.0) or type(sigma) == type(1) :
        sigma = numpy.diag([sigma] * dim)
    else :
        sigma = numpy.array(sigma)
        if sigma.ndim == 1 :
            sigma = numpy.diag(sigma)
        else :
            assert sigma.shape[0] == sigma.shape[1]

    return numpy.random.multivariate_normal(mu, sigma, n)

def gaussianData(mu, sigma, n) :

    numClasses = len(mu)
    if len(sigma) == 1 :
        sigma = [sigma for i in range(numClasses)]
    if len(n) == 1 :
        n = [n for i in range(numClasses)]

    Y = []
    for i in range(numClasses) :
        Y.extend([str(i) for j in range(n[i])])
        
    X = []
    for i in range(numClasses) :
        print mu[i], sigma[i], n[i]
        X.extend(multivariate_normal(mu[i], sigma[i], n[i]).tolist())

    return datafunc.VectorDataSet(X, L = Y)

def noisyData() :
    """
    Creates two populations, usually linearly-separable, but with
    vastly different variance.  Simulates a problem where one
    population has significantly more noise than another.  Data are
    output in a CSV format suitable for creating a PyML VectorDataSet
    (labelsColumn=1).
    """
    pid = 0
    for label in [-1,1] :
        if label < 0 :
            X,Y = gaussCloud(-0.5, 0.0, sigma=0.05, n=20)
        else :
            X,Y = gaussCloud(0.3, 0.0, sigma=0.25, n=20)
        for i in xrange(len(X)) :
            pid += 1
            print "%(p)d,%(l)d,%(x)f,%(y)f" % {'p':pid, 'l':label, 'x':X[i], 'y':Y[i]}

def sineData(n = 30) :
    """
    Uses sine-wave populations to create two class populations that
    meander close to each other.  Data are output in a CSV format
    suitable for creating a PyML VectorDataSet (labelsColumn=1).
    """
    pid = 0
    lim = 0.8
    X = []
    Y = []
    for label in [-1,1] :
        if label > 0 :
            X.extend(sineClass([-lim,lim], [0, 0.6], n))
        else :
            X.extend(sineClass([-lim,lim], [-0.4, 0.2], n))
        Y.extend([str(label) for i in range(n)])

    return datafunc.VectorDataSet(X, L = Y)

def separableData() :
    """
    Creates two linearly-separable populations, one centered
    at (-.5,0) and the other at (0.5,0).  Data are output in
    a CSV format suitable for creating a PyML VectorDataSet
    (labelsColumn=1).
    """
    pid = 0
    for label in [-1,1] :
        if label < 0 :
            X,Y = gaussCloud(-0.5, 0.0, sigma=0.2, n=20)
        else :
            X,Y = gaussCloud(0.5, 0.0, sigma=0.2, n=20)
        for i in xrange(len(X)) :
            pid += 1
            print "%(p)d,%(l)d,%(x)f,%(y)f" % {'p':pid, 'l':label, 'x':X[i], 'y':Y[i]}

## Main:
USAGE = """
Usage: python generate.py type
Where 'type' is one of:
    l  - two similar, linearly-separable populations
    n  - two linearly-separable populations, one with more
         noise than the other
    s  - two populations generated by sine waves (with some noise)
"""

if __name__ == '__main__' :
    import sys
    if len(sys.argv) != 2 :
        print USAGE
        sys.exit(1)
    type = sys.argv[1]
    if type == 'l' :
        separableData()
    elif type == 'n' :
        noisyData()
    elif type == 's' :
        curvyData()
    else :
        print "Unrecognized data generation type:", type
