/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.networks.training.propagation.quick;

import org.encog.EncogError;
import org.encog.ml.data.MLDataSet;
import org.encog.neural.networks.ContainsFlat;
import org.encog.neural.networks.training.LearningRate;
import org.encog.neural.networks.training.TrainingError;
import org.encog.neural.networks.training.propagation.Propagation;
import org.encog.neural.networks.training.propagation.TrainingContinuation;
import org.encog.util.EngineArray;
import org.encog.util.validate.ValidateNetwork;

public class QuickPropagation
extends Propagation
implements LearningRate {
    public static final String LAST_GRADIENTS = "LAST_GRADIENTS";
    private double learningRate;
    private double[] lastDelta;
    private double decay = 1.0E-4;
    private double eps;
    private double outputEpsilon = 0.35;
    private double shrink;

    public QuickPropagation(ContainsFlat network, MLDataSet training) {
        this(network, training, 2.0);
    }

    public QuickPropagation(ContainsFlat network, MLDataSet training, double theLearningRate) {
        super(network, training);
        ValidateNetwork.validateMethodToData(network, training);
        this.learningRate = theLearningRate;
        this.lastDelta = new double[this.network.getFlat().getWeights().length];
    }

    @Override
    public boolean canContinue() {
        return false;
    }

    public double[] getLastDelta() {
        return this.lastDelta;
    }

    @Override
    public double getLearningRate() {
        return this.learningRate;
    }

    public boolean isValidResume(TrainingContinuation state) {
        if (!state.getContents().containsKey(LAST_GRADIENTS)) {
            return false;
        }
        if (!state.getTrainingType().equals(this.getClass().getSimpleName())) {
            return false;
        }
        double[] d = (double[])state.get(LAST_GRADIENTS);
        return d.length == ((ContainsFlat)this.getMethod()).getFlat().getWeights().length;
    }

    @Override
    public TrainingContinuation pause() {
        TrainingContinuation result = new TrainingContinuation();
        result.setTrainingType(this.getClass().getSimpleName());
        result.set(LAST_GRADIENTS, this.getLastGradient());
        return result;
    }

    @Override
    public void resume(TrainingContinuation state) {
        if (!this.isValidResume(state)) {
            throw new TrainingError("Invalid training resume data length");
        }
        double[] lastGradient = (double[])state.get(LAST_GRADIENTS);
        EngineArray.arrayCopy(lastGradient, this.getLastGradient());
    }

    @Override
    public void setLearningRate(double rate) {
        this.learningRate = rate;
    }

    public double getOutputEpsilon() {
        return this.outputEpsilon;
    }

    public double getShrink() {
        return this.shrink;
    }

    public void setShrink(double s) {
        this.shrink = s;
    }

    public void setOutputEpsilon(double theOutputEpsilon) {
        this.outputEpsilon = theOutputEpsilon;
    }

    @Override
    public void initOthers() {
        this.eps = this.outputEpsilon / (double)this.getTraining().getRecordCount();
        this.shrink = this.learningRate / (1.0 + this.learningRate);
    }

    @Override
    public double updateWeight(double[] gradients, double[] lastGradient, int index) {
        double w = this.network.getFlat().getWeights()[index];
        double d = this.lastDelta[index];
        double s = -this.gradients[index] + this.decay * w;
        double p = -lastGradient[index];
        double nextStep = 0.0;
        if (d < 0.0) {
            if (s > 0.0) {
                nextStep -= this.eps * s;
            }
            nextStep = s >= this.shrink * p ? (nextStep += this.learningRate * d) : (nextStep += d * s / (p - s));
        } else if (d > 0.0) {
            if (s < 0.0) {
                nextStep -= this.eps * s;
            }
            nextStep = s <= this.shrink * p ? (nextStep += this.learningRate * d) : (nextStep += d * s / (p - s));
        } else {
            nextStep -= this.eps * s;
        }
        this.lastDelta[index] = nextStep;
        this.getLastGradient()[index] = gradients[index];
        return nextStep;
    }

    @Override
    public void setBatchSize(int theBatchSize) {
        if (theBatchSize != 0) {
            throw new EncogError("Online training is not supported for:" + this.getClass().getSimpleName());
        }
    }
}

