diff --git a/.asf.yaml b/.asf.yaml index 83de7215a..1d9d56e20 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -49,14 +49,12 @@ github: required_pull_request_reviews: dismiss_stale_reviews: true require_code_owner_reviews: false - required_approving_review_count: 2 + required_approving_review_count: 1 # (for non-committer): assign/edit/close issues & PR, without write access to the code collaborators: - - Pengzna + - kenssa4eedfd - haohao0103 - - Thespica - FrostyHec - - MuLeiSY2021 notifications: # use https://selfserve.apache.org to manage it diff --git a/hugegraph-loader/src/main/java/org/apache/hugegraph/loader/HugeGraphLoader.java b/hugegraph-loader/src/main/java/org/apache/hugegraph/loader/HugeGraphLoader.java index 2fb9eb4aa..dab3fb579 100644 --- a/hugegraph-loader/src/main/java/org/apache/hugegraph/loader/HugeGraphLoader.java +++ b/hugegraph-loader/src/main/java/org/apache/hugegraph/loader/HugeGraphLoader.java @@ -662,24 +662,21 @@ private List prepareTaskItems(List structs, } private void loadStructs(List structs) { - int parallelCount = this.context.options().parallelCount; + int parseThreads = this.context.options().parseThreads; if (structs.size() == 0) { return; } - if (parallelCount <= 0) { - parallelCount = Math.min(structs.size(), Runtime.getRuntime().availableProcessors() * 2); - } boolean scatter = this.context.options().scatterSources; - LOG.info("{} threads for loading {} structs, from {} to {} in {} mode", - parallelCount, structs.size(), this.context.options().startFile, + LOG.info("{} parser threads for loading {} structs, from {} to {} in {} mode", + parseThreads, structs.size(), this.context.options().startFile, this.context.options().endFile, scatter ? "scatter" : "sequential"); ExecutorService loadService = null; try { - loadService = ExecutorUtil.newFixedThreadPool(parallelCount, "loader"); + loadService = ExecutorUtil.newFixedThreadPool(parseThreads, "loader"); List taskItems = prepareTaskItems(structs, scatter); List> loadTasks = new ArrayList<>(); diff --git a/hugegraph-loader/src/main/java/org/apache/hugegraph/loader/executor/LoadOptions.java b/hugegraph-loader/src/main/java/org/apache/hugegraph/loader/executor/LoadOptions.java index 95babb557..f0ea30b7b 100644 --- a/hugegraph-loader/src/main/java/org/apache/hugegraph/loader/executor/LoadOptions.java +++ b/hugegraph-loader/src/main/java/org/apache/hugegraph/loader/executor/LoadOptions.java @@ -20,6 +20,7 @@ import java.io.File; import java.lang.reflect.Field; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Set; @@ -45,6 +46,8 @@ public final class LoadOptions implements Cloneable { public static final String HTTPS_SCHEMA = "https"; public static final String HTTP_SCHEMA = "http"; private static final int CPUS = Runtime.getRuntime().availableProcessors(); + private static final int DEFAULT_MAX_CONNECTIONS = CPUS * 4; + private static final int DEFAULT_MAX_CONNECTIONS_PER_ROUTE = CPUS * 2; private static final int MINIMUM_REQUIRED_ARGS = 3; @Parameter(names = {"-f", "--file"}, required = true, arity = 1, @@ -156,7 +159,7 @@ public final class LoadOptions implements Cloneable { @Parameter(names = {"--batch-insert-threads"}, arity = 1, validateWith = {PositiveValidator.class}, - description = "The number of threads to execute batch insert") + description = "The number of threads to execute batch insert (default: CPUS)") public int batchInsertThreads = CPUS; @Parameter(names = {"--single-insert-threads"}, arity = 1, @@ -165,21 +168,27 @@ public final class LoadOptions implements Cloneable { public int singleInsertThreads = 8; @Parameter(names = {"--max-conn"}, arity = 1, - description = "Max number of HTTP connections to server") - public int maxConnections = CPUS * 4; + validateWith = {PositiveValidator.class}, + description = "Max HTTP connections (default: CPUS*4; auto-adjusted by " + + "--batch-insert-threads)") + public int maxConnections = DEFAULT_MAX_CONNECTIONS; @Parameter(names = {"--max-conn-per-route"}, arity = 1, - description = "Max number of HTTP connections to each route") - public int maxConnectionsPerRoute = CPUS * 2; + validateWith = {PositiveValidator.class}, + description = "Max HTTP connections per route (default: CPUS*2; " + + "auto-adjusted by --batch-insert-threads)") + public int maxConnectionsPerRoute = DEFAULT_MAX_CONNECTIONS_PER_ROUTE; @Parameter(names = {"--batch-size"}, arity = 1, validateWith = {PositiveValidator.class}, description = "The number of lines in each submit") public int batchSize = 500; - @Parameter(names = {"--parallel-count"}, arity = 1, - description = "The number of parallel read pipelines") - public int parallelCount = 1; + @Parameter(names = {"--parallel-count", "--parser-threads"}, arity = 1, + validateWith = {PositiveValidator.class}, + description = "Parallel read pipelines (default: max(2, CPUS/2); " + + "--parallel-count is deprecated)") + public int parseThreads = Math.max(2, CPUS / 2); @Parameter(names = {"--start-file"}, arity = 1, description = "start file index for partial loading") @@ -329,6 +338,11 @@ public final class LoadOptions implements Cloneable { description = "The task scheduler type (when creating graph if not exists") public String schedulerType = "distributed"; + @Parameter(names = {"--batch-failure-fallback"}, arity = 1, + description = "Whether to fallback to single insert when batch insert fails. " + + "Default: true") + public boolean batchFailureFallback = true; + public String workModeString() { if (this.incrementalMode) { return "INCREMENTAL MODE"; @@ -406,9 +420,32 @@ public static LoadOptions parseOptions(String[] args) { options.maxParseErrors = Constants.NO_LIMIT; options.maxInsertErrors = Constants.NO_LIMIT; } + if (Arrays.asList(args).contains("--parallel-count")) { + LOG.warn("Parameter --parallel-count is deprecated, " + + "please use --parser-threads instead"); + } + adjustConnectionPoolIfDefault(options); return options; } + private static void adjustConnectionPoolIfDefault(LoadOptions options) { + int batchThreads = options.batchInsertThreads; + int maxConn = options.maxConnections; + int maxConnPerRoute = options.maxConnectionsPerRoute; + + if (maxConn == DEFAULT_MAX_CONNECTIONS && maxConn < batchThreads * 4) { + options.maxConnections = batchThreads * 4; + LOG.info("Auto adjusted max-conn to {} based on batch-insert-threads({})", + options.maxConnections, batchThreads); + } + + if (maxConnPerRoute == DEFAULT_MAX_CONNECTIONS_PER_ROUTE && maxConnPerRoute < batchThreads * 2) { + options.maxConnectionsPerRoute = batchThreads * 2; + LOG.info("Auto adjusted max-conn-per-route to {} based on batch-insert-threads({})", + options.maxConnectionsPerRoute, batchThreads); + } + } + public ShortIdConfig getShortIdConfig(String vertexLabel) { for (ShortIdConfig config: shorterIDConfigs) { if (config.getVertexLabel().equals(vertexLabel)) { diff --git a/hugegraph-loader/src/main/java/org/apache/hugegraph/loader/task/TaskManager.java b/hugegraph-loader/src/main/java/org/apache/hugegraph/loader/task/TaskManager.java index ce4d77a92..7d0793955 100644 --- a/hugegraph-loader/src/main/java/org/apache/hugegraph/loader/task/TaskManager.java +++ b/hugegraph-loader/src/main/java/org/apache/hugegraph/loader/task/TaskManager.java @@ -26,6 +26,7 @@ import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; +import org.apache.hugegraph.loader.util.Printer; import org.slf4j.Logger; import org.apache.hugegraph.loader.builder.Record; @@ -140,6 +141,10 @@ public void submitBatch(InputStruct struct, ElementMapping mapping, long start = System.currentTimeMillis(); try { this.batchSemaphore.acquire(); + if (this.context.stopped()) { + this.batchSemaphore.release(); + return; + } } catch (InterruptedException e) { throw new LoadException("Interrupted while waiting to submit %s " + "batch in batch mode", e, mapping.type()); @@ -152,10 +157,18 @@ public void submitBatch(InputStruct struct, ElementMapping mapping, CompletableFuture.runAsync(task, this.batchService).whenComplete( (r, e) -> { if (e != null) { - LOG.warn("Batch insert {} error, try single insert", - mapping.type(), e); - // The time of single insert is counted separately - this.submitInSingle(struct, mapping, batch); + if (this.options.batchFailureFallback) { + LOG.warn("Batch insert {} error, try single insert", + mapping.type(), e); + this.submitInSingle(struct, mapping, batch); + } else { + summary.metrics(struct).minusFlighting(batch.size()); + this.context.occurredError(); + this.context.stopLoading(); + LOG.error("Batch insert {} error, interrupting import", mapping.type(), e); + Printer.printError("Batch insert %s failed, stop loading. Please check the logs", + mapping.type().string()); + } } else { summary.metrics(struct).minusFlighting(batch.size()); } diff --git a/hugegraph-loader/src/test/java/org/apache/hugegraph/loader/test/functional/FileLoadTest.java b/hugegraph-loader/src/test/java/org/apache/hugegraph/loader/test/functional/FileLoadTest.java index d069aaecf..5be6a61ea 100644 --- a/hugegraph-loader/src/test/java/org/apache/hugegraph/loader/test/functional/FileLoadTest.java +++ b/hugegraph-loader/src/test/java/org/apache/hugegraph/loader/test/functional/FileLoadTest.java @@ -1197,7 +1197,9 @@ public void testMultiFilesHaveHeader() { "-s", configPath("multi_files_have_header/schema.groovy"), "-g", GRAPH, "-h", SERVER, - "--test-mode", "true" + "--test-mode", "true", + // FIXME: Set parser-threads to 1 because values > 1 currently trigger a NullPointerException (NPE). + "--parser-threads", "1" }; loadWithAuth(args); @@ -1332,7 +1334,9 @@ public void testDirHasMultiFiles() { "-s", configPath("dir_has_multi_files/schema.groovy"), "-g", GRAPH, "-h", SERVER, - "--test-mode", "true" + "--test-mode", "true", + // FIXME: Set parser-threads to 1 because values > 1 currently trigger a NullPointerException (NPE). + "--parser-threads", "1" }; loadWithAuth(args); @@ -1628,7 +1632,9 @@ public void testFilterPathBySuffix() { "-s", configPath("filter_path_by_suffix/schema.groovy"), "-g", GRAPH, "-h", SERVER, - "--test-mode", "true" + "--test-mode", "true", + // FIXME: Set parser-threads to 1 because values > 1 currently trigger a NullPointerException (NPE). + "--parser-threads", "1" }; loadWithAuth(args); @@ -2058,7 +2064,8 @@ public void testLoadIncrementalModeAndLoadFailure() "-h", SERVER, "--batch-insert-threads", "2", "--max-parse-errors", "1", - "--test-mode", "false" + "--test-mode", "false", + "--parser-threads", "1" )); argsList.addAll(Arrays.asList("--username", "admin", "--password", "pa")); @@ -2259,7 +2266,8 @@ public void testReloadJsonFailureFiles() throws IOException, "-h", SERVER, "--check-vertex", "true", "--batch-insert-threads", "2", - "--test-mode", "false" + "--test-mode", "false", + "--parser-threads", "1" )); argsList.addAll(Arrays.asList("--username", "admin", "--password", "pa")); HugeGraphLoader loader = new HugeGraphLoader(argsList.toArray(new String[0])); @@ -2564,7 +2572,8 @@ public void testSourceOrTargetPrimaryValueNull() { "-g", GRAPH, "-h", SERVER, "--batch-insert-threads", "2", - "--test-mode", "true" + "--test-mode", "true", + "--parser-threads", "1" )); argsList.addAll(Arrays.asList("--username", "admin", "--password", "pa")); @@ -3047,7 +3056,8 @@ public void testReadReachedMaxLines() { "-h", SERVER, "--max-read-lines", "4", "--batch-insert-threads", "2", - "--test-mode", "true" + "--test-mode", "true", + "--parser-threads", "1" }; loadWithAuth(args); @@ -3061,7 +3071,8 @@ public void testReadReachedMaxLines() { "-h", SERVER, "--max-read-lines", "6", "--batch-insert-threads", "2", - "--test-mode", "true" + "--test-mode", "true", + "--parser-threads", "1" }; loadWithAuth(args); diff --git a/hugegraph-loader/src/test/java/org/apache/hugegraph/loader/test/functional/KafkaLoadTest.java b/hugegraph-loader/src/test/java/org/apache/hugegraph/loader/test/functional/KafkaLoadTest.java index c6c31520a..b44ffbd8f 100644 --- a/hugegraph-loader/src/test/java/org/apache/hugegraph/loader/test/functional/KafkaLoadTest.java +++ b/hugegraph-loader/src/test/java/org/apache/hugegraph/loader/test/functional/KafkaLoadTest.java @@ -85,7 +85,8 @@ public void testCustomizedSchema() { "-h", SERVER, "-p", String.valueOf(PORT), "--batch-insert-threads", "2", - "--test-mode", "true" + "--test-mode", "true", + "--parser-threads", "1" }; loadWithAuth(args); diff --git a/hugegraph-loader/src/test/java/org/apache/hugegraph/loader/test/unit/LoadOptionsTest.java b/hugegraph-loader/src/test/java/org/apache/hugegraph/loader/test/unit/LoadOptionsTest.java new file mode 100644 index 000000000..b327f59ef --- /dev/null +++ b/hugegraph-loader/src/test/java/org/apache/hugegraph/loader/test/unit/LoadOptionsTest.java @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with this + * work for additional information regarding copyright ownership. The ASF + * licenses this file to You under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package org.apache.hugegraph.loader.test.unit; + +import java.io.File; +import java.io.FileWriter; +import java.lang.reflect.Field; +import java.lang.reflect.Method; + +import org.apache.log4j.AppenderSkeleton; +import org.apache.log4j.Level; +import org.apache.log4j.Logger; +import org.apache.log4j.spi.LoggingEvent; + +import org.apache.hugegraph.loader.executor.LoadOptions; +import org.junit.Test; + +import org.apache.hugegraph.testutil.Assert; + +public class LoadOptionsTest { + + @Test + public void testConnectionPoolAutoAdjustWithDefaultBatchThreads() throws Exception { + int cpus = readStaticInt(LoadOptions.class, "CPUS"); + LoadOptions options = new LoadOptions(); + + Assert.assertEquals(cpus * 4, options.maxConnections); + Assert.assertEquals(cpus * 2, options.maxConnectionsPerRoute); + } + + @Test + public void testConnectionPoolAutoAdjustWithCustomBatchThreads() throws Exception { + int cpus = readStaticInt(LoadOptions.class, "CPUS"); + int defaultMaxConn = readStaticInt(LoadOptions.class, "DEFAULT_MAX_CONNECTIONS"); + int defaultMaxConnPerRoute = readStaticInt(LoadOptions.class, + "DEFAULT_MAX_CONNECTIONS_PER_ROUTE"); + LoadOptions options = new LoadOptions(); + options.batchInsertThreads = 20; + + CapturingAppender appender = attachAppender(); + try { + invokeAdjustConnectionPool(options); + } finally { + detachAppender(appender); + } + + int expectedMaxConn = defaultMaxConn; + int expectedMaxConnPerRoute = defaultMaxConnPerRoute; + if (defaultMaxConn == cpus * 4 && defaultMaxConn < 80) { + expectedMaxConn = 80; + } + if (defaultMaxConnPerRoute == cpus * 2 && defaultMaxConnPerRoute < 40) { + expectedMaxConnPerRoute = 40; + } + + Assert.assertEquals(expectedMaxConn, options.maxConnections); + Assert.assertEquals(expectedMaxConnPerRoute, options.maxConnectionsPerRoute); + if (expectedMaxConn == 80 || expectedMaxConnPerRoute == 40) { + Assert.assertTrue(appender.contains("Auto adjusted max-conn")); + } + } + + @Test + public void testConnectionPoolNoAdjustWithCustomMaxConn() throws Exception { + LoadOptions options = new LoadOptions(); + options.batchInsertThreads = 20; + options.maxConnections = 100; + options.maxConnectionsPerRoute = 50; + + CapturingAppender appender = attachAppender(); + try { + invokeAdjustConnectionPool(options); + } finally { + detachAppender(appender); + } + + Assert.assertEquals(100, options.maxConnections); + Assert.assertEquals(50, options.maxConnectionsPerRoute); + Assert.assertFalse(appender.contains("Auto adjusted max-conn")); + } + + @Test + public void testParseThreadsMinValue() { + LoadOptions.PositiveValidator validator = + new LoadOptions.PositiveValidator(); + + validator.validate("--parser-threads", "1"); + + Assert.assertTrue(validateFails(validator, "--parser-threads", "0")); + Assert.assertTrue(validateFails(validator, "--parser-threads", "-1")); + } + + @Test + public void testParseThreadsDefaultValue() throws Exception { + int cpus = readStaticInt(LoadOptions.class, "CPUS"); + LoadOptions options = new LoadOptions(); + Assert.assertEquals(Math.max(2, cpus / 2), options.parseThreads); + } + + @Test + public void testDeprecatedParallelCountParameter() throws Exception { + File mapping = createTempMapping(); + String[] args = new String[]{ + "-f", mapping.getPath(), + "-g", "g", + "-h", "localhost", + "--parallel-count", "4" + }; + + CapturingAppender appender = attachAppender(); + try { + LoadOptions options = LoadOptions.parseOptions(args); + Assert.assertEquals(4, options.parseThreads); + Assert.assertTrue(appender.contains("deprecated")); + } finally { + detachAppender(appender); + mapping.delete(); + } + } + + private static int readStaticInt(Class type, String name) + throws Exception { + Field field = type.getDeclaredField(name); + field.setAccessible(true); + return field.getInt(null); + } + + private static void invokeAdjustConnectionPool(LoadOptions options) + throws Exception { + Method method = LoadOptions.class + .getDeclaredMethod("adjustConnectionPoolIfDefault", + LoadOptions.class); + method.setAccessible(true); + method.invoke(null, options); + } + + private static boolean validateFails(LoadOptions.PositiveValidator validator, + String name, String value) { + try { + validator.validate(name, value); + return false; + } catch (Exception ignored) { + return true; + } + } + + private static File createTempMapping() throws Exception { + File file = File.createTempFile("load-options-", ".json", new File(".")); + try (FileWriter writer = new FileWriter(file)) { + writer.write("{\"version\":\"2.0\",\"structs\":[]}"); + } + return file; + } + + private static CapturingAppender attachAppender() { + Logger logger = Logger.getLogger(LoadOptions.class.getName()); + CapturingAppender appender = new CapturingAppender(); + appender.setThreshold(Level.INFO); + logger.addAppender(appender); + return appender; + } + + private static void detachAppender(CapturingAppender appender) { + if (appender == null) { + return; + } + Logger logger = Logger.getLogger(LoadOptions.class.getName()); + logger.removeAppender(appender); + } + + private static final class CapturingAppender extends AppenderSkeleton { + + private final StringBuilder buffer = new StringBuilder(); + + @Override + protected void append(LoggingEvent event) { + if (event == null || event.getRenderedMessage() == null) { + return; + } + buffer.append(event.getRenderedMessage()).append('\n'); + } + + boolean contains(String text) { + return this.buffer.toString().contains(text); + } + + @Override + public void close() { + // No-op. + } + + @Override + public boolean requiresLayout() { + return false; + } + } + +} diff --git a/hugegraph-loader/src/test/java/org/apache/hugegraph/loader/test/unit/TaskManagerTest.java b/hugegraph-loader/src/test/java/org/apache/hugegraph/loader/test/unit/TaskManagerTest.java new file mode 100644 index 000000000..08b6b9043 --- /dev/null +++ b/hugegraph-loader/src/test/java/org/apache/hugegraph/loader/test/unit/TaskManagerTest.java @@ -0,0 +1,536 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with this + * work for additional information regarding copyright ownership. The ASF + * licenses this file to You under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package org.apache.hugegraph.loader.test.unit; + +import java.io.ByteArrayOutputStream; +import java.io.PrintStream; +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.LongAdder; + +import org.apache.hugegraph.driver.GraphManager; +import org.apache.hugegraph.driver.HugeClient; +import org.apache.hugegraph.loader.builder.Record; +import org.apache.hugegraph.loader.executor.LoadContext; +import org.apache.hugegraph.loader.executor.LoadOptions; +import org.apache.hugegraph.loader.mapping.EdgeMapping; +import org.apache.hugegraph.loader.mapping.InputStruct; +import org.apache.hugegraph.loader.metrics.LoadMetrics; +import org.apache.hugegraph.loader.metrics.LoadSummary; +import org.apache.hugegraph.loader.progress.LoadProgress; +import org.apache.hugegraph.loader.task.TaskManager; +import org.apache.hugegraph.structure.graph.Edge; +import org.junit.Test; + +import org.apache.hugegraph.testutil.Assert; + +public class TaskManagerTest { + + @Test + public void testBatchInsertFailureWithFallbackDisabled() throws Exception { + LoadOptions options = new LoadOptions(); + options.batchFailureFallback = false; + Assert.assertFalse(options.batchFailureFallback); + + LoadContext context = newTestContext(options); + TaskManager taskManager = new TaskManager(context); + + EdgeMapping mapping = new EdgeMapping(Arrays.asList("s"), false, + Arrays.asList("t"), false); + mapping.label("knows"); + + InputStruct struct = new InputStruct(new ArrayList<>(), + new ArrayList<>()); + struct.id("1"); + struct.add(mapping); + + LoadSummary summary = context.summary(); + summary.inputMetricsMap() + .put(struct.id(), new LoadMetrics(struct)); + LoadMetrics metrics = summary.metrics(struct); + + setField(context.client(), "graph", newFailingBatchGraphManager()); + + List batch = new ArrayList<>(); + batch.add(new Record("line1", new Edge("knows"))); + batch.add(new Record("line2", new Edge("knows"))); + + ByteArrayOutputStream errOutput = new ByteArrayOutputStream(); + PrintStream originalErr = System.err; + System.setErr(new PrintStream(errOutput, true, + StandardCharsets.UTF_8.name())); + try { + taskManager.submitBatch(struct, mapping, batch); + taskManager.waitFinished(); + + Assert.assertEquals(0L, flightingCount(metrics)); + Assert.assertTrue(context.stopped()); + Assert.assertFalse(context.noError()); + + String errText = errOutput.toString(StandardCharsets.UTF_8.name()); + Assert.assertTrue(errText.contains( + "Batch insert edges failed, stop loading.")); + + long before = flightingCount(metrics); + taskManager.submitBatch(struct, mapping, batch); + taskManager.waitFinished(); + Assert.assertEquals(before, flightingCount(metrics)); + } finally { + System.setErr(originalErr); + taskManager.shutdown(); + } + } + + @Test + public void testBatchInsertFailureWithFallbackEnabled() throws Exception { + LoadOptions options = new LoadOptions(); + options.batchFailureFallback = true; + + LoadContext context = newTestContext(options); + TaskManager taskManager = new TaskManager(context); + + EdgeMapping mapping = new EdgeMapping(Arrays.asList("s"), false, + Arrays.asList("t"), false); + mapping.label("knows"); + + InputStruct struct = new InputStruct(new ArrayList<>(), + new ArrayList<>()); + struct.id("1"); + struct.add(mapping); + + LoadSummary summary = context.summary(); + summary.inputMetricsMap() + .put(struct.id(), new LoadMetrics(struct)); + LoadMetrics metrics = summary.metrics(struct); + + FailingBatchGraphManager.BATCH_CALLS.set(0); + FailingBatchGraphManager.SINGLE_CALLS.set(0); + setField(context.client(), "graph", newFailingBatchGraphManager()); + + List batch = new ArrayList<>(); + batch.add(new Record("line1", new Edge("knows"))); + batch.add(new Record("line2", new Edge("knows"))); + + try { + taskManager.submitBatch(struct, mapping, batch); + taskManager.waitFinished(); + + Assert.assertEquals(1, FailingBatchGraphManager.BATCH_CALLS.get()); + Assert.assertEquals(2, FailingBatchGraphManager.SINGLE_CALLS.get()); + Assert.assertEquals(0L, flightingCount(metrics)); + Assert.assertFalse(context.stopped()); + Assert.assertTrue(context.noError()); + } finally { + taskManager.shutdown(); + } + } + + @Test + public void testMultipleBatchFailuresCounterConsistency() throws Exception { + LoadOptions options = new LoadOptions(); + options.batchFailureFallback = true; + + LoadContext context = newTestContext(options); + TaskManager taskManager = new TaskManager(context); + + EdgeMapping mapping = new EdgeMapping(Arrays.asList("s"), false, + Arrays.asList("t"), false); + mapping.label("knows"); + + InputStruct struct = new InputStruct(new ArrayList<>(), + new ArrayList<>()); + struct.id("1"); + struct.add(mapping); + + LoadSummary summary = context.summary(); + summary.inputMetricsMap() + .put(struct.id(), new LoadMetrics(struct)); + LoadMetrics metrics = summary.metrics(struct); + + FailingBatchGraphManager.BATCH_CALLS.set(0); + FailingBatchGraphManager.SINGLE_CALLS.set(0); + setField(context.client(), "graph", newFailingBatchGraphManager()); + + List batch1 = new ArrayList<>(); + batch1.add(new Record("line1", new Edge("knows"))); + batch1.add(new Record("line2", new Edge("knows"))); + + List batch2 = new ArrayList<>(); + batch2.add(new Record("line3", new Edge("knows"))); + batch2.add(new Record("line4", new Edge("knows"))); + + try { + taskManager.submitBatch(struct, mapping, batch1); + taskManager.submitBatch(struct, mapping, batch2); + taskManager.waitFinished(); + + Assert.assertEquals(2, FailingBatchGraphManager.BATCH_CALLS.get()); + Assert.assertEquals(4, FailingBatchGraphManager.SINGLE_CALLS.get()); + Assert.assertEquals(0L, flightingCount(metrics)); + Assert.assertFalse(context.stopped()); + Assert.assertTrue(context.noError()); + + int expectedBatchPermits = 1 + options.batchInsertThreads; + int expectedSinglePermits = 2 * options.singleInsertThreads; + Assert.assertEquals(expectedBatchPermits, + getSemaphorePermits(taskManager, "batchSemaphore")); + Assert.assertEquals(expectedSinglePermits, + getSemaphorePermits(taskManager, "singleSemaphore")); + } finally { + taskManager.shutdown(); + } + } + + @Test + public void testConcurrentSubmitWhenStopping() throws Exception { + LoadOptions options = new LoadOptions(); + options.batchFailureFallback = false; + options.batchInsertThreads = 2; + options.singleInsertThreads = 1; + + LoadContext context = newTestContext(options); + TaskManager taskManager = new TaskManager(context); + + EdgeMapping mapping = new EdgeMapping(Arrays.asList("s"), false, + Arrays.asList("t"), false); + mapping.label("knows"); + + InputStruct struct = new InputStruct(new ArrayList<>(), + new ArrayList<>()); + struct.id("1"); + struct.add(mapping); + + LoadSummary summary = context.summary(); + summary.inputMetricsMap() + .put(struct.id(), new LoadMetrics(struct)); + LoadMetrics metrics = summary.metrics(struct); + + CountDownLatch firstStarted = new CountDownLatch(1); + CountDownLatch allowFirstFinish = new CountDownLatch(1); + CountDownLatch failureCalled = new CountDownLatch(1); + FailingConcurrentGraphManager.BATCH_CALLS.set(0); + FailingConcurrentGraphManager.FIRST_STARTED = firstStarted; + FailingConcurrentGraphManager.ALLOW_FIRST_FINISH = allowFirstFinish; + FailingConcurrentGraphManager.FAILURE_CALLED = failureCalled; + setField(context.client(), "graph", newFailingConcurrentGraphManager()); + + List batch = new ArrayList<>(); + batch.add(new Record("line1", new Edge("knows"))); + batch.add(new Record("line2", new Edge("knows"))); + + ExecutorService executor = Executors.newFixedThreadPool(10); + List> futures = new ArrayList<>(); + try { + for (int i = 0; i < 10; i++) { + futures.add(executor.submit(() -> { + taskManager.submitBatch(struct, mapping, batch); + })); + } + + Assert.assertTrue(firstStarted.await(5, TimeUnit.SECONDS)); + Assert.assertTrue(failureCalled.await(5, TimeUnit.SECONDS)); + waitStopped(context, 5, TimeUnit.SECONDS); + allowFirstFinish.countDown(); + + for (Future future : futures) { + future.get(5, TimeUnit.SECONDS); + } + + taskManager.waitFinished(); + + int batchCalls = FailingConcurrentGraphManager.BATCH_CALLS.get(); + Assert.assertTrue(batchCalls >= 2 && batchCalls <= 3); + Assert.assertEquals(0L, flightingCount(metrics)); + Assert.assertTrue(context.stopped()); + Assert.assertFalse(context.noError()); + + long before = FailingConcurrentGraphManager.BATCH_CALLS.get(); + taskManager.submitBatch(struct, mapping, batch); + taskManager.waitFinished(); + Assert.assertEquals(before, FailingConcurrentGraphManager.BATCH_CALLS.get()); + + int expectedBatchPermits = 1 + options.batchInsertThreads; + int expectedSinglePermits = 2 * options.singleInsertThreads; + Assert.assertEquals(expectedBatchPermits, + getSemaphorePermits(taskManager, "batchSemaphore")); + Assert.assertEquals(expectedSinglePermits, + getSemaphorePermits(taskManager, "singleSemaphore")); + } finally { + allowFirstFinish.countDown(); + executor.shutdownNow(); + taskManager.shutdown(); + } + } + + @Test + public void testStopCheckTimingInSubmitBatch() throws Exception { + LoadOptions options = new LoadOptions(); + options.batchFailureFallback = false; + options.batchInsertThreads = 1; + options.singleInsertThreads = 1; + + LoadContext context = newTestContext(options); + TaskManager taskManager = new TaskManager(context); + + EdgeMapping mapping = new EdgeMapping(Arrays.asList("s"), false, + Arrays.asList("t"), false); + mapping.label("knows"); + + InputStruct struct = new InputStruct(new ArrayList<>(), + new ArrayList<>()); + struct.id("1"); + struct.add(mapping); + + LoadSummary summary = context.summary(); + summary.inputMetricsMap() + .put(struct.id(), new LoadMetrics(struct)); + LoadMetrics metrics = summary.metrics(struct); + + setField(context.client(), "graph", newSimpleGraphManager()); + + List batch = new ArrayList<>(); + batch.add(new Record("line1", new Edge("knows"))); + batch.add(new Record("line2", new Edge("knows"))); + + ExecutorService executor = Executors.newFixedThreadPool(2); + try { + taskManager.submitBatch(struct, mapping, batch); + taskManager.waitFinished(); + + Semaphore semaphore = getSemaphore(taskManager, "batchSemaphore"); + semaphore.acquire(); + + Future blocked = executor.submit(() -> { + taskManager.submitBatch(struct, mapping, batch); + }); + + Thread.sleep(50); + context.stopLoading(); + semaphore.release(); + + blocked.get(5, TimeUnit.SECONDS); + + taskManager.waitFinished(); + + Assert.assertTrue(context.stopped()); + Assert.assertEquals(0L, flightingCount(metrics)); + int expectedPermits = 1 + options.batchInsertThreads; + Assert.assertEquals(expectedPermits, + getSemaphorePermits(taskManager, "batchSemaphore")); + } finally { + executor.shutdownNow(); + taskManager.shutdown(); + } + } + + private static void waitStopped(LoadContext context, long timeout, + TimeUnit unit) throws Exception { + long deadline = System.nanoTime() + unit.toNanos(timeout); + while (!context.stopped() && System.nanoTime() < deadline) { + Thread.sleep(10); + } + Assert.assertTrue(context.stopped()); + } + + private static long flightingCount(LoadMetrics metrics) + throws Exception { + Field field = LoadMetrics.class.getDeclaredField("flightingNums"); + field.setAccessible(true); + LongAdder adder = (LongAdder) field.get(metrics); + return adder.longValue(); + } + + private static LoadContext newTestContext(LoadOptions options) + throws Exception { + LoadContext context = (LoadContext) allocateInstance(LoadContext.class); + setField(context, "timestamp", "test"); + setField(context, "closed", false); + setField(context, "stopped", false); + setField(context, "noError", true); + setField(context, "options", options); + setField(context, "summary", new LoadSummary()); + setField(context, "oldProgress", new LoadProgress()); + setField(context, "newProgress", new LoadProgress()); + setField(context, "loggers", new ConcurrentHashMap<>()); + + HugeClient client = (HugeClient) allocateInstance(HugeClient.class); + setField(context, "client", client); + setField(context, "indirectClient", client); + setField(context, "schemaCache", null); + setField(context, "parseGroup", null); + return context; + } + + private static Object allocateInstance(Class type) throws Exception { + Object unsafe = unsafe(); + Method method = unsafe.getClass() + .getMethod("allocateInstance", Class.class); + return method.invoke(unsafe, type); + } + + private static Object unsafe() throws Exception { + Class unsafeClass; + try { + unsafeClass = Class.forName("sun.misc.Unsafe"); + } catch (ClassNotFoundException e) { + unsafeClass = Class.forName("jdk.internal.misc.Unsafe"); + } + Field field = unsafeClass.getDeclaredField("theUnsafe"); + field.setAccessible(true); + return field.get(null); + } + + private static void setField(Object target, String name, Object value) + throws Exception { + Field field = target.getClass().getDeclaredField(name); + field.setAccessible(true); + field.set(target, value); + } + + private static int getSemaphorePermits(Object target, String name) + throws Exception { + Field field = target.getClass().getDeclaredField(name); + field.setAccessible(true); + Semaphore semaphore = (Semaphore) field.get(target); + return semaphore.availablePermits(); + } + + private static Semaphore getSemaphore(Object target, String name) + throws Exception { + Field field = target.getClass().getDeclaredField(name); + field.setAccessible(true); + return (Semaphore) field.get(target); + } + + private static GraphManager newFailingConcurrentGraphManager() + throws Exception { + return (GraphManager) allocateInstance(FailingConcurrentGraphManager.class); + } + + private static GraphManager newFailingBatchGraphManager() throws Exception { + return (GraphManager) allocateInstance(FailingBatchGraphManager.class); + } + + private static GraphManager newSimpleGraphManager() throws Exception { + return (GraphManager) allocateInstance(SimpleGraphManager.class); + } + + private static final class SimpleGraphManager extends GraphManager { + + private SimpleGraphManager() { + super(null, null, null); + } + + @Override + public List addEdges(List edges, boolean checkVertex) { + return this.addEdges(edges); + } + + @Override + public List addEdges(List edges) { + return edges; + } + } + + private static final class FailingConcurrentGraphManager extends GraphManager { + + private static final AtomicInteger BATCH_CALLS = new AtomicInteger(); + private static volatile CountDownLatch FIRST_STARTED; + private static volatile CountDownLatch ALLOW_FIRST_FINISH; + private static volatile CountDownLatch FAILURE_CALLED; + + private FailingConcurrentGraphManager() { + super(null, null, null); + } + + @Override + public List addEdges(List edges, boolean checkVertex) { + return this.addEdges(edges); + } + + @Override + public List addEdges(List edges) { + int call = BATCH_CALLS.incrementAndGet(); + if (call == 1) { + CountDownLatch started = FIRST_STARTED; + if (started != null) { + started.countDown(); + } + await(ALLOW_FIRST_FINISH); + return edges; + } + if (call == 2) { + CountDownLatch failed = FAILURE_CALLED; + if (failed != null) { + failed.countDown(); + } + throw new RuntimeException("batch insert failure"); + } + return edges; + } + + private void await(CountDownLatch latch) { + if (latch == null) { + return; + } + try { + latch.await(5, TimeUnit.SECONDS); + } catch (InterruptedException ignored) { + // Let the task finish on interruption. + Thread.currentThread().interrupt(); + } + } + } + + private static final class FailingBatchGraphManager extends GraphManager { + + private static final AtomicInteger BATCH_CALLS = new AtomicInteger(); + private static final AtomicInteger SINGLE_CALLS = new AtomicInteger(); + + private FailingBatchGraphManager() { + super(null, null, null); + } + + @Override + public List addEdges(List edges, boolean checkVertex) { + return this.addEdges(edges); + } + + @Override + public List addEdges(List edges) { + if (edges.size() > 1) { + BATCH_CALLS.incrementAndGet(); + throw new RuntimeException("batch insert failure"); + } + SINGLE_CALLS.addAndGet(edges.size()); + return edges; + } + } +}