/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.sql.analysis.analyzer;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import org.elasticsearch.common.logging.LoggerMessageFormat;
import org.elasticsearch.xpack.sql.capabilities.Unresolvable;
import org.elasticsearch.xpack.sql.expression.Alias;
import org.elasticsearch.xpack.sql.expression.Attribute;
import org.elasticsearch.xpack.sql.expression.AttributeSet;
import org.elasticsearch.xpack.sql.expression.Exists;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.Expressions;
import org.elasticsearch.xpack.sql.expression.FieldAttribute;
import org.elasticsearch.xpack.sql.expression.UnresolvedAttribute;
import org.elasticsearch.xpack.sql.expression.function.Function;
import org.elasticsearch.xpack.sql.expression.function.FunctionAttribute;
import org.elasticsearch.xpack.sql.expression.function.Functions;
import org.elasticsearch.xpack.sql.expression.function.Score;
import org.elasticsearch.xpack.sql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.sql.expression.function.aggregate.AggregateFunctionAttribute;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Max;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Min;
import org.elasticsearch.xpack.sql.expression.function.aggregate.TopHits;
import org.elasticsearch.xpack.sql.expression.function.grouping.GroupingFunction;
import org.elasticsearch.xpack.sql.expression.function.grouping.GroupingFunctionAttribute;
import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunction;
import org.elasticsearch.xpack.sql.plan.logical.Aggregate;
import org.elasticsearch.xpack.sql.plan.logical.Distinct;
import org.elasticsearch.xpack.sql.plan.logical.Filter;
import org.elasticsearch.xpack.sql.plan.logical.Limit;
import org.elasticsearch.xpack.sql.plan.logical.LocalRelation;
import org.elasticsearch.xpack.sql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.sql.plan.logical.OrderBy;
import org.elasticsearch.xpack.sql.plan.logical.Project;
import org.elasticsearch.xpack.sql.plan.logical.command.Command;
import org.elasticsearch.xpack.sql.stats.FeatureMetric;
import org.elasticsearch.xpack.sql.stats.Metrics;
import org.elasticsearch.xpack.sql.tree.Node;
import org.elasticsearch.xpack.sql.type.DataType;
import org.elasticsearch.xpack.sql.type.EsField;
import org.elasticsearch.xpack.sql.util.StringUtils;

public final class Verifier {
    private final Metrics metrics;

    public Verifier(Metrics metrics) {
        this.metrics = metrics;
    }

    private static Failure fail(Node<?> source, String message, Object ... args) {
        return new Failure(source, LoggerMessageFormat.format((String)message, (Object[])args));
    }

    public Map<Node<?>, String> verifyFailures(LogicalPlan plan) {
        Collection<Failure> failures = this.verify(plan);
        return failures.stream().collect(Collectors.toMap(Failure::source, Failure::message));
    }

    Collection<Failure> verify(LogicalPlan plan) {
        LinkedHashSet<Failure> failures = new LinkedHashSet<Failure>();
        plan.forEachUp(p -> {
            if (p.analyzed()) {
                return;
            }
            if (!p.childrenResolved()) {
                return;
            }
            LinkedHashSet<Failure> localFailures = new LinkedHashSet<Failure>();
            if (p instanceof Unresolvable) {
                localFailures.add(Verifier.fail(p, ((Unresolvable)((Object)p)).unresolvedMessage(), new Object[0]));
            } else if (p instanceof Distinct) {
                localFailures.add(Verifier.fail(p, "SELECT DISTINCT is not yet supported", new Object[0]));
            } else {
                p.forEachExpressions(e -> {
                    if (e.resolved()) {
                        return;
                    }
                    e.forEachUp(ae -> {
                        if (!ae.childrenResolved()) {
                            return;
                        }
                        if (ae instanceof Unresolvable) {
                            UnresolvedAttribute ua;
                            if (ae instanceof UnresolvedAttribute && !(ua = (UnresolvedAttribute)ae).customMessage()) {
                                boolean useQualifier = ua.qualifier() != null;
                                ArrayList<String> potentialMatches = new ArrayList<String>();
                                for (Attribute a : p.inputSet()) {
                                    String nameCandidate;
                                    String string = nameCandidate = useQualifier ? a.qualifiedName() : a.name();
                                    if (a.dataType() == DataType.UNSUPPORTED || !a.dataType().isPrimitive()) continue;
                                    potentialMatches.add(nameCandidate);
                                }
                                List<String> matches = StringUtils.findSimilar(ua.qualifiedName(), potentialMatches);
                                if (!matches.isEmpty()) {
                                    ae = ua.withUnresolvedMessage(UnresolvedAttribute.errorMessage(ua.qualifiedName(), matches));
                                }
                            }
                            localFailures.add(Verifier.fail(ae, ((Unresolvable)((Object)ae)).unresolvedMessage(), new Object[0]));
                            return;
                        }
                        if (ae.typeResolved().unresolved()) {
                            localFailures.add(Verifier.fail(ae, ae.typeResolved().message(), new Object[0]));
                        } else if (ae instanceof Exists) {
                            localFailures.add(Verifier.fail(ae, "EXISTS is not yet supported", new Object[0]));
                        }
                    });
                });
            }
            failures.addAll(localFailures);
        });
        if (failures.isEmpty()) {
            Map<String, Function> resolvedFunctions = Functions.collectFunctions(plan);
            LinkedHashSet groupingFailures = new LinkedHashSet();
            plan.forEachDown(p -> {
                if (p.analyzed()) {
                    return;
                }
                if (!p.childrenResolved()) {
                    return;
                }
                LinkedHashSet<Failure> localFailures = new LinkedHashSet<Failure>();
                Verifier.checkGroupingFunctionInGroupBy(p, localFailures);
                Verifier.checkFilterOnAggs(p, localFailures);
                Verifier.checkFilterOnGrouping(p, localFailures);
                if (!groupingFailures.contains(p)) {
                    Verifier.checkGroupBy(p, localFailures, resolvedFunctions, groupingFailures);
                }
                Verifier.checkForScoreInsideFunctions(p, localFailures);
                Verifier.checkNestedUsedInGroupByOrHaving(p, localFailures);
                Verifier.checkForGeoFunctionsOnDocValues(p, localFailures);
                if (localFailures.isEmpty()) {
                    p.setAnalyzed();
                }
                failures.addAll(localFailures);
            });
        }
        if (failures.isEmpty()) {
            BitSet b = new BitSet(FeatureMetric.values().length);
            plan.forEachDown(p -> {
                if (p instanceof Aggregate) {
                    b.set(FeatureMetric.GROUPBY.ordinal());
                } else if (p instanceof OrderBy) {
                    b.set(FeatureMetric.ORDERBY.ordinal());
                } else if (p instanceof Filter) {
                    if (((Filter)p).child() instanceof Aggregate) {
                        b.set(FeatureMetric.HAVING.ordinal());
                    } else {
                        b.set(FeatureMetric.WHERE.ordinal());
                    }
                } else if (p instanceof Limit) {
                    b.set(FeatureMetric.LIMIT.ordinal());
                } else if (p instanceof LocalRelation) {
                    b.set(FeatureMetric.LOCAL.ordinal());
                } else if (p instanceof Command) {
                    b.set(FeatureMetric.COMMAND.ordinal());
                }
            });
            int i = b.nextSetBit(0);
            while (i >= 0) {
                this.metrics.inc(FeatureMetric.values()[i]);
                i = b.nextSetBit(i + 1);
            }
        }
        return failures;
    }

    private static boolean checkGroupBy(LogicalPlan p, Set<Failure> localFailures, Map<String, Function> resolvedFunctions, Set<LogicalPlan> groupingFailures) {
        return Verifier.checkGroupByInexactField(p, localFailures) && Verifier.checkGroupByAgg(p, localFailures, resolvedFunctions) && Verifier.checkGroupByOrder(p, localFailures, groupingFailures) && Verifier.checkGroupByHaving(p, localFailures, groupingFailures, resolvedFunctions) && Verifier.checkGroupByTime(p, localFailures);
    }

    private static boolean checkGroupByOrder(LogicalPlan p, Set<Failure> localFailures, Set<LogicalPlan> groupingFailures) {
        if (p instanceof OrderBy) {
            OrderBy o = (OrderBy)p;
            LogicalPlan child = o.child();
            if (child instanceof Project) {
                child = ((Project)child).child();
            }
            if (child instanceof Filter) {
                child = ((Filter)child).child();
            }
            if (child instanceof Aggregate) {
                Aggregate a = (Aggregate)child;
                LinkedHashMap missing = new LinkedHashMap();
                o.order().forEach(oe -> {
                    Expression e = oe.child();
                    if (Functions.isAggregate(e) || e instanceof AggregateFunctionAttribute) {
                        return;
                    }
                    ArrayList<Expression> groupingAndMatchingAggregatesAliases = new ArrayList<Expression>(a.groupings());
                    a.aggregates().forEach(as -> {
                        if (as instanceof Alias) {
                            Alias al = (Alias)as;
                            if (Expressions.anyMatch(a.groupings(), g -> Expressions.equalsAsAttribute(al.child(), g))) {
                                groupingAndMatchingAggregatesAliases.add(al);
                            }
                        }
                    });
                    if (e.anyMatch(expression -> Expressions.anyMatch(groupingAndMatchingAggregatesAliases, g -> expression.semanticEquals(expression instanceof Attribute ? Expressions.attribute(g) : g)))) {
                        return;
                    }
                    missing.put(e, oe);
                });
                if (!missing.isEmpty()) {
                    String plural = missing.size() > 1 ? "s" : "";
                    localFailures.add(Verifier.fail((Node)missing.values().iterator().next(), "Cannot order by non-grouped column" + plural + " {}, expected {} or an aggregate function", Expressions.names(missing.keySet()), Expressions.names(a.groupings())));
                    groupingFailures.add(a);
                    return false;
                }
            }
        }
        return true;
    }

    private static boolean checkGroupByHaving(LogicalPlan p, Set<Failure> localFailures, Set<LogicalPlan> groupingFailures, Map<String, Function> functions) {
        Filter f;
        if (p instanceof Filter && (f = (Filter)p).child() instanceof Aggregate) {
            Aggregate a = (Aggregate)f.child();
            LinkedHashSet missing = new LinkedHashSet();
            LinkedHashSet unsupported = new LinkedHashSet();
            Expression condition = f.condition();
            condition.collectFirstChildren(c -> Verifier.checkGroupByHavingHasOnlyAggs(c, missing, unsupported, functions));
            if (!missing.isEmpty()) {
                String plural = missing.size() > 1 ? "s" : "";
                localFailures.add(Verifier.fail(condition, "Cannot use HAVING filter on non-aggregate" + plural + " {}; use WHERE instead", Expressions.names(missing)));
                groupingFailures.add(a);
                return false;
            }
            if (!unsupported.isEmpty()) {
                String plural = unsupported.size() > 1 ? "s" : "";
                localFailures.add(Verifier.fail(condition, "HAVING filter is unsupported for function" + plural + " {}", Expressions.names(unsupported)));
                groupingFailures.add(a);
                return false;
            }
        }
        return true;
    }

    private static boolean checkGroupByHavingHasOnlyAggs(Expression e, Set<Expression> missing, Set<Expression> unsupported, Map<String, Function> functions) {
        if (e instanceof FunctionAttribute) {
            FunctionAttribute fa = (FunctionAttribute)e;
            Function function = functions.get(fa.functionId());
            if (function == null) {
                return false;
            }
            e = function;
        }
        if (e instanceof ScalarFunction) {
            ScalarFunction sf = (ScalarFunction)e;
            for (Expression arg : sf.arguments()) {
                arg.collectFirstChildren(c -> Verifier.checkGroupByHavingHasOnlyAggs(c, missing, unsupported, functions));
            }
            return true;
        }
        if (e instanceof Score) {
            unsupported.add(e);
            return true;
        }
        if (e instanceof TopHits) {
            unsupported.add(e);
            return true;
        }
        if ((e instanceof Min || e instanceof Max) && ((AggregateFunction)e).field().dataType().isString()) {
            unsupported.add(e);
            return true;
        }
        if (e.foldable()) {
            return true;
        }
        if (Functions.isAggregate(e) || Functions.isGrouping(e)) {
            return true;
        }
        if (e instanceof Attribute) {
            missing.add(e);
            return true;
        }
        return false;
    }

    private static boolean checkGroupByInexactField(LogicalPlan p, Set<Failure> localFailures) {
        if (p instanceof Aggregate) {
            Aggregate a = (Aggregate)p;
            a.groupings().forEach(e -> e.forEachUp(c -> {
                EsField.Exact exact = c.getExactInfo();
                if (!exact.hasExact()) {
                    localFailures.add(Verifier.fail(c, "Field [" + c.sourceText() + "] of data type [" + c.dataType().typeName + "] cannot be used for grouping; " + exact.errorMsg(), new Object[0]));
                }
            }, FieldAttribute.class));
        }
        return true;
    }

    private static boolean checkGroupByTime(LogicalPlan p, Set<Failure> localFailures) {
        if (p instanceof Aggregate) {
            Aggregate a = (Aggregate)p;
            a.groupings().forEach(f -> {
                if (f.dataType().isTimeBased()) {
                    localFailures.add(Verifier.fail(f, "Function [" + f.sourceText() + "] with data type [" + f.dataType().typeName + "] cannot be used for grouping", new Object[0]));
                }
            });
        }
        return true;
    }

    private static boolean checkGroupByAgg(LogicalPlan p, Set<Failure> localFailures, Map<String, Function> functions) {
        if (p instanceof Aggregate) {
            Aggregate a = (Aggregate)p;
            a.groupings().forEach(e -> e.forEachUp(c -> {
                if (Functions.isAggregate(c)) {
                    localFailures.add(Verifier.fail(c, "Cannot use an aggregate [" + c.nodeName().toUpperCase(Locale.ROOT) + "] for grouping", new Object[0]));
                }
                if (c instanceof Score) {
                    localFailures.add(Verifier.fail(c, "Cannot use [SCORE()] for grouping", new Object[0]));
                }
            }));
            a.groupings().forEach(e -> {
                if (!Functions.isGrouping(e)) {
                    e.collectFirstChildren(c -> {
                        if (Functions.isGrouping(c)) {
                            localFailures.add(Verifier.fail(c, "Cannot combine [{}] grouping function inside GROUP BY, found [{}]; consider moving the expression inside the histogram", Expressions.name(c), Expressions.name(e)));
                            return true;
                        }
                        return false;
                    });
                }
            });
            if (!localFailures.isEmpty()) {
                return false;
            }
            LinkedHashMap missing = new LinkedHashMap();
            a.aggregates().forEach(ne -> ne.collectFirstChildren(c -> Verifier.checkGroupMatch(c, ne, a.groupings(), missing, functions)));
            if (!missing.isEmpty()) {
                String plural = missing.size() > 1 ? "s" : "";
                localFailures.add(Verifier.fail((Node)missing.values().iterator().next(), "Cannot use non-grouped column" + plural + " {}, expected {}", Expressions.names(missing.keySet()), Expressions.names(a.groupings())));
                return false;
            }
        }
        return true;
    }

    private static boolean checkGroupMatch(Expression e, Node<?> source, List<Expression> groupings, Map<Expression, Node<?>> missing, Map<String, Function> functions) {
        if (Expressions.match(groupings, e::semanticEquals)) {
            return true;
        }
        if (e instanceof FunctionAttribute) {
            FunctionAttribute fa = (FunctionAttribute)e;
            Function function = functions.get(fa.functionId());
            if (function == null) {
                return false;
            }
            e = function;
        }
        if (e instanceof ScalarFunction) {
            ScalarFunction sf = (ScalarFunction)e;
            if (Expressions.anyMatch(groupings, e::semanticEquals)) {
                return true;
            }
            for (Expression arg : sf.arguments()) {
                arg.collectFirstChildren(c -> Verifier.checkGroupMatch(c, source, groupings, missing, functions));
            }
            return true;
        }
        if (e instanceof Score) {
            missing.put(e, source);
            return true;
        }
        if (e.foldable()) {
            return true;
        }
        if (Functions.isAggregate(e)) {
            return true;
        }
        Expression exp = e;
        if (e.children().isEmpty()) {
            if (!Expressions.match(groupings, c -> exp.semanticEquals(exp instanceof Attribute ? Expressions.attribute(c) : c))) {
                missing.put(exp, source);
            }
            return true;
        }
        return false;
    }

    private static void checkGroupingFunctionInGroupBy(LogicalPlan p, Set<Failure> localFailures) {
        if (p instanceof Project) {
            Project proj = (Project)p;
            proj.projections().forEach(e -> e.forEachDown(f -> localFailures.add(Verifier.fail(f, "[{}] needs to be part of the grouping", Expressions.name(f))), GroupingFunction.class));
        } else if (p instanceof Aggregate) {
            Aggregate a = (Aggregate)p;
            a.aggregates().forEach(agg -> agg.forEachDown(e -> {
                if (a.groupings().size() == 0 || !Expressions.anyMatch(a.groupings(), g -> g instanceof Function && e.functionEquals((Function)g))) {
                    localFailures.add(Verifier.fail(e, "[{}] needs to be part of the grouping", Expressions.name(e)));
                } else {
                    Verifier.checkGroupingFunctionTarget(e, localFailures);
                }
            }, GroupingFunction.class));
            a.groupings().forEach(g -> g.forEachDown(e -> Verifier.checkGroupingFunctionTarget(e, localFailures), GroupingFunction.class));
        }
    }

    private static void checkGroupingFunctionTarget(GroupingFunction f, Set<Failure> localFailures) {
        f.field().forEachDown(e -> {
            if (e instanceof GroupingFunction) {
                localFailures.add(Verifier.fail(f.field(), "Cannot embed grouping functions within each other, found [{}] in [{}]", Expressions.name(f.field()), Expressions.name(f)));
            }
        });
    }

    private static void checkFilterOnAggs(LogicalPlan p, Set<Failure> localFailures) {
        Filter filter;
        if (p instanceof Filter && !((filter = (Filter)p).child() instanceof Aggregate)) {
            filter.condition().forEachDown(e -> {
                if (Functions.isAggregate(e) || e instanceof AggregateFunctionAttribute) {
                    localFailures.add(Verifier.fail(e, "Cannot use WHERE filtering on aggregate function [{}], use HAVING instead", Expressions.name(e)));
                }
            }, Expression.class);
        }
    }

    private static void checkFilterOnGrouping(LogicalPlan p, Set<Failure> localFailures) {
        if (p instanceof Filter) {
            Filter filter = (Filter)p;
            filter.condition().forEachDown(e -> {
                if (Functions.isGrouping(e) || e instanceof GroupingFunctionAttribute) {
                    localFailures.add(Verifier.fail(e, "Cannot filter on grouping function [{}], use its argument instead", Expressions.name(e)));
                }
            }, Expression.class);
        }
    }

    private static void checkForScoreInsideFunctions(LogicalPlan p, Set<Failure> localFailures) {
        p.forEachExpressions(e -> e.forEachUp(f -> f.arguments().stream().filter(exp -> exp.anyMatch(Score.class::isInstance)).forEach(exp -> localFailures.add(Verifier.fail(exp, "[SCORE()] cannot be an argument to a function", new Object[0]))), Function.class));
    }

    private static void checkNestedUsedInGroupByOrHaving(LogicalPlan p, Set<Failure> localFailures) {
        ArrayList nested = new ArrayList();
        Consumer<FieldAttribute> match = fa -> {
            if (fa.isNested()) {
                nested.add(fa);
            }
        };
        p.forEachDown(a -> a.groupings().forEach(agg -> agg.forEachUp(match, FieldAttribute.class)), Aggregate.class);
        if (!nested.isEmpty()) {
            localFailures.add(Verifier.fail((Node)nested.get(0), "Grouping isn't (yet) compatible with nested fields " + new AttributeSet(nested).names(), new Object[0]));
            nested.clear();
        }
        p.forEachDown(f -> {
            if (f.child() instanceof Aggregate) {
                f.condition().forEachUp(match, FieldAttribute.class);
            }
        }, Filter.class);
        if (!nested.isEmpty()) {
            localFailures.add(Verifier.fail((Node)nested.get(0), "HAVING isn't (yet) compatible with nested fields " + new AttributeSet(nested).names(), new Object[0]));
        }
    }

    private static void checkForGeoFunctionsOnDocValues(LogicalPlan p, Set<Failure> localFailures) {
        p.forEachDown(f -> f.condition().forEachUp(fa -> {
            if (fa.field().getDataType() == DataType.GEO_SHAPE) {
                localFailures.add(Verifier.fail(fa, "geo shapes cannot be used for filtering", new Object[0]));
            }
        }, FieldAttribute.class), Filter.class);
        p.forEachDown(a -> a.groupings().forEach(agg -> agg.forEachUp(fa -> {
            if (fa.field().getDataType() == DataType.GEO_SHAPE) {
                localFailures.add(Verifier.fail(fa, "geo shapes cannot be used in grouping", new Object[0]));
            }
        }, FieldAttribute.class)), Aggregate.class);
        p.forEachDown(o -> o.order().forEach(agg -> agg.forEachUp(fa -> {
            if (fa.field().getDataType() == DataType.GEO_SHAPE) {
                localFailures.add(Verifier.fail(fa, "geo shapes cannot be used for sorting", new Object[0]));
            }
        }, FieldAttribute.class)), OrderBy.class);
    }

    static class Failure {
        private final Node<?> source;
        private final String message;

        Failure(Node<?> source, String message) {
            this.source = source;
            this.message = message;
        }

        Node<?> source() {
            return this.source;
        }

        String message() {
            return this.message;
        }

        public int hashCode() {
            return this.source.hashCode();
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || this.getClass() != obj.getClass()) {
                return false;
            }
            Failure other = (Failure)obj;
            return Objects.equals(this.source, other.source);
        }

        public String toString() {
            return this.message;
        }
    }
}

