/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.spark;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.stream.IntStream;
import org.apache.commons.lang.ArrayUtils;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysds.lops.PickByCount;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.spark.BinarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.runtime.util.UtilFunctions;
import scala.Tuple2;

public class QuantilePickSPInstruction
extends BinarySPInstruction {
    private PickByCount.OperationTypes _type = null;

    private QuantilePickSPInstruction(Operator op, CPOperand in, CPOperand out, PickByCount.OperationTypes type, boolean inmem, String opcode, String istr) {
        this(op, in, null, out, type, inmem, opcode, istr);
    }

    private QuantilePickSPInstruction(Operator op, CPOperand in, CPOperand in2, CPOperand out, PickByCount.OperationTypes type, boolean inmem, String opcode, String istr) {
        super(SPInstruction.SPType.QPick, op, in, in2, out, opcode, istr);
        this._type = type;
    }

    public static QuantilePickSPInstruction parseInstruction(String str) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        if (!opcode.equalsIgnoreCase("qpick")) {
            throw new DMLRuntimeException("Unknown opcode while parsing a QuantilePickCPInstruction: " + str);
        }
        if (parts.length == 4) {
            CPOperand in1 = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand out = new CPOperand(parts[3]);
            PickByCount.OperationTypes ptype = PickByCount.OperationTypes.IQM;
            return new QuantilePickSPInstruction(null, in1, in2, out, ptype, false, opcode, str);
        }
        if (parts.length == 5) {
            CPOperand in1 = new CPOperand(parts[1]);
            CPOperand out = new CPOperand(parts[2]);
            PickByCount.OperationTypes ptype = PickByCount.OperationTypes.valueOf(parts[3]);
            boolean inmem = Boolean.parseBoolean(parts[4]);
            return new QuantilePickSPInstruction(null, in1, out, ptype, inmem, opcode, str);
        }
        if (parts.length == 6) {
            CPOperand in1 = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand out = new CPOperand(parts[3]);
            PickByCount.OperationTypes ptype = PickByCount.OperationTypes.valueOf(parts[4]);
            boolean inmem = Boolean.parseBoolean(parts[5]);
            return new QuantilePickSPInstruction(null, in1, in2, out, ptype, inmem, opcode, str);
        }
        return null;
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        JavaPairRDD<MatrixIndexes, MatrixBlock> in = sec.getBinaryMatrixBlockRDDHandleForVariable(this.input1.getName());
        DataCharacteristics mc = sec.getDataCharacteristics(this.input1.getName());
        switch (this._type) {
            case VALUEPICK: {
                if (this.input2.isScalar()) {
                    ScalarObject quantile = ec.getScalarInput(this.input2);
                    double[] wt = QuantilePickSPInstruction.getWeightedQuantileSummary(in, mc, new double[]{quantile.getDoubleValue()});
                    ec.setScalarOutput(this.output.getName(), new DoubleObject(wt[3]));
                    break;
                }
                double[] wt = QuantilePickSPInstruction.getWeightedQuantileSummary(in, mc, DataConverter.convertToDoubleVector(ec.getMatrixInput(this.input2.getName())));
                ec.releaseMatrixInput(this.input2.getName());
                int qlen = wt.length / 3;
                MatrixBlock out = new MatrixBlock(qlen, 1, false);
                IntStream.range(0, out.getNumRows()).forEach(i -> out.quickSetValue(i, 0, wt[2 * qlen + i + 1]));
                ec.setMatrixOutput(this.output.getName(), out);
                break;
            }
            case MEDIAN: {
                double[] wt = QuantilePickSPInstruction.getWeightedQuantileSummary(in, mc, new double[]{0.5});
                ec.setScalarOutput(this.output.getName(), new DoubleObject(wt[3]));
                break;
            }
            case IQM: {
                double[] wt = QuantilePickSPInstruction.getWeightedQuantileSummary(in, mc, new double[]{0.25, 0.75});
                long key25 = (long)Math.ceil(wt[1]);
                long key75 = (long)Math.ceil(wt[2]);
                JavaPairRDD out = in.filter((Function)new FilterFunction(key25 + 1L, key75, mc.getBlocksize())).mapToPair((PairFunction)new ExtractAndSumFunction(key25 + 1L, key75, mc.getBlocksize()));
                double sum = RDDAggregateUtils.sumStable((JavaPairRDD<MatrixIndexes, MatrixBlock>)out).getValue(0, 0);
                double val = MatrixBlock.computeIQMCorrection(sum, wt[0], wt[3], wt[5], wt[4], wt[6]);
                ec.setScalarOutput(this.output.getName(), new DoubleObject(val));
                break;
            }
            default: {
                throw new DMLRuntimeException("Unsupported qpick operation type: " + this._type);
            }
        }
    }

    private static double[] getWeightedQuantileSummary(JavaPairRDD<MatrixIndexes, MatrixBlock> w, DataCharacteristics mc, double[] quantiles) {
        double[] ret = new double[3 * quantiles.length + 1];
        if (mc.getCols() == 2L) {
            w = w.sortByKey();
            List partWeights = w.mapPartitionsWithIndex((Function2)new SumWeightsFunction(), false).collect();
            ret[0] = partWeights.stream().mapToDouble(p -> (Double)p._2()).sum();
            double[] qdKeys = new double[quantiles.length];
            long[] qiKeys = new long[quantiles.length];
            int[] partitionIDs = new int[quantiles.length];
            double[] offsets = new double[quantiles.length];
            for (int i = 0; i < quantiles.length; ++i) {
                qdKeys[i] = quantiles[i] * ret[0];
                qiKeys[i] = (long)Math.ceil(qdKeys[i]);
            }
            double cumSum = 0.0;
            for (Tuple2 psum : partWeights) {
                double tmp = cumSum + (Double)psum._2();
                for (int i = 0; i < quantiles.length; ++i) {
                    if (!(tmp >= (double)qiKeys[i]) || partitionIDs[i] != 0) continue;
                    partitionIDs[i] = (Integer)psum._1();
                    offsets[i] = cumSum;
                }
                cumSum = tmp;
            }
            List qVals = w.mapPartitionsWithIndex((Function2)new ExtractWeightedQuantileFunction(mc, qdKeys, qiKeys, partitionIDs, offsets), false).collect();
            for (Tuple2 qVal : qVals) {
                ret[((Integer)qVal._1()).intValue() + 1] = ((double[])qVal._2())[0];
                ret[((Integer)qVal._1()).intValue() + quantiles.length + 1] = ((double[])qVal._2())[1];
                ret[((Integer)qVal._1()).intValue() + 2 * quantiles.length + 1] = ((double[])qVal._2())[2];
            }
        } else {
            ret[0] = mc.getRows();
            for (int i = 0; i < quantiles.length; ++i) {
                ret[i + 1] = quantiles[i] * (double)mc.getRows();
                ret[i + quantiles.length + 1] = Math.ceil(ret[i + 1]) - ret[i + 1];
                ret[i + 2 * quantiles.length + 1] = QuantilePickSPInstruction.lookupKey((JavaPairRDD<MatrixIndexes, MatrixBlock>)w, (long)Math.ceil(ret[i + 1]), mc.getBlocksize());
            }
        }
        return ret;
    }

    private static double lookupKey(JavaPairRDD<MatrixIndexes, MatrixBlock> in, long key, int blen) {
        long rix = UtilFunctions.computeBlockIndex(key, blen);
        long pos = UtilFunctions.computeCellInBlock(key, blen);
        List val = in.lookup((Object)new MatrixIndexes(rix, 1L));
        if (val.isEmpty()) {
            throw new DMLRuntimeException("Invalid key lookup in empty list.");
        }
        MatrixBlock tmp = (MatrixBlock)val.get(0);
        if ((long)tmp.getNumRows() <= pos) {
            throw new DMLRuntimeException("Invalid key lookup for " + pos + " in block of size " + tmp.getNumRows() + "x" + tmp.getNumColumns());
        }
        return ((MatrixBlock)val.get(0)).quickGetValue((int)pos, 0);
    }

    private static class ExtractWeightedQuantileFunction
    implements Function2<Integer, Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, Iterator<Tuple2<Integer, double[]>>> {
        private static final long serialVersionUID = 4879975971050093739L;
        private final DataCharacteristics _mc;
        private final double[] _qdKeys;
        private final long[] _qiKeys;
        private final int[] _qPIDs;
        private final double[] _offsets;

        public ExtractWeightedQuantileFunction(DataCharacteristics mc, double[] qdKeys, long[] qiKeys, int[] qPIDs, double[] offsets) {
            this._mc = mc;
            this._qdKeys = qdKeys;
            this._qiKeys = qiKeys;
            this._qPIDs = qPIDs;
            this._offsets = offsets;
        }

        public Iterator<Tuple2<Integer, double[]>> call(Integer v1, Iterator<Tuple2<MatrixIndexes, MatrixBlock>> v2) throws Exception {
            if (!ArrayUtils.contains((int[])this._qPIDs, (int)v1)) {
                return Collections.emptyIterator();
            }
            int qlen = (int)Arrays.stream(this._qPIDs).filter(i -> i == v1).count();
            int[] qix = new int[qlen];
            int pos = 0;
            for (int i2 = 0; i2 < this._qPIDs.length; ++i2) {
                if (this._qPIDs[i2] != v1) continue;
                qix[pos++] = i2;
            }
            double offset = this._offsets[qix[0]];
            ArrayList<Tuple2> ret = new ArrayList<Tuple2>();
            while (v2.hasNext()) {
                Tuple2<MatrixIndexes, MatrixBlock> tmp = v2.next();
                MatrixIndexes ix = (MatrixIndexes)tmp._1();
                MatrixBlock mb = (MatrixBlock)tmp._2();
                for (int i3 = 0; i3 < mb.getNumRows(); ++i3) {
                    double val = mb.quickGetValue(i3, 1);
                    for (int j = 0; j < qlen; ++j) {
                        if (!(offset + val >= (double)this._qiKeys[qix[j]])) continue;
                        long pos2 = UtilFunctions.computeCellIndex(ix.getRowIndex(), this._mc.getBlocksize(), i3);
                        double posPart = offset + val - this._qdKeys[qix[j]];
                        ret.add(new Tuple2((Object)qix[j], (Object)new double[]{pos2, posPart, mb.quickGetValue(i3, 0)}));
                        this._qiKeys[qix[j]] = Long.MAX_VALUE;
                    }
                    offset += val;
                }
            }
            return ret.iterator();
        }
    }

    private static class SumWeightsFunction
    implements Function2<Integer, Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, Iterator<Tuple2<Integer, Double>>> {
        private static final long serialVersionUID = 7169831202450745373L;

        private SumWeightsFunction() {
        }

        public Iterator<Tuple2<Integer, Double>> call(Integer v1, Iterator<Tuple2<MatrixIndexes, MatrixBlock>> v2) throws Exception {
            double sum = 0.0;
            while (v2.hasNext()) {
                sum += ((MatrixBlock)v2.next()._2()).sumWeightForQuantile();
            }
            return Arrays.asList(new Tuple2((Object)v1, (Object)sum)).iterator();
        }
    }

    private static class ExtractAndSumFunction
    implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = -584044441055250489L;
        private long _minRowIndex;
        private long _maxRowIndex;
        private int _minPos;
        private int _maxPos;

        public ExtractAndSumFunction(long key25, long key75, int blen) {
            this._minRowIndex = UtilFunctions.computeBlockIndex(key25, blen);
            this._maxRowIndex = UtilFunctions.computeBlockIndex(key75, blen);
            this._minPos = UtilFunctions.computeCellInBlock(key25, blen);
            this._maxPos = UtilFunctions.computeCellInBlock(key75, blen);
        }

        public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock> arg0) throws Exception {
            MatrixIndexes ix = (MatrixIndexes)arg0._1();
            MatrixBlock mb = (MatrixBlock)arg0._2();
            int rl = ix.getRowIndex() == this._minRowIndex ? this._minPos : 0;
            int ru = ix.getRowIndex() == this._maxRowIndex ? this._maxPos + 1 : mb.getNumRows();
            MatrixBlock ret = new MatrixBlock(1, 2, false);
            ret.setValue(0, 0, mb.getNumColumns() == 1 ? ExtractAndSumFunction.sum(mb, rl, ru) : ExtractAndSumFunction.sumWeighted(mb, rl, ru));
            return new Tuple2((Object)new MatrixIndexes(1L, 1L), (Object)ret);
        }

        private static double sum(MatrixBlock mb, int rl, int ru) {
            double sum = 0.0;
            for (int i = rl; i < ru; ++i) {
                sum += mb.quickGetValue(i, 0);
            }
            return sum;
        }

        private static double sumWeighted(MatrixBlock mb, int rl, int ru) {
            double sum = 0.0;
            for (int i = rl; i < ru; ++i) {
                sum += mb.quickGetValue(i, 0) * mb.quickGetValue(i, 1);
            }
            return sum;
        }
    }

    private static class FilterFunction
    implements Function<Tuple2<MatrixIndexes, MatrixBlock>, Boolean> {
        private static final long serialVersionUID = -8249102381116157388L;
        private long _minRowIndex;
        private long _maxRowIndex;

        public FilterFunction(long key25, long key75, int blen) {
            this._minRowIndex = UtilFunctions.computeBlockIndex(key25, blen);
            this._maxRowIndex = UtilFunctions.computeBlockIndex(key75, blen);
        }

        public Boolean call(Tuple2<MatrixIndexes, MatrixBlock> arg0) throws Exception {
            long rowIndex = ((MatrixIndexes)arg0._1()).getRowIndex();
            return rowIndex >= this._minRowIndex && rowIndex <= this._maxRowIndex;
        }
    }
}

