/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.codegen.cplan;

import java.util.ArrayList;
import org.apache.sysds.hops.codegen.SpoofCompiler;
import org.apache.sysds.hops.codegen.SpoofFusedOp;
import org.apache.sysds.hops.codegen.cplan.CNode;
import org.apache.sysds.hops.codegen.cplan.CNodeBinary;
import org.apache.sysds.hops.codegen.cplan.CNodeTpl;
import org.apache.sysds.hops.codegen.template.TemplateUtils;
import org.apache.sysds.runtime.codegen.SpoofRowwise;
import org.apache.sysds.runtime.util.UtilFunctions;

public class CNodeRow
extends CNodeTpl {
    protected static final String JAVA_TEMPLATE = "package codegen;\nimport org.apache.sysds.runtime.codegen.LibSpoofPrimitives;\nimport org.apache.sysds.runtime.codegen.SpoofOperator.SideInput;\nimport org.apache.sysds.runtime.codegen.SpoofRowwise;\nimport org.apache.sysds.runtime.codegen.SpoofRowwise.RowType;\nimport org.apache.commons.math3.util.FastMath;\n\npublic final class %TMP% extends SpoofRowwise { \n  public %TMP%() {\n    super(RowType.%TYPE%, %CONST_DIM2%, %TB1%, %VECT_MEM%);\n  }\n  protected void genexec(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int len, long grix, int rix) { \n%BODY_dense%  }\n  protected void genexec(double[] avals, int[] aix, int ai, SideInput[] b, double[] scalars, double[] c, int ci, int alen, int len, long grix, int rix) { \n%BODY_sparse%  }\n}\n";
    private static final String TEMPLATE_ROWAGG_OUT = "    c[rix] = %IN%;\n";
    private static final String TEMPLATE_FULLAGG_OUT = "    c[0] += %IN%;\n";
    private static final String TEMPLATE_NOAGG_OUT = "    LibSpoofPrimitives.vectWrite(%IN%, c, ci, %LEN%);\n";
    private static final String TEMPLATE_NOAGG_CONST_OUT_CUDA = "\t\tvectWrite(%IN%, c.vals(0), 0, ci, %LEN%);\n";
    private static final String TEMPLATE_NOAGG_OUT_CUDA = "\t\tvectWrite(%IN%, c.vals(0), 0, ci, %LEN%);\n";
    private static final String TEMPLATE_ROWAGG_OUT_CUDA = "\t\tif(threadIdx.x == 0){\n\t\t\t*(c.vals(rix)) = %IN%;\n//printf(\"rix=%d TMP7=%f TMP8=%f %IN%=%f\\n\",rix, TMP7, TMP8,%IN%);\n}\n";
    private static final String TEMPLATE_FULLAGG_OUT_CUDA = "\t\tif(threadIdx.x == 0) {\n\t\t\tT old = atomicAdd(c.vals(0), %IN%);\n//\t\t\tprintf(\"bid=%d full_agg add %f to %f\\n\",blockIdx.x, %IN%, old);\n\t\t}\n";
    private SpoofRowwise.RowType _type = null;
    private long _constDim2 = -1L;
    private int _numVectors = -1;
    private boolean _tb1 = false;

    public CNodeRow(ArrayList<CNode> inputs, CNode output) {
        super(inputs, output);
    }

    public void setRowType(SpoofRowwise.RowType type) {
        this._type = type;
        this._hash = 0;
    }

    public SpoofRowwise.RowType getRowType() {
        return this._type;
    }

    public void setNumVectorIntermediates(int num) {
        this._numVectors = num;
        this._hash = 0;
    }

    public int getNumVectorIntermediates() {
        return this._numVectors;
    }

    public void setConstDim2(long dim2) {
        this._constDim2 = dim2;
        this._hash = 0;
    }

    public long getConstDim2() {
        return this._constDim2;
    }

    @Override
    public void renameInputs() {
        this.rRenameDataNode(this._output, (CNode)this._inputs.get(0), "a");
        this.renameInputs(this._inputs, 1);
    }

    @Override
    public String codegen(boolean sparse, SpoofCompiler.GeneratorAPI _api) {
        this.api = _api;
        String tmp = this.getLanguageTemplate(this, this.api);
        String tmpDense = this._output.codegen(false, this.api) + this.getOutputStatement(this._output.getVarname());
        this._output.resetGenerated();
        String tmpSparse = this._output.codegen(true, this.api) + this.getOutputStatement(this._output.getVarname());
        this._output.resetGenerated();
        String varName = this.createVarname();
        tmp = tmp.replace(this.api.isJava() ? "%TMP%" : "//%TMP%", varName);
        if (!this.api.isJava()) {
            tmp = tmp.replace("/*%TMP%*/SPOOF_OP_NAME", varName);
        }
        String prefix = this.api.isJava() ? "" : "//";
        tmp = tmp.replace(prefix + "%BODY_dense%", tmpDense);
        tmp = tmp.replace(prefix + "%BODY_sparse%", tmpSparse);
        tmp = this.api.isJava() ? tmp.replace("%OUT%", "c") : tmp.replace("%OUT%", "c.vals(0)");
        tmp = tmp.replace("%POSOUT%", "0");
        tmp = tmp.replace("%LEN%", "a.cols()");
        tmp = tmp.replace("%TYPE%", this._type.name());
        tmp = tmp.replace("%CONST_DIM2%", String.valueOf(this._constDim2));
        this._tb1 = TemplateUtils.containsBinary(this._output, CNodeBinary.BinType.VECT_MATRIXMULT);
        tmp = tmp.replace("%TB1%", String.valueOf(this._tb1));
        if (this.api == SpoofCompiler.GeneratorAPI.CUDA && this._numVectors > 0) {
            tmp = tmp.replace("//%HAS_TEMP_VECT%", ": public TempStorageImpl<T, NUM_TMP_VECT, TMP_VECT_LEN>");
            tmp = tmp.replace("/*%INIT_TEMP_VECT%*/", ", TempStorageImpl<T, NUM_TMP_VECT, TMP_VECT_LEN>(tmp_stor)");
        } else {
            tmp = tmp.replace("//%HAS_TEMP_VECT%", "");
            tmp = tmp.replace("/*%INIT_TEMP_VECT%*/", "");
        }
        tmp = tmp.replace("%VECT_MEM%", String.valueOf(this._numVectors));
        return tmp;
    }

    private String getOutputStatement(String varName) {
        switch (this._type) {
            case NO_AGG: {
                if (this.api == SpoofCompiler.GeneratorAPI.CUDA) {
                    return "\t\tvectWrite(%IN%, c.vals(0), 0, ci, %LEN%);\n".replace("%IN%", varName + ".vals(0)").replaceAll("%LEN%", this._output.getVarname() + ".length");
                }
            }
            case NO_AGG_B1: 
            case NO_AGG_CONST: {
                if (this.api == SpoofCompiler.GeneratorAPI.JAVA) {
                    return TEMPLATE_NOAGG_OUT.replace("%IN%", varName).replace("%LEN%", this._output.getVarname() + ".length");
                }
                return "\t\tvectWrite(%IN%, c.vals(0), 0, ci, %LEN%);\n".replace("%IN%", varName + ".vals(0)").replaceAll("%LEN%", this._output.getVarname() + ".length");
            }
            case FULL_AGG: {
                if (this.api == SpoofCompiler.GeneratorAPI.JAVA) {
                    return TEMPLATE_FULLAGG_OUT.replace("%IN%", varName);
                }
                return TEMPLATE_FULLAGG_OUT_CUDA.replace("%IN%", varName);
            }
            case ROW_AGG: {
                if (this.api == SpoofCompiler.GeneratorAPI.JAVA) {
                    return TEMPLATE_ROWAGG_OUT.replace("%IN%", varName);
                }
                return TEMPLATE_ROWAGG_OUT_CUDA.replace("%IN%", varName);
            }
        }
        return "";
    }

    @Override
    public void setOutputDims() {
    }

    @Override
    public SpoofFusedOp.SpoofOutputDimsType getOutputDimType() {
        switch (this._type) {
            case NO_AGG: {
                return SpoofFusedOp.SpoofOutputDimsType.INPUT_DIMS;
            }
            case NO_AGG_B1: {
                return SpoofFusedOp.SpoofOutputDimsType.ROW_RANK_DIMS;
            }
            case NO_AGG_CONST: {
                return SpoofFusedOp.SpoofOutputDimsType.INPUT_DIMS_CONST2;
            }
            case FULL_AGG: {
                return SpoofFusedOp.SpoofOutputDimsType.SCALAR;
            }
            case ROW_AGG: {
                return SpoofFusedOp.SpoofOutputDimsType.ROW_DIMS;
            }
            case COL_AGG: {
                return SpoofFusedOp.SpoofOutputDimsType.COLUMN_DIMS_COLS;
            }
            case COL_AGG_T: {
                return SpoofFusedOp.SpoofOutputDimsType.COLUMN_DIMS_ROWS;
            }
            case COL_AGG_B1: {
                return SpoofFusedOp.SpoofOutputDimsType.COLUMN_RANK_DIMS;
            }
            case COL_AGG_B1_T: {
                return SpoofFusedOp.SpoofOutputDimsType.COLUMN_RANK_DIMS_T;
            }
            case COL_AGG_B1R: {
                return SpoofFusedOp.SpoofOutputDimsType.RANK_DIMS_COLS;
            }
            case COL_AGG_CONST: {
                return SpoofFusedOp.SpoofOutputDimsType.VECT_CONST2;
            }
        }
        throw new RuntimeException("Unsupported row type: " + this._type.toString());
    }

    @Override
    public CNodeTpl clone() {
        CNodeRow tmp = new CNodeRow(this._inputs, this._output);
        tmp.setRowType(this._type);
        tmp.setNumVectorIntermediates(this._numVectors);
        return tmp;
    }

    @Override
    public int hashCode() {
        if (this._hash == 0) {
            int h = UtilFunctions.intHashCode(super.hashCode(), this._type.hashCode());
            h = UtilFunctions.intHashCode(h, Long.hashCode(this._constDim2));
            this._hash = UtilFunctions.intHashCode(h, Integer.hashCode(this._numVectors));
        }
        return this._hash;
    }

    @Override
    public boolean equals(Object o) {
        if (!(o instanceof CNodeRow)) {
            return false;
        }
        CNodeRow that = (CNodeRow)o;
        return super.equals(o) && this._type == that._type && this._numVectors == that._numVectors && this._constDim2 == that._constDim2 && CNodeRow.equalInputReferences(this._output, that._output, (ArrayList<CNode>)this._inputs, (ArrayList<CNode>)that._inputs);
    }

    @Override
    public String getTemplateInfo() {
        StringBuilder sb = new StringBuilder();
        sb.append("SPOOF ROWAGGREGATE [type=");
        sb.append(this._type.name());
        sb.append(", reqVectMem=");
        sb.append(this._numVectors);
        sb.append("]");
        return sb.toString();
    }

    @Override
    public boolean isSupported(SpoofCompiler.GeneratorAPI api) {
        return (api == SpoofCompiler.GeneratorAPI.CUDA || api == SpoofCompiler.GeneratorAPI.JAVA) && this._output.isSupported(api);
    }

    @Override
    public int compile(SpoofCompiler.GeneratorAPI api, String src) {
        if (api == SpoofCompiler.GeneratorAPI.CUDA) {
            return this.compile_nvrtc(SpoofCompiler.native_contexts.get((Object)api), this._genVar, src, this._type.getValue(), this._constDim2, this._numVectors, this._tb1);
        }
        return -1;
    }

    private native int compile_nvrtc(long var1, String var3, String var4, int var5, long var6, int var8, boolean var9);
}

