/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier.sgd;

import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.collect.Collections2;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.csv.CSVUtils;
import org.apache.mahout.classifier.sgd.RecordFactory;
import org.apache.mahout.math.Vector;
import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder;
import org.apache.mahout.vectorizer.encoders.ContinuousValueEncoder;
import org.apache.mahout.vectorizer.encoders.Dictionary;
import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder;
import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder;
import org.apache.mahout.vectorizer.encoders.TextValueEncoder;

public class CsvRecordFactory
implements RecordFactory {
    private static final String INTERCEPT_TERM = "Intercept Term";
    private static final Map<String, Class<? extends FeatureVectorEncoder>> TYPE_DICTIONARY = ImmutableMap.builder().put("continuous", ContinuousValueEncoder.class).put("numeric", ContinuousValueEncoder.class).put("n", ContinuousValueEncoder.class).put("word", StaticWordValueEncoder.class).put("w", StaticWordValueEncoder.class).put("text", TextValueEncoder.class).put("t", TextValueEncoder.class).build();
    private final Map<String, Set<Integer>> traceDictionary = Maps.newTreeMap();
    private int target;
    private final Dictionary targetDictionary;
    private String idName;
    private int id = -1;
    private List<Integer> predictors;
    private Map<Integer, FeatureVectorEncoder> predictorEncoders;
    private int maxTargetValue = Integer.MAX_VALUE;
    private final String targetName;
    private final Map<String, String> typeMap;
    private List<String> variableNames;
    private boolean includeBiasTerm;
    private static final String CANNOT_CONSTRUCT_CONVERTER = "Unable to construct type converter... shouldn't be possible";

    private List<String> parseCsvLine(String line) {
        try {
            return Arrays.asList(CSVUtils.parseLine(line));
        }
        catch (IOException e) {
            ArrayList<String> list = Lists.newArrayList();
            list.add(line);
            return list;
        }
    }

    private List<String> parseCsvLine(CharSequence line) {
        return this.parseCsvLine(((Object)line).toString());
    }

    public CsvRecordFactory(String targetName, Map<String, String> typeMap) {
        this.targetName = targetName;
        this.typeMap = typeMap;
        this.targetDictionary = new Dictionary();
    }

    public CsvRecordFactory(String targetName, String idName, Map<String, String> typeMap) {
        this(targetName, typeMap);
        this.idName = idName;
    }

    @Override
    public void defineTargetCategories(List<String> values) {
        Preconditions.checkArgument(values.size() <= this.maxTargetValue, "Must have less than or equal to " + this.maxTargetValue + " categories for target variable, but found " + values.size());
        if (this.maxTargetValue == Integer.MAX_VALUE) {
            this.maxTargetValue = values.size();
        }
        for (String value : values) {
            this.targetDictionary.intern(value);
        }
    }

    @Override
    public CsvRecordFactory maxTargetValue(int max) {
        this.maxTargetValue = max;
        return this;
    }

    @Override
    public boolean usesFirstLineAsSchema() {
        return true;
    }

    @Override
    public void firstLine(String line) {
        final HashMap<String, Integer> vars = Maps.newHashMap();
        this.variableNames = this.parseCsvLine(line);
        int column = 0;
        for (String var : this.variableNames) {
            vars.put(var, column++);
        }
        this.target = (Integer)vars.get(this.targetName);
        if (this.idName != null) {
            this.id = (Integer)vars.get(this.idName);
        }
        this.predictors = Lists.newArrayList(Collections2.transform(this.typeMap.keySet(), new Function<String, Integer>(){

            @Override
            public Integer apply(String from) {
                Integer r = (Integer)vars.get(from);
                Preconditions.checkArgument(r != null, "Can't find variable %s, only know about %s", from, vars);
                return r;
            }
        }));
        if (this.includeBiasTerm) {
            this.predictors.add(-1);
        }
        Collections.sort(this.predictors);
        this.predictorEncoders = Maps.newHashMap();
        for (Integer predictor : this.predictors) {
            Class c;
            String name;
            if (predictor == -1) {
                name = INTERCEPT_TERM;
                c = ConstantValueEncoder.class;
            } else {
                name = this.variableNames.get(predictor);
                c = TYPE_DICTIONARY.get(this.typeMap.get(name));
            }
            try {
                Preconditions.checkArgument(c != null, "Invalid type of variable %s,  wanted one of %s", this.typeMap.get(name), TYPE_DICTIONARY.keySet());
                Constructor constructor = c.getConstructor(String.class);
                Preconditions.checkArgument(constructor != null, "Can't find correct constructor for %s", this.typeMap.get(name));
                FeatureVectorEncoder encoder = (FeatureVectorEncoder)constructor.newInstance(name);
                this.predictorEncoders.put(predictor, encoder);
                encoder.setTraceDictionary(this.traceDictionary);
            }
            catch (InstantiationException e) {
                throw new IllegalStateException(CANNOT_CONSTRUCT_CONVERTER, e);
            }
            catch (IllegalAccessException e) {
                throw new IllegalStateException(CANNOT_CONSTRUCT_CONVERTER, e);
            }
            catch (InvocationTargetException e) {
                throw new IllegalStateException(CANNOT_CONSTRUCT_CONVERTER, e);
            }
            catch (NoSuchMethodException e) {
                throw new IllegalStateException(CANNOT_CONSTRUCT_CONVERTER, e);
            }
        }
    }

    @Override
    public int processLine(String line, Vector featureVector) {
        List<String> values = this.parseCsvLine(line);
        int targetValue = this.targetDictionary.intern(values.get(this.target));
        if (targetValue >= this.maxTargetValue) {
            targetValue = this.maxTargetValue - 1;
        }
        for (Integer predictor : this.predictors) {
            String value = predictor >= 0 ? values.get(predictor) : null;
            this.predictorEncoders.get(predictor).addToVector(value, featureVector);
        }
        return targetValue;
    }

    public int processLine(CharSequence line, Vector featureVector, boolean returnTarget) {
        List<String> values = this.parseCsvLine(line);
        int targetValue = -1;
        if (returnTarget && (targetValue = this.targetDictionary.intern(values.get(this.target))) >= this.maxTargetValue) {
            targetValue = this.maxTargetValue - 1;
        }
        for (Integer predictor : this.predictors) {
            String value = predictor >= 0 ? values.get(predictor) : null;
            this.predictorEncoders.get(predictor).addToVector(value, featureVector);
        }
        return targetValue;
    }

    public String getTargetString(CharSequence line) {
        List<String> values = this.parseCsvLine(line);
        return values.get(this.target);
    }

    public String getTargetLabel(int code) {
        for (String key : this.targetDictionary.values()) {
            if (this.targetDictionary.intern(key) != code) continue;
            return key;
        }
        return null;
    }

    public String getIdString(CharSequence line) {
        List<String> values = this.parseCsvLine(line);
        return values.get(this.id);
    }

    @Override
    public Iterable<String> getPredictors() {
        return Lists.transform(this.predictors, new Function<Integer, String>(){

            @Override
            public String apply(Integer v) {
                if (v >= 0) {
                    return (String)CsvRecordFactory.this.variableNames.get(v);
                }
                return CsvRecordFactory.INTERCEPT_TERM;
            }
        });
    }

    @Override
    public Map<String, Set<Integer>> getTraceDictionary() {
        return this.traceDictionary;
    }

    @Override
    public CsvRecordFactory includeBiasTerm(boolean useBias) {
        this.includeBiasTerm = useBias;
        return this;
    }

    @Override
    public List<String> getTargetCategories() {
        List<String> r = this.targetDictionary.values();
        if (r.size() > this.maxTargetValue) {
            r.subList(this.maxTargetValue, r.size()).clear();
        }
        return r;
    }

    public String getIdName() {
        return this.idName;
    }

    public void setIdName(String idName) {
        this.idName = idName;
    }
}

