/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.colgroup;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import org.apache.commons.lang.NotImplementedException;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed;
import org.apache.sysds.runtime.compress.colgroup.APreAgg;
import org.apache.sysds.runtime.compress.colgroup.ColGroupDDCFOR;
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
import org.apache.sysds.runtime.compress.colgroup.ColGroupSDCSingleZeros;
import org.apache.sysds.runtime.compress.colgroup.ColGroupSDCZeros;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils;
import org.apache.sysds.runtime.compress.colgroup.FORUtil;
import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
import org.apache.sysds.runtime.compress.colgroup.offset.AIterator;
import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.Minus;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;

public class ColGroupDDC
extends APreAgg {
    private static final long serialVersionUID = -5769772089913918987L;
    protected AMapToData _data;

    protected ColGroupDDC(int numRows) {
        super(numRows);
    }

    private ColGroupDDC(int[] colIndexes, int numRows, ADictionary dict, AMapToData data, int[] cachedCounts) {
        super(colIndexes, numRows, dict, cachedCounts);
        if (data.getUnique() != dict.getNumberOfValues(colIndexes.length)) {
            throw new DMLCompressionException("Invalid construction of DDC group " + data.getUnique() + " vs. " + dict.getNumberOfValues(colIndexes.length));
        }
        this._zeros = false;
        this._data = data;
    }

    protected static AColGroup create(int[] colIndexes, int numRows, ADictionary dict, AMapToData data, int[] cachedCounts) {
        if (dict == null) {
            return new ColGroupEmpty(colIndexes);
        }
        return new ColGroupDDC(colIndexes, numRows, dict, data, cachedCounts);
    }

    @Override
    public AColGroup.CompressionType getCompType() {
        return AColGroup.CompressionType.DDC;
    }

    @Override
    protected void decompressToDenseBlockSparseDictionary(DenseBlock db, int rl, int ru, int offR, int offC, SparseBlock sb) {
        int r = rl;
        int offT = rl + offR;
        while (r < ru) {
            int vr = this._data.getIndex(r);
            if (!sb.isEmpty(vr)) {
                double[] c = db.values(offT);
                int off = db.pos(offT) + offC;
                int apos = sb.pos(vr);
                int alen = sb.size(vr) + apos;
                int[] aix = sb.indexes(vr);
                double[] aval = sb.values(vr);
                for (int j = apos; j < alen; ++j) {
                    int n = off + this._colIndexes[aix[j]];
                    c[n] = c[n] + aval[j];
                }
            }
            ++r;
            ++offT;
        }
    }

    @Override
    protected void decompressToDenseBlockDenseDictionary(DenseBlock db, int rl, int ru, int offR, int offC, double[] values) {
        if (db.isContiguous() && this._colIndexes.length == 1) {
            if (db.getDim(1) == 1) {
                this.decompressToDenseBlockDenseDictSingleColOutContiguous(db, rl, ru, offR, offC, values);
            } else {
                this.decompressToDenseBlockDenseDictSingleColContiguous(db, rl, ru, offR, offC, values);
            }
        } else if (db.isContiguous() && this._colIndexes.length == db.getDim(1) && offC == 0) {
            this.decompressToDenseBlockDenseDictAllColumnsContiguous(db, rl, ru, offR, values);
        } else if (db.isContiguous() && offC == 0) {
            this.decompressToDenseBlockDenseDictNoColOffset(db, rl, ru, offR, values);
        } else {
            this.decompressToDenseBlockDenseDictGeneric(db, rl, ru, offR, offC, values);
        }
    }

    private void decompressToDenseBlockDenseDictSingleColContiguous(DenseBlock db, int rl, int ru, int offR, int offC, double[] values) {
        double[] c = db.values(0);
        int nCols = db.getDim(1);
        int colOff = this._colIndexes[0] + offC;
        int i = rl;
        int offT = (rl + offR) * nCols + colOff;
        while (i < ru) {
            int n = offT;
            c[n] = c[n] + values[this._data.getIndex(i)];
            ++i;
            offT += nCols;
        }
    }

    private void decompressToDenseBlockDenseDictSingleColOutContiguous(DenseBlock db, int rl, int ru, int offR, int offC, double[] values) {
        double[] c = db.values(0);
        int offT = rl + offR + this._colIndexes[0] + offC;
        for (int i = rl; i < ru; ++i) {
            int n = offT++;
            c[n] = c[n] + values[this._data.getIndex(i)];
        }
    }

    private void decompressToDenseBlockDenseDictAllColumnsContiguous(DenseBlock db, int rl, int ru, int offR, double[] values) {
        double[] c = db.values(0);
        int nCol = this._colIndexes.length;
        for (int r = rl; r < ru; ++r) {
            int start = this._data.getIndex(r) * nCol;
            int end = start + nCol;
            int offStart = (offR + r) * nCol;
            int off = offStart;
            for (int vOff = start; vOff < end; ++vOff) {
                int n = off++;
                c[n] = c[n] + values[vOff];
            }
        }
    }

    private void decompressToDenseBlockDenseDictNoColOffset(DenseBlock db, int rl, int ru, int offR, double[] values) {
        int nCol = this._colIndexes.length;
        int colOut = db.getDim(1);
        int off = (rl + offR) * colOut;
        int i = rl;
        int offT = rl + offR;
        while (i < ru) {
            double[] c = db.values(offT);
            int rowIndex = this._data.getIndex(i) * nCol;
            for (int j = 0; j < nCol; ++j) {
                int n = off + this._colIndexes[j];
                c[n] = c[n] + values[rowIndex + j];
            }
            ++i;
            off += colOut;
        }
    }

    private void decompressToDenseBlockDenseDictGeneric(DenseBlock db, int rl, int ru, int offR, int offC, double[] values) {
        int nCol = this._colIndexes.length;
        int i = rl;
        int offT = rl + offR;
        while (i < ru) {
            double[] c = db.values(offT);
            int off = db.pos(offT) + offC;
            int rowIndex = this._data.getIndex(i) * nCol;
            for (int j = 0; j < nCol; ++j) {
                int n = off + this._colIndexes[j];
                c[n] = c[n] + values[rowIndex + j];
            }
            ++i;
            ++offT;
        }
    }

    @Override
    protected void decompressToSparseBlockSparseDictionary(SparseBlock ret, int rl, int ru, int offR, int offC, SparseBlock sb) {
        throw new NotImplementedException();
    }

    @Override
    protected void decompressToSparseBlockDenseDictionary(SparseBlock ret, int rl, int ru, int offR, int offC, double[] values) {
        int nCol = this._colIndexes.length;
        int i = rl;
        int offT = rl + offR;
        while (i < ru) {
            int rowIndex = this._data.getIndex(i) * nCol;
            for (int j = 0; j < nCol; ++j) {
                ret.append(offT, this._colIndexes[j] + offC, values[rowIndex + j]);
            }
            ++i;
            ++offT;
        }
    }

    @Override
    public double getIdx(int r, int colIdx) {
        return this._dict.getValue(this._data.getIndex(r) * this._colIndexes.length + colIdx);
    }

    @Override
    protected void computeRowSums(double[] c, int rl, int ru, double[] preAgg) {
        for (int rix = rl; rix < ru; ++rix) {
            int n = rix;
            c[n] = c[n] + preAgg[this._data.getIndex(rix)];
        }
    }

    @Override
    protected void computeRowMxx(double[] c, Builtin builtin, int rl, int ru, double[] preAgg) {
        for (int i = rl; i < ru; ++i) {
            c[i] = builtin.execute(c[i], preAgg[this._data.getIndex(i)]);
        }
    }

    @Override
    public int[] getCounts(int[] counts) {
        return this._data.getCounts(counts);
    }

    @Override
    public void leftMultByMatrixNoPreAgg(MatrixBlock matrix, MatrixBlock result, int rl, int ru, int cl, int cu) {
        if (this._colIndexes.length == 1) {
            this.leftMultByMatrixNoPreAggSingleCol(matrix, result, rl, ru, cl, cu);
        } else {
            this.lmMatrixNoPreAggMultiCol(matrix, result, rl, ru, cl, cu);
        }
    }

    private void leftMultByMatrixNoPreAggSingleCol(MatrixBlock matrix, MatrixBlock result, int rl, int ru, int cl, int cu) {
        double[] retV = result.getDenseBlockValues();
        int nColM = matrix.getNumColumns();
        int nColRet = result.getNumColumns();
        double[] dictVals = this._dict.getValues();
        if (matrix.isInSparseFormat()) {
            this.lmSparseMatrixNoPreAggSingleCol(matrix.getSparseBlock(), nColM, retV, nColRet, dictVals, rl, ru, cl, cu);
        } else {
            this.lmDenseMatrixNoPreAggSingleCol(matrix.getDenseBlockValues(), nColM, retV, nColRet, dictVals, rl, ru, cl, cu);
        }
    }

    private void lmSparseMatrixNoPreAggSingleCol(SparseBlock sb, int nColM, double[] retV, int nColRet, double[] vals, int rl, int ru, int cl, int cu) {
        int colOut = this._colIndexes[0];
        for (int r = rl; r < ru; ++r) {
            if (sb.isEmpty(r)) continue;
            int apos = sb.pos(r);
            int alen = sb.size(r) + apos;
            int[] aix = sb.indexes(r);
            double[] aval = sb.values(r);
            int offR = r * nColRet;
            for (int i = apos; i < alen; ++i) {
                int n = offR + colOut;
                retV[n] = retV[n] + aval[i] * vals[this._data.getIndex(aix[i])];
            }
        }
    }

    private void lmDenseMatrixNoPreAggSingleCol(double[] mV, int nColM, double[] retV, int nColRet, double[] vals, int rl, int ru, int cl, int cu) {
        int colOut = this._colIndexes[0];
        for (int r = rl; r < ru; ++r) {
            int offL = r * nColM;
            int offR = r * nColRet;
            for (int c = cl; c < cu; ++c) {
                int n = offR + colOut;
                retV[n] = retV[n] + mV[offL + c] * vals[this._data.getIndex(r)];
            }
        }
    }

    private void lmMatrixNoPreAggMultiCol(MatrixBlock matrix, MatrixBlock result, int rl, int ru, int cl, int cu) {
        if (matrix.isInSparseFormat()) {
            this.lmSparseMatrixNoPreAggMultiCol(matrix, result, rl, ru, cl, cu);
        } else {
            this.lmDenseMatrixNoPreAggMultiCol(matrix, result, rl, ru, cl, cu);
        }
    }

    private void lmSparseMatrixNoPreAggMultiCol(MatrixBlock matrix, MatrixBlock result, int rl, int ru, int cl, int cu) {
        double[] retV = result.getDenseBlockValues();
        int nColRet = result.getNumColumns();
        SparseBlock sb = matrix.getSparseBlock();
        for (int r = rl; r < ru; ++r) {
            if (sb.isEmpty(r)) continue;
            int apos = sb.pos(r);
            int alen = sb.size(r) + apos;
            int[] aix = sb.indexes(r);
            double[] aval = sb.values(r);
            int offR = r * nColRet;
            for (int i = apos; i < alen; ++i) {
                this._dict.multiplyScalar(aval[i], retV, offR, this._data.getIndex(aix[i]), this._colIndexes);
            }
        }
    }

    private void lmDenseMatrixNoPreAggMultiCol(MatrixBlock matrix, MatrixBlock result, int rl, int ru, int cl, int cu) {
        double[] retV = result.getDenseBlockValues();
        int nColM = matrix.getNumColumns();
        int nColRet = result.getNumColumns();
        double[] mV = matrix.getDenseBlockValues();
        for (int r = rl; r < ru; ++r) {
            int offL = r * nColM;
            int offR = r * nColRet;
            for (int c = cl; c < cu; ++c) {
                this._dict.multiplyScalar(mV[offL + c], retV, offR, this._data.getIndex(c), this._colIndexes);
            }
        }
    }

    @Override
    public void preAggregateDense(MatrixBlock m, double[] preAgg, int rl, int ru, int cl, int cu) {
        this._data.preAggregateDense(m, preAgg, rl, ru, cl, cu);
    }

    @Override
    public void preAggregateSparse(SparseBlock sb, double[] preAgg, int rl, int ru) {
        this._data.preAggregateSparse(sb, preAgg, rl, ru);
    }

    @Override
    public void preAggregateThatDDCStructure(ColGroupDDC that, Dictionary ret) {
        this._data.preAggregateDDC_DDC(that._data, that._dict, ret, that._colIndexes.length);
    }

    @Override
    public void preAggregateThatSDCZerosStructure(ColGroupSDCZeros that, Dictionary ret) {
        this._data.preAggregateDDC_SDCZ(that._data, that._dict, that._indexes, ret, that._colIndexes.length);
    }

    @Override
    public void preAggregateThatSDCSingleZerosStructure(ColGroupSDCSingleZeros that, Dictionary ret) {
        AIterator itThat = that._indexes.getIterator();
        int nCol = that._colIndexes.length;
        int finalOff = that._indexes.getOffsetToLast();
        double[] v = ret.getValues();
        while (true) {
            int to = this._data.getIndex(itThat.value());
            that._dict.addToEntry(v, 0, to, nCol);
            if (itThat.value() == finalOff) break;
            itThat.next();
        }
    }

    @Override
    public boolean sameIndexStructure(AColGroupCompressed that) {
        return that instanceof ColGroupDDC && ((ColGroupDDC)that)._data == this._data;
    }

    @Override
    public AColGroup.ColGroupType getColGroupType() {
        return AColGroup.ColGroupType.DDC;
    }

    @Override
    public long estimateInMemorySize() {
        long size = super.estimateInMemorySize();
        return size += this._data.getInMemorySize();
    }

    @Override
    public AColGroup scalarOperation(ScalarOperator op) {
        if ((op.fn instanceof Plus || op.fn instanceof Minus) && this._dict instanceof MatrixBlockDictionary && ((MatrixBlockDictionary)this._dict).getMatrixBlock().isInSparseFormat()) {
            double v0 = op.executeScalar(0.0);
            if (v0 == 0.0) {
                return this;
            }
            double[] reference = FORUtil.createReference(this._colIndexes.length, v0);
            return ColGroupDDCFOR.create(this._colIndexes, this._numRows, this._dict, this._data, this.getCachedCounts(), reference);
        }
        return ColGroupDDC.create(this._colIndexes, this._numRows, this._dict.applyScalarOp(op), this._data, this.getCachedCounts());
    }

    @Override
    public AColGroup unaryOperation(UnaryOperator op) {
        return ColGroupDDC.create(this._colIndexes, this._numRows, this._dict.applyUnaryOp(op), this._data, this.getCachedCounts());
    }

    @Override
    public AColGroup binaryRowOpLeft(BinaryOperator op, double[] v, boolean isRowSafe) {
        ADictionary ret = this._dict.binOpLeft(op, v, this._colIndexes);
        return ColGroupDDC.create(this._colIndexes, this._numRows, ret, this._data, this.getCachedCounts());
    }

    @Override
    public AColGroup binaryRowOpRight(BinaryOperator op, double[] v, boolean isRowSafe) {
        if ((op.fn instanceof Plus || op.fn instanceof Minus) && this._dict instanceof MatrixBlockDictionary && ((MatrixBlockDictionary)this._dict).getMatrixBlock().isInSparseFormat()) {
            double[] reference = ColGroupUtils.binaryDefRowRight(op, v, this._colIndexes);
            return ColGroupDDCFOR.create(this._colIndexes, this._numRows, this._dict, this._data, this.getCachedCounts(), reference);
        }
        ADictionary ret = this._dict.binOpRight(op, v, this._colIndexes);
        return ColGroupDDC.create(this._colIndexes, this._numRows, ret, this._data, this.getCachedCounts());
    }

    @Override
    public void write(DataOutput out) throws IOException {
        super.write(out);
        this._data.write(out);
    }

    @Override
    public void readFields(DataInput in) throws IOException {
        super.readFields(in);
        this._data = MapToFactory.readIn(in);
    }

    @Override
    public long getExactSizeOnDisk() {
        long ret = super.getExactSizeOnDisk();
        return ret += this._data.getExactSizeOnDisk();
    }

    @Override
    public double getCost(ComputationCostEstimator e, int nRows) {
        int nVals = this.getNumValues();
        int nCols = this.getNumCols();
        return e.getCost(nRows, nRows, nCols, nVals, this._dict.getSparsity());
    }

    @Override
    protected int numRowsToMultiply() {
        return this._numRows;
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(super.toString());
        sb.append(String.format("\n%15s ", "Data: "));
        sb.append(this._data);
        return sb.toString();
    }
}

