/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.spark;

import java.util.List;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.spark.AppendGSPInstruction;
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.instructions.spark.functions.MapInputSignature;
import org.apache.sysds.runtime.instructions.spark.functions.MapJoinSignature;
import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.lineage.LineageTraceable;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.SimpleOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.UtilFunctions;
import scala.Tuple2;

public class BuiltinNarySPInstruction
extends SPInstruction
implements LineageTraceable {
    private CPOperand[] inputs;
    private CPOperand output;

    protected BuiltinNarySPInstruction(CPOperand[] in, CPOperand out, String opcode, String istr) {
        super(SPInstruction.SPType.BuiltinNary, opcode, istr);
        this.inputs = in;
        this.output = out;
    }

    public static BuiltinNarySPInstruction parseInstruction(String str) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        CPOperand output = new CPOperand(parts[parts.length - 1]);
        CPOperand[] inputs = null;
        inputs = new CPOperand[parts.length - 2];
        for (int i = 1; i < parts.length - 1; ++i) {
            inputs[i - 1] = new CPOperand(parts[i]);
        }
        return new BuiltinNarySPInstruction(inputs, output, opcode, str);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        JavaPairRDD<MatrixIndexes, MatrixBlock> out = null;
        DataCharacteristics mcOut = null;
        if (this.getOpcode().equals("cbind") || this.getOpcode().equals("rbind")) {
            boolean cbind = this.getOpcode().equals("cbind");
            mcOut = BuiltinNarySPInstruction.computeAppendOutputDataCharacteristics(sec, this.inputs, cbind);
            MatrixCharacteristics off = new MatrixCharacteristics(0L, 0L, mcOut.getBlocksize(), 0L);
            for (CPOperand input : this.inputs) {
                DataCharacteristics mcIn = sec.getDataCharacteristics(input.getName());
                JavaPairRDD<MatrixIndexes, MatrixBlock> in = sec.getBinaryMatrixBlockRDDHandleForVariable(input.getName()).flatMapToPair(new AppendGSPInstruction.ShiftMatrix(off, mcIn, cbind)).mapToPair(new PadBlocksFunction(mcOut));
                out = out != null ? out.union(in) : in;
                BuiltinNarySPInstruction.updateAppendDataCharacteristics(mcIn, off, cbind);
            }
            int numPartOut = SparkUtils.getNumPreferredPartitions(mcOut);
            out = RDDAggregateUtils.mergeByKey(out, numPartOut, false);
        } else if (ArrayUtils.contains((Object[])new String[]{"nmin", "nmax", "n+"}, (Object)this.getOpcode())) {
            mcOut = BuiltinNarySPInstruction.computeMinMaxOutputDataCharacteristics(sec, this.inputs);
            List<ScalarObject> scalars = sec.getScalarInputs(this.inputs);
            JavaPairRDD<MatrixIndexes, MatrixBlock[]> in = null;
            for (CPOperand input : this.inputs) {
                if (!input.getDataType().isMatrix()) continue;
                JavaPairRDD<MatrixIndexes, MatrixBlock> tmp = sec.getBinaryMatrixBlockRDDHandleForVariable(input.getName());
                in = in == null ? tmp.mapValues(new MapInputSignature()) : in.join(tmp).mapValues(new MapJoinSignature());
            }
            out = in.mapValues(new MinMaxAddFunction(this.getOpcode(), scalars));
        }
        sec.getDataCharacteristics(this.output.getName()).set(mcOut);
        sec.setRDDHandleForVariable(this.output.getName(), out);
        for (CPOperand input : this.inputs) {
            if (input.isScalar()) continue;
            sec.addLineageRDD(this.output.getName(), input.getName());
        }
    }

    private static DataCharacteristics computeAppendOutputDataCharacteristics(SparkExecutionContext sec, CPOperand[] inputs, boolean cbind) {
        DataCharacteristics mcIn1 = sec.getDataCharacteristics(inputs[0].getName());
        MatrixCharacteristics mcOut = new MatrixCharacteristics(0L, 0L, mcIn1.getBlocksize(), 0L);
        for (CPOperand input : inputs) {
            DataCharacteristics mcIn = sec.getDataCharacteristics(input.getName());
            BuiltinNarySPInstruction.updateAppendDataCharacteristics(mcIn, mcOut, cbind);
        }
        return mcOut;
    }

    private static void updateAppendDataCharacteristics(DataCharacteristics in, DataCharacteristics out, boolean cbind) {
        out.setDimension(cbind ? Math.max(out.getRows(), in.getRows()) : out.getRows() + in.getRows(), cbind ? out.getCols() + in.getCols() : Math.max(out.getCols(), in.getCols()));
        out.setNonZeros(out.getNonZeros() != -1L && in.dimsKnown(true) ? out.getNonZeros() + in.getNonZeros() : -1L);
    }

    private static DataCharacteristics computeMinMaxOutputDataCharacteristics(SparkExecutionContext sec, CPOperand[] inputs) {
        MatrixCharacteristics mcOut = new MatrixCharacteristics();
        for (CPOperand input : inputs) {
            if (!input.getDataType().isMatrix()) continue;
            DataCharacteristics mcIn = sec.getDataCharacteristics(input.getName());
            ((DataCharacteristics)mcOut).setRows(Math.max(((DataCharacteristics)mcOut).getRows(), mcIn.getRows()));
            ((DataCharacteristics)mcOut).setCols(Math.max(((DataCharacteristics)mcOut).getCols(), mcIn.getCols()));
            mcOut.setBlocksize(mcIn.getBlocksize());
        }
        return mcOut;
    }

    @Override
    public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
        return Pair.of((Object)this.output.getName(), (Object)new LineageItem(this.getOpcode(), LineageItemUtils.getLineage(ec, this.inputs)));
    }

    private static class MinMaxAddFunction
    implements Function<MatrixBlock[], MatrixBlock> {
        private static final long serialVersionUID = -4227447915387484397L;
        private final SimpleOperator _op;
        private final ScalarObject[] _scalars;

        public MinMaxAddFunction(String opcode, List<ScalarObject> scalars) {
            this._scalars = scalars.toArray(new ScalarObject[0]);
            this._op = new SimpleOperator(opcode.equals("n+") ? Plus.getPlusFnObject() : Builtin.getBuiltinFnObject(opcode.substring(1)));
        }

        public MatrixBlock call(MatrixBlock[] v1) throws Exception {
            return MatrixBlock.naryOperations(this._op, v1, this._scalars, new MatrixBlock());
        }
    }

    public static class PadBlocksFunction
    implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 1291358959908299855L;
        private final DataCharacteristics _mcOut;

        public PadBlocksFunction(DataCharacteristics mcOut) {
            this._mcOut = mcOut;
        }

        public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock> arg0) throws Exception {
            MatrixIndexes ix = (MatrixIndexes)arg0._1();
            MatrixBlock mb = (MatrixBlock)arg0._2();
            int brlen = UtilFunctions.computeBlockSize(this._mcOut.getRows(), ix.getRowIndex(), this._mcOut.getBlocksize());
            int bclen = UtilFunctions.computeBlockSize(this._mcOut.getCols(), ix.getColumnIndex(), this._mcOut.getBlocksize());
            if (brlen == mb.getNumRows() && bclen == mb.getNumColumns()) {
                return arg0;
            }
            if (brlen > mb.getNumRows()) {
                mb = mb.append(new MatrixBlock(brlen - mb.getNumRows(), bclen, true), new MatrixBlock(), false);
            } else if (bclen > mb.getNumColumns()) {
                mb = mb.append(new MatrixBlock(brlen, bclen - mb.getNumColumns(), true), new MatrixBlock(), true);
            }
            return new Tuple2((Object)ix, (Object)mb);
        }
    }
}

