/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.bucket.filter.Filter;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AbstractConfusionMatrixMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ScoreByThresholdResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric;

public class Precision
extends AbstractConfusionMatrixMetric {
    public static final ParseField NAME = new ParseField("precision", new String[0]);
    private static final ConstructingObjectParser<Precision, Void> PARSER = new ConstructingObjectParser(NAME.getPreferredName(), a -> new Precision((List)a[0]));

    public static Precision fromXContent(XContentParser parser) {
        return (Precision)PARSER.apply(parser, null);
    }

    public Precision(List<Double> at) {
        super(at.stream().mapToDouble(Double::doubleValue).toArray());
    }

    public Precision(StreamInput in) throws IOException {
        super(in);
    }

    public String getWriteableName() {
        return NAME.getPreferredName();
    }

    @Override
    public String getMetricName() {
        return NAME.getPreferredName();
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        Precision that = (Precision)o;
        return Arrays.equals(this.thresholds, that.thresholds);
    }

    public int hashCode() {
        return Arrays.hashCode(this.thresholds);
    }

    @Override
    protected List<AggregationBuilder> aggsAt(String labelField, List<SoftClassificationMetric.ClassInfo> classInfos, double threshold) {
        ArrayList<AggregationBuilder> aggs = new ArrayList<AggregationBuilder>();
        for (SoftClassificationMetric.ClassInfo classInfo : classInfos) {
            aggs.add(this.buildAgg(classInfo, threshold, AbstractConfusionMatrixMetric.Condition.TP));
            aggs.add(this.buildAgg(classInfo, threshold, AbstractConfusionMatrixMetric.Condition.FP));
        }
        return aggs;
    }

    @Override
    public EvaluationMetricResult evaluate(SoftClassificationMetric.ClassInfo classInfo, Aggregations aggs) {
        double[] precisions = new double[this.thresholds.length];
        for (int i = 0; i < precisions.length; ++i) {
            long fp;
            double threshold = this.thresholds[i];
            Filter tpAgg = (Filter)aggs.get(this.aggName(classInfo, threshold, AbstractConfusionMatrixMetric.Condition.TP));
            Filter fpAgg = (Filter)aggs.get(this.aggName(classInfo, threshold, AbstractConfusionMatrixMetric.Condition.FP));
            long tp = tpAgg.getDocCount();
            precisions[i] = tp + (fp = fpAgg.getDocCount()) == 0L ? 0.0 : (double)tp / (double)(tp + fp);
        }
        return new ScoreByThresholdResult(NAME.getPreferredName(), this.thresholds, precisions);
    }

    static {
        PARSER.declareDoubleArray(ConstructingObjectParser.constructorArg(), AT);
    }
}

