/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.lib;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.sysds.runtime.DMLCompressionException;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.CompressionSettings;
import org.apache.sysds.runtime.compress.CompressionSettingsBuilder;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupFactory;
import org.apache.sysds.runtime.compress.colgroup.ColGroupValue;
import org.apache.sysds.runtime.compress.lib.BitmapEncoder;
import org.apache.sysds.runtime.compress.lib.BitmapLossyEncoder;
import org.apache.sysds.runtime.compress.readers.ReaderColumnSelection;
import org.apache.sysds.runtime.compress.utils.ABitmap;
import org.apache.sysds.runtime.util.CommonThreadPool;

public class CLALibSquash {
    public static CompressedMatrixBlock squash(CompressedMatrixBlock m, int k) {
        CompressedMatrixBlock ret = new CompressedMatrixBlock(m.getNumRows(), m.getNumColumns());
        CompressionSettings cs = new CompressionSettingsBuilder().create();
        double[] minMaxes = CLALibSquash.extractMinMaxes(m);
        List<AColGroup> retCg = k <= 1 ? CLALibSquash.singleThreadSquash(m, cs, minMaxes) : CLALibSquash.multiThreadSquash(m, cs, k, minMaxes);
        ret.allocateColGroupList(retCg);
        ret.recomputeNonZeros();
        if (ret.isOverlapping()) {
            throw new DMLCompressionException("Squash should output compressed nonOverlapping matrix");
        }
        return ret;
    }

    private static double[] extractMinMaxes(CompressedMatrixBlock m) {
        double[] ret = new double[m.getNumColumns() * 2];
        for (AColGroup g : m.getColGroups()) {
            if (g instanceof ColGroupValue) {
                ((ColGroupValue)g).addMinMax(ret);
                continue;
            }
            throw new DMLCompressionException("Not valid to squash if not all colGroups are of ColGroupValue type.");
        }
        return ret;
    }

    private static List<AColGroup> singleThreadSquash(CompressedMatrixBlock m, CompressionSettings cs, double[] minMaxes) {
        ArrayList<AColGroup> retCg = new ArrayList<AColGroup>();
        int blkSz = 1;
        for (int i = 0; i < m.getNumColumns(); i += blkSz) {
            int[] columnIds = new int[Math.min(blkSz, m.getNumColumns() - i)];
            for (int j = 0; j < Math.min(blkSz, m.getNumColumns() - i); ++j) {
                columnIds[j] = i + j;
            }
            retCg.add(CLALibSquash.extractNewGroup(m, cs, columnIds, minMaxes));
        }
        return retCg;
    }

    private static List<AColGroup> multiThreadSquash(CompressedMatrixBlock m, CompressionSettings cs, int k, double[] minMaxes) {
        ArrayList<AColGroup> retCg = new ArrayList<AColGroup>();
        ExecutorService pool = CommonThreadPool.get(k);
        ArrayList<SquashTask> tasks = new ArrayList<SquashTask>();
        try {
            int blkSz = 1;
            for (int i = 0; i < m.getNumColumns(); i += blkSz) {
                int[] columnIds = new int[Math.min(blkSz, m.getNumColumns() - i)];
                for (int j = 0; j < Math.min(blkSz, m.getNumColumns() - i); ++j) {
                    columnIds[j] = i + j;
                }
                tasks.add(new SquashTask(m, cs, columnIds, minMaxes));
            }
            for (Future future : pool.invokeAll(tasks)) {
                retCg.add((AColGroup)future.get());
            }
            pool.shutdown();
        }
        catch (InterruptedException | ExecutionException e) {
            throw new DMLRuntimeException(e);
        }
        return retCg;
    }

    private static AColGroup extractNewGroup(CompressedMatrixBlock m, CompressionSettings cs, int[] columnIds, double[] minMaxes) {
        ABitmap map = CLALibSquash.extractBitmap(columnIds, m);
        AColGroup newGroup = ColGroupFactory.compress(columnIds, m.getNumRows(), map, AColGroup.CompressionType.DDC, cs, m, 1.0);
        return newGroup;
    }

    private static ABitmap extractBitmap(int[] colIndices, CompressedMatrixBlock compressedBlock) {
        ABitmap x = BitmapEncoder.extractBitmap(colIndices, ReaderColumnSelection.createCompressedReader(compressedBlock, colIndices), compressedBlock.getNumRows());
        return BitmapLossyEncoder.makeBitmapLossy(x, compressedBlock.getNumRows());
    }

    private static class SquashTask
    implements Callable<AColGroup> {
        private final CompressedMatrixBlock _m;
        private final CompressionSettings _cs;
        private final int[] _columnIds;
        private final double[] _minMaxes;

        protected SquashTask(CompressedMatrixBlock m, CompressionSettings cs, int[] columnIds, double[] minMaxes) {
            this._m = m;
            this._cs = cs;
            this._columnIds = columnIds;
            this._minMaxes = minMaxes;
        }

        @Override
        public AColGroup call() {
            return CLALibSquash.extractNewGroup(this._m, this._cs, this._columnIds, this._minMaxes);
        }
    }
}

