001/**
002 *
003 * Licensed to the Apache Software Foundation (ASF) under one
004 * or more contributor license agreements.  See the NOTICE file
005 * distributed with this work for additional information
006 * regarding copyright ownership.  The ASF licenses this file
007 * to you under the Apache License, Version 2.0 (the
008 * "License"); you may not use this file except in compliance
009 * with the License.  You may obtain a copy of the License at
010 *
011 *     http://www.apache.org/licenses/LICENSE-2.0
012 *
013 * Unless required by applicable law or agreed to in writing, software
014 * distributed under the License is distributed on an "AS IS" BASIS,
015 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
016 * See the License for the specific language governing permissions and
017 * limitations under the License.
018 */
019package org.apache.hadoop.hbase.spark;
020
021import org.apache.commons.lang3.RandomStringUtils;
022import org.apache.hadoop.conf.Configuration;
023import org.apache.hadoop.fs.Path;
024
025import org.apache.hadoop.hbase.Cell;
026import org.apache.hadoop.hbase.CellUtil;
027import org.apache.hadoop.hbase.HBaseConfiguration;
028import org.apache.hadoop.hbase.HBaseTestingUtility;
029import org.apache.hadoop.hbase.HConstants;
030import org.apache.hadoop.hbase.HTableDescriptor;
031import org.apache.hadoop.hbase.IntegrationTestBase;
032import org.apache.hadoop.hbase.IntegrationTestingUtility;
033import org.apache.hadoop.hbase.TableName;
034
035import org.apache.hadoop.hbase.client.Admin;
036import org.apache.hadoop.hbase.client.Connection;
037import org.apache.hadoop.hbase.client.ConnectionFactory;
038import org.apache.hadoop.hbase.client.Consistency;
039import org.apache.hadoop.hbase.client.RegionLocator;
040import org.apache.hadoop.hbase.client.Result;
041import org.apache.hadoop.hbase.client.Scan;
042import org.apache.hadoop.hbase.client.Table;
043
044import org.apache.hadoop.hbase.io.ImmutableBytesWritable;
045
046import org.apache.hadoop.hbase.mapreduce.IntegrationTestBulkLoad;
047import org.apache.hadoop.hbase.tool.LoadIncrementalHFiles;
048import org.apache.hadoop.hbase.util.Bytes;
049import org.apache.hadoop.hbase.util.EnvironmentEdgeManager;
050import org.apache.hadoop.hbase.util.Pair;
051import org.apache.hadoop.hbase.util.RegionSplitter;
052
053import org.apache.hadoop.util.StringUtils;
054import org.apache.hadoop.util.ToolRunner;
055import org.apache.spark.SerializableWritable;
056import org.apache.spark.SparkConf;
057import org.apache.spark.api.java.JavaRDD;
058import org.apache.spark.api.java.JavaSparkContext;
059
060import org.apache.spark.Partitioner;
061
062import org.apache.spark.api.java.function.Function;
063import org.apache.spark.api.java.function.Function2;
064import org.apache.spark.api.java.function.PairFlatMapFunction;
065import org.apache.spark.api.java.function.VoidFunction;
066import org.junit.Test;
067import org.slf4j.Logger;
068import org.slf4j.LoggerFactory;
069import scala.Tuple2;
070import java.io.IOException;
071import java.util.ArrayList;
072import java.util.Arrays;
073import java.util.HashMap;
074import java.util.Iterator;
075import java.util.LinkedList;
076import java.util.List;
077import java.util.Map;
078import java.util.Random;
079import java.util.Set;
080
081import org.apache.hbase.thirdparty.com.google.common.collect.Sets;
082import org.apache.hbase.thirdparty.org.apache.commons.cli.CommandLine;
083
084/**
085 * Test Bulk Load and Spark on a distributed cluster.
086 * It starts an Spark job that creates linked chains.
087 * This test mimic {@link IntegrationTestBulkLoad} in mapreduce.
088 *
089 * Usage on cluster:
090 *   First add hbase related jars and hbase-spark.jar into spark classpath.
091 *
092 *   spark-submit --class org.apache.hadoop.hbase.spark.IntegrationTestSparkBulkLoad
093 *                HBASE_HOME/lib/hbase-spark-it-XXX-tests.jar -m slowDeterministic -Dhbase.spark.bulkload.chainlength=300
094 */
095@org.junit.Ignore("CDH-35577 Our hbase-spark story is incompat with upstream. Fix after rebase.")
096public class IntegrationTestSparkBulkLoad extends IntegrationTestBase {
097
098  private static final Logger LOG = LoggerFactory.getLogger(IntegrationTestSparkBulkLoad.class);
099
100  // The number of partitions for random generated data
101  private static String BULKLOAD_PARTITIONS_NUM = "hbase.spark.bulkload.partitionsnum";
102  private static int DEFAULT_BULKLOAD_PARTITIONS_NUM = 3;
103
104  private static String BULKLOAD_CHAIN_LENGTH = "hbase.spark.bulkload.chainlength";
105  private static int DEFAULT_BULKLOAD_CHAIN_LENGTH = 200000;
106
107  private static String BULKLOAD_IMPORT_ROUNDS = "hbase.spark.bulkload.importround";
108  private static int DEFAULT_BULKLOAD_IMPORT_ROUNDS  = 1;
109
110  private static String CURRENT_ROUND_NUM = "hbase.spark.bulkload.current.roundnum";
111
112  private static String NUM_REPLICA_COUNT_KEY = "hbase.spark.bulkload.replica.countkey";
113  private static int DEFAULT_NUM_REPLICA_COUNT = 1;
114
115  private static String BULKLOAD_TABLE_NAME = "hbase.spark.bulkload.tableName";
116  private static String DEFAULT_BULKLOAD_TABLE_NAME = "IntegrationTestSparkBulkLoad";
117
118  private static String BULKLOAD_OUTPUT_PATH = "hbase.spark.bulkload.output.path";
119
120  private static final String OPT_LOAD = "load";
121  private static final String OPT_CHECK = "check";
122
123  private boolean load = false;
124  private boolean check = false;
125
126  private static final byte[] CHAIN_FAM  = Bytes.toBytes("L");
127  private static final byte[] SORT_FAM = Bytes.toBytes("S");
128  private static final byte[] DATA_FAM = Bytes.toBytes("D");
129
130  /**
131   * Running spark job to load data into hbase table
132   */
133  public void runLoad() throws Exception {
134    setupTable();
135    int numImportRounds = getConf().getInt(BULKLOAD_IMPORT_ROUNDS, DEFAULT_BULKLOAD_IMPORT_ROUNDS);
136    LOG.info("Running load with numIterations:" + numImportRounds);
137    for (int i = 0; i < numImportRounds; i++) {
138      runLinkedListSparkJob(i);
139    }
140  }
141
142  /**
143   * Running spark job to create LinkedList for testing
144   * @param iteration iteration th of this job
145   * @throws Exception
146   */
147  public void runLinkedListSparkJob(int iteration) throws Exception {
148    String jobName =  IntegrationTestSparkBulkLoad.class.getSimpleName() + " _load " +
149        EnvironmentEdgeManager.currentTime();
150
151    LOG.info("Running iteration " + iteration + "in Spark Job");
152
153    Path output = null;
154    if (conf.get(BULKLOAD_OUTPUT_PATH) == null) {
155      output = util.getDataTestDirOnTestFS(getTablename() + "-" + iteration);
156    } else {
157      output = new Path(conf.get(BULKLOAD_OUTPUT_PATH));
158    }
159
160    SparkConf sparkConf = new SparkConf().setAppName(jobName).setMaster("local");
161    Configuration hbaseConf = new Configuration(getConf());
162    hbaseConf.setInt(CURRENT_ROUND_NUM, iteration);
163    int partitionNum = hbaseConf.getInt(BULKLOAD_PARTITIONS_NUM, DEFAULT_BULKLOAD_PARTITIONS_NUM);
164
165
166    JavaSparkContext jsc = new JavaSparkContext(sparkConf);
167    JavaHBaseContext hbaseContext = new JavaHBaseContext(jsc, hbaseConf);
168
169
170    LOG.info("Partition RDD into " + partitionNum + " parts");
171    List<String> temp = new ArrayList<>();
172    JavaRDD<List<byte[]>> rdd = jsc.parallelize(temp, partitionNum).
173        mapPartitionsWithIndex(new LinkedListCreationMapper(new SerializableWritable<>(hbaseConf)), false);
174
175    hbaseContext.bulkLoad(rdd, getTablename(), new ListToKeyValueFunc(), output.toUri().getPath(),
176        new HashMap<>(), false, HConstants.DEFAULT_MAX_FILE_SIZE);
177
178    try (Connection conn = ConnectionFactory.createConnection(conf);
179        Admin admin = conn.getAdmin();
180        Table table = conn.getTable(getTablename());
181        RegionLocator regionLocator = conn.getRegionLocator(getTablename())) {
182      // Create a new loader.
183      LoadIncrementalHFiles loader = new LoadIncrementalHFiles(conf);
184
185      // Load the HFiles into table.
186      loader.doBulkLoad(output, admin, table, regionLocator);
187    }
188
189
190    // Delete the files.
191    util.getTestFileSystem().delete(output, true);
192    jsc.close();
193  }
194
195  // See mapreduce.IntegrationTestBulkLoad#LinkedListCreationMapper
196  // Used to generate test data
197  public static class LinkedListCreationMapper implements
198      Function2<Integer, Iterator<String>, Iterator<List<byte[]>>> {
199
200    SerializableWritable swConfig = null;
201    private Random rand = new Random();
202
203    public LinkedListCreationMapper(SerializableWritable conf) {
204      this.swConfig = conf;
205    }
206
207    @Override
208    public Iterator<List<byte[]>> call(Integer v1, Iterator v2) throws Exception {
209      Configuration config = (Configuration) swConfig.value();
210      int partitionId = v1.intValue();
211      LOG.info("Starting create List in Partition " + partitionId);
212
213      int partitionNum = config.getInt(BULKLOAD_PARTITIONS_NUM, DEFAULT_BULKLOAD_PARTITIONS_NUM);
214      int chainLength = config.getInt(BULKLOAD_CHAIN_LENGTH, DEFAULT_BULKLOAD_CHAIN_LENGTH);
215      int iterationsNum = config.getInt(BULKLOAD_IMPORT_ROUNDS, DEFAULT_BULKLOAD_IMPORT_ROUNDS);
216      int iterationsCur = config.getInt(CURRENT_ROUND_NUM, 0);
217      List<List<byte[]>> res = new LinkedList<>();
218
219
220      long tempId = partitionId + iterationsCur * partitionNum;
221      long totalPartitionNum = partitionNum * iterationsNum;
222      long chainId = Math.abs(rand.nextLong());
223      chainId = chainId - (chainId % totalPartitionNum) + tempId;
224
225      byte[] chainIdArray = Bytes.toBytes(chainId);
226      long currentRow = 0;
227      long nextRow = getNextRow(0, chainLength);
228      for(long i = 0; i < chainLength; i++) {
229        byte[] rk = Bytes.toBytes(currentRow);
230        // Insert record into a list
231        List<byte[]> tmp1 = Arrays.asList(rk, CHAIN_FAM, chainIdArray, Bytes.toBytes(nextRow));
232        List<byte[]> tmp2 = Arrays.asList(rk, SORT_FAM, chainIdArray, Bytes.toBytes(i));
233        List<byte[]> tmp3 = Arrays.asList(rk, DATA_FAM, chainIdArray, Bytes.toBytes(
234            RandomStringUtils.randomAlphabetic(50)));
235        res.add(tmp1);
236        res.add(tmp2);
237        res.add(tmp3);
238
239        currentRow = nextRow;
240        nextRow = getNextRow(i+1, chainLength);
241      }
242      return res.iterator();
243    }
244
245    /** Returns a unique row id within this chain for this index */
246    private long getNextRow(long index, long chainLength) {
247      long nextRow = Math.abs(new Random().nextLong());
248      // use significant bits from the random number, but pad with index to ensure it is unique
249      // this also ensures that we do not reuse row = 0
250      // row collisions from multiple mappers are fine, since we guarantee unique chainIds
251      nextRow = nextRow - (nextRow % chainLength) + index;
252      return nextRow;
253    }
254  }
255
256
257
258  public static class ListToKeyValueFunc implements
259      Function<List<byte[]>, Pair<KeyFamilyQualifier, byte[]>> {
260    @Override
261    public Pair<KeyFamilyQualifier, byte[]> call(List<byte[]> v1) throws Exception {
262      if (v1 == null || v1.size() != 4) {
263        return null;
264      }
265      KeyFamilyQualifier kfq = new KeyFamilyQualifier(v1.get(0), v1.get(1), v1.get(2));
266
267      return new Pair<>(kfq, v1.get(3));
268    }
269  }
270
271  /**
272   * After adding data to the table start a mr job to
273   * @throws IOException
274   * @throws ClassNotFoundException
275   * @throws InterruptedException
276   */
277  public void runCheck() throws Exception {
278    LOG.info("Running check");
279    String jobName = IntegrationTestSparkBulkLoad.class.getSimpleName() + "_check" + EnvironmentEdgeManager.currentTime();
280
281    SparkConf sparkConf = new SparkConf().setAppName(jobName).setMaster("local");
282    Configuration hbaseConf = new Configuration(getConf());
283    JavaSparkContext jsc = new JavaSparkContext(sparkConf);
284    JavaHBaseContext hbaseContext = new JavaHBaseContext(jsc, hbaseConf);
285
286    Scan scan = new Scan();
287    scan.addFamily(CHAIN_FAM);
288    scan.addFamily(SORT_FAM);
289    scan.setMaxVersions(1);
290    scan.setCacheBlocks(false);
291    scan.setBatch(1000);
292    int replicaCount = conf.getInt(NUM_REPLICA_COUNT_KEY, DEFAULT_NUM_REPLICA_COUNT);
293    if (replicaCount != DEFAULT_NUM_REPLICA_COUNT) {
294      scan.setConsistency(Consistency.TIMELINE);
295    }
296
297    // 1. Using TableInputFormat to get data from HBase table
298    // 2. Mimic LinkedListCheckingMapper in mapreduce.IntegrationTestBulkLoad
299    // 3. Sort LinkKey by its order ID
300    // 4. Group LinkKey if they have same chainId, and repartition RDD by NaturalKeyPartitioner
301    // 5. Check LinkList in each Partition using LinkedListCheckingFlatMapFunc
302    hbaseContext.hbaseRDD(getTablename(), scan).flatMapToPair(new LinkedListCheckingFlatMapFunc())
303        .sortByKey()
304        .combineByKey(new createCombinerFunc(), new mergeValueFunc(), new mergeCombinersFunc(),
305            new NaturalKeyPartitioner(new SerializableWritable<>(hbaseConf)))
306        .foreach(new LinkedListCheckingForeachFunc(new SerializableWritable<>(hbaseConf)));
307    jsc.close();
308  }
309
310  private void runCheckWithRetry() throws Exception {
311    try {
312      runCheck();
313    } catch (Throwable t) {
314      LOG.warn("Received " + StringUtils.stringifyException(t));
315      LOG.warn("Running the check MR Job again to see whether an ephemeral problem or not");
316      runCheck();
317      throw t; // we should still fail the test even if second retry succeeds
318    }
319    // everything green
320  }
321
322  /**
323   * PairFlatMapFunction used to transfer <Row, Result> to Tuple <SparkLinkKey, SparkLinkChain>
324   */
325  public static class LinkedListCheckingFlatMapFunc implements
326      PairFlatMapFunction<Tuple2<ImmutableBytesWritable, Result>, SparkLinkKey, SparkLinkChain> {
327
328    @Override
329    public Iterator<Tuple2<SparkLinkKey, SparkLinkChain>> call(Tuple2<ImmutableBytesWritable, Result> v)
330        throws Exception {
331      Result value = v._2();
332      long longRk = Bytes.toLong(value.getRow());
333      List<Tuple2<SparkLinkKey, SparkLinkChain>> list = new LinkedList<>();
334
335      for (Map.Entry<byte[], byte[]> entry : value.getFamilyMap(CHAIN_FAM).entrySet()) {
336        long chainId = Bytes.toLong(entry.getKey());
337        long next = Bytes.toLong(entry.getValue());
338        Cell c = value.getColumnCells(SORT_FAM, entry.getKey()).get(0);
339        long order = Bytes.toLong(CellUtil.cloneValue(c));
340        Tuple2<SparkLinkKey, SparkLinkChain> tuple2 =
341            new Tuple2<>(new SparkLinkKey(chainId, order), new SparkLinkChain(longRk, next));
342        list.add(tuple2);
343      }
344      return list.iterator();
345    }
346  }
347
348  public static class createCombinerFunc implements
349      Function<SparkLinkChain, List<SparkLinkChain>> {
350    @Override
351    public List<SparkLinkChain> call(SparkLinkChain v1) throws Exception {
352      List<SparkLinkChain> list = new LinkedList<>();
353      list.add(v1);
354      return list;
355    }
356  }
357
358  public static class mergeValueFunc implements
359      Function2<List<SparkLinkChain>, SparkLinkChain, List<SparkLinkChain>> {
360    @Override
361    public List<SparkLinkChain> call(List<SparkLinkChain> v1, SparkLinkChain v2) throws Exception {
362      if (v1 == null)
363        v1 = new LinkedList<>();
364      v1.add(v2);
365      return v1;
366    }
367  }
368
369  public static class mergeCombinersFunc implements
370      Function2<List<SparkLinkChain>, List<SparkLinkChain>, List<SparkLinkChain>> {
371    @Override
372    public List<SparkLinkChain> call(List<SparkLinkChain> v1, List<SparkLinkChain> v2) throws Exception {
373      v1.addAll(v2);
374      return v1;
375    }
376  }
377
378  /**
379   * Class to figure out what partition to send a link in the chain to.  This is based upon
380   * the linkKey's ChainId.
381   */
382  public static class NaturalKeyPartitioner extends Partitioner {
383
384    private int numPartions = 0;
385    public NaturalKeyPartitioner(SerializableWritable swConf) {
386      Configuration hbaseConf = (Configuration) swConf.value();
387      numPartions = hbaseConf.getInt(BULKLOAD_PARTITIONS_NUM, DEFAULT_BULKLOAD_PARTITIONS_NUM);
388
389    }
390
391    @Override
392    public int numPartitions() {
393      return numPartions;
394    }
395
396    @Override
397    public int getPartition(Object key) {
398      if (!(key instanceof SparkLinkKey))
399        return -1;
400      int hash = ((SparkLinkKey) key).getChainId().hashCode();
401      return Math.abs(hash % numPartions);
402
403    }
404  }
405
406  /**
407   * Sort all LinkChain for one LinkKey, and test List<LinkChain>
408   */
409  public static class LinkedListCheckingForeachFunc
410      implements VoidFunction<Tuple2<SparkLinkKey, List<SparkLinkChain>>> {
411
412    private  SerializableWritable swConf = null;
413
414    public LinkedListCheckingForeachFunc(SerializableWritable conf) {
415      swConf = conf;
416    }
417
418    @Override
419    public void call(Tuple2<SparkLinkKey, List<SparkLinkChain>> v1) throws Exception {
420      long next = -1L;
421      long prev = -1L;
422      long count = 0L;
423
424      SparkLinkKey key = v1._1();
425      List<SparkLinkChain> values = v1._2();
426
427      for (SparkLinkChain lc : values) {
428
429        if (next == -1) {
430          if (lc.getRk() != 0L) {
431            String msg = "Chains should all start at rk 0, but read rk " + lc.getRk()
432                + ". Chain:" + key.getChainId() + ", order:" + key.getOrder();
433            throw new RuntimeException(msg);
434          }
435          next = lc.getNext();
436        } else {
437          if (next != lc.getRk()) {
438            String msg = "Missing a link in the chain. Prev rk " + prev + " was, expecting "
439                + next + " but got " + lc.getRk() + ". Chain:" + key.getChainId()
440                + ", order:" + key.getOrder();
441            throw new RuntimeException(msg);
442          }
443          prev = lc.getRk();
444          next = lc.getNext();
445        }
446        count++;
447      }
448      Configuration hbaseConf = (Configuration) swConf.value();
449      int expectedChainLen = hbaseConf.getInt(BULKLOAD_CHAIN_LENGTH, DEFAULT_BULKLOAD_CHAIN_LENGTH);
450      if (count != expectedChainLen) {
451        String msg = "Chain wasn't the correct length.  Expected " + expectedChainLen + " got "
452            + count + ". Chain:" + key.getChainId() + ", order:" + key.getOrder();
453        throw new RuntimeException(msg);
454      }
455    }
456  }
457
458  /**
459   * Writable class used as the key to group links in the linked list.
460   *
461   * Used as the key emited from a pass over the table.
462   */
463  public static class SparkLinkKey implements java.io.Serializable, Comparable<SparkLinkKey> {
464
465    private Long chainId;
466    private Long order;
467
468    public Long getOrder() {
469      return order;
470    }
471
472    public Long getChainId() {
473      return chainId;
474    }
475
476    public SparkLinkKey(long chainId, long order) {
477      this.chainId = chainId;
478      this.order = order;
479    }
480
481    @Override
482    public int hashCode() {
483      return this.getChainId().hashCode();
484    }
485
486    @Override
487    public boolean equals(Object other) {
488      if (!(other instanceof SparkLinkKey))
489        return false;
490      SparkLinkKey otherKey = (SparkLinkKey) other;
491      return this.getChainId().equals(otherKey.getChainId());
492    }
493
494    @Override
495    public int compareTo(SparkLinkKey other) {
496      int res = getChainId().compareTo(other.getChainId());
497      if (res == 0)
498        res= getOrder().compareTo(other.getOrder());
499      return res;
500    }
501  }
502
503  /**
504   * Writable used as the value emitted from a pass over the hbase table.
505   */
506  public static class SparkLinkChain implements java.io.Serializable, Comparable<SparkLinkChain>{
507
508    public Long getNext() {
509      return next;
510    }
511
512    public Long getRk() {
513      return rk;
514    }
515
516
517    public SparkLinkChain(Long rk, Long next) {
518      this.rk = rk;
519      this.next = next;
520    }
521
522    private Long rk;
523    private Long next;
524
525    @Override
526    public int compareTo(SparkLinkChain linkChain) {
527      int res = getRk().compareTo(linkChain.getRk());
528      if (res == 0) {
529        res = getNext().compareTo(linkChain.getNext());
530      }
531      return res;
532    }
533
534    @Override
535    public int hashCode() {
536      return getRk().hashCode() ^ getNext().hashCode();
537    }
538
539    @Override
540    public boolean equals(Object other) {
541      if (!(other instanceof SparkLinkChain))
542        return false;
543      SparkLinkChain otherKey = (SparkLinkChain) other;
544      return this.getRk().equals(otherKey.getRk()) && this.getNext().equals(otherKey.getNext());
545    }
546  }
547
548
549  /**
550   * Allow the scan to go to replica, this would not affect the runCheck()
551   * Since data are BulkLoaded from HFile into table
552   * @throws IOException
553   * @throws InterruptedException
554   */
555  private void installSlowingCoproc() throws IOException, InterruptedException {
556    int replicaCount = conf.getInt(NUM_REPLICA_COUNT_KEY, DEFAULT_NUM_REPLICA_COUNT);
557    if (replicaCount == DEFAULT_NUM_REPLICA_COUNT) return;
558
559    TableName t = getTablename();
560    Admin admin = util.getAdmin();
561    HTableDescriptor desc = admin.getTableDescriptor(t);
562    desc.addCoprocessor(IntegrationTestBulkLoad.SlowMeCoproScanOperations.class.getName());
563    HBaseTestingUtility.modifyTableSync(admin, desc);
564  }
565
566  @Test
567  public void testBulkLoad() throws Exception {
568    runLoad();
569    installSlowingCoproc();
570    runCheckWithRetry();
571  }
572
573
574  private byte[][] getSplits(int numRegions) {
575    RegionSplitter.UniformSplit split = new RegionSplitter.UniformSplit();
576    split.setFirstRow(Bytes.toBytes(0L));
577    split.setLastRow(Bytes.toBytes(Long.MAX_VALUE));
578    return split.split(numRegions);
579  }
580
581  private void setupTable() throws IOException, InterruptedException {
582    if (util.getAdmin().tableExists(getTablename())) {
583      util.deleteTable(getTablename());
584    }
585
586    util.createTable(
587        getTablename(),
588        new byte[][]{CHAIN_FAM, SORT_FAM, DATA_FAM},
589        getSplits(16)
590    );
591
592    int replicaCount = conf.getInt(NUM_REPLICA_COUNT_KEY, DEFAULT_NUM_REPLICA_COUNT);
593    if (replicaCount == DEFAULT_NUM_REPLICA_COUNT) return;
594
595    TableName t = getTablename();
596    HBaseTestingUtility.setReplicas(util.getAdmin(), t, replicaCount);
597  }
598
599  @Override
600  public void setUpCluster() throws Exception {
601    util = getTestingUtil(getConf());
602    util.initializeCluster(1);
603    int replicaCount = getConf().getInt(NUM_REPLICA_COUNT_KEY, DEFAULT_NUM_REPLICA_COUNT);
604    if (LOG.isDebugEnabled() && replicaCount != DEFAULT_NUM_REPLICA_COUNT) {
605      LOG.debug("Region Replicas enabled: " + replicaCount);
606    }
607
608    // Scale this up on a real cluster
609    if (util.isDistributedCluster()) {
610      util.getConfiguration().setIfUnset(BULKLOAD_PARTITIONS_NUM, String.valueOf(DEFAULT_BULKLOAD_PARTITIONS_NUM));
611      util.getConfiguration().setIfUnset(BULKLOAD_IMPORT_ROUNDS, "1");
612    } else {
613      util.startMiniMapReduceCluster();
614    }
615  }
616
617  @Override
618  protected void addOptions() {
619    super.addOptions();
620    super.addOptNoArg(OPT_CHECK, "Run check only");
621    super.addOptNoArg(OPT_LOAD, "Run load only");
622  }
623
624  @Override
625  protected void processOptions(CommandLine cmd) {
626    super.processOptions(cmd);
627    check = cmd.hasOption(OPT_CHECK);
628    load = cmd.hasOption(OPT_LOAD);
629  }
630
631  @Override
632  public int runTestFromCommandLine() throws Exception {
633    if (load) {
634      runLoad();
635    } else if (check) {
636      installSlowingCoproc();
637      runCheckWithRetry();
638    } else {
639      testBulkLoad();
640    }
641    return 0;
642  }
643
644  @Override
645  public TableName getTablename() {
646    return getTableName(getConf());
647  }
648
649  public static TableName getTableName(Configuration conf) {
650    return TableName.valueOf(conf.get(BULKLOAD_TABLE_NAME, DEFAULT_BULKLOAD_TABLE_NAME));
651  }
652
653  @Override
654  protected Set<String> getColumnFamilies() {
655    return Sets.newHashSet(Bytes.toString(CHAIN_FAM) , Bytes.toString(DATA_FAM),
656        Bytes.toString(SORT_FAM));
657  }
658
659  public static void main(String[] args) throws Exception {
660    Configuration conf = HBaseConfiguration.create();
661    IntegrationTestingUtility.setUseDistributedCluster(conf);
662    int status =  ToolRunner.run(conf, new IntegrationTestSparkBulkLoad(), args);
663    System.exit(status);
664  }
665}