/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.rel.rules;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.rules.HyperEdge;
import org.apache.calcite.rel.rules.HyperGraph;
import org.apache.calcite.rel.rules.ImmutableJoinToHyperGraphRule;
import org.apache.calcite.rel.rules.LongBitmap;
import org.apache.calcite.rel.rules.TransformationRule;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.rex.RexVisitorImpl;
import org.immutables.value.Value;

@Value.Enclosing
public class JoinToHyperGraphRule
extends RelRule<Config>
implements TransformationRule {
    protected JoinToHyperGraphRule(Config config) {
        super(config);
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        int leftNodeCount;
        final Join origJoin = (Join)call.rel(0);
        Object left = call.rel(1);
        Object right = call.rel(2);
        if (origJoin.getJoinType() != JoinRelType.INNER) {
            return;
        }
        ArrayList<RelNode> inputs = new ArrayList<RelNode>();
        ArrayList<HyperEdge> edges = new ArrayList<HyperEdge>();
        ArrayList<RexNode> joinConds = new ArrayList<RexNode>();
        if (origJoin.getCondition().isAlwaysTrue()) {
            joinConds.add(origJoin.getCondition());
        } else {
            RelOptUtil.decomposeConjunction(origJoin.getCondition(), joinConds);
        }
        int leftFieldCount = left.getRowType().getFieldCount();
        if (left instanceof HyperGraph && right instanceof HyperGraph) {
            leftNodeCount = left.getInputs().size();
            inputs.addAll(left.getInputs());
            inputs.addAll(right.getInputs());
            edges.addAll(((HyperGraph)left).getEdges());
            edges.addAll(((HyperGraph)right).getEdges().stream().map(hyperEdge -> JoinToHyperGraphRule.adjustNodeBit(hyperEdge, leftNodeCount, leftFieldCount)).collect(Collectors.toList()));
        } else if (left instanceof HyperGraph) {
            leftNodeCount = left.getInputs().size();
            inputs.addAll(left.getInputs());
            inputs.add((RelNode)right);
            edges.addAll(((HyperGraph)left).getEdges());
        } else if (right instanceof HyperGraph) {
            leftNodeCount = 1;
            inputs.add((RelNode)left);
            inputs.addAll(right.getInputs());
            edges.addAll(((HyperGraph)right).getEdges().stream().map(hyperEdge -> JoinToHyperGraphRule.adjustNodeBit(hyperEdge, leftNodeCount, leftFieldCount)).collect(Collectors.toList()));
        } else {
            leftNodeCount = 1;
            inputs.add((RelNode)left);
            inputs.add((RelNode)right);
        }
        final HashMap<Integer, Integer> fieldIndexToNodeIndexMap = new HashMap<Integer, Integer>();
        int fieldCount = 0;
        for (int i = 0; i < inputs.size(); ++i) {
            for (int j = 0; j < ((RelNode)inputs.get(i)).getRowType().getFieldCount(); ++j) {
                fieldIndexToNodeIndexMap.put(fieldCount++, i);
            }
        }
        for (RexNode joinCond : joinConds) {
            long rightNodeBits;
            long leftNodeBits;
            final ArrayList<Integer> leftRefs = new ArrayList<Integer>();
            final ArrayList<Integer> rightRefs = new ArrayList<Integer>();
            RexVisitorImpl<Void> visitor = new RexVisitorImpl<Void>(true){

                @Override
                public Void visitInputRef(RexInputRef inputRef) {
                    Integer nodeIndex = (Integer)fieldIndexToNodeIndexMap.get(inputRef.getIndex());
                    if (nodeIndex == null) {
                        throw new IllegalArgumentException("RexInputRef refers a dummy field: " + inputRef + ", rowType is: " + origJoin.getRowType());
                    }
                    if (nodeIndex < leftNodeCount) {
                        leftRefs.add(nodeIndex);
                    } else {
                        rightRefs.add(nodeIndex);
                    }
                    return null;
                }
            };
            joinCond.accept(visitor);
            if (leftRefs.isEmpty() || rightRefs.isEmpty()) {
                leftNodeBits = LongBitmap.newBitmapBetween(0, leftNodeCount);
                rightNodeBits = LongBitmap.newBitmapBetween(leftNodeCount, inputs.size());
            } else {
                leftNodeBits = LongBitmap.newBitmapFromList(leftRefs);
                rightNodeBits = LongBitmap.newBitmapFromList(rightRefs);
            }
            edges.add(new HyperEdge(leftNodeBits, rightNodeBits, origJoin.getJoinType(), joinCond));
        }
        HyperGraph result = new HyperGraph(origJoin.getCluster(), origJoin.getTraitSet(), inputs, edges, origJoin.getRowType());
        call.transformTo(result);
    }

    private static HyperEdge adjustNodeBit(HyperEdge hyperEdge, int nodeOffset, int fieldOffset) {
        RexNode newCondition = RexUtil.shift(hyperEdge.getCondition(), fieldOffset);
        return new HyperEdge(hyperEdge.getLeftNodeBitmap() << nodeOffset, hyperEdge.getRightNodeBitmap() << nodeOffset, hyperEdge.getJoinType(), newCondition);
    }

    @Value.Immutable
    public static interface Config
    extends RelRule.Config {
        public static final Config DEFAULT = ImmutableJoinToHyperGraphRule.Config.of().withOperandSupplier(b1 -> b1.operand(Join.class).inputs(b2 -> b2.operand(RelNode.class).anyInputs(), b3 -> b3.operand(RelNode.class).anyInputs()));

        @Override
        default public JoinToHyperGraphRule toRule() {
            return new JoinToHyperGraphRule(this);
        }
    }
}

