/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.calcite.rules;

import com.google.common.base.Function;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.LinkedHashMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Multimap;
import com.google.common.collect.Multimaps;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveBetween;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveIn;
import org.apache.hadoop.hive.ql.parse.SemanticException;

public abstract class HivePointLookupOptimizerRule
extends RelOptRule {
    protected static final Log LOG = LogFactory.getLog(HivePointLookupOptimizerRule.class);
    protected final int minNumORClauses;

    protected HivePointLookupOptimizerRule(RelOptRuleOperand operand, int minNumORClauses) {
        super(operand);
        this.minNumORClauses = minNumORClauses;
    }

    public RexNode analyzeRexNode(RexBuilder rexBuilder, RexNode condition) {
        RexTransformIntoInClause transformIntoInClause = new RexTransformIntoInClause(rexBuilder, this.minNumORClauses);
        RexNode newCondition = transformIntoInClause.apply(condition);
        RexMergeInClause mergeInClause = new RexMergeInClause(rexBuilder);
        newCondition = mergeInClause.apply(newCondition);
        RexTranformIntoBetween t = new RexTranformIntoBetween(rexBuilder);
        newCondition = t.apply(newCondition);
        return newCondition;
    }

    protected static class RexMergeInClause
    extends RexShuttle {
        private final RexBuilder rexBuilder;

        RexMergeInClause(RexBuilder rexBuilder) {
            this.rexBuilder = rexBuilder;
        }

        public RexNode visitCall(RexCall call) {
            RexNode node;
            HashMap<String, RexNode> stringToExpr = Maps.newHashMap();
            LinkedHashMultimap<String, String> inLHSExprToRHSExprs = LinkedHashMultimap.create();
            switch (call.getKind()) {
                case AND: {
                    ArrayList operands = Lists.newArrayList(RexUtil.flattenAnd((Iterable)call.getOperands()));
                    for (int i = 0; i < operands.size(); ++i) {
                        RexCall inCall;
                        RexNode operand = (RexNode)operands.get(i);
                        if (operand.getKind() != SqlKind.IN || !HiveCalciteUtil.isDeterministic((RexNode)(inCall = (RexCall)operand).getOperands().get(0))) continue;
                        String ref = ((RexNode)inCall.getOperands().get(0)).toString();
                        stringToExpr.put(ref, (RexNode)inCall.getOperands().get(0));
                        if (inLHSExprToRHSExprs.containsKey(ref)) {
                            HashSet<String> expressions = Sets.newHashSet();
                            for (int j = 1; j < inCall.getOperands().size(); ++j) {
                                String expr = ((RexNode)inCall.getOperands().get(j)).toString();
                                expressions.add(expr);
                                stringToExpr.put(expr, (RexNode)inCall.getOperands().get(j));
                            }
                            inLHSExprToRHSExprs.get(ref).retainAll(expressions);
                            if (!inLHSExprToRHSExprs.containsKey(ref)) {
                                return this.rexBuilder.makeLiteral(false);
                            }
                        } else {
                            for (int j = 1; j < inCall.getOperands().size(); ++j) {
                                String expr = ((RexNode)inCall.getOperands().get(j)).toString();
                                inLHSExprToRHSExprs.put(ref, expr);
                                stringToExpr.put(expr, (RexNode)inCall.getOperands().get(j));
                            }
                        }
                        operands.remove(i);
                        --i;
                    }
                    List<RexNode> newOperands = RexMergeInClause.createInClauses(this.rexBuilder, stringToExpr, inLHSExprToRHSExprs);
                    newOperands.addAll(operands);
                    node = RexUtil.composeConjunction((RexBuilder)this.rexBuilder, newOperands, (boolean)false);
                    break;
                }
                case OR: {
                    ArrayList operands = Lists.newArrayList(RexUtil.flattenOr((Iterable)call.getOperands()));
                    for (int i = 0; i < operands.size(); ++i) {
                        RexCall inCall;
                        RexNode operand = (RexNode)operands.get(i);
                        if (operand.getKind() != SqlKind.IN || !HiveCalciteUtil.isDeterministic((RexNode)(inCall = (RexCall)operand).getOperands().get(0))) continue;
                        String ref = ((RexNode)inCall.getOperands().get(0)).toString();
                        stringToExpr.put(ref, (RexNode)inCall.getOperands().get(0));
                        for (int j = 1; j < inCall.getOperands().size(); ++j) {
                            String expr = ((RexNode)inCall.getOperands().get(j)).toString();
                            inLHSExprToRHSExprs.put(ref, expr);
                            stringToExpr.put(expr, (RexNode)inCall.getOperands().get(j));
                        }
                        operands.remove(i);
                        --i;
                    }
                    List<RexNode> newOperands = RexMergeInClause.createInClauses(this.rexBuilder, stringToExpr, inLHSExprToRHSExprs);
                    newOperands.addAll(operands);
                    node = RexUtil.composeDisjunction((RexBuilder)this.rexBuilder, newOperands, (boolean)false);
                    break;
                }
                default: {
                    return super.visitCall(call);
                }
            }
            return node;
        }

        private static List<RexNode> createInClauses(RexBuilder rexBuilder, Map<String, RexNode> stringToExpr, Multimap<String, String> inLHSExprToRHSExprs) {
            ArrayList<RexNode> newExpressions = Lists.newArrayList();
            for (Map.Entry<String, Collection<String>> entry : inLHSExprToRHSExprs.asMap().entrySet()) {
                String ref = entry.getKey();
                Collection<String> exprs = entry.getValue();
                if (exprs.isEmpty()) {
                    newExpressions.add((RexNode)rexBuilder.makeLiteral(false));
                    continue;
                }
                ArrayList<RexNode> newOperands = new ArrayList<RexNode>(exprs.size() + 1);
                newOperands.add(stringToExpr.get(ref));
                for (String expr : exprs) {
                    newOperands.add(stringToExpr.get(expr));
                }
                newExpressions.add(rexBuilder.makeCall((SqlOperator)HiveIn.INSTANCE, newOperands));
            }
            return newExpressions;
        }
    }

    protected static class RexTransformIntoInClause
    extends RexShuttle {
        private final RexBuilder rexBuilder;
        private final int minNumORClauses;

        RexTransformIntoInClause(RexBuilder rexBuilder, int minNumORClauses) {
            this.rexBuilder = rexBuilder;
            this.minNumORClauses = minNumORClauses;
        }

        public RexNode visitCall(RexCall inputCall) {
            RexNode node = super.visitCall(inputCall);
            if (node instanceof RexCall) {
                RexCall call = (RexCall)node;
                switch (call.getKind()) {
                    case OR: {
                        try {
                            RexNode newNode = this.transformIntoInClauseCondition(this.rexBuilder, (RexNode)call, this.minNumORClauses);
                            if (newNode != null) {
                                return newNode;
                            }
                            break;
                        }
                        catch (SemanticException e) {
                            LOG.error((Object)"Exception in HivePointLookupOptimizerRule", (Throwable)e);
                            return call;
                        }
                    }
                }
            }
            return node;
        }

        private RexNode transformIntoInClauseCondition(RexBuilder rexBuilder, RexNode condition, int minNumORClauses) throws SemanticException {
            assert (condition.getKind() == SqlKind.OR);
            ImmutableList operands = RexUtil.flattenOr((Iterable)((RexCall)condition).getOperands());
            if (operands.size() < minNumORClauses) {
                return null;
            }
            ArrayList<Object> allNodes = new ArrayList<Object>();
            ArrayList processedNodes = new ArrayList();
            for (int i = 0; i < operands.size(); ++i) {
                ConstraintGroup m = new ConstraintGroup((RexNode)operands.get(i));
                allNodes.add(m);
            }
            ImmutableListMultimap<Set<RexNodeRef>, ConstraintGroup> assignmentGroups = Multimaps.index(allNodes, ConstraintGroup.KEY_FUNCTION);
            for (Map.Entry entry : assignmentGroups.asMap().entrySet()) {
                if (((Set)entry.getKey()).size() == 0 || ((Collection)entry.getValue()).size() < 2 || ((Collection)entry.getValue()).size() < minNumORClauses) continue;
                allNodes.add(new ConstraintGroup(this.buildInFor((Set)entry.getKey(), (Collection)entry.getValue())));
                processedNodes.addAll((Collection)entry.getValue());
            }
            if (processedNodes.isEmpty()) {
                return null;
            }
            allNodes.removeAll(processedNodes);
            ArrayList<RexNode> ops = new ArrayList<RexNode>();
            for (ConstraintGroup constraintGroup : allNodes) {
                ops.add(constraintGroup.originalRexNode);
            }
            if (ops.size() == 1) {
                return (RexNode)ops.get(0);
            }
            return rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.OR, ops);
        }

        private RexNode buildInFor(Set<RexNodeRef> set, Collection<ConstraintGroup> value) throws SemanticException {
            ArrayList<RexNodeRef> columns = new ArrayList<RexNodeRef>();
            columns.addAll(set);
            columns.sort(RexNodeRef.COMPARATOR);
            ArrayList<RexNode> operands = new ArrayList<RexNode>();
            List columnNodes = columns.stream().map(n -> n.getRexNode()).collect(Collectors.toList());
            operands.add(this.useStructIfNeeded(columnNodes));
            for (ConstraintGroup node : value) {
                List<RexNode> values = node.getValuesInOrder(columns);
                operands.add(this.useStructIfNeeded(values));
            }
            return this.rexBuilder.makeCall((SqlOperator)HiveIn.INSTANCE, operands);
        }

        private RexNode useStructIfNeeded(List<? extends RexNode> columns) {
            if (columns.size() == 1) {
                return columns.get(0);
            }
            return this.rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.ROW, columns);
        }

        static class ConstraintGroup {
            public static final Function<ConstraintGroup, Set<RexNodeRef>> KEY_FUNCTION = new Function<ConstraintGroup, Set<RexNodeRef>>(){

                @Override
                public Set<RexNodeRef> apply(ConstraintGroup cg) {
                    return cg.key;
                }
            };
            private Map<RexNodeRef, Constraint> constraints = new HashMap<RexNodeRef, Constraint>();
            private RexNode originalRexNode;
            private final Set<RexNodeRef> key;

            public ConstraintGroup(RexNode rexNode) {
                this.originalRexNode = rexNode;
                List conjunctions = RelOptUtil.conjunctions((RexNode)rexNode);
                for (RexNode n : conjunctions) {
                    Constraint c = Constraint.of(n);
                    if (c == null) {
                        this.key = Collections.emptySet();
                        return;
                    }
                    this.constraints.put(c.getKey(), c);
                }
                if (this.constraints.size() != conjunctions.size()) {
                    LOG.debug((Object)"unexpected situation; giving up on this branch");
                    this.key = Collections.emptySet();
                    return;
                }
                this.key = this.constraints.keySet();
            }

            public List<RexNode> getValuesInOrder(List<RexNodeRef> columns) throws SemanticException {
                ArrayList<RexNode> ret = new ArrayList<RexNode>();
                for (RexNodeRef rexInputRef : columns) {
                    Constraint constraint = this.constraints.get(rexInputRef);
                    if (constraint == null) {
                        throw new SemanticException("Unable to find constraint which was earlier added.");
                    }
                    ret.add(constraint.exprNode);
                }
                return ret;
            }
        }

        static class Constraint {
            private RexNode exprNode;
            private RexNode constNode;

            public Constraint(RexNode exprNode, RexNode constNode) {
                this.exprNode = constNode;
                this.constNode = exprNode;
            }

            public static Constraint of(RexNode n) {
                if (!(n instanceof RexCall)) {
                    return null;
                }
                RexCall call = (RexCall)n;
                if (call.getOperator().getKind() != SqlKind.EQUALS) {
                    return null;
                }
                RexNode opA = (RexNode)call.operands.get(0);
                RexNode opB = (RexNode)call.operands.get(1);
                if (RexUtil.isNull((RexNode)opA) || RexUtil.isNull((RexNode)opB)) {
                    return null;
                }
                if (Constraint.isConstExpr(opA) && Constraint.isColumnExpr(opB)) {
                    return new Constraint(opB, opA);
                }
                if (Constraint.isColumnExpr(opA) && Constraint.isConstExpr(opB)) {
                    return new Constraint(opA, opB);
                }
                return null;
            }

            private static boolean isColumnExpr(RexNode node) {
                return !node.getType().isStruct() && HiveCalciteUtil.getInputRefs(node).size() > 0 && HiveCalciteUtil.isDeterministic(node);
            }

            private static boolean isConstExpr(RexNode node) {
                return !node.getType().isStruct() && HiveCalciteUtil.getInputRefs(node).size() == 0 && HiveCalciteUtil.isDeterministic(node);
            }

            public RexNodeRef getKey() {
                return new RexNodeRef(this.constNode);
            }
        }
    }

    static class RexNodeRef {
        public static Comparator<RexNodeRef> COMPARATOR = (o1, o2) -> o1.node.toString().compareTo(o2.node.toString());
        private RexNode node;

        public RexNodeRef(RexNode node) {
            this.node = node;
        }

        public RexNode getRexNode() {
            return this.node;
        }

        public int hashCode() {
            return this.node.toString().hashCode();
        }

        public boolean equals(Object o) {
            if (o instanceof RexNodeRef) {
                RexNodeRef otherRef = (RexNodeRef)o;
                return this.node.toString().equals(otherRef.node.toString());
            }
            return false;
        }

        public String toString() {
            return "ref for:" + this.node.toString();
        }
    }

    protected static class RexTranformIntoBetween
    extends RexShuttle {
        private final RexBuilder rexBuilder;

        RexTranformIntoBetween(RexBuilder rexBuilder) {
            this.rexBuilder = rexBuilder;
        }

        public RexNode visitCall(RexCall inputCall) {
            RexNode node = super.visitCall(inputCall);
            if (node instanceof RexCall) {
                RexCall call = (RexCall)node;
                switch (call.getKind()) {
                    case AND: {
                        return this.processComparisions(call, SqlKind.LESS_THAN_OR_EQUAL, false);
                    }
                    case OR: {
                        return this.processComparisions(call, SqlKind.GREATER_THAN, true);
                    }
                }
            }
            return node;
        }

        private RexNode processComparisions(RexCall call, SqlKind forwardEdge, boolean invert) {
            DiGraph<RexNodeRef, RexCall> g = this.buildComparisionGraph(call.getOperands(), forwardEdge);
            IdentityHashMap<RexNode, BetweenCandidate> replacedNodes = new IdentityHashMap<RexNode, BetweenCandidate>();
            for (RexNodeRef n : g.nodes()) {
                Set<RexNodeRef> pred = g.predecessors(n);
                Set<RexNodeRef> succ = g.successors(n);
                if (pred.size() <= 0 || succ.size() <= 0) continue;
                RexNodeRef p = pred.iterator().next();
                RexNodeRef s = succ.iterator().next();
                RexNode between = this.rexBuilder.makeCall((SqlOperator)HiveBetween.INSTANCE, new RexNode[]{this.rexBuilder.makeLiteral(invert), n.node, p.node, s.node});
                BetweenCandidate bc = new BetweenCandidate(between, (RexNode)g.removeEdge(p, n), (RexNode)g.removeEdge(n, s));
                for (RexNode node : bc.oldNodes) {
                    replacedNodes.put(node, bc);
                }
            }
            if (replacedNodes.isEmpty()) {
                return call;
            }
            ArrayList<RexNode> newOperands = new ArrayList<RexNode>();
            for (RexNode o : call.getOperands()) {
                BetweenCandidate candidate = (BetweenCandidate)replacedNodes.get(o);
                if (candidate == null) {
                    newOperands.add(o);
                    continue;
                }
                if (candidate.used) continue;
                newOperands.add(candidate.newNode);
                candidate.used = true;
            }
            if (newOperands.size() == 1) {
                return (RexNode)newOperands.get(0);
            }
            return this.rexBuilder.makeCall(call.getOperator(), newOperands);
        }

        private DiGraph<RexNodeRef, RexCall> buildComparisionGraph(List<RexNode> operands, SqlKind cmpForward) {
            DiGraph<RexNodeRef, RexCall> g = new DiGraph<RexNodeRef, RexCall>();
            for (RexNode node : operands) {
                RexNode opB;
                RexNode opA;
                if (!(node instanceof RexCall)) continue;
                RexCall rexCall = (RexCall)node;
                SqlKind kind = rexCall.getKind();
                if (kind == cmpForward) {
                    opA = (RexNode)rexCall.getOperands().get(0);
                    opB = (RexNode)rexCall.getOperands().get(1);
                    g.putEdgeValue(new RexNodeRef(opA), new RexNodeRef(opB), rexCall);
                    continue;
                }
                if (kind != cmpForward.reverse()) continue;
                opA = (RexNode)rexCall.getOperands().get(1);
                opB = (RexNode)rexCall.getOperands().get(0);
                g.putEdgeValue(new RexNodeRef(opA), new RexNodeRef(opB), rexCall);
            }
            return g;
        }

        static class BetweenCandidate {
            private final RexNode newNode;
            private final RexNode[] oldNodes;
            private boolean used;

            public BetweenCandidate(RexNode newNode, RexNode ... oldNodes) {
                this.newNode = newNode;
                this.oldNodes = oldNodes;
            }
        }

        static class DiGraph<V, E> {
            private final Map<V, Node<V, E>> nodes = new LinkedHashMap<V, Node<V, E>>();

            public void putEdgeValue(V s, V t, E e) {
                Node<V, E> nodeS = this.nodeOf(s);
                Node<V, E> nodeT = this.nodeOf(t);
                Edge<V, E> edge = new Edge<V, E>(nodeS, nodeT, e);
                nodeS.addEdge(edge);
                nodeT.addEdge(edge);
            }

            private Node<V, E> nodeOf(V s) {
                Node<V, E> node = this.nodes.get(s);
                if (node == null) {
                    node = new Node(s);
                    this.nodes.put(s, node);
                }
                return node;
            }

            public Set<V> nodes() {
                return this.nodes.keySet();
            }

            public Set<V> predecessors(V n) {
                LinkedHashSet ret = new LinkedHashSet();
                Node<V, E> node = this.nodes.get(n);
                if (node == null) {
                    return ret;
                }
                for (Edge edge : node.edges) {
                    if (!edge.t.v.equals(n)) continue;
                    ret.add(edge.s.v);
                }
                return ret;
            }

            public Set<V> successors(V n) {
                LinkedHashSet ret = new LinkedHashSet();
                Node<V, E> node = this.nodes.get(n);
                if (node == null) {
                    return ret;
                }
                for (Edge edge : node.edges) {
                    if (!edge.s.v.equals(n)) continue;
                    ret.add(edge.t.v);
                }
                return ret;
            }

            public E removeEdge(V s, V t) {
                this.nodeOf(s).removeEdge(s, t);
                return this.nodeOf(t).removeEdge(s, t);
            }

            static class Node<V, E> {
                final Set<Edge<V, E>> edges = new LinkedHashSet<Edge<V, E>>();
                final V v;

                public Node(V v) {
                    this.v = v;
                }

                public void addEdge(Edge<V, E> edge) {
                    this.edges.add(edge);
                }

                public E removeEdge(V s, V t) {
                    Iterator<Edge<V, E>> it = this.edges.iterator();
                    while (it.hasNext()) {
                        Edge<V, E> edge = it.next();
                        if (!edge.s.v.equals(s) || !edge.t.v.equals(t)) continue;
                        it.remove();
                        return edge.e;
                    }
                    return null;
                }
            }

            static class Edge<V, E> {
                final Node<V, E> s;
                final Node<V, E> t;
                final E e;

                public Edge(Node<V, E> s, Node<V, E> t, E e) {
                    this.s = s;
                    this.t = t;
                    this.e = e;
                }
            }
        }
    }

    public static class ProjectionExpressions
    extends HivePointLookupOptimizerRule {
        public ProjectionExpressions(int minNumORClauses) {
            super(ProjectionExpressions.operand(Project.class, (RelOptRuleOperandChildren)ProjectionExpressions.any()), minNumORClauses);
        }

        public void onMatch(RelOptRuleCall call) {
            Project project = (Project)call.rel(0);
            boolean changed = false;
            RexBuilder rexBuilder = project.getCluster().getRexBuilder();
            ArrayList<RexNode> newProjects = new ArrayList<RexNode>();
            for (RexNode oldNode : project.getProjects()) {
                RexNode newNode = this.analyzeRexNode(rexBuilder, oldNode);
                if (!newNode.toString().equals(oldNode.toString())) {
                    changed = true;
                    newProjects.add(newNode);
                    continue;
                }
                newProjects.add(oldNode);
            }
            if (!changed) {
                return;
            }
            Project newProject = project.copy(project.getTraitSet(), project.getInput(), newProjects, project.getRowType(), project.getFlags());
            call.transformTo((RelNode)newProject);
        }
    }

    public static class JoinCondition
    extends HivePointLookupOptimizerRule {
        public JoinCondition(int minNumORClauses) {
            super(JoinCondition.operand(Join.class, (RelOptRuleOperandChildren)JoinCondition.any()), minNumORClauses);
        }

        public void onMatch(RelOptRuleCall call) {
            RexNode condition;
            Join join = (Join)call.rel(0);
            RexBuilder rexBuilder = join.getCluster().getRexBuilder();
            RexNode newCondition = this.analyzeRexNode(rexBuilder, condition = RexUtil.pullFactors((RexBuilder)rexBuilder, (RexNode)join.getCondition()));
            if (newCondition.toString().equals(condition.toString())) {
                return;
            }
            Join newNode = join.copy(join.getTraitSet(), newCondition, join.getLeft(), join.getRight(), join.getJoinType(), join.isSemiJoinDone());
            call.transformTo((RelNode)newNode);
        }
    }

    public static class FilterCondition
    extends HivePointLookupOptimizerRule {
        public FilterCondition(int minNumORClauses) {
            super(FilterCondition.operand(Filter.class, (RelOptRuleOperandChildren)FilterCondition.any()), minNumORClauses);
        }

        public void onMatch(RelOptRuleCall call) {
            RexNode condition;
            Filter filter = (Filter)call.rel(0);
            RexBuilder rexBuilder = filter.getCluster().getRexBuilder();
            RexNode newCondition = this.analyzeRexNode(rexBuilder, condition = RexUtil.pullFactors((RexBuilder)rexBuilder, (RexNode)filter.getCondition()));
            if (newCondition.toString().equals(condition.toString())) {
                return;
            }
            Filter newNode = filter.copy(filter.getTraitSet(), filter.getInput(), newCondition);
            call.transformTo((RelNode)newNode);
        }
    }
}

