diff --git a/extended-it/src/test/java/apoc/neo4j/docker/CypherEnterpriseExtendedTest.java b/extended-it/src/test/java/apoc/neo4j/docker/CypherEnterpriseExtendedTest.java index a965cf46dd..81fb2896bd 100644 --- a/extended-it/src/test/java/apoc/neo4j/docker/CypherEnterpriseExtendedTest.java +++ b/extended-it/src/test/java/apoc/neo4j/docker/CypherEnterpriseExtendedTest.java @@ -31,7 +31,6 @@ import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; -@Ignore public class CypherEnterpriseExtendedTest { private static final String CREATE_RETURNQUERY_NODES = "UNWIND range(0,3) as id \n" + "CREATE (n:ReturnQuery {id:id})-[:REL {idRel: id}]->(:Other {idOther: id})"; diff --git a/extended/src/main/java/apoc/cypher/CypherExtended.java b/extended/src/main/java/apoc/cypher/CypherExtended.java index ae3162f085..c5c459ab08 100644 --- a/extended/src/main/java/apoc/cypher/CypherExtended.java +++ b/extended/src/main/java/apoc/cypher/CypherExtended.java @@ -38,6 +38,7 @@ import java.util.concurrent.BlockingQueue; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Consumer; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -262,25 +263,43 @@ private static boolean isCommentOrEmpty(String stmt) { private final static Pattern shellControl = Pattern.compile("^:?\\b(begin|commit|rollback)\\b", Pattern.CASE_INSENSITIVE); - private Object consumeResult(Result result, BlockingQueue queue, boolean addStatistics, Transaction tx, String fileName) { + private Object consumeResult(Result result, BlockingQueue queue, boolean addStatistics, Transaction transaction, String fileName) { try { long time = System.currentTimeMillis(); int row = 0; - while (result.hasNext()) { + AtomicBoolean closed = new AtomicBoolean(false); + while (isOpenAndHasNext(result, closed)) { terminationGuard.check(); - Map res = EntityUtil.anyRebind(tx, result.next()); + Map res = EntityUtil.anyRebind(transaction, result.next()); queue.put(new RowResult(row++, res, fileName)); } if (addStatistics) { Map mapResult = toMap(result.getQueryStatistics(), System.currentTimeMillis() - time, row); queue.put(new RowResult(-1, mapResult, fileName)); } + if (closed.get()) { + queue.put(RowResult.TOMBSTONE); + return null; + } return row; } catch (InterruptedException e) { throw new RuntimeException(e); } } + /** + * If the transaction is closed, result.hasNext() will throw an error. + * In that case, we set closed = true, to put a RowResult.TOMBSTONE and terminate the iteration + */ + private static boolean isOpenAndHasNext(Result result, AtomicBoolean closed) { + try { + return result.hasNext(); + } catch (Exception e) { + closed.set(true); + return false; + } + } + private String removeShellControlCommands(String stmt) { Matcher matcher = shellControl.matcher(stmt.trim()); if (matcher.find()) { @@ -389,6 +408,7 @@ public Stream mapParallel(@Name("fragment") String fragment, @Name("p .flatMap((partition) -> Iterators.asList(tx.execute(statement, parallelParams(params, "_", partition))).stream()) .map(MapResult::new); } + @Procedure @Description("apoc.cypher.mapParallel2(fragment, params, list-to-parallelize) yield value - executes fragment in parallel batches with the list segments being assigned to _") public Stream mapParallel2(@Name("fragment") String fragment, @Name("params") Map params, @Name("list") List data, @Name("partitions") long partitions,@Name(value = "timeout",defaultValue = "10") long timeout) { @@ -397,25 +417,31 @@ public Stream mapParallel2(@Name("fragment") String fragment, @Name(" int queueCapacity = 100000; BlockingQueue queue = new ArrayBlockingQueue<>(queueCapacity); ArrayBlockingQueue transactions = new ArrayBlockingQueue<>(queueCapacity); + ArrayBlockingQueue results = new ArrayBlockingQueue<>(queueCapacity); Stream> parallelPartitions = Util.partitionSubList(data, (int)(partitions <= 0 ? PARTITIONS : partitions), null); Util.inFuture(pools, () -> { long total = parallelPartitions - .map((List partition) -> { - Transaction transaction = db.beginTx(); - transactions.add(transaction); - try (Result result = transaction.execute(statement, parallelParams(params, "_", partition))) { - return consumeResult(result, queue, false, transaction, null); - } catch (Exception e) { - throw new RuntimeException(e); - }} + .map((List partition) -> { + Transaction transaction = db.beginTx(); + transactions.add(transaction); + Result result = transaction.execute(statement, parallelParams(params, "_", partition)); + results.add(result); + try { + return consumeResult(result, queue, false, transaction, null); + } catch (Exception e) { + throw new RuntimeException(e); + }} ).count(); queue.put(RowResult.TOMBSTONE); return total; }); - + return StreamSupport.stream(new QueueBasedSpliterator<>(queue, RowResult.TOMBSTONE, terminationGuard, (int)timeout),true) .map(rowResult -> new MapResult(rowResult.result)) - .onClose(() -> transactions.forEach(Transaction::close)); + .onClose(() -> { + transactions.forEach(i -> Util.close(i)); + results.forEach(i -> Util.close(i)); + }); } public Map parallelParams(@Name("params") Map params, String key, List partition) { diff --git a/extended/src/test/java/apoc/cypher/CypherExtendedTest.java b/extended/src/test/java/apoc/cypher/CypherExtendedTest.java index aad5dd2b97..87f50d83fb 100644 --- a/extended/src/test/java/apoc/cypher/CypherExtendedTest.java +++ b/extended/src/test/java/apoc/cypher/CypherExtendedTest.java @@ -49,7 +49,6 @@ * @since 08.05.16 */ -@Ignore public class CypherExtendedTest { public static final String IMPORT_DIR = "src/test/resources"; @ClassRule