/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.colgroup;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.lang.ref.SoftReference;
import java.util.Arrays;
import java.util.HashSet;
import org.apache.commons.lang.NotImplementedException;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed;
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
import org.apache.sysds.runtime.compress.utils.Util;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.CMOperator;

public abstract class AColGroupValue
extends AColGroupCompressed
implements Cloneable {
    private static final long serialVersionUID = -6835757655517301955L;
    protected final int _numRows;
    protected boolean _zeros = false;
    protected ADictionary _dict;
    private SoftReference<int[]> counts = null;

    protected AColGroupValue(int numRows) {
        this._numRows = numRows;
    }

    protected AColGroupValue(int[] colIndices, int numRows, ADictionary dict, int[] cachedCounts) {
        super(colIndices);
        this._numRows = numRows;
        this._dict = dict;
        if (dict == null) {
            throw new NullPointerException("null dict is invalid");
        }
        if (cachedCounts != null) {
            this.counts = new SoftReference<int[]>(cachedCounts);
        }
    }

    @Override
    public final void decompressToDenseBlock(DenseBlock db, int rl, int ru, int offR, int offC) {
        if (this._dict instanceof MatrixBlockDictionary) {
            MatrixBlockDictionary md = (MatrixBlockDictionary)this._dict;
            MatrixBlock mb = md.getMatrixBlock();
            if (mb.isInSparseFormat()) {
                this.decompressToDenseBlockSparseDictionary(db, rl, ru, offR, offC, mb.getSparseBlock());
            } else {
                this.decompressToDenseBlockDenseDictionary(db, rl, ru, offR, offC, mb.getDenseBlockValues());
            }
        } else {
            this.decompressToDenseBlockDenseDictionary(db, rl, ru, offR, offC, this._dict.getValues());
        }
    }

    @Override
    public final void decompressToSparseBlock(SparseBlock sb, int rl, int ru, int offR, int offC) {
        if (this._dict instanceof MatrixBlockDictionary) {
            MatrixBlockDictionary md = (MatrixBlockDictionary)this._dict;
            MatrixBlock mb = md.getMatrixBlock();
            if (mb.isEmpty()) {
                return;
            }
            if (mb.isInSparseFormat()) {
                this.decompressToSparseBlockSparseDictionary(sb, rl, ru, offR, offC, mb.getSparseBlock());
            } else {
                this.decompressToSparseBlockDenseDictionary(sb, rl, ru, offR, offC, mb.getDenseBlockValues());
            }
        } else {
            this.decompressToSparseBlockDenseDictionary(sb, rl, ru, offR, offC, this._dict.getValues());
        }
    }

    protected abstract void decompressToDenseBlockSparseDictionary(DenseBlock var1, int var2, int var3, int var4, int var5, SparseBlock var6);

    protected abstract void decompressToDenseBlockDenseDictionary(DenseBlock var1, int var2, int var3, int var4, int var5, double[] var6);

    protected abstract void decompressToSparseBlockSparseDictionary(SparseBlock var1, int var2, int var3, int var4, int var5, SparseBlock var6);

    protected abstract void decompressToSparseBlockDenseDictionary(SparseBlock var1, int var2, int var3, int var4, int var5, double[] var6);

    @Override
    public int getNumValues() {
        return this._dict.getNumberOfValues(this._colIndexes.length);
    }

    public ADictionary getDictionary() {
        return this._dict;
    }

    public final int[] getCounts() {
        int[] ret = this.getCachedCounts();
        if (ret == null) {
            ret = this.getCounts(new int[this.getNumValues()]);
            this.counts = new SoftReference<int[]>(ret);
        }
        return ret;
    }

    public final int[] getCachedCounts() {
        return this.counts != null ? this.counts.get() : null;
    }

    private int[] rightMMGetColsDense(double[] b, int cl, int cu, int cut) {
        HashSet<Integer> aggregateColumnsSet = new HashSet<Integer>();
        int retCols = cu - cl;
        for (int k = 0; k < this._colIndexes.length; ++k) {
            int rowIdxOffset = this._colIndexes[k] * cut;
            for (int h = cl; h < cu; ++h) {
                double v = b[rowIdxOffset + h];
                if (v == 0.0) continue;
                aggregateColumnsSet.add(h);
            }
            if (aggregateColumnsSet.size() == retCols) break;
        }
        int[] aggregateColumns = aggregateColumnsSet.stream().mapToInt(x -> x).toArray();
        Arrays.sort(aggregateColumns);
        return aggregateColumns;
    }

    private int[] rightMMGetColsSparse(SparseBlock b, int retCols) {
        HashSet<Integer> aggregateColumnsSet = new HashSet<Integer>();
        for (int h = 0; h < this._colIndexes.length; ++h) {
            int colIdx = this._colIndexes[h];
            if (!b.isEmpty(colIdx)) {
                int[] sIndexes = b.indexes(colIdx);
                for (int i = b.pos(colIdx); i < b.size(colIdx) + b.pos(colIdx); ++i) {
                    aggregateColumnsSet.add(sIndexes[i]);
                }
            }
            if (aggregateColumnsSet.size() == retCols) break;
        }
        int[] aggregateColumns = aggregateColumnsSet.stream().mapToInt(x -> x).toArray();
        Arrays.sort(aggregateColumns);
        return aggregateColumns;
    }

    private double[] rightMMPreAggSparse(int numVals, SparseBlock b, int[] aggregateColumns, int cl, int cu, int cut) {
        double[] ret = new double[numVals * aggregateColumns.length];
        for (int h = 0; h < this._colIndexes.length; ++h) {
            int colIdx = this._colIndexes[h];
            if (b.isEmpty(colIdx)) continue;
            double[] sValues = b.values(colIdx);
            int[] sIndexes = b.indexes(colIdx);
            int retIdx = 0;
            for (int i = b.pos(colIdx); i < b.size(colIdx) + b.pos(colIdx); ++i) {
                while (aggregateColumns[retIdx] < sIndexes[i]) {
                    ++retIdx;
                }
                if (sIndexes[i] != aggregateColumns[retIdx]) continue;
                int j = 0;
                int offOrg = h;
                while (j < numVals * aggregateColumns.length) {
                    int n = j + retIdx;
                    ret[n] = ret[n] + this._dict.getValue(offOrg) * sValues[i];
                    j += aggregateColumns.length;
                    offOrg += this._colIndexes.length;
                }
            }
        }
        return ret;
    }

    @Override
    protected double computeMxx(double c, Builtin builtin) {
        if (this._zeros) {
            c = builtin.execute(c, 0.0);
        }
        return this._dict.aggregate(c, builtin);
    }

    @Override
    protected void computeColMxx(double[] c, Builtin builtin) {
        if (this._zeros) {
            for (int x = 0; x < this._colIndexes.length; ++x) {
                c[this._colIndexes[x]] = builtin.execute(c[this._colIndexes[x]], 0.0);
            }
        }
        this._dict.aggregateCols(c, builtin, this._colIndexes);
    }

    @Override
    public void readFields(DataInput in) throws IOException {
        super.readFields(in);
        this._zeros = in.readBoolean();
        this._dict = DictionaryFactory.read(in);
    }

    @Override
    public void write(DataOutput out) throws IOException {
        super.write(out);
        out.writeBoolean(this._zeros);
        this._dict.write(out);
    }

    @Override
    public long getExactSizeOnDisk() {
        long ret = super.getExactSizeOnDisk();
        ++ret;
        return ret += this._dict.getExactSizeOnDisk();
    }

    public abstract int[] getCounts(int[] var1);

    @Override
    protected void computeSum(double[] c, int nRows) {
        c[0] = c[0] + this._dict.sum(this.getCounts(), this._colIndexes.length);
    }

    @Override
    public void computeColSums(double[] c, int nRows) {
        this._dict.colSum(c, this.getCounts(), this._colIndexes);
    }

    @Override
    protected void computeSumSq(double[] c, int nRows) {
        c[0] = c[0] + this._dict.sumSq(this.getCounts(), this._colIndexes.length);
    }

    @Override
    protected void computeColSumsSq(double[] c, int nRows) {
        this._dict.colSumSq(c, this.getCounts(), this._colIndexes);
    }

    @Override
    protected void computeProduct(double[] c, int nRows) {
        this._dict.product(c, this.getCounts(), this._colIndexes.length);
    }

    @Override
    protected void computeColProduct(double[] c, int nRows) {
        this._dict.colProduct(c, this.getCounts(), this._colIndexes);
    }

    @Override
    protected void computeRowProduct(double[] c, int rl, int ru, double[] preAgg) {
        throw new NotImplementedException();
    }

    @Override
    protected double[] preAggSumRows() {
        return this._dict.sumAllRowsToDouble(this._colIndexes.length);
    }

    @Override
    protected double[] preAggSumSqRows() {
        return this._dict.sumAllRowsToDoubleSq(this._colIndexes.length);
    }

    @Override
    protected double[] preAggProductRows() {
        throw new NotImplementedException();
    }

    @Override
    protected double[] preAggBuiltinRows(Builtin builtin) {
        return this._dict.aggregateRows(builtin, this._colIndexes.length);
    }

    protected Object clone() {
        try {
            return super.clone();
        }
        catch (CloneNotSupportedException e) {
            throw new DMLCompressionException("Error while cloning: " + this.getClass().getSimpleName(), e);
        }
    }

    protected AColGroup copyAndSet(ADictionary newDictionary) {
        AColGroupValue clone = (AColGroupValue)this.clone();
        clone._dict = newDictionary;
        return clone;
    }

    private AColGroup copyAndSet(int[] colIndexes, double[] newDictionary) {
        return this.copyAndSet(colIndexes, new Dictionary(newDictionary));
    }

    private AColGroup copyAndSet(int[] colIndexes, ADictionary newDictionary) {
        AColGroupValue clone = (AColGroupValue)this.clone();
        clone._dict = newDictionary;
        clone.setColIndices(colIndexes);
        return clone;
    }

    @Override
    public AColGroupValue copy() {
        return (AColGroupValue)this.clone();
    }

    @Override
    protected AColGroup sliceSingleColumn(int idx) {
        int[] retIndexes = new int[]{0};
        if (this._colIndexes.length == 1) {
            AColGroupValue ret = (AColGroupValue)this.clone();
            ret._colIndexes = retIndexes;
            ret._dict = ret._dict.clone();
            ret._dict.getNumberOfValues(1);
            return ret;
        }
        ADictionary retDict = this._dict.sliceOutColumnRange(idx, idx + 1, this._colIndexes.length);
        if (retDict == null) {
            return new ColGroupEmpty(retIndexes);
        }
        AColGroupValue ret = (AColGroupValue)this.clone();
        ret._colIndexes = retIndexes;
        ret._dict = retDict;
        ret._dict.getNumberOfValues(1);
        return ret;
    }

    @Override
    protected AColGroup sliceMultiColumns(int idStart, int idEnd, int[] outputCols) {
        ADictionary retDict = this._dict.sliceOutColumnRange(idStart, idEnd, this._colIndexes.length);
        if (retDict == null) {
            return new ColGroupEmpty(this._colIndexes);
        }
        AColGroupValue ret = (AColGroupValue)this.clone();
        ret._dict = retDict;
        ret._colIndexes = outputCols;
        ret._dict.getNumberOfValues(outputCols.length);
        return ret;
    }

    @Override
    protected void tsmm(double[] result, int numColumns, int nRows) {
        int[] counts = this.getCounts();
        AColGroupValue.tsmm(result, numColumns, counts, this._dict, this._colIndexes);
    }

    @Override
    public boolean containsValue(double pattern) {
        if (pattern == 0.0 && this._zeros) {
            return true;
        }
        return this._dict.containsValue(pattern);
    }

    @Override
    public long getNumberNonZeros(int nRows) {
        int[] counts = this.getCounts();
        return this._dict.getNumberNonZeros(counts, this._colIndexes.length);
    }

    public synchronized void forceMatrixBlockDictionary() {
        if (!(this._dict instanceof MatrixBlockDictionary)) {
            this._dict = this._dict.getMBDict(this._colIndexes.length);
        }
    }

    @Override
    public final AColGroup rightMultByMatrix(MatrixBlock right) {
        if (right.isEmpty()) {
            return null;
        }
        boolean cl = false;
        int cr = right.getNumColumns();
        int numVals = this.getNumValues();
        if (right.isInSparseFormat()) {
            SparseBlock sb = right.getSparseBlock();
            int[] agCols = this.rightMMGetColsSparse(sb, cr);
            if (agCols.length == 0) {
                return null;
            }
            return this.copyAndSet(agCols, this.rightMMPreAggSparse(numVals, sb, agCols, 0, cr, cr));
        }
        double[] rightV = right.getDenseBlockValues();
        int[] agCols = this.rightMMGetColsDense(rightV, 0, cr, cr);
        if (agCols.length == 0) {
            return null;
        }
        ADictionary d = this._dict.preaggValuesFromDense(numVals, this._colIndexes, agCols, rightV, cr);
        if (d == null) {
            return null;
        }
        return this.copyAndSet(agCols, d);
    }

    @Override
    public long estimateInMemorySize() {
        long size = super.estimateInMemorySize();
        size += 8L;
        size += 4L;
        ++size;
        ++size;
        size += 2L;
        size += this._dict.getInMemorySize();
        return size += 8L;
    }

    @Override
    public AColGroup replace(double pattern, double replace) {
        ADictionary replaced = this._dict.replace(pattern, replace, this._colIndexes.length);
        return this.copyAndSet(replaced);
    }

    @Override
    public CM_COV_Object centralMoment(CMOperator op, int nRows) {
        return this._dict.centralMoment(op.fn, this.getCounts(), nRows);
    }

    @Override
    public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) {
        ADictionary d = this._dict.rexpandCols(max, ignore, cast, this._colIndexes.length);
        if (d == null) {
            return ColGroupEmpty.create(max);
        }
        return this.copyAndSet(Util.genColsIndices(max), d);
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(super.toString());
        sb.append(String.format("\n%15s%s", "Values: ", this._dict.getClass().getSimpleName()));
        sb.append(this._dict.getString(this._colIndexes.length));
        return sb.toString();
    }
}

