/*
 * Decompiled with CFR 0.152.
 */
package com.googlecode.clearnlp.clustering;

import com.carrotsearch.hppc.IntOpenHashSet;
import com.carrotsearch.hppc.ObjectIntOpenHashMap;
import com.carrotsearch.hppc.cursors.IntCursor;
import com.googlecode.clearnlp.dependency.DEPTree;
import com.googlecode.clearnlp.pos.POSNode;
import com.googlecode.clearnlp.util.pair.IntDoublePair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.Set;

public class Kmeans {
    private final int RAND_SEED = 0;
    private int K;
    private int N;
    private int D;
    private ObjectIntOpenHashMap<String> m_lexica = new ObjectIntOpenHashMap();
    private List<int[]> v_units = new ArrayList<int[]>();
    private double[] d_centroid;
    private double[] d_scala;

    public void addUnit(Set<String> lexica) {
        int i = 0;
        int size = lexica.size();
        int[] unit = new int[size];
        for (String lexicon : lexica) {
            int index;
            if (this.m_lexica.containsKey((Object)lexicon)) {
                index = this.m_lexica.get((Object)lexicon);
            } else {
                index = this.m_lexica.size();
                this.m_lexica.put((Object)lexicon, index);
            }
            unit[i++] = index;
        }
        Arrays.sort(unit);
        this.v_units.add(unit);
    }

    public void addUnit(POSNode[] nodes) {
        HashSet<String> lexica = new HashSet<String>();
        for (POSNode node : nodes) {
            lexica.add(node.lemma);
        }
        this.addUnit(lexica);
    }

    public void addUnit(DEPTree tree) {
        HashSet<String> lexica = new HashSet<String>();
        int size = tree.size();
        for (int i = 1; i < size; ++i) {
            lexica.add(tree.get((int)i).lemma);
        }
        this.addUnit(lexica);
    }

    public List<List<IntDoublePair>> cluster(int k, double threshold) {
        List<List<IntDoublePair>> currCluster = null;
        List<List<IntDoublePair>> prevCluster = null;
        double prevRss = -1.0;
        this.K = k;
        this.N = this.v_units.size();
        this.D = this.m_lexica.size();
        this.initCentroids();
        int max = this.N / this.K;
        for (int iter = 0; iter < max; ++iter) {
            System.out.printf("===== Iteration: %d =====\n", iter);
            currCluster = this.getClusters();
            this.updateCentroids(currCluster);
            double currRss = this.getRSS(currCluster);
            if (prevRss >= currRss) {
                return prevCluster;
            }
            if (currRss >= threshold) break;
            prevRss = currRss;
            prevCluster = currCluster;
        }
        return currCluster;
    }

    private void initCentroids() {
        IntOpenHashSet set = new IntOpenHashSet();
        Random rand = new Random(0L);
        this.d_centroid = new double[this.K * this.D];
        this.d_scala = new double[this.K];
        while (set.size() < this.K) {
            set.add(rand.nextInt(this.N));
        }
        int k = 0;
        for (IntCursor cur : set) {
            int[] unit;
            for (int index : unit = this.v_units.get(cur.value)) {
                this.d_centroid[this.getCentroidIndex((int)k, (int)index)] = 1.0;
            }
            this.d_scala[k++] = Math.sqrt(unit.length);
        }
    }

    private void updateCentroids(List<List<IntDoublePair>> cluster) {
        Arrays.fill(this.d_centroid, 0.0);
        Arrays.fill(this.d_scala, 0.0);
        System.out.print("Updating centroids: ");
        for (int k = 0; k < this.K; ++k) {
            List<IntDoublePair> ck = cluster.get(k);
            for (IntDoublePair p : ck) {
                for (int index : this.v_units.get(p.i)) {
                    int n = this.getCentroidIndex(k, index);
                    this.d_centroid[n] = this.d_centroid[n] + 1.0;
                }
            }
            int size = ck.size();
            double scala = 0.0;
            for (int i = k * this.D; i < (k + 1) * this.D; ++i) {
                if (!(this.d_centroid[i] > 0.0)) continue;
                int n = i;
                this.d_centroid[n] = this.d_centroid[n] / (double)size;
                scala += this.d_centroid[i] * this.d_centroid[i];
            }
            this.d_scala[k] = Math.sqrt(scala);
            System.out.print(".");
        }
        System.out.println();
    }

    private List<List<IntDoublePair>> getClusters() {
        int k;
        ArrayList<List<IntDoublePair>> cluster = new ArrayList<List<IntDoublePair>>(this.K);
        IntDoublePair max = new IntDoublePair(-1, -1.0);
        for (k = 0; k < this.K; ++k) {
            cluster.add(new ArrayList());
        }
        System.out.print("Clustering: ");
        for (int i = 0; i < this.N; ++i) {
            int[] unit = this.v_units.get(i);
            max.set(-1, -1.0);
            for (k = 0; k < this.K; ++k) {
                double d;
                double sim = this.cosine(unit, k);
                if (!(d > max.d)) continue;
                max.set(k, sim);
            }
            ((List)cluster.get(max.i)).add(new IntDoublePair(i, max.d));
            if (i % 10000 != 0) continue;
            System.out.print(".");
        }
        System.out.println();
        for (k = 0; k < this.K; ++k) {
            System.out.printf("- %4d: %d\n", k, ((List)cluster.get(k)).size());
        }
        return cluster;
    }

    private int getCentroidIndex(int k, int index) {
        return k * this.D + index;
    }

    private double getRSS(List<List<IntDoublePair>> cluster) {
        double sim = 0.0;
        System.out.print("Calulating RSS: ");
        for (int k = 0; k < this.K; ++k) {
            for (IntDoublePair tup : cluster.get(k)) {
                sim += this.cosine(this.v_units.get(tup.i), k);
            }
            System.out.print(".");
        }
        System.out.println();
        System.out.println("RSS = " + (sim /= (double)this.N));
        return sim / (double)this.N;
    }

    private double cosine(int[] unit, int k) {
        double dot = 0.0;
        for (int index : unit) {
            dot += this.d_centroid[this.getCentroidIndex(k, index)];
        }
        return dot / (Math.sqrt(unit.length) * this.d_scala[k]);
    }
}

