/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.instructions.mr;

import java.util.ArrayList;
import org.apache.sysml.lops.WeightedCrossEntropy;
import org.apache.sysml.lops.WeightedDivMM;
import org.apache.sysml.lops.WeightedSigmoid;
import org.apache.sysml.lops.WeightedSquaredLoss;
import org.apache.sysml.lops.WeightedUnaryMM;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.functionobjects.SwapIndex;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.mr.IDistributedCacheConsumer;
import org.apache.sysml.runtime.instructions.mr.MRInstruction;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.MatrixValue;
import org.apache.sysml.runtime.matrix.mapred.CachedValueMap;
import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue;
import org.apache.sysml.runtime.matrix.mapred.MRBaseForCommonInstructions;
import org.apache.sysml.runtime.matrix.operators.Operator;
import org.apache.sysml.runtime.matrix.operators.QuaternaryOperator;
import org.apache.sysml.runtime.matrix.operators.ReorgOperator;

public class QuaternaryInstruction
extends MRInstruction
implements IDistributedCacheConsumer {
    private byte _input1 = (byte)-1;
    private byte _input2 = (byte)-1;
    private byte _input3 = (byte)-1;
    private byte _input4 = (byte)-1;
    private boolean _cacheU = false;
    private boolean _cacheV = false;

    private QuaternaryInstruction(Operator op, byte in1, byte in2, byte in3, byte in4, byte out, boolean cacheU, boolean cacheV, String istr) {
        super(MRInstruction.MRType.Quaternary, op, out);
        this.instString = istr;
        this._input1 = in1;
        this._input2 = in2;
        this._input3 = in3;
        this._input4 = in4;
        this._cacheU = cacheU;
        this._cacheV = cacheV;
    }

    public byte getInput1() {
        return this._input1;
    }

    public byte getInput2() {
        return this._input2;
    }

    public byte getInput3() {
        return this._input3;
    }

    public byte getInput4() {
        return this._input4;
    }

    public void computeMatrixCharacteristics(MatrixCharacteristics mc1, MatrixCharacteristics mc2, MatrixCharacteristics mc3, MatrixCharacteristics dimOut) {
        QuaternaryOperator qop = (QuaternaryOperator)this.optr;
        if (qop.wtype1 != null || qop.wtype4 != null) {
            dimOut.set(1L, 1L, mc1.getRowsPerBlock(), mc1.getColsPerBlock());
        } else if (qop.wtype2 != null || qop.wtype5 != null) {
            dimOut.set(mc1.getRows(), mc1.getCols(), mc1.getRowsPerBlock(), mc1.getColsPerBlock());
        } else if (qop.wtype3 != null) {
            boolean mapwdivmm;
            boolean bl = mapwdivmm = this._cacheU && this._cacheV;
            long rank = qop.wtype3.isLeft() ? (mapwdivmm ? mc3.getCols() : mc3.getNonZeros()) : (mapwdivmm ? mc2.getCols() : mc2.getNonZeros());
            MatrixCharacteristics mcTmp = qop.wtype3.computeOutputCharacteristics(mc1.getRows(), mc1.getCols(), rank);
            dimOut.set(mcTmp.getRows(), mcTmp.getCols(), mc1.getRowsPerBlock(), mc1.getColsPerBlock());
        }
    }

    public static QuaternaryInstruction parseInstruction(String str) {
        boolean cacheV;
        int addInput4;
        String opcode = InstructionUtils.getOpCode(str);
        if (!InstructionUtils.isDistQuaternaryOpcode(opcode)) {
            throw new DMLRuntimeException("Unexpected opcode in QuaternaryInstruction: " + str);
        }
        if ("mapwsloss".equalsIgnoreCase(opcode) || "redwsloss".equalsIgnoreCase(opcode)) {
            boolean isRed = "redwsloss".equalsIgnoreCase(opcode);
            if (isRed) {
                InstructionUtils.checkNumFields(str, 8);
            } else {
                InstructionUtils.checkNumFields(str, 6);
            }
            String[] parts = InstructionUtils.getInstructionParts(str);
            byte in1 = Byte.parseByte(parts[1]);
            byte in2 = Byte.parseByte(parts[2]);
            byte in3 = Byte.parseByte(parts[3]);
            byte in4 = Byte.parseByte(parts[4]);
            byte out = Byte.parseByte(parts[5]);
            WeightedSquaredLoss.WeightsType wtype = WeightedSquaredLoss.WeightsType.valueOf(parts[6]);
            boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true;
            boolean cacheV2 = isRed ? Boolean.parseBoolean(parts[8]) : true;
            return new QuaternaryInstruction(new QuaternaryOperator(wtype), in1, in2, in3, in4, out, cacheU, cacheV2, str);
        }
        if ("mapwumm".equalsIgnoreCase(opcode) || "redwumm".equalsIgnoreCase(opcode)) {
            boolean isRed = "redwumm".equalsIgnoreCase(opcode);
            if (isRed) {
                InstructionUtils.checkNumFields(str, 8);
            } else {
                InstructionUtils.checkNumFields(str, 6);
            }
            String[] parts = InstructionUtils.getInstructionParts(str);
            String uopcode = parts[1];
            byte in1 = Byte.parseByte(parts[2]);
            byte in2 = Byte.parseByte(parts[3]);
            byte in3 = Byte.parseByte(parts[4]);
            byte out = Byte.parseByte(parts[5]);
            WeightedUnaryMM.WUMMType wtype = WeightedUnaryMM.WUMMType.valueOf(parts[6]);
            boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true;
            boolean cacheV3 = isRed ? Boolean.parseBoolean(parts[8]) : true;
            return new QuaternaryInstruction(new QuaternaryOperator(wtype, uopcode), in1, in2, in3, -1, out, cacheU, cacheV3, str);
        }
        if ("mapwdivmm".equalsIgnoreCase(opcode) || "redwdivmm".equalsIgnoreCase(opcode)) {
            boolean isRed = opcode.startsWith("red");
            if (isRed) {
                InstructionUtils.checkNumFields(str, 8);
            } else {
                InstructionUtils.checkNumFields(str, 6);
            }
            String[] parts = InstructionUtils.getInstructionParts(str);
            WeightedDivMM.WDivMMType wtype = WeightedDivMM.WDivMMType.valueOf(parts[6]);
            byte in1 = Byte.parseByte(parts[1]);
            byte in2 = Byte.parseByte(parts[2]);
            byte in3 = Byte.parseByte(parts[3]);
            byte in4 = wtype.hasScalar() ? (byte)-1 : (byte)Byte.parseByte(parts[4]);
            byte out = Byte.parseByte(parts[5]);
            boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true;
            boolean cacheV4 = isRed ? Boolean.parseBoolean(parts[8]) : true;
            return new QuaternaryInstruction(new QuaternaryOperator(wtype), in1, in2, in3, in4, out, cacheU, cacheV4, str);
        }
        boolean isRed = opcode.startsWith("red");
        int n = addInput4 = opcode.endsWith("wcemm") ? 1 : 0;
        if (isRed) {
            InstructionUtils.checkNumFields(str, 7 + addInput4);
        } else {
            InstructionUtils.checkNumFields(str, 5 + addInput4);
        }
        String[] parts = InstructionUtils.getInstructionParts(str);
        byte in1 = Byte.parseByte(parts[1]);
        byte in2 = Byte.parseByte(parts[2]);
        byte in3 = Byte.parseByte(parts[3]);
        byte out = Byte.parseByte(parts[4 + addInput4]);
        boolean cacheU = isRed ? Boolean.parseBoolean(parts[6 + addInput4]) : true;
        boolean bl = cacheV = isRed ? Boolean.parseBoolean(parts[7 + addInput4]) : true;
        if (opcode.endsWith("wsigmoid")) {
            return new QuaternaryInstruction(new QuaternaryOperator(WeightedSigmoid.WSigmoidType.valueOf(parts[5])), in1, in2, in3, -1, out, cacheU, cacheV, str);
        }
        if (opcode.endsWith("wcemm")) {
            return new QuaternaryInstruction(new QuaternaryOperator(WeightedCrossEntropy.WCeMMType.valueOf(parts[6])), in1, in2, in3, -1, out, cacheU, cacheV, str);
        }
        return null;
    }

    @Override
    public boolean isDistCacheOnlyIndex(String inst, byte index) {
        if (this._cacheU && this._cacheV) {
            return index == this._input2 && index != this._input1 && index != this._input4 || index == this._input3 && index != this._input1 && index != this._input4;
        }
        return this._cacheU && index == this._input2 && index != this._input1 && index != this._input4 || this._cacheV && index == this._input3 && index != this._input1 && index != this._input4;
    }

    @Override
    public void addDistCacheIndex(String inst, ArrayList<Byte> indexes) {
        if (this._cacheU) {
            indexes.add(this._input2);
        }
        if (this._cacheV) {
            indexes.add(this._input3);
        }
    }

    @Override
    public byte[] getInputIndexes() {
        QuaternaryOperator qop = (QuaternaryOperator)this.optr;
        if (qop.hasFourInputs()) {
            return new byte[]{this._input1, this._input2, this._input3, this._input4};
        }
        return new byte[]{this._input1, this._input2, this._input3};
    }

    @Override
    public byte[] getAllIndexes() {
        QuaternaryOperator qop = (QuaternaryOperator)this.optr;
        if (qop.hasFourInputs()) {
            return new byte[]{this._input1, this._input2, this._input3, this._input4, this.output};
        }
        return new byte[]{this._input1, this._input2, this._input3, this.output};
    }

    @Override
    public void processInstruction(Class<? extends MatrixValue> valueClass, CachedValueMap cachedValues, IndexedMatrixValue tempValue, IndexedMatrixValue zeroInput, int blockRowFactor, int blockColFactor) {
        QuaternaryOperator qop = (QuaternaryOperator)this.optr;
        ArrayList<IndexedMatrixValue> blkList = cachedValues.get(this._input1);
        if (blkList != null) {
            for (IndexedMatrixValue imv : blkList) {
                MatrixValue Vj;
                MatrixValue Wij;
                if (imv == null) continue;
                MatrixIndexes inIx = imv.getIndexes();
                MatrixBlock inVal = (MatrixBlock)imv.getValue();
                IndexedMatrixValue iout = null;
                iout = this.output == this._input1 ? tempValue : cachedValues.holdPlace(this.output, valueClass);
                MatrixIndexes outIx = iout.getIndexes();
                MatrixValue outVal = iout.getValue();
                MatrixBlock Xij = inVal;
                IndexedMatrixValue iWij = this._input4 != -1 ? cachedValues.getFirst(this._input4) : null;
                MatrixValue matrixValue = Wij = iWij != null ? iWij.getValue() : null;
                if (null == Wij && qop.hasFourInputs()) {
                    String[] parts = InstructionUtils.getInstructionParts(this.instString);
                    MatrixBlock mb = new MatrixBlock(Double.valueOf(parts[4]));
                    Wij = mb;
                }
                MatrixValue Ui = !this._cacheU ? cachedValues.getFirst(this._input2).getValue() : MRBaseForCommonInstructions.dcValues.get(this._input2).getDataBlock((int)inIx.getRowIndex(), 1).getValue();
                MatrixValue matrixValue2 = Vj = !this._cacheV ? cachedValues.getFirst(this._input3).getValue() : MRBaseForCommonInstructions.dcValues.get(this._input3).getDataBlock((int)inIx.getColumnIndex(), 1).getValue();
                if (Ui.getNumColumns() != Vj.getNumColumns()) {
                    Vj = LibMatrixReorg.reorg((MatrixBlock)Vj, new MatrixBlock(Vj.getNumColumns(), Vj.getNumRows(), Vj.isInSparseFormat()), new ReorgOperator(SwapIndex.getSwapIndexFnObject()));
                }
                Xij.quaternaryOperations(qop, (MatrixBlock)Ui, (MatrixBlock)Vj, (MatrixBlock)Wij, (MatrixBlock)outVal);
                if (qop.wtype1 != null || qop.wtype4 != null) {
                    outIx.setIndexes(1L, 1L);
                } else if (qop.wtype2 != null || qop.wtype5 != null || qop.wtype3 != null && qop.wtype3.isBasic()) {
                    outIx.setIndexes(inIx);
                } else {
                    boolean left = qop.wtype3.isLeft();
                    outIx.setIndexes(left ? inIx.getColumnIndex() : inIx.getRowIndex(), 1L);
                }
                if (iout != tempValue) continue;
                cachedValues.add(this.output, iout);
            }
        }
    }
}

