/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.vectors.query;

import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder;
import org.elasticsearch.xpack.vectors.query.VectorScriptDocValues;

public class ScoreScriptUtils {
    public static double l1norm(List<Number> queryVector, VectorScriptDocValues.DenseVectorScriptDocValues dvs) {
        BytesRef value = dvs.getEncodedValue();
        float[] docVector = VectorEncoderDecoder.decodeDenseVector(value);
        if (queryVector.size() != docVector.length) {
            throw new IllegalArgumentException("Can't calculate l1norm! The number of dimensions of the query vector [" + queryVector.size() + "] is different from the documents' vectors [" + docVector.length + "].");
        }
        Iterator<Number> queryVectorIter = queryVector.iterator();
        double l1norm = 0.0;
        for (int dim = 0; dim < docVector.length; ++dim) {
            l1norm += (double)Math.abs(queryVectorIter.next().floatValue() - docVector[dim]);
        }
        return l1norm;
    }

    public static double l2norm(List<Number> queryVector, VectorScriptDocValues.DenseVectorScriptDocValues dvs) {
        BytesRef value = dvs.getEncodedValue();
        float[] docVector = VectorEncoderDecoder.decodeDenseVector(value);
        if (queryVector.size() != docVector.length) {
            throw new IllegalArgumentException("Can't calculate l2norm! The number of dimensions of the query vector [" + queryVector.size() + "] is different from the documents' vectors [" + docVector.length + "].");
        }
        Iterator<Number> queryVectorIter = queryVector.iterator();
        double l2norm = 0.0;
        for (int dim = 0; dim < docVector.length; ++dim) {
            double diff = queryVectorIter.next().floatValue() - docVector[dim];
            l2norm += diff * diff;
        }
        return Math.sqrt(l2norm);
    }

    public static double dotProduct(List<Number> queryVector, VectorScriptDocValues.DenseVectorScriptDocValues dvs) {
        BytesRef value = dvs.getEncodedValue();
        float[] docVector = VectorEncoderDecoder.decodeDenseVector(value);
        if (queryVector.size() != docVector.length) {
            throw new IllegalArgumentException("Can't calculate dotProduct! The number of dimensions of the query vector [" + queryVector.size() + "] is different from the documents' vectors [" + docVector.length + "].");
        }
        return ScoreScriptUtils.intDotProduct(queryVector, docVector);
    }

    private static double intDotProduct(List<Number> v1, float[] v2) {
        double v1v2DotProduct = 0.0;
        Iterator<Number> v1Iter = v1.iterator();
        for (int dim = 0; dim < v2.length; ++dim) {
            v1v2DotProduct += (double)(v1Iter.next().floatValue() * v2[dim]);
        }
        return v1v2DotProduct;
    }

    private static double intDotProductSparse(float[] v1Values, int[] v1Dims, float[] v2Values, int[] v2Dims) {
        double v1v2DotProduct = 0.0;
        int v1Index = 0;
        int v2Index = 0;
        while (v1Index < v1Values.length && v2Index < v2Values.length) {
            if (v1Dims[v1Index] == v2Dims[v2Index]) {
                v1v2DotProduct += (double)(v1Values[v1Index] * v2Values[v2Index]);
                ++v1Index;
                ++v2Index;
                continue;
            }
            if (v1Dims[v1Index] > v2Dims[v2Index]) {
                ++v2Index;
                continue;
            }
            ++v1Index;
        }
        return v1v2DotProduct;
    }

    public static final class CosineSimilaritySparse
    extends VectorSparseFunctions {
        final double queryVectorMagnitude;

        public CosineSimilaritySparse(Map<String, Number> queryVector) {
            super(queryVector);
            double dotProduct = 0.0;
            for (int i = 0; i < this.queryDims.length; ++i) {
                dotProduct += (double)(this.queryValues[i] * this.queryValues[i]);
            }
            this.queryVectorMagnitude = Math.sqrt(dotProduct);
        }

        public double cosineSimilaritySparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) {
            BytesRef value = dvs.getEncodedValue();
            int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(value);
            float[] docValues = VectorEncoderDecoder.decodeSparseVector(value);
            double dotProduct = 0.0;
            for (float docValue : docValues) {
                dotProduct += (double)docValue * (double)docValue;
            }
            double docVectorMagnitude = Math.sqrt(dotProduct);
            double docQueryDotProduct = ScoreScriptUtils.intDotProductSparse(this.queryValues, this.queryDims, docValues, docDims);
            return docQueryDotProduct / (docVectorMagnitude * this.queryVectorMagnitude);
        }
    }

    public static final class DotProductSparse
    extends VectorSparseFunctions {
        public DotProductSparse(Map<String, Number> queryVector) {
            super(queryVector);
        }

        public double dotProductSparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) {
            BytesRef value = dvs.getEncodedValue();
            int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(value);
            float[] docValues = VectorEncoderDecoder.decodeSparseVector(value);
            return ScoreScriptUtils.intDotProductSparse(this.queryValues, this.queryDims, docValues, docDims);
        }
    }

    public static final class L2NormSparse
    extends VectorSparseFunctions {
        public L2NormSparse(Map<String, Number> queryVector) {
            super(queryVector);
        }

        public double l2normSparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) {
            BytesRef value = dvs.getEncodedValue();
            int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(value);
            float[] docValues = VectorEncoderDecoder.decodeSparseVector(value);
            int queryIndex = 0;
            int docIndex = 0;
            double l2norm = 0.0;
            while (queryIndex < this.queryDims.length && docIndex < docDims.length) {
                double diff;
                if (this.queryDims[queryIndex] == docDims[docIndex]) {
                    diff = this.queryValues[queryIndex] - docValues[docIndex];
                    l2norm += diff * diff;
                    ++queryIndex;
                    ++docIndex;
                    continue;
                }
                if (this.queryDims[queryIndex] > docDims[docIndex]) {
                    diff = docValues[docIndex];
                    l2norm += diff * diff;
                    ++docIndex;
                    continue;
                }
                diff = this.queryValues[queryIndex];
                l2norm += diff * diff;
                ++queryIndex;
            }
            while (queryIndex < this.queryDims.length) {
                l2norm += (double)(this.queryValues[queryIndex] * this.queryValues[queryIndex]);
                ++queryIndex;
            }
            while (docIndex < docDims.length) {
                l2norm += (double)(docValues[docIndex] * docValues[docIndex]);
                ++docIndex;
            }
            return Math.sqrt(l2norm);
        }
    }

    public static final class L1NormSparse
    extends VectorSparseFunctions {
        public L1NormSparse(Map<String, Number> queryVector) {
            super(queryVector);
        }

        public double l1normSparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) {
            BytesRef value = dvs.getEncodedValue();
            int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(value);
            float[] docValues = VectorEncoderDecoder.decodeSparseVector(value);
            int queryIndex = 0;
            int docIndex = 0;
            double l1norm = 0.0;
            while (queryIndex < this.queryDims.length && docIndex < docDims.length) {
                if (this.queryDims[queryIndex] == docDims[docIndex]) {
                    l1norm += (double)Math.abs(this.queryValues[queryIndex] - docValues[docIndex]);
                    ++queryIndex;
                    ++docIndex;
                    continue;
                }
                if (this.queryDims[queryIndex] > docDims[docIndex]) {
                    l1norm += (double)Math.abs(docValues[docIndex]);
                    ++docIndex;
                    continue;
                }
                l1norm += (double)Math.abs(this.queryValues[queryIndex]);
                ++queryIndex;
            }
            while (queryIndex < this.queryDims.length) {
                l1norm += (double)Math.abs(this.queryValues[queryIndex]);
                ++queryIndex;
            }
            while (docIndex < docDims.length) {
                l1norm += (double)Math.abs(docValues[docIndex]);
                ++docIndex;
            }
            return l1norm;
        }
    }

    public static class VectorSparseFunctions {
        final float[] queryValues;
        final int[] queryDims;

        public VectorSparseFunctions(Map<String, Number> queryVector) {
            int n = queryVector.size();
            this.queryValues = new float[n];
            this.queryDims = new int[n];
            int i = 0;
            for (Map.Entry<String, Number> dimValue : queryVector.entrySet()) {
                try {
                    this.queryDims[i] = Integer.parseInt(dimValue.getKey());
                }
                catch (NumberFormatException e) {
                    throw new IllegalArgumentException("Failed to parse a query vector dimension, it must be an integer!", e);
                }
                this.queryValues[i] = dimValue.getValue().floatValue();
                ++i;
            }
            VectorEncoderDecoder.sortSparseDimsFloatValues(this.queryDims, this.queryValues, n);
        }
    }

    public static final class CosineSimilarity {
        final double queryVectorMagnitude;
        final List<Number> queryVector;

        public CosineSimilarity(List<Number> queryVector) {
            this.queryVector = queryVector;
            double dotProduct = 0.0;
            for (Number value : queryVector) {
                float floatValue = value.floatValue();
                dotProduct += (double)(floatValue * floatValue);
            }
            this.queryVectorMagnitude = Math.sqrt(dotProduct);
        }

        public double cosineSimilarity(VectorScriptDocValues.DenseVectorScriptDocValues dvs) {
            BytesRef value = dvs.getEncodedValue();
            float[] docVector = VectorEncoderDecoder.decodeDenseVector(value);
            if (this.queryVector.size() != docVector.length) {
                throw new IllegalArgumentException("Can't calculate cosineSimilarity! The number of dimensions of the query vector [" + this.queryVector.size() + "] is different from the documents' vectors [" + docVector.length + "].");
            }
            double dotProduct = 0.0;
            for (int dim = 0; dim < docVector.length; ++dim) {
                dotProduct += (double)docVector[dim] * (double)docVector[dim];
            }
            double docVectorMagnitude = Math.sqrt(dotProduct);
            double docQueryDotProduct = ScoreScriptUtils.intDotProduct(this.queryVector, docVector);
            return docQueryDotProduct / (docVectorMagnitude * this.queryVectorMagnitude);
        }
    }
}

