/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

abstract class AbstractConfusionMatrixMetric
implements SoftClassificationMetric {
    public static final ParseField AT = new ParseField("at", new String[0]);
    protected final double[] thresholds;

    protected AbstractConfusionMatrixMetric(double[] thresholds) {
        this.thresholds = ExceptionsHelper.requireNonNull(thresholds, AT);
        if (thresholds.length == 0) {
            throw ExceptionsHelper.badRequestException("[" + this.getMetricName() + "." + AT.getPreferredName() + "] must have at least one value", new Object[0]);
        }
        for (double threshold : thresholds) {
            if (!(threshold < 0.0) && !(threshold > 1.0)) continue;
            throw ExceptionsHelper.badRequestException("[" + this.getMetricName() + "." + AT.getPreferredName() + "] values must be in [0.0, 1.0]", new Object[0]);
        }
    }

    protected AbstractConfusionMatrixMetric(StreamInput in) throws IOException {
        this.thresholds = in.readDoubleArray();
    }

    public void writeTo(StreamOutput out) throws IOException {
        out.writeDoubleArray(this.thresholds);
    }

    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject();
        builder.field(AT.getPreferredName(), (Object)this.thresholds);
        builder.endObject();
        return builder;
    }

    @Override
    public final List<AggregationBuilder> aggs(String actualField, List<SoftClassificationMetric.ClassInfo> classInfos) {
        ArrayList<AggregationBuilder> aggs = new ArrayList<AggregationBuilder>();
        for (double threshold : this.thresholds) {
            aggs.addAll(this.aggsAt(actualField, classInfos, threshold));
        }
        return aggs;
    }

    protected abstract List<AggregationBuilder> aggsAt(String var1, List<SoftClassificationMetric.ClassInfo> var2, double var3);

    protected String aggName(SoftClassificationMetric.ClassInfo classInfo, double threshold, Condition condition) {
        return this.getMetricName() + "_" + classInfo.getName() + "_at_" + threshold + "_" + condition.name();
    }

    protected AggregationBuilder buildAgg(SoftClassificationMetric.ClassInfo classInfo, double threshold, Condition condition) {
        BoolQueryBuilder boolQuery = QueryBuilders.boolQuery();
        switch (condition) {
            case TP: {
                boolQuery.must(classInfo.matchingQuery());
                boolQuery.must((QueryBuilder)QueryBuilders.rangeQuery((String)classInfo.getProbabilityField()).gte((Object)threshold));
                break;
            }
            case FP: {
                boolQuery.mustNot(classInfo.matchingQuery());
                boolQuery.must((QueryBuilder)QueryBuilders.rangeQuery((String)classInfo.getProbabilityField()).gte((Object)threshold));
                break;
            }
            case TN: {
                boolQuery.mustNot(classInfo.matchingQuery());
                boolQuery.must((QueryBuilder)QueryBuilders.rangeQuery((String)classInfo.getProbabilityField()).lt((Object)threshold));
                break;
            }
            case FN: {
                boolQuery.must(classInfo.matchingQuery());
                boolQuery.must((QueryBuilder)QueryBuilders.rangeQuery((String)classInfo.getProbabilityField()).lt((Object)threshold));
                break;
            }
            default: {
                throw new IllegalArgumentException("Unknown enum value: " + (Object)((Object)condition));
            }
        }
        return AggregationBuilders.filter((String)this.aggName(classInfo, threshold, condition), (QueryBuilder)boolQuery);
    }

    protected static enum Condition {
        TP,
        FP,
        TN,
        FN;

    }
}

