/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.util;

import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.IDSorter;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.util.CommandOption;
import cc.mallet.util.MalletLogger;
import gnu.trove.TIntIntHashMap;
import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.text.NumberFormat;
import java.util.Arrays;
import java.util.logging.Logger;

public class FeatureCooccurrenceCounter {
    private static Logger logger = MalletLogger.getLogger(FeatureCooccurrenceCounter.class.getName());
    static CommandOption.String inputFile = new CommandOption.String(FeatureCooccurrenceCounter.class, "input", "FILENAME", true, null, "The filename from which to read the list of training instances.  Use - for stdin.  The instances must be FeatureSequence or FeatureSequenceWithBigrams, not FeatureVector", null);
    static CommandOption.String weightsFile = new CommandOption.String(FeatureCooccurrenceCounter.class, "weights-filename", "FILENAME", true, null, "The filename to write the word-word weights file.", null);
    static CommandOption.Double idfCutoff = new CommandOption.Double(FeatureCooccurrenceCounter.class, "idf-cutoff", "NUMBER", true, 3.0, "Words with IDF below this threshold will not be linked to any other word.", null);
    static CommandOption.String unlinkedFile = new CommandOption.String(FeatureCooccurrenceCounter.class, "unlinked-filename", "FILENAME", true, null, "A file to write words that were not linked.", null);
    TIntIntHashMap[] featureFeatureCounts;
    InstanceList instances;
    int numFeatures;
    int[] documentFrequencies;

    public FeatureCooccurrenceCounter(InstanceList instances) {
        this.instances = instances;
        this.numFeatures = instances.getDataAlphabet().size();
        this.featureFeatureCounts = new TIntIntHashMap[this.numFeatures];
        for (int feature = 0; feature < this.numFeatures; ++feature) {
            this.featureFeatureCounts[feature] = new TIntIntHashMap();
        }
        this.documentFrequencies = new int[this.numFeatures];
    }

    public void count() {
        TIntIntHashMap featureCounts = new TIntIntHashMap();
        int index = 0;
        for (Instance instance : this.instances) {
            FeatureSequence features = (FeatureSequence)instance.getData();
            for (int i = 0; i < features.getLength(); ++i) {
                featureCounts.adjustOrPutValue(features.getIndexAtPosition(i), 1, 1);
            }
            int[] keys = featureCounts.keys();
            for (int i = 0; i < keys.length - 1; ++i) {
                int leftFeature = keys[i];
                for (int j = i + 1; j < keys.length; ++j) {
                    int rightFeature = keys[j];
                    this.featureFeatureCounts[leftFeature].adjustOrPutValue(rightFeature, 1, 1);
                    this.featureFeatureCounts[rightFeature].adjustOrPutValue(leftFeature, 1, 1);
                }
            }
            int[] nArray = keys;
            int n = nArray.length;
            for (int i = 0; i < n; ++i) {
                int key;
                int n2 = key = nArray[i];
                this.documentFrequencies[n2] = this.documentFrequencies[n2] + 1;
            }
            featureCounts = new TIntIntHashMap();
            if (++index % 1000 != 0) continue;
            System.err.println(index);
        }
    }

    public double g2(double left, double right, double both, double total) {
        double justLeft = left - both + 0.01;
        double justRight = right - both + 0.01;
        double neither = total - left - right + (both += 0.01) + 0.01;
        double leftMarginalProb = (justLeft + both) / (total += 0.04);
        double rightMarginalProb = (justRight + both) / total;
        double logLeft = Math.log(leftMarginalProb);
        double logRight = Math.log(rightMarginalProb);
        double logNotLeft = Math.log(1.0 - leftMarginalProb);
        double logNotRight = Math.log(1.0 - rightMarginalProb);
        double g2 = both * (Math.log(both / total) - logLeft - logRight) + justLeft * (Math.log(justLeft / total) - logLeft - logNotRight) + justRight * (Math.log(justRight / total) - logNotLeft - logRight) + neither * (Math.log(neither / total) - logNotLeft - logNotRight);
        return g2;
    }

    public void printCounts() throws IOException {
        int feature;
        NumberFormat formatter = NumberFormat.getInstance();
        formatter.setMaximumFractionDigits(3);
        Alphabet alphabet = this.instances.getDataAlphabet();
        double logTotalDocs = Math.log(this.instances.size());
        double[] logCache = new double[this.instances.size() + 1];
        for (int n = 1; n < logCache.length; ++n) {
            logCache[n] = Math.log(n);
        }
        if (FeatureCooccurrenceCounter.unlinkedFile.value != null) {
            PrintWriter out = new PrintWriter(FeatureCooccurrenceCounter.unlinkedFile.value);
            for (feature = 0; feature < this.numFeatures; ++feature) {
                double featureIDF = logTotalDocs - logCache[this.documentFrequencies[feature]];
                if (!(featureIDF < FeatureCooccurrenceCounter.idfCutoff.value)) continue;
                out.println(alphabet.lookupObject(feature));
            }
            out.close();
        }
        PrintWriter out = new PrintWriter(FeatureCooccurrenceCounter.weightsFile.value);
        for (feature = 0; feature < this.numFeatures; ++feature) {
            TIntIntHashMap featureCounts = this.featureFeatureCounts[feature];
            int[] keys = featureCounts.keys();
            double featureIDF = logTotalDocs - logCache[this.documentFrequencies[feature]];
            StringBuilder output = new StringBuilder();
            output.append(alphabet.lookupObject(feature));
            output.append("\t");
            output.append("1.0");
            if (this.documentFrequencies[feature] <= 5) {
                out.println(output);
                continue;
            }
            if (featureIDF - FeatureCooccurrenceCounter.idfCutoff.value > 0.0) {
                Object[] sortedWeights = new IDSorter[keys.length];
                int i = 0;
                for (int key : keys) {
                    double keyIDF = logTotalDocs - logCache[this.documentFrequencies[key]];
                    sortedWeights[i] = keyIDF - FeatureCooccurrenceCounter.idfCutoff.value > 0.0 ? new IDSorter(key, (keyIDF - FeatureCooccurrenceCounter.idfCutoff.value) / (featureIDF - FeatureCooccurrenceCounter.idfCutoff.value) * ((double)featureCounts.get(key) / (double)this.documentFrequencies[feature])) : new IDSorter(key, 0);
                    ++i;
                }
                Arrays.sort(sortedWeights);
                for (i = 0; i < 10 && i < sortedWeights.length; ++i) {
                    int key = ((IDSorter)sortedWeights[i]).getID();
                    Object word = alphabet.lookupObject(((IDSorter)sortedWeights[i]).getID());
                    double weight = ((IDSorter)sortedWeights[i]).getWeight();
                    if (weight < 0.05) break;
                    output.append("\t" + word + "\t" + weight);
                }
            }
            out.println(output);
        }
        out.close();
    }

    public static void main(String[] args) throws Exception {
        CommandOption.setSummary(FeatureCooccurrenceCounter.class, "Build a file containing weights between word types");
        CommandOption.process(FeatureCooccurrenceCounter.class, args);
        InstanceList training = InstanceList.load(new File(FeatureCooccurrenceCounter.inputFile.value));
        FeatureCooccurrenceCounter counter = new FeatureCooccurrenceCounter(training);
        counter.count();
        counter.printCounts();
    }
}

