/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.transform.encode;

import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.Callable;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.data.SparseRowVector;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.transform.encode.Encoder;
import org.apache.sysds.runtime.transform.encode.EncoderFactory;
import org.apache.sysds.runtime.util.DependencyTask;
import org.apache.sysds.runtime.util.DependencyThreadPool;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.utils.Statistics;

public abstract class ColumnEncoder
implements Externalizable,
Encoder,
Comparable<ColumnEncoder> {
    protected static final Log LOG = LogFactory.getLog((String)ColumnEncoder.class.getName());
    protected static final int APPLY_ROW_BLOCKS_PER_COLUMN = 1;
    public static int BUILD_ROW_BLOCKS_PER_COLUMN = 1;
    private static final long serialVersionUID = 2299156350718979064L;
    protected int _colID;
    protected Set<Integer> _sparseRowsWZeros = null;

    protected ColumnEncoder(int colID) {
        this._colID = colID;
    }

    @Override
    public MatrixBlock apply(CacheBlock in, MatrixBlock out, int outputCol) {
        return this.apply(in, out, outputCol, 0, -1);
    }

    public MatrixBlock apply(CacheBlock in, MatrixBlock out, int outputCol, int rowStart, int blk) {
        long t0;
        long l = t0 = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        if (out.isInSparseFormat()) {
            this.applySparse(in, out, outputCol, rowStart, blk);
        } else {
            this.applyDense(in, out, outputCol, rowStart, blk);
        }
        if (DMLScript.STATISTICS) {
            long t = System.nanoTime() - t0;
            switch (this.getTransformType()) {
                case RECODE: {
                    Statistics.incTransformRecodeApplyTime(t);
                    break;
                }
                case BIN: {
                    Statistics.incTransformBinningApplyTime(t);
                    break;
                }
                case DUMMYCODE: {
                    Statistics.incTransformDummyCodeApplyTime(t);
                    break;
                }
                case FEATURE_HASH: {
                    Statistics.incTransformFeatureHashingApplyTime(t);
                    break;
                }
                case PASS_THROUGH: {
                    Statistics.incTransformPassThroughApplyTime(t);
                    break;
                }
            }
        }
        return out;
    }

    protected abstract double getCode(CacheBlock var1, int var2);

    protected void applySparse(CacheBlock in, MatrixBlock out, int outputCol, int rowStart, int blk) {
        int index = this._colID - 1;
        for (int r = rowStart; r < UtilFunctions.getEndIndex(in.getNumRows(), rowStart, blk); ++r) {
            SparseRowVector row = (SparseRowVector)out.getSparseBlock().get(r);
            row.values()[index] = this.getCode(in, r);
            row.indexes()[index] = outputCol;
        }
    }

    protected void applyDense(CacheBlock in, MatrixBlock out, int outputCol, int rowStart, int blk) {
        for (int i = rowStart; i < UtilFunctions.getEndIndex(in.getNumRows(), rowStart, blk); ++i) {
            out.quickSetValue(i, outputCol, this.getCode(in, i));
        }
    }

    protected abstract TransformType getTransformType();

    public boolean isApplicable() {
        return this._colID != -1;
    }

    public boolean isApplicable(int colID) {
        return colID == this._colID;
    }

    @Override
    public void prepareBuildPartial() {
    }

    @Override
    public void buildPartial(FrameBlock in) {
    }

    public void mergeAt(ColumnEncoder other) {
        throw new DMLRuntimeException(this.getClass().getSimpleName() + " does not support merging with " + other.getClass().getSimpleName());
    }

    @Override
    public void updateIndexRanges(long[] beginDims, long[] endDims, int colOffset) {
    }

    public MatrixBlock getColMapping(FrameBlock meta) {
        return null;
    }

    @Override
    public void writeExternal(ObjectOutput os) throws IOException {
        os.writeInt(this._colID);
    }

    @Override
    public void readExternal(ObjectInput in) throws IOException {
        this._colID = in.readInt();
    }

    public int getColID() {
        return this._colID;
    }

    public void setColID(int colID) {
        this._colID = colID;
    }

    public void shiftCol(int columnOffset) {
        this._colID += columnOffset;
    }

    @Override
    public int compareTo(ColumnEncoder o) {
        return Integer.compare(EncoderFactory.getEncoderType(this), EncoderFactory.getEncoderType(o));
    }

    public List<DependencyTask<?>> getBuildTasks(CacheBlock in) {
        ArrayList<Callable<Object>> tasks = new ArrayList<Callable<Object>>();
        ArrayList<Object> dep = null;
        int nRows = in.getNumRows();
        int[] blockSizes = UtilFunctions.getBlockSizes(nRows, this.getNumBuildRowPartitions());
        if (blockSizes.length == 1) {
            tasks.add(this.getBuildTask(in));
        } else {
            HashMap<Integer, Object> ret = new HashMap<Integer, Object>();
            int startRow = 0;
            for (int i = 0; i < blockSizes.length; ++i) {
                tasks.add(this.getPartialBuildTask(in, startRow, blockSizes[i], ret));
                startRow += blockSizes[i];
            }
            tasks.add(this.getPartialMergeBuildTask(ret));
            dep = new ArrayList<Object>(Collections.nCopies(tasks.size() - 1, null));
            dep.add(tasks.subList(0, tasks.size() - 1));
        }
        return DependencyThreadPool.createDependencyTasks(tasks, dep);
    }

    public Callable<Object> getBuildTask(CacheBlock in) {
        throw new DMLRuntimeException("Trying to get the Build task of an Encoder which does not require building");
    }

    public Callable<Object> getPartialBuildTask(CacheBlock in, int startRow, int blockSize, HashMap<Integer, Object> ret) {
        throw new DMLRuntimeException("Trying to get the PartialBuild task of an Encoder which does not support  partial building");
    }

    public Callable<Object> getPartialMergeBuildTask(HashMap<Integer, ?> ret) {
        throw new DMLRuntimeException("Trying to get the BuildMergeTask task of an Encoder which does not support partial building");
    }

    public List<DependencyTask<?>> getApplyTasks(CacheBlock in, MatrixBlock out, int outputCol) {
        ArrayList<Callable<Object>> tasks = new ArrayList<Callable<Object>>();
        ArrayList<Object> dep = null;
        int[] blockSizes = UtilFunctions.getBlockSizes(in.getNumRows(), this.getNumApplyRowPartitions());
        int startRow = 0;
        for (int i = 0; i < blockSizes.length; ++i) {
            if (out.isInSparseFormat()) {
                tasks.add(this.getSparseTask(in, out, outputCol, startRow, blockSizes[i]));
            } else {
                tasks.add(this.getDenseTask(in, out, outputCol, startRow, blockSizes[i]));
            }
            startRow += blockSizes[i];
        }
        if (tasks.size() > 1) {
            dep = new ArrayList<Object>(Collections.nCopies(tasks.size(), null));
            tasks.add(() -> null);
            dep.add(tasks.subList(0, tasks.size() - 1));
        }
        return DependencyThreadPool.createDependencyTasks(tasks, dep);
    }

    protected ColumnApplyTask<? extends ColumnEncoder> getSparseTask(CacheBlock in, MatrixBlock out, int outputCol, int startRow, int blk) {
        return new ColumnApplyTask<ColumnEncoder>(this, in, out, outputCol, startRow, blk);
    }

    protected ColumnApplyTask<? extends ColumnEncoder> getDenseTask(CacheBlock in, MatrixBlock out, int outputCol, int startRow, int blk) {
        return new ColumnApplyTask<ColumnEncoder>(this, in, out, outputCol, startRow, blk);
    }

    public Set<Integer> getSparseRowsWZeros() {
        return this._sparseRowsWZeros;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    protected void addSparseRowsWZeros(Set<Integer> sparseRowsWZeros) {
        ColumnEncoder columnEncoder = this;
        synchronized (columnEncoder) {
            if (this._sparseRowsWZeros == null) {
                this._sparseRowsWZeros = new HashSet<Integer>();
            }
            this._sparseRowsWZeros.addAll(sparseRowsWZeros);
        }
    }

    protected int getNumApplyRowPartitions() {
        return ConfigurationManager.getParallelApplyBlocks();
    }

    protected int getNumBuildRowPartitions() {
        return ConfigurationManager.getParallelBuildBlocks();
    }

    protected static class ColumnApplyTask<T extends ColumnEncoder>
    implements Callable<Object> {
        protected final T _encoder;
        protected final CacheBlock _input;
        protected final MatrixBlock _out;
        protected final int _outputCol;
        protected final int _startRow;
        protected final int _blk;

        protected ColumnApplyTask(T encoder, CacheBlock input, MatrixBlock out, int outputCol) {
            this(encoder, input, out, outputCol, 0, -1);
        }

        protected ColumnApplyTask(T encoder, CacheBlock input, MatrixBlock out, int outputCol, int startRow, int blk) {
            this._encoder = encoder;
            this._input = input;
            this._out = out;
            this._outputCol = outputCol;
            this._startRow = startRow;
            this._blk = blk;
        }

        @Override
        public Object call() throws Exception {
            assert (this._outputCol >= 0);
            ((ColumnEncoder)this._encoder).apply(this._input, this._out, this._outputCol, this._startRow, this._blk);
            return null;
        }

        public String toString() {
            return this.getClass().getSimpleName() + "<Encoder: " + this._encoder.getClass().getSimpleName() + "; ColId: " + ((ColumnEncoder)this._encoder)._colID + ">";
        }
    }

    public static enum EncoderType {
        Recode,
        FeatureHash,
        PassThrough,
        Bin,
        Dummycode,
        Omit,
        MVImpute,
        Composite;

    }

    protected static enum TransformType {
        BIN,
        RECODE,
        DUMMYCODE,
        FEATURE_HASH,
        PASS_THROUGH,
        N_A;

    }
}

