/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier.df.builder;

import com.google.common.collect.Sets;
import java.util.HashSet;
import java.util.Random;
import org.apache.mahout.classifier.df.builder.TreeBuilder;
import org.apache.mahout.classifier.df.data.Data;
import org.apache.mahout.classifier.df.data.Dataset;
import org.apache.mahout.classifier.df.data.Instance;
import org.apache.mahout.classifier.df.data.conditions.Condition;
import org.apache.mahout.classifier.df.node.CategoricalNode;
import org.apache.mahout.classifier.df.node.Leaf;
import org.apache.mahout.classifier.df.node.Node;
import org.apache.mahout.classifier.df.node.NumericalNode;
import org.apache.mahout.classifier.df.split.IgSplit;
import org.apache.mahout.classifier.df.split.OptIgSplit;
import org.apache.mahout.classifier.df.split.RegressionSplit;
import org.apache.mahout.classifier.df.split.Split;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DecisionTreeBuilder
implements TreeBuilder {
    private static final Logger log = LoggerFactory.getLogger(DecisionTreeBuilder.class);
    private static final int[] NO_ATTRIBUTES = new int[0];
    private static final double EPSILON = 1.0E-6;
    private boolean[] selected;
    private int m;
    private IgSplit igSplit;
    private boolean complemented = true;
    private double minSplitNum = 2.0;
    private double minVarianceProportion = 0.001;
    private Data fullSet;
    private double minVariance = Double.NaN;

    public void setM(int m) {
        this.m = m;
    }

    public void setIgSplit(IgSplit igSplit) {
        this.igSplit = igSplit;
    }

    public void setComplemented(boolean complemented) {
        this.complemented = complemented;
    }

    public void setMinSplitNum(int minSplitNum) {
        this.minSplitNum = minSplitNum;
    }

    public void setMinVarianceProportion(double minVarianceProportion) {
        this.minVarianceProportion = minVarianceProportion;
    }

    @Override
    public Node build(Random rng, Data data) {
        Node childNode;
        int[] attributes;
        if (this.selected == null) {
            this.selected = new boolean[data.getDataset().nbAttributes()];
            this.selected[data.getDataset().getLabelId()] = true;
        }
        if (this.m == 0) {
            double e = data.getDataset().nbAttributes() - 1;
            this.m = data.getDataset().isNumerical(data.getDataset().getLabelId()) ? (int)Math.ceil(e / 3.0) : (int)Math.ceil(Math.sqrt(e));
        }
        if (data.isEmpty()) {
            return new Leaf(Double.NaN);
        }
        double sum = 0.0;
        if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
            double sumSquared = 0.0;
            for (int i = 0; i < data.size(); ++i) {
                double label = data.getDataset().getLabel(data.get(i));
                sum += label;
                sumSquared += label * label;
            }
            double var = sumSquared - sum * sum / (double)data.size();
            if (Double.compare(this.minVariance, Double.NaN) == 0) {
                this.minVariance = var / (double)data.size() * this.minVarianceProportion;
                log.debug("minVariance:{}", (Object)this.minVariance);
            }
            if (var / (double)data.size() < this.minVariance) {
                log.debug("variance({}) < minVariance({}) Leaf({})", new Object[]{var / (double)data.size(), this.minVariance, sum / (double)data.size()});
                return new Leaf(sum / (double)data.size());
            }
        } else {
            if (this.isIdentical(data)) {
                return new Leaf(data.majorityLabel(rng));
            }
            if (data.identicalLabel()) {
                return new Leaf(data.getDataset().getLabel(data.get(0)));
            }
        }
        if (this.fullSet == null) {
            this.fullSet = data;
        }
        if ((attributes = DecisionTreeBuilder.randomAttributes(rng, this.selected, this.m)) == null || attributes.length == 0) {
            double label = data.getDataset().isNumerical(data.getDataset().getLabelId()) ? sum / (double)data.size() : (double)data.majorityLabel(rng);
            log.warn("attribute which can be selected is not found Leaf({})", (Object)label);
            return new Leaf(label);
        }
        if (this.igSplit == null) {
            this.igSplit = data.getDataset().isNumerical(data.getDataset().getLabelId()) ? new RegressionSplit() : new OptIgSplit();
        }
        Split best = null;
        for (int attr : attributes) {
            Split split = this.igSplit.computeSplit(data, attr);
            if (best != null && !(best.getIg() < split.getIg())) continue;
            best = split;
        }
        if (best.getIg() < 1.0E-6) {
            double label = data.getDataset().isNumerical(data.getDataset().getLabelId()) ? sum / (double)data.size() : (double)data.majorityLabel(rng);
            log.debug("ig is near to zero Leaf({})", (Object)label);
            return new Leaf(label);
        }
        log.debug("best split attr:{}, split:{}, ig:{}", new Object[]{best.getAttr(), best.getSplit(), best.getIg()});
        boolean alreadySelected = this.selected[best.getAttr()];
        if (alreadySelected) {
            log.warn("attribute {} already selected in a parent node", (Object)best.getAttr());
        }
        if (data.getDataset().isNumerical(best.getAttr())) {
            boolean[] temp = null;
            Data loSubset = data.subset(Condition.lesser(best.getAttr(), best.getSplit()));
            Data hiSubset = data.subset(Condition.greaterOrEquals(best.getAttr(), best.getSplit()));
            if (loSubset.isEmpty() || hiSubset.isEmpty()) {
                this.selected[best.getAttr()] = true;
            } else {
                temp = this.selected;
                this.selected = DecisionTreeBuilder.cloneCategoricalAttributes(data.getDataset(), this.selected);
            }
            if ((double)loSubset.size() < this.minSplitNum || (double)hiSubset.size() < this.minSplitNum) {
                double label = data.getDataset().isNumerical(data.getDataset().getLabelId()) ? sum / (double)data.size() : (double)data.majorityLabel(rng);
                log.debug("branch is not split Leaf({})", (Object)label);
                return new Leaf(label);
            }
            Node loChild = this.build(rng, loSubset);
            Node hiChild = this.build(rng, hiSubset);
            if (temp != null) {
                this.selected = temp;
            } else {
                this.selected[best.getAttr()] = alreadySelected;
            }
            childNode = new NumericalNode(best.getAttr(), best.getSplit(), loChild, hiChild);
        } else {
            double[] values = data.values(best.getAttr());
            HashSet<Double> subsetValues = null;
            if (this.complemented) {
                subsetValues = Sets.newHashSet();
                for (double value : values) {
                    subsetValues.add(value);
                }
                values = this.fullSet.values(best.getAttr());
            }
            int cnt = 0;
            Data[] subsets = new Data[values.length];
            for (int index = 0; index < values.length; ++index) {
                if (this.complemented && !subsetValues.contains(values[index])) continue;
                subsets[index] = data.subset(Condition.equals(best.getAttr(), values[index]));
                if (!((double)subsets[index].size() >= this.minSplitNum)) continue;
                ++cnt;
            }
            if (cnt < 2) {
                double label = data.getDataset().isNumerical(data.getDataset().getLabelId()) ? sum / (double)data.size() : (double)data.majorityLabel(rng);
                log.debug("branch is not split Leaf({})", (Object)label);
                return new Leaf(label);
            }
            this.selected[best.getAttr()] = true;
            Node[] children = new Node[values.length];
            for (int index = 0; index < values.length; ++index) {
                if (this.complemented && (subsetValues == null || !subsetValues.contains(values[index]))) {
                    double label = data.getDataset().isNumerical(data.getDataset().getLabelId()) ? sum / (double)data.size() : (double)data.majorityLabel(rng);
                    log.debug("complemented Leaf({})", (Object)label);
                    children[index] = new Leaf(label);
                    continue;
                }
                children[index] = this.build(rng, subsets[index]);
            }
            this.selected[best.getAttr()] = alreadySelected;
            childNode = new CategoricalNode(best.getAttr(), values, children);
        }
        return childNode;
    }

    private boolean isIdentical(Data data) {
        if (data.isEmpty()) {
            return true;
        }
        Instance instance = data.get(0);
        for (int attr = 0; attr < this.selected.length; ++attr) {
            if (this.selected[attr]) continue;
            for (int index = 1; index < data.size(); ++index) {
                if (data.get(index).get(attr) == instance.get(attr)) continue;
                return false;
            }
        }
        return true;
    }

    private static boolean[] cloneCategoricalAttributes(Dataset dataset, boolean[] selected) {
        boolean[] cloned = new boolean[selected.length];
        for (int i = 0; i < selected.length; ++i) {
            cloned[i] = !dataset.isNumerical(i) && selected[i];
        }
        cloned[dataset.getLabelId()] = true;
        return cloned;
    }

    private static int[] randomAttributes(Random rng, boolean[] selected, int m) {
        int index;
        int[] result;
        int nbNonSelected = 0;
        for (boolean sel : selected) {
            if (sel) continue;
            ++nbNonSelected;
        }
        if (nbNonSelected == 0) {
            log.warn("All attributes are selected !");
            return NO_ATTRIBUTES;
        }
        if (nbNonSelected <= m) {
            result = new int[nbNonSelected];
            index = 0;
            for (int attr = 0; attr < selected.length; ++attr) {
                if (selected[attr]) continue;
                result[index++] = attr;
            }
        } else {
            result = new int[m];
            for (index = 0; index < m; ++index) {
                int rind;
                while (selected[rind = rng.nextInt(selected.length)]) {
                }
                result[index] = rind;
                selected[rind] = true;
            }
            for (int attr : result) {
                selected[attr] = false;
            }
        }
        return result;
    }
}

