/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.topics;

import cc.mallet.types.FeatureSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.util.Randoms;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintStream;
import java.io.Serializable;

public class MarginalProbEstimator
implements Serializable {
    protected int numTopics;
    protected int topicMask;
    protected int topicBits;
    protected double[] alpha;
    protected double alphaSum;
    protected double beta;
    protected double betaSum;
    protected double smoothingOnlyMass = 0.0;
    protected double[] cachedCoefficients;
    protected int[][] typeTopicCounts;
    protected int[] tokensPerTopic;
    protected Randoms random;
    private static final long serialVersionUID = 1L;
    private static final int CURRENT_SERIAL_VERSION = 0;
    private static final int NULL_INTEGER = -1;

    public MarginalProbEstimator(int numTopics, double[] alpha, double alphaSum, double beta, int[][] typeTopicCounts, int[] tokensPerTopic) {
        this.numTopics = numTopics;
        if (Integer.bitCount(numTopics) == 1) {
            this.topicMask = numTopics - 1;
            this.topicBits = Integer.bitCount(this.topicMask);
        } else {
            this.topicMask = Integer.highestOneBit(numTopics) * 2 - 1;
            this.topicBits = Integer.bitCount(this.topicMask);
        }
        this.typeTopicCounts = typeTopicCounts;
        this.tokensPerTopic = tokensPerTopic;
        this.alphaSum = alphaSum;
        this.alpha = alpha;
        this.beta = beta;
        this.betaSum = beta * (double)typeTopicCounts.length;
        this.random = new Randoms();
        this.cachedCoefficients = new double[numTopics];
        this.smoothingOnlyMass = 0.0;
        for (int topic = 0; topic < numTopics; ++topic) {
            this.smoothingOnlyMass += alpha[topic] * beta / ((double)tokensPerTopic[topic] + this.betaSum);
            this.cachedCoefficients[topic] = alpha[topic] / ((double)tokensPerTopic[topic] + this.betaSum);
        }
        System.err.println("Topic Evaluator: " + numTopics + " topics, " + this.topicBits + " topic bits, " + Integer.toBinaryString(this.topicMask) + " topic mask");
    }

    public int[] getTokensPerTopic() {
        return this.tokensPerTopic;
    }

    public int[][] getTypeTopicCounts() {
        return this.typeTopicCounts;
    }

    public double evaluateLeftToRight(InstanceList testing, int numParticles, boolean usingResampling, PrintStream docProbabilityStream) {
        this.random = new Randoms();
        double logNumParticles = Math.log(numParticles);
        double totalLogLikelihood = 0.0;
        for (Instance instance : testing) {
            FeatureSequence tokenSequence = (FeatureSequence)instance.getData();
            double docLogLikelihood = 0.0;
            double[][] particleProbabilities = new double[numParticles][];
            for (int particle = 0; particle < numParticles; ++particle) {
                particleProbabilities[particle] = this.leftToRight(tokenSequence, usingResampling);
            }
            for (int position = 0; position < particleProbabilities[0].length; ++position) {
                double sum = 0.0;
                for (int particle = 0; particle < numParticles; ++particle) {
                    sum += particleProbabilities[particle][position];
                }
                if (!(sum > 0.0)) continue;
                docLogLikelihood += Math.log(sum) - logNumParticles;
            }
            if (docProbabilityStream != null) {
                docProbabilityStream.println(docLogLikelihood);
            }
            totalLogLikelihood += docLogLikelihood;
        }
        return totalLogLikelihood;
    }

    protected double[] leftToRight(FeatureSequence tokenSequence, boolean usingResampling) {
        int denseIndex;
        int[] oneDocTopics = new int[tokenSequence.getLength()];
        double[] wordProbabilities = new double[tokenSequence.getLength()];
        int docLength = tokenSequence.getLength();
        int tokensSoFar = 0;
        int[] localTopicCounts = new int[this.numTopics];
        int[] localTopicIndex = new int[this.numTopics];
        int nonZeroTopics = denseIndex = 0;
        double topicBetaMass = 0.0;
        double topicTermMass = 0.0;
        double[] topicTermScores = new double[this.numTopics];
        double logLikelihood = 0.0;
        for (int limit = 0; limit < docLength; ++limit) {
            double sample;
            int i;
            int newTopic;
            double score;
            int[] currentTypeTopicCounts;
            int type;
            if (usingResampling) {
                for (int position = 0; position < limit; ++position) {
                    double sample2;
                    type = tokenSequence.getIndexAtPosition(position);
                    int oldTopic = oneDocTopics[position];
                    if (type >= this.typeTopicCounts.length || this.typeTopicCounts[type] == null) continue;
                    currentTypeTopicCounts = this.typeTopicCounts[type];
                    topicBetaMass -= this.beta * (double)localTopicCounts[oldTopic] / ((double)this.tokensPerTopic[oldTopic] + this.betaSum);
                    int n = oldTopic;
                    localTopicCounts[n] = localTopicCounts[n] - 1;
                    if (localTopicCounts[oldTopic] == 0) {
                        denseIndex = 0;
                        while (localTopicIndex[denseIndex] != oldTopic) {
                            ++denseIndex;
                        }
                        while (denseIndex < nonZeroTopics) {
                            if (denseIndex < localTopicIndex.length - 1) {
                                localTopicIndex[denseIndex] = localTopicIndex[denseIndex + 1];
                            }
                            ++denseIndex;
                        }
                        --nonZeroTopics;
                    }
                    topicBetaMass += this.beta * (double)localTopicCounts[oldTopic] / ((double)this.tokensPerTopic[oldTopic] + this.betaSum);
                    this.cachedCoefficients[oldTopic] = (this.alpha[oldTopic] + (double)localTopicCounts[oldTopic]) / ((double)this.tokensPerTopic[oldTopic] + this.betaSum);
                    boolean alreadyDecremented = false;
                    topicTermMass = 0.0;
                    for (int index = 0; index < currentTypeTopicCounts.length && currentTypeTopicCounts[index] > 0; ++index) {
                        int currentTopic = currentTypeTopicCounts[index] & this.topicMask;
                        int currentValue = currentTypeTopicCounts[index] >> this.topicBits;
                        score = this.cachedCoefficients[currentTopic] * (double)currentValue;
                        topicTermMass += score;
                        topicTermScores[index] = score;
                    }
                    double origSample = sample2 = this.random.nextUniform() * (this.smoothingOnlyMass + topicBetaMass + topicTermMass);
                    newTopic = -1;
                    if (sample2 < topicTermMass) {
                        i = -1;
                        while (sample2 > 0.0) {
                            sample2 -= topicTermScores[++i];
                        }
                        newTopic = currentTypeTopicCounts[i] & this.topicMask;
                    } else if ((sample2 -= topicTermMass) < topicBetaMass) {
                        sample2 /= this.beta;
                        for (denseIndex = 0; denseIndex < nonZeroTopics; ++denseIndex) {
                            int topic = localTopicIndex[denseIndex];
                            if (!((sample2 -= (double)localTopicCounts[topic] / ((double)this.tokensPerTopic[topic] + this.betaSum)) <= 0.0)) continue;
                            newTopic = topic;
                            break;
                        }
                    } else {
                        sample2 -= topicBetaMass;
                        sample2 /= this.beta;
                        newTopic = 0;
                        sample2 -= this.alpha[newTopic] / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
                        while (sample2 > 0.0) {
                            sample2 -= this.alpha[++newTopic] / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
                        }
                    }
                    if (newTopic == -1) {
                        System.err.println("sampling error: " + origSample + " " + sample2 + " " + this.smoothingOnlyMass + " " + topicBetaMass + " " + topicTermMass);
                        newTopic = this.numTopics - 1;
                    }
                    oneDocTopics[position] = newTopic;
                    topicBetaMass -= this.beta * (double)localTopicCounts[newTopic] / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
                    int n2 = newTopic;
                    localTopicCounts[n2] = localTopicCounts[n2] + 1;
                    if (localTopicCounts[newTopic] == 1) {
                        for (denseIndex = nonZeroTopics; denseIndex > 0 && localTopicIndex[denseIndex - 1] > newTopic; --denseIndex) {
                            localTopicIndex[denseIndex] = localTopicIndex[denseIndex - 1];
                        }
                        localTopicIndex[denseIndex] = newTopic;
                        ++nonZeroTopics;
                    }
                    this.cachedCoefficients[newTopic] = (this.alpha[newTopic] + (double)localTopicCounts[newTopic]) / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
                    topicBetaMass += this.beta * (double)localTopicCounts[newTopic] / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
                }
            }
            if ((type = tokenSequence.getIndexAtPosition(limit)) >= this.typeTopicCounts.length || this.typeTopicCounts[type] == null) continue;
            currentTypeTopicCounts = this.typeTopicCounts[type];
            topicTermMass = 0.0;
            for (int index = 0; index < currentTypeTopicCounts.length && currentTypeTopicCounts[index] > 0; ++index) {
                int currentTopic = currentTypeTopicCounts[index] & this.topicMask;
                int currentValue = currentTypeTopicCounts[index] >> this.topicBits;
                score = this.cachedCoefficients[currentTopic] * (double)currentValue;
                topicTermMass += score;
                topicTermScores[index] = score;
            }
            double origSample = sample = this.random.nextUniform() * (this.smoothingOnlyMass + topicBetaMass + topicTermMass);
            int n = limit;
            wordProbabilities[n] = wordProbabilities[n] + (this.smoothingOnlyMass + topicBetaMass + topicTermMass) / (this.alphaSum + (double)tokensSoFar);
            ++tokensSoFar;
            newTopic = -1;
            if (sample < topicTermMass) {
                i = -1;
                while (sample > 0.0) {
                    sample -= topicTermScores[++i];
                }
                newTopic = currentTypeTopicCounts[i] & this.topicMask;
            } else if ((sample -= topicTermMass) < topicBetaMass) {
                sample /= this.beta;
                for (denseIndex = 0; denseIndex < nonZeroTopics; ++denseIndex) {
                    int topic = localTopicIndex[denseIndex];
                    if (!((sample -= (double)localTopicCounts[topic] / ((double)this.tokensPerTopic[topic] + this.betaSum)) <= 0.0)) continue;
                    newTopic = topic;
                    break;
                }
            } else {
                sample -= topicBetaMass;
                sample /= this.beta;
                newTopic = 0;
                sample -= this.alpha[newTopic] / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
                while (sample > 0.0) {
                    sample -= this.alpha[++newTopic] / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
                }
            }
            if (newTopic == -1) {
                System.err.println("sampling error: " + origSample + " " + sample + " " + this.smoothingOnlyMass + " " + topicBetaMass + " " + topicTermMass);
                newTopic = this.numTopics - 1;
            }
            oneDocTopics[limit] = newTopic;
            topicBetaMass -= this.beta * (double)localTopicCounts[newTopic] / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
            int n3 = newTopic;
            localTopicCounts[n3] = localTopicCounts[n3] + 1;
            if (localTopicCounts[newTopic] == 1) {
                for (denseIndex = nonZeroTopics; denseIndex > 0 && localTopicIndex[denseIndex - 1] > newTopic; --denseIndex) {
                    localTopicIndex[denseIndex] = localTopicIndex[denseIndex - 1];
                }
                localTopicIndex[denseIndex] = newTopic;
                ++nonZeroTopics;
            }
            this.cachedCoefficients[newTopic] = (this.alpha[newTopic] + (double)localTopicCounts[newTopic]) / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
            topicBetaMass += this.beta * (double)localTopicCounts[newTopic] / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
        }
        for (denseIndex = 0; denseIndex < nonZeroTopics; ++denseIndex) {
            int topic = localTopicIndex[denseIndex];
            this.cachedCoefficients[topic] = this.alpha[topic] / ((double)this.tokensPerTopic[topic] + this.betaSum);
        }
        return wordProbabilities;
    }

    private void writeObject(ObjectOutputStream out) throws IOException {
        out.writeInt(0);
        out.writeInt(this.numTopics);
        out.writeInt(this.topicMask);
        out.writeInt(this.topicBits);
        out.writeObject(this.alpha);
        out.writeDouble(this.alphaSum);
        out.writeDouble(this.beta);
        out.writeDouble(this.betaSum);
        out.writeObject(this.typeTopicCounts);
        out.writeObject(this.tokensPerTopic);
        out.writeObject(this.random);
        out.writeDouble(this.smoothingOnlyMass);
        out.writeObject(this.cachedCoefficients);
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        int version = in.readInt();
        this.numTopics = in.readInt();
        this.topicMask = in.readInt();
        this.topicBits = in.readInt();
        this.alpha = (double[])in.readObject();
        this.alphaSum = in.readDouble();
        this.beta = in.readDouble();
        this.betaSum = in.readDouble();
        this.typeTopicCounts = (int[][])in.readObject();
        this.tokensPerTopic = (int[])in.readObject();
        this.random = (Randoms)in.readObject();
        this.smoothingOnlyMass = in.readDouble();
        this.cachedCoefficients = (double[])in.readObject();
    }

    public static MarginalProbEstimator read(File f) throws Exception {
        MarginalProbEstimator estimator = null;
        ObjectInputStream ois = new ObjectInputStream(new FileInputStream(f));
        estimator = (MarginalProbEstimator)ois.readObject();
        ois.close();
        return estimator;
    }
}

