/*
 * Decompiled with CFR 0.152.
 */
package com.googlecode.clearnlp.classification.algorithm;

import com.carrotsearch.hppc.IntArrayList;
import com.googlecode.clearnlp.classification.algorithm.AbstractAlgorithm;
import com.googlecode.clearnlp.classification.prediction.IntPrediction;
import com.googlecode.clearnlp.classification.train.AbstractTrainSpace;
import com.googlecode.clearnlp.util.UTArray;
import com.googlecode.clearnlp.util.triple.Triple;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;

public class AdaGrad
extends AbstractAlgorithm {
    protected int n_iter;
    protected Random r_rand;
    protected double d_alpha;
    protected double d_rho;

    public AdaGrad(int iter, double alpha, double rho, Random rand) {
        this.n_iter = iter;
        this.r_rand = rand;
        this.d_alpha = alpha;
        this.d_rho = rho;
    }

    @Override
    public double[] getWeight(AbstractTrainSpace space, int numThreads) {
        double[] weights = new double[space.getFeatureSize() * space.getLabelSize()];
        this.updateWeight(space, weights);
        return weights;
    }

    public void updateWeight(AbstractTrainSpace space) {
        this.updateWeight(space, space.getModel().getWeights());
    }

    public void updateWeight(AbstractTrainSpace space, double[] weights) {
        int D = space.getFeatureSize();
        int L = space.getLabelSize();
        int N = space.getInstanceSize();
        double[] gs = new double[D * L];
        IntArrayList ys = space.getYs();
        ArrayList<int[]> xs = space.getXs();
        ArrayList<double[]> vs = space.getVs();
        double[] vi = null;
        for (int i = 0; i < this.n_iter; ++i) {
            int[] indices = this.getShuffledIndices(N);
            Arrays.fill(gs, 0.0);
            int sum = 0;
            for (int j = 0; j < N; ++j) {
                int yi = ys.get(indices[j]);
                int[] xi = xs.get(indices[j]);
                if (space.hasWeight()) {
                    vi = vs.get(indices[j]);
                }
                Triple<IntPrediction, IntPrediction, IntPrediction> ps = this.getPredictions(L, yi, xi, vi, weights);
                IntPrediction fst = (IntPrediction)ps.o1;
                IntPrediction snd = (IntPrediction)ps.o2;
                if (fst.label == yi) {
                    if (fst.score - snd.score < 1.0) {
                        this.updateCounts(L, gs, yi, snd.label, xi, vi);
                        this.updateWeights(L, gs, yi, snd.label, xi, vi, weights);
                        continue;
                    }
                    ++sum;
                    continue;
                }
                this.updateCounts(L, gs, yi, fst.label, xi, vi);
                this.updateWeights(L, gs, yi, fst.label, xi, vi, weights);
            }
            double acc = 100.0 * (double)sum / (double)N;
            System.out.printf("- %3d: acc = %7.4f\n", i + 1, acc);
        }
    }

    private int[] getShuffledIndices(int N) {
        int i;
        int[] indices = new int[N];
        for (i = 0; i < N; ++i) {
            indices[i] = i;
        }
        for (i = 0; i < N; ++i) {
            int j = i + this.r_rand.nextInt(N - i);
            UTArray.swap(indices, i, j);
        }
        return indices;
    }

    protected IntPrediction getPrediction(int L, int y, int[] x, double[] v, double[] weights) {
        int label;
        int i;
        double[] scores = new double[L];
        int size = x.length;
        Arrays.fill(scores, 1.0);
        scores[y] = 0.0;
        if (v != null) {
            for (i = 0; i < size; ++i) {
                for (label = 0; label < L; ++label) {
                    int n = label;
                    scores[n] = scores[n] + weights[this.getWeightIndex(L, label, x[i])] * v[i];
                }
            }
        } else {
            for (i = 0; i < size; ++i) {
                for (label = 0; label < L; ++label) {
                    int n = label;
                    scores[n] = scores[n] + weights[this.getWeightIndex(L, label, x[i])];
                }
            }
        }
        IntPrediction max = new IntPrediction(0, scores[0]);
        for (label = 1; label < L; ++label) {
            if (!(max.score < scores[label])) continue;
            max.set(label, scores[label]);
        }
        return max;
    }

    protected Triple<IntPrediction, IntPrediction, IntPrediction> getPredictions(int L, int y, int[] x, double[] v, double[] weights) {
        IntPrediction snd;
        IntPrediction fst;
        int label;
        int i;
        double[] scores = new double[L];
        int size = x.length;
        if (v != null) {
            for (i = 0; i < size; ++i) {
                for (label = 0; label < L; ++label) {
                    int n = label;
                    scores[n] = scores[n] + weights[this.getWeightIndex(L, label, x[i])] * v[i];
                }
            }
        } else {
            for (i = 0; i < size; ++i) {
                for (label = 0; label < L; ++label) {
                    int n = label;
                    scores[n] = scores[n] + weights[this.getWeightIndex(L, label, x[i])];
                }
            }
        }
        if (scores[0] > scores[1]) {
            fst = new IntPrediction(0, scores[0]);
            snd = new IntPrediction(1, scores[1]);
        } else {
            fst = new IntPrediction(1, scores[1]);
            snd = new IntPrediction(0, scores[0]);
        }
        for (label = 2; label < L; ++label) {
            if (fst.score < scores[label]) {
                snd.set(fst.label, fst.score);
                fst.set(label, scores[label]);
                continue;
            }
            if (!(snd.score < scores[label])) continue;
            snd.set(label, scores[label]);
        }
        return new Triple<IntPrediction, IntPrediction, IntPrediction>(fst, snd, new IntPrediction(y, scores[y]));
    }

    protected void updateCounts(int L, double[] gs, int yp, int yn, int[] x, double[] v) {
        int len = x.length;
        if (v != null) {
            for (int i = 0; i < len; ++i) {
                double d = v[i] * v[i];
                int n = this.getWeightIndex(L, yp, x[i]);
                gs[n] = gs[n] + d;
                int n2 = this.getWeightIndex(L, yn, x[i]);
                gs[n2] = gs[n2] + d;
            }
        } else {
            for (int i = 0; i < len; ++i) {
                int n = this.getWeightIndex(L, yp, x[i]);
                gs[n] = gs[n] + 1.0;
                int n3 = this.getWeightIndex(L, yn, x[i]);
                gs[n3] = gs[n3] + 1.0;
            }
        }
    }

    protected void updateWeights(int L, double[] gs, int yp, int yn, int[] x, double[] v, double[] weights) {
        int len = x.length;
        if (v != null) {
            for (int i = 0; i < len; ++i) {
                int xi = x[i];
                double vi = v[i];
                int n = this.getWeightIndex(L, yp, xi);
                weights[n] = weights[n] + this.getUpdate(L, gs, yp, xi) * vi;
                int n2 = this.getWeightIndex(L, yn, xi);
                weights[n2] = weights[n2] - this.getUpdate(L, gs, yn, xi) * vi;
            }
        } else {
            for (int i = 0; i < len; ++i) {
                int xi = x[i];
                int n = this.getWeightIndex(L, yp, xi);
                weights[n] = weights[n] + this.getUpdate(L, gs, yp, xi);
                int n3 = this.getWeightIndex(L, yn, xi);
                weights[n3] = weights[n3] - this.getUpdate(L, gs, yn, xi);
            }
        }
    }

    protected double getUpdate(int L, double[] gs, int y, int x) {
        return this.d_alpha / (this.d_rho + Math.sqrt(gs[this.getWeightIndex(L, y, x)]));
    }
}

