/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops;

import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.HopsException;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.MemoTable;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.RightIndex;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;

public class IndexingOp
extends Hop {
    public static String OPSTRING = "rix";
    private boolean _rowLowerEqualsUpper = false;
    private boolean _colLowerEqualsUpper = false;

    private IndexingOp() {
    }

    public IndexingOp(String l, Types.DataType dt, Types.ValueType vt, Hop inpMatrix, Hop inpRowL, Hop inpRowU, Hop inpColL, Hop inpColU, boolean passedRowsLEU, boolean passedColsLEU) {
        super(l, dt, vt);
        this.getInput().add(0, inpMatrix);
        this.getInput().add(1, inpRowL);
        this.getInput().add(2, inpRowU);
        this.getInput().add(3, inpColL);
        this.getInput().add(4, inpColU);
        inpMatrix.getParent().add(this);
        inpRowL.getParent().add(this);
        inpRowU.getParent().add(this);
        inpColL.getParent().add(this);
        inpColU.getParent().add(this);
        this.setRowLowerEqualsUpper(passedRowsLEU);
        this.setColLowerEqualsUpper(passedColsLEU);
    }

    @Override
    public void checkArity() {
        HopsException.check(this._input.size() == 5, this, "should have 5 inputs but has %d inputs", this._input.size());
    }

    public boolean isRowLowerEqualsUpper() {
        return this._rowLowerEqualsUpper;
    }

    public boolean isColLowerEqualsUpper() {
        return this._colLowerEqualsUpper;
    }

    public void setRowLowerEqualsUpper(boolean passed) {
        this._rowLowerEqualsUpper = passed;
    }

    public void setColLowerEqualsUpper(boolean passed) {
        this._colLowerEqualsUpper = passed;
    }

    @Override
    public boolean isGPUEnabled() {
        if (!DMLScript.USE_ACCELERATOR) {
            return false;
        }
        return this.getDataType() == Types.DataType.MATRIX && this.getInputMemEstimate() < 2.0E9;
    }

    @Override
    public Lop constructLops() {
        if (this.getLops() != null) {
            return this.getLops();
        }
        Hop input = this.getInput().get(0);
        if (HopRewriteUtils.isUnnecessaryRightIndexing(this)) {
            this.setLops(input.constructLops());
        } else {
            try {
                Types.ExecType et = this.optFindExecType();
                if (et == Types.ExecType.SPARK) {
                    IndexingMethod method = IndexingOp.optFindIndexingMethod(this._rowLowerEqualsUpper, this._colLowerEqualsUpper, input.getDim1(), input.getDim2(), this.getDim1(), this.getDim2());
                    AggBinaryOp.SparkAggType aggtype = method == IndexingMethod.MR_VRIX || this.isBlockAligned() ? AggBinaryOp.SparkAggType.NONE : AggBinaryOp.SparkAggType.MULTI_BLOCK;
                    RightIndex reindex = new RightIndex(input.constructLops(), this.getInput(1).constructLops(), this.getInput(2).constructLops(), this.getInput(3).constructLops(), this.getInput(4).constructLops(), this.getDataType(), this.getValueType(), aggtype, et);
                    this.setOutputDimensions(reindex);
                    this.setLineNumbers(reindex);
                    this.setLops(reindex);
                } else {
                    RightIndex reindex = new RightIndex(input.constructLops(), this.getInput(1).constructLops(), this.getInput(2).constructLops(), this.getInput(3).constructLops(), this.getInput(4).constructLops(), this.getDataType(), this.getValueType(), et);
                    this.setOutputDimensions(reindex);
                    this.setLineNumbers(reindex);
                    this.setLops(reindex);
                }
            }
            catch (Exception e) {
                throw new HopsException(this.printErrorLocation() + "In IndexingOp Hop, error constructing Lops ", e);
            }
        }
        this.constructAndSetLopsDataFlowProperties();
        return this.getLops();
    }

    @Override
    public String getOpString() {
        Object s = new String("");
        s = (String)s + OPSTRING;
        return s;
    }

    @Override
    public boolean allowsAllExecTypes() {
        return true;
    }

    @Override
    public void computeMemEstimate(MemoTable memo) {
        super.computeMemEstimate(memo);
        DataCharacteristics dcM1 = memo.getAllInputStats(this.getInput().get(0));
        if (this.dimsKnown() && dcM1.getNonZeros() >= 0L) {
            long lnnz = dcM1.getNonZeros();
            double lOutMemEst = this.computeOutputMemEstimate(this.getDim1(), this.getDim2(), lnnz);
            if (lOutMemEst < this._outputMemEstimate) {
                this._outputMemEstimate = lOutMemEst;
                this._memEstimate = this.getInputOutputSize();
            }
        }
    }

    @Override
    protected double computeOutputMemEstimate(long dim1, long dim2, long nnz) {
        double sparsity = this.isGPUEnabled() ? 1.0 : OptimizerUtils.getSparsity(dim1, dim2, nnz);
        return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity);
    }

    @Override
    protected double computeIntermediateMemEstimate(long dim1, long dim2, long nnz) {
        return 0.0;
    }

    @Override
    protected DataCharacteristics inferOutputCharacteristics(MemoTable memo) {
        MatrixCharacteristics ret = null;
        Hop input = this.getInput().get(0);
        DataCharacteristics dc = memo.getAllInputStats(input);
        if (dc != null) {
            long lnnz = dc.dimsKnown() ? Math.min(dc.getRows() * dc.getCols(), dc.getNonZeros()) : -1L;
            ret = new MatrixCharacteristics(dc.getRows(), dc.getCols(), -1, lnnz);
            if (this._rowLowerEqualsUpper) {
                ((DataCharacteristics)ret).setRows(1L);
            }
            if (this._colLowerEqualsUpper) {
                ((DataCharacteristics)ret).setCols(1L);
            }
            Hop rl = this.getInput().get(1);
            Hop ru = this.getInput().get(2);
            Hop cl = this.getInput().get(3);
            Hop cu = this.getInput().get(4);
            if (IndexingOp.isBlockIndexingExpression(rl, ru)) {
                ((DataCharacteristics)ret).setRows(IndexingOp.getBlockIndexingExpressionSize(rl, ru));
            }
            if (IndexingOp.isBlockIndexingExpression(cl, cu)) {
                ((DataCharacteristics)ret).setCols(IndexingOp.getBlockIndexingExpressionSize(cl, cu));
            }
        }
        return ret;
    }

    private static boolean isBlockIndexingExpression(Hop lbound, Hop ubound) {
        BinaryOp lminus;
        BinaryOp lmult;
        boolean ret = false;
        LiteralOp constant = null;
        Hop var = null;
        if (lbound instanceof BinaryOp && ((BinaryOp)lbound).getOp() == Types.OpOp2.PLUS && lbound.getInput(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)lbound.getInput(1)) == 1.0 && lbound.getInput(0) instanceof BinaryOp && (lmult = (BinaryOp)lbound.getInput(0)).getOp() == Types.OpOp2.MULT && lmult.getInput(0) instanceof LiteralOp && lmult.getInput(1) instanceof BinaryOp && (lminus = (BinaryOp)lmult.getInput(1)).getOp() == Types.OpOp2.MINUS && lminus.getInput(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)lminus.getInput(1)) == 1.0) {
            constant = (LiteralOp)lmult.getInput(0);
            var = lminus.getInput(0);
        }
        if (var != null && constant != null && ubound instanceof BinaryOp && (ubound.getInput(0) instanceof LiteralOp && ubound.getInput(1) == var || ubound.getInput(1) instanceof LiteralOp && ubound.getInput(0) == var)) {
            int constIndex = ubound.getInput(1) == var ? 0 : 1;
            LiteralOp constant2 = (LiteralOp)ubound.getInput(constIndex);
            ret = HopRewriteUtils.getDoubleValueSafe(constant) == HopRewriteUtils.getDoubleValueSafe(constant2);
        }
        return ret;
    }

    private boolean isBlockAligned() {
        Hop input1 = this.getInput().get(0);
        Hop input2 = this.getInput().get(1);
        Hop input3 = this.getInput().get(2);
        Hop input4 = this.getInput().get(3);
        Hop input5 = this.getInput().get(4);
        long rl = input2 instanceof LiteralOp ? HopRewriteUtils.getIntValueSafe((LiteralOp)input2) : -1L;
        long ru = input3 instanceof LiteralOp ? HopRewriteUtils.getIntValueSafe((LiteralOp)input3) : -1L;
        long cl = input4 instanceof LiteralOp ? HopRewriteUtils.getIntValueSafe((LiteralOp)input4) : -1L;
        long cu = input5 instanceof LiteralOp ? HopRewriteUtils.getIntValueSafe((LiteralOp)input5) : -1L;
        int blen = input1.getBlocksize();
        return OptimizerUtils.isIndexingRangeBlockAligned(rl, ru, cl, cu, blen);
    }

    private static long getBlockIndexingExpressionSize(Hop lbound, Hop ubound) {
        LiteralOp c = (LiteralOp)ubound.getInput().get(0);
        return HopRewriteUtils.getIntValueSafe(c);
    }

    @Override
    protected Types.ExecType optFindExecType(boolean transitive) {
        this.checkAndSetForcedPlatform();
        if (this._etypeForced != null) {
            this._etype = this._etypeForced;
        } else {
            this._etype = OptimizerUtils.isMemoryBasedOptLevel() ? this.findExecTypeByMemEstimate() : (this.getInput().get(0).areDimsBelowThreshold() ? Types.ExecType.CP : Types.ExecType.SPARK);
            this.checkAndSetInvalidCPDimsAndSize();
        }
        if (this.getInput().get(0).getDataType() == Types.DataType.LIST) {
            this._etype = Types.ExecType.CP;
        }
        this.setRequiresRecompileIfNecessary();
        return this._etype;
    }

    private static IndexingMethod optFindIndexingMethod(boolean singleRow, boolean singleCol, long m1_dim1, long m1_dim2, long m2_dim1, long m2_dim2) {
        if (singleRow && m1_dim2 == m2_dim2 && m2_dim2 != -1L || singleCol && m1_dim1 == m2_dim1 && m2_dim1 != -1L) {
            return IndexingMethod.MR_VRIX;
        }
        return IndexingMethod.MR_RIX;
    }

    @Override
    public void refreshSizeInformation() {
        boolean constColRange;
        Hop input1 = this.getInput().get(0);
        Hop input2 = this.getInput().get(1);
        Hop input3 = this.getInput().get(2);
        Hop input4 = this.getInput().get(3);
        Hop input5 = this.getInput().get(4);
        this._rowLowerEqualsUpper = input2 == input3;
        this._colLowerEqualsUpper = input4 == input5;
        boolean allRows = this.isAllRows();
        boolean allCols = this.isAllCols();
        boolean constRowRange = input2 instanceof LiteralOp && input3 instanceof LiteralOp;
        boolean bl = constColRange = input4 instanceof LiteralOp && input5 instanceof LiteralOp;
        if (this._rowLowerEqualsUpper) {
            this.setDim1(1L);
        } else if (allRows) {
            this.setDim1(input1.getDim1());
        } else if (constRowRange) {
            this.setDim1(HopRewriteUtils.getIntValueSafe((LiteralOp)input3) - HopRewriteUtils.getIntValueSafe((LiteralOp)input2) + 1L);
        } else if (IndexingOp.isBlockIndexingExpression(input2, input3)) {
            this.setDim1(IndexingOp.getBlockIndexingExpressionSize(input2, input3));
        } else {
            this.setDim1(-1L);
        }
        if (this._colLowerEqualsUpper) {
            this.setDim2(1L);
        } else if (allCols) {
            this.setDim2(input1.getDim2());
        } else if (constColRange) {
            this.setDim2(HopRewriteUtils.getIntValueSafe((LiteralOp)input5) - HopRewriteUtils.getIntValueSafe((LiteralOp)input4) + 1L);
        } else if (IndexingOp.isBlockIndexingExpression(input4, input5)) {
            this.setDim2(IndexingOp.getBlockIndexingExpressionSize(input4, input5));
        } else {
            this.setDim2(-1L);
        }
    }

    public boolean isAllRows() {
        Hop input1 = this.getInput().get(0);
        Hop input2 = this.getInput().get(1);
        Hop input3 = this.getInput().get(2);
        return HopRewriteUtils.isLiteralOfValue(input2, 1.0) && (HopRewriteUtils.isUnary(input3, Types.OpOp1.NROW) && input3.getInput().get(0) == input1 || HopRewriteUtils.isLiteralOfValue(input3, input1.getDim1()));
    }

    public boolean isAllCols() {
        Hop input1 = this.getInput().get(0);
        Hop input4 = this.getInput().get(3);
        Hop input5 = this.getInput().get(4);
        return HopRewriteUtils.isLiteralOfValue(input4, 1.0) && (HopRewriteUtils.isUnary(input5, Types.OpOp1.NCOL) && input5.getInput().get(0) == input1 || HopRewriteUtils.isLiteralOfValue(input5, input1.getDim2()));
    }

    @Override
    public Object clone() throws CloneNotSupportedException {
        IndexingOp ret = new IndexingOp();
        ret.clone(this, false);
        return ret;
    }

    @Override
    public boolean compare(Hop that) {
        if (!(that instanceof IndexingOp) || this.getInput().size() != that.getInput().size()) {
            return false;
        }
        return this.getInput().get(0) == that.getInput().get(0) && this.getInput().get(1) == that.getInput().get(1) && this.getInput().get(2) == that.getInput().get(2) && this.getInput().get(3) == that.getInput().get(3) && this.getInput().get(4) == that.getInput().get(4);
    }

    private static enum IndexingMethod {
        CP_RIX,
        MR_RIX,
        MR_VRIX;

    }
}

