/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.physical;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Stack;
import java.util.TreeMap;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.exec.AbstractMapJoinOperator;
import org.apache.hadoop.hive.ql.exec.CommonMergeJoinOperator;
import org.apache.hadoop.hive.ql.exec.ConditionalTask;
import org.apache.hadoop.hive.ql.exec.JoinOperator;
import org.apache.hadoop.hive.ql.exec.MapJoinOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator;
import org.apache.hadoop.hive.ql.exec.TableScanOperator;
import org.apache.hadoop.hive.ql.exec.Task;
import org.apache.hadoop.hive.ql.exec.mr.MapRedTask;
import org.apache.hadoop.hive.ql.exec.tez.TezTask;
import org.apache.hadoop.hive.ql.lib.DefaultGraphWalker;
import org.apache.hadoop.hive.ql.lib.DefaultRuleDispatcher;
import org.apache.hadoop.hive.ql.lib.Dispatcher;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.NodeProcessor;
import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx;
import org.apache.hadoop.hive.ql.lib.Rule;
import org.apache.hadoop.hive.ql.lib.RuleRegExp;
import org.apache.hadoop.hive.ql.lib.TaskGraphWalker;
import org.apache.hadoop.hive.ql.optimizer.physical.LlapClusterStateForCompile;
import org.apache.hadoop.hive.ql.optimizer.physical.PhysicalContext;
import org.apache.hadoop.hive.ql.optimizer.physical.PhysicalPlanResolver;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.plan.BaseWork;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.JoinDesc;
import org.apache.hadoop.hive.ql.plan.MapJoinDesc;
import org.apache.hadoop.hive.ql.plan.MapWork;
import org.apache.hadoop.hive.ql.plan.MapredWork;
import org.apache.hadoop.hive.ql.plan.MergeJoinWork;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;
import org.apache.hadoop.hive.ql.plan.ReduceSinkDesc;
import org.apache.hadoop.hive.ql.plan.ReduceWork;
import org.apache.hadoop.hive.ql.plan.TableScanDesc;
import org.apache.hadoop.hive.ql.plan.TezEdgeProperty;
import org.apache.hadoop.hive.ql.plan.TezWork;
import org.apache.hadoop.hive.ql.session.SessionState;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CrossProductHandler
implements PhysicalPlanResolver,
Dispatcher {
    protected static final transient Logger LOG = LoggerFactory.getLogger(CrossProductHandler.class);
    private Boolean cartesianProductEdgeEnabled = null;

    @Override
    public PhysicalContext resolve(PhysicalContext pctx) throws SemanticException {
        HiveConf conf = pctx.getConf();
        this.cartesianProductEdgeEnabled = HiveConf.getBoolVar(conf, HiveConf.ConfVars.TEZ_CARTESIAN_PRODUCT_EDGE_ENABLED);
        if (this.cartesianProductEdgeEnabled.booleanValue() && HiveConf.getVar(conf, HiveConf.ConfVars.HIVE_EXECUTION_MODE).equals("llap") && conf.get("tez.cartesian-product.max-parallelism") == null) {
            LlapClusterStateForCompile llapInfo = LlapClusterStateForCompile.getClusterInfo(conf);
            llapInfo.initClusterInfo();
            if (llapInfo.hasClusterInfo()) {
                conf.setInt("tez.cartesian-product.max-parallelism", llapInfo.getKnownExecutorCount());
            }
        }
        TaskGraphWalker ogw = new TaskGraphWalker(this);
        ArrayList<Node> topNodes = new ArrayList<Node>();
        topNodes.addAll(pctx.getRootTasks());
        ogw.startWalking(topNodes, null);
        return pctx;
    }

    @Override
    public Object dispatch(Node nd, Stack<Node> stack, Object ... nodeOutputs) throws SemanticException {
        Task currTask = (Task)nd;
        if (currTask instanceof MapRedTask) {
            MapRedTask mrTsk = (MapRedTask)currTask;
            MapredWork mrWrk = (MapredWork)mrTsk.getWork();
            this.checkMapJoins(mrTsk);
            this.checkMRReducer(currTask.toString(), mrWrk);
        } else if (currTask instanceof ConditionalTask) {
            List<Task<? extends Serializable>> taskListInConditionalTask = ((ConditionalTask)currTask).getListTasks();
            for (Task<? extends Serializable> tsk : taskListInConditionalTask) {
                this.dispatch(tsk, stack, nodeOutputs);
            }
        } else if (currTask instanceof TezTask) {
            TezTask tezTask = (TezTask)currTask;
            TezWork tezWork = (TezWork)tezTask.getWork();
            this.checkMapJoins(tezWork);
            this.checkTezReducer(tezWork);
        }
        return null;
    }

    private void warn(String msg) {
        SessionState.getConsole().printInfo("Warning: " + msg, false);
    }

    private void checkMapJoins(MapRedTask mrTsk) throws SemanticException {
        ReduceWork redWork;
        MapredWork mrWrk = (MapredWork)mrTsk.getWork();
        MapWork mapWork = mrWrk.getMapWork();
        List<String> warnings = new MapJoinCheck(mrTsk.toString()).analyze(mapWork);
        if (!warnings.isEmpty()) {
            for (String w : warnings) {
                this.warn(w);
            }
        }
        if ((redWork = mrWrk.getReduceWork()) != null && !(warnings = new MapJoinCheck(mrTsk.toString()).analyze(redWork)).isEmpty()) {
            for (String w : warnings) {
                this.warn(w);
            }
        }
    }

    private void checkMapJoins(TezWork tezWork) throws SemanticException {
        for (BaseWork wrk : tezWork.getAllWork()) {
            List<String> warnings;
            if (wrk instanceof MergeJoinWork) {
                wrk = ((MergeJoinWork)wrk).getMainWork();
            }
            if ((warnings = new MapJoinCheck(wrk.getName()).analyze(wrk)).isEmpty()) continue;
            for (String w : warnings) {
                this.warn(w);
            }
        }
    }

    private void checkTezReducer(TezWork tezWork) throws SemanticException {
        for (BaseWork wrk : tezWork.getAllWork()) {
            BaseWork origWrk = null;
            if (wrk instanceof MergeJoinWork) {
                origWrk = wrk;
                wrk = ((MergeJoinWork)wrk).getMainWork();
            }
            if (!(wrk instanceof ReduceWork)) continue;
            ReduceWork rWork = (ReduceWork)wrk;
            Operator<?> reducer = ((ReduceWork)wrk).getReducer();
            if (!(reducer instanceof JoinOperator) && !(reducer instanceof CommonMergeJoinOperator)) continue;
            boolean noOuterJoin = ((JoinDesc)reducer.getConf()).isNoOuterJoin();
            TreeMap<Integer, ExtractReduceSinkInfo.Info> rsInfo = new TreeMap<Integer, ExtractReduceSinkInfo.Info>();
            for (Map.Entry<Integer, String> e : rWork.getTagToInput().entrySet()) {
                rsInfo.putAll(this.getReducerInfo(tezWork, rWork.getName(), e.getValue()));
            }
            if (!this.checkForCrossProduct(rWork.getName(), reducer, rsInfo) || !this.cartesianProductEdgeEnabled.booleanValue() || !noOuterJoin) continue;
            List<BaseWork> parents = tezWork.getParents(null == origWrk ? wrk : origWrk);
            for (BaseWork p : parents) {
                TezEdgeProperty prop = tezWork.getEdgeProperty(p, null == origWrk ? wrk : origWrk);
                LOG.info("Edge Type: " + (Object)((Object)prop.getEdgeType()));
                if (!prop.getEdgeType().equals((Object)TezEdgeProperty.EdgeType.CUSTOM_SIMPLE_EDGE) && !prop.getEdgeType().equals((Object)TezEdgeProperty.EdgeType.CUSTOM_EDGE)) continue;
                prop.setEdgeType(TezEdgeProperty.EdgeType.XPROD_EDGE);
                rWork.setNumReduceTasks(-1);
                rWork.setMaxReduceTasks(-1);
                rWork.setMinReduceTasks(-1);
            }
        }
    }

    private void checkMRReducer(String taskName, MapredWork mrWrk) throws SemanticException {
        ReduceWork rWrk = mrWrk.getReduceWork();
        if (rWrk == null) {
            return;
        }
        Operator<?> reducer = rWrk.getReducer();
        if (reducer instanceof JoinOperator || reducer instanceof CommonMergeJoinOperator) {
            MapWork parentWork = mrWrk.getMapWork();
            this.checkForCrossProduct(taskName, reducer, new ExtractReduceSinkInfo(null).analyze(parentWork));
        }
    }

    private boolean checkForCrossProduct(String taskName, Operator<? extends OperatorDesc> reducer, Map<Integer, ExtractReduceSinkInfo.Info> rsInfo) {
        if (rsInfo.isEmpty()) {
            return false;
        }
        Iterator<ExtractReduceSinkInfo.Info> it = rsInfo.values().iterator();
        ExtractReduceSinkInfo.Info info = it.next();
        if (info.keyCols.size() == 0) {
            ArrayList<String> iAliases = new ArrayList<String>();
            iAliases.addAll(info.inputAliases);
            while (it.hasNext()) {
                info = it.next();
                iAliases.addAll(info.inputAliases);
            }
            String warning = String.format("Shuffle Join %s[tables = %s] in Stage '%s' is a cross product", reducer.toString(), iAliases, taskName);
            this.warn(warning);
            return true;
        }
        return false;
    }

    private Map<Integer, ExtractReduceSinkInfo.Info> getReducerInfo(TezWork tezWork, String vertex, String prntVertex) throws SemanticException {
        BaseWork parentWork = tezWork.getWorkMap().get(prntVertex);
        return new ExtractReduceSinkInfo(vertex).analyze(parentWork);
    }

    static class NoopProcessor
    implements NodeProcessor {
        NoopProcessor() {
        }

        @Override
        public final Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx, Object ... nodeOutputs) throws SemanticException {
            return nd;
        }
    }

    public static class ExtractReduceSinkInfo
    implements NodeProcessor,
    NodeProcessorCtx {
        final String outputTaskName;
        final Map<Integer, Info> reduceSinkInfo;

        ExtractReduceSinkInfo(String parentTaskName) {
            this.outputTaskName = parentTaskName;
            this.reduceSinkInfo = new HashMap<Integer, Info>();
        }

        Map<Integer, Info> analyze(BaseWork work) throws SemanticException {
            LinkedHashMap<Rule, NodeProcessor> opRules = new LinkedHashMap<Rule, NodeProcessor>();
            opRules.put(new RuleRegExp("R1", ReduceSinkOperator.getOperatorName() + "%"), this);
            DefaultRuleDispatcher disp = new DefaultRuleDispatcher(new NoopProcessor(), opRules, this);
            DefaultGraphWalker ogw = new DefaultGraphWalker(disp);
            ArrayList<Node> topNodes = new ArrayList<Node>();
            topNodes.addAll(work.getAllRootOperators());
            ogw.startWalking(topNodes, null);
            return this.reduceSinkInfo;
        }

        @Override
        public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx, Object ... nodeOutputs) throws SemanticException {
            String rOutputName;
            ReduceSinkOperator rsOp = (ReduceSinkOperator)nd;
            ReduceSinkDesc rsDesc = (ReduceSinkDesc)rsOp.getConf();
            if (!(this.outputTaskName == null || (rOutputName = rsDesc.getOutputName()) != null && this.outputTaskName.equals(rOutputName))) {
                return null;
            }
            this.reduceSinkInfo.put(rsDesc.getTag(), new Info(rsDesc.getKeyCols(), rsOp.getInputAliases()));
            return null;
        }

        static class Info {
            List<ExprNodeDesc> keyCols;
            List<String> inputAliases;

            Info(List<ExprNodeDesc> keyCols, List<String> inputAliases) {
                this.keyCols = keyCols;
                this.inputAliases = inputAliases == null ? new ArrayList() : inputAliases;
            }

            Info(List<ExprNodeDesc> keyCols, String[] inputAliases) {
                this.keyCols = keyCols;
                this.inputAliases = inputAliases == null ? new ArrayList() : Arrays.asList(inputAliases);
            }
        }
    }

    public static class MapJoinCheck
    implements NodeProcessor,
    NodeProcessorCtx {
        final List<String> warnings;
        final String taskName;

        MapJoinCheck(String taskName) {
            this.taskName = taskName;
            this.warnings = new ArrayList<String>();
        }

        List<String> analyze(BaseWork work) throws SemanticException {
            LinkedHashMap<Rule, NodeProcessor> opRules = new LinkedHashMap<Rule, NodeProcessor>();
            opRules.put(new RuleRegExp("R1", MapJoinOperator.getOperatorName() + "%"), this);
            DefaultRuleDispatcher disp = new DefaultRuleDispatcher(new NoopProcessor(), opRules, this);
            DefaultGraphWalker ogw = new DefaultGraphWalker(disp);
            ArrayList<Node> topNodes = new ArrayList<Node>();
            topNodes.addAll(work.getAllRootOperators());
            ogw.startWalking(topNodes, null);
            return this.warnings;
        }

        @Override
        public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx, Object ... nodeOutputs) throws SemanticException {
            AbstractMapJoinOperator mjOp = (AbstractMapJoinOperator)nd;
            MapJoinDesc mjDesc = (MapJoinDesc)mjOp.getConf();
            String bigTablAlias = mjDesc.getBigTableAlias();
            if (bigTablAlias == null) {
                Operator<OperatorDesc> parent = null;
                for (Operator<OperatorDesc> op : mjOp.getParentOperators()) {
                    if (!(op instanceof TableScanOperator)) continue;
                    parent = op;
                }
                if (parent != null) {
                    TableScanDesc tDesc = (TableScanDesc)((TableScanOperator)parent).getConf();
                    bigTablAlias = tDesc.getAlias();
                }
            }
            bigTablAlias = bigTablAlias == null ? "?" : bigTablAlias;
            List<ExprNodeDesc> joinExprs = mjDesc.getKeys().values().iterator().next();
            if (joinExprs.size() == 0) {
                this.warnings.add(String.format("Map Join %s[bigTable=%s] in task '%s' is a cross product", mjOp.toString(), bigTablAlias, this.taskName));
            }
            return null;
        }
    }
}

