diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java b/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java index e404e2b8152..d9bbb56be55 100644 --- a/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java @@ -243,6 +243,35 @@ default AsyncRunnable thenRunRetryingWhile( }); } + /** + * This method is equivalent to a while loop, where the condition is checked before each iteration. + * If the condition returns {@code false} on the first check, the body is never executed. + * + * @param loopBodyRunnable the asynchronous task to be executed in each iteration of the loop + * @param whileCheck a condition to check before each iteration; the loop continues as long as this condition returns true + * @return the composition of this and the looping branch + * @see AsyncCallbackLoop + */ + default AsyncRunnable thenRunWhileLoop(final BooleanSupplier whileCheck, final AsyncRunnable loopBodyRunnable) { + return thenRun(finalCallback -> { + LoopState loopState = new LoopState(); + new AsyncCallbackLoop(loopState, iterationCallback -> { + + if (loopState.breakAndCompleteIf(() -> !whileCheck.getAsBoolean(), iterationCallback)) { + return; + } + loopBodyRunnable.finish((result, t) -> { + if (t != null) { + iterationCallback.completeExceptionally(t); + return; + } + iterationCallback.complete(iterationCallback); + }); + + }).run(finalCallback); + }); + } + /** * This method is equivalent to a do-while loop, where the loop body is executed first and * then the condition is checked to determine whether the loop should continue. diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncTrampoline.java b/driver-core/src/main/com/mongodb/internal/async/AsyncTrampoline.java new file mode 100644 index 00000000000..5fc074b7008 --- /dev/null +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncTrampoline.java @@ -0,0 +1,91 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.async; + +import com.mongodb.annotations.NotThreadSafe; +import com.mongodb.lang.Nullable; + +/** + * A trampoline that converts recursive callback invocations into an iterative loop, + * preventing stack overflow in async loops. + * + *

When async loop iterations complete synchronously on the same thread, callback + * recursion occurs: each iteration's {@code callback.onResult()} immediately triggers + * the next iteration, causing unbounded stack growth. For example, a 1000-iteration + * loop would create > 1000 stack frames and cause {@code StackOverflowError}.

+ * + *

The trampoline intercepts this recursion: instead of executing the next iteration + * immediately (which would deepen the stack), it enqueues the continuation and returns, allowing + * the stack to unwind. A flat loop at the top then processes enqueued continuation iteratively, + * maintaining constant stack depth regardless of iteration count.

+ * + *

Since async chains are sequential, at most one task is pending at any time. + * The trampoline uses a single slot rather than a queue.

+ * + * The first call on a thread becomes the "trampoline owner" and runs the drain loop. + * Subsequent (re-entrant) calls on the same thread enqueue their continuation and return immediately.

+ * + *

This class is not part of the public API and may be removed or changed at any time

+ */ +@NotThreadSafe +public final class AsyncTrampoline { + + private static final ThreadLocal TRAMPOLINE = new ThreadLocal<>(); + + private AsyncTrampoline() {} + + /** + * Execute continuation through the trampoline. If no trampoline is active, become the owner + * and drain all enqueued continuations. If a trampoline is already active, enqueue and return. + */ + public static void run(final Runnable continuation) { + ContinuationHolder continuationHolder = TRAMPOLINE.get(); + if (continuationHolder != null) { + continuationHolder.enqueue(continuation); + } else { + continuationHolder = new ContinuationHolder(); + TRAMPOLINE.set(continuationHolder); + try { + continuation.run(); + while (continuationHolder.continuation != null) { + Runnable continuationToRun = continuationHolder.continuation; + continuationHolder.continuation = null; + continuationToRun.run(); + } + } finally { + TRAMPOLINE.remove(); + } + } + } + + /** + * A single-slot container for continuation. + * At most one continuation is pending at any time in a sequential async chain. + */ + @NotThreadSafe + private static final class ContinuationHolder { + @Nullable + private Runnable continuation; + + void enqueue(final Runnable continuation) { + if (this.continuation != null) { + throw new AssertionError("Trampoline slot already occupied"); + } + this.continuation = continuation; + } + } +} diff --git a/driver-core/src/main/com/mongodb/internal/async/function/AsyncCallbackLoop.java b/driver-core/src/main/com/mongodb/internal/async/function/AsyncCallbackLoop.java index a347a2a7e47..311892874b0 100644 --- a/driver-core/src/main/com/mongodb/internal/async/function/AsyncCallbackLoop.java +++ b/driver-core/src/main/com/mongodb/internal/async/function/AsyncCallbackLoop.java @@ -16,6 +16,7 @@ package com.mongodb.internal.async.function; import com.mongodb.annotations.NotThreadSafe; +import com.mongodb.internal.async.AsyncTrampoline; import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.lang.Nullable; @@ -62,9 +63,11 @@ public void run(final SingleResultCallback callback) { @NotThreadSafe private class LoopingCallback implements SingleResultCallback { private final SingleResultCallback wrapped; + private final Runnable nextIteration; LoopingCallback(final SingleResultCallback callback) { wrapped = callback; + nextIteration = () -> body.run(this); } @Override @@ -80,7 +83,7 @@ public void onResult(@Nullable final Void result, @Nullable final Throwable t) { return; } if (continueLooping) { - body.run(this); + AsyncTrampoline.run(nextIteration); } else { wrapped.onResult(result, null); } diff --git a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsAbstractTest.java b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsAbstractTest.java index 9a9b7552d3e..8f6bc7046a2 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsAbstractTest.java +++ b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsAbstractTest.java @@ -26,6 +26,7 @@ import static com.mongodb.assertions.Assertions.assertNotNull; import static com.mongodb.internal.async.AsyncRunnable.beginAsync; +import static org.junit.jupiter.api.Assertions.assertEquals; abstract class AsyncFunctionsAbstractTest extends AsyncFunctionsTestBase { private static final TimeoutContext TIMEOUT_CONTEXT = new TimeoutContext(new TimeoutSettings(0, 0, 0, 0L, 0)); @@ -723,6 +724,120 @@ void testTryCatchTestAndRethrow() { }); } + @Test + void testWhile() { + // last iteration: 3 < 3 = 1 + // 1(plainTest exception) + 1(plainTest false) + 1(sync exception) + 1(sync success) * 1(transition to next iteration) = 4 + // 1(plainTest exception) + 1(plainTest false) + 1(sync exception) + 1(sync success) * 4(transition to next iteration) = 7 + // 1(plainTest exception) + 1(plainTest false) + 1(sync exception) + 1(sync success) * 7(transition to next iteration) = 10 + assertBehavesSameVariations(10, + () -> { + int counter = 0; + while (counter < 3 && plainTest(counter)) { + counter++; + sync(counter); + } + }, + (callback) -> { + MutableValue counter = new MutableValue<>(0); + beginAsync().thenRunWhileLoop(() -> counter.get() < 3 && plainTest(counter.get()), c2 -> { + counter.set(counter.get() + 1); + async(counter.get(), c2); + }).finish(callback); + }); + } + + @Test + void testWhileWithThenRun() { + // while: last iteration: 3 < 3 = 1 + // 1(plainTest exception) + 1(plainTest false) + 1(sync exception) + 1(sync success) * 1(transition to next iteration) = 4 + // 1(plainTest exception) + 1(plainTest false) + 1(sync exception) + 1(sync success) * 4(transition to next iteration) = 7 + // 1(plainTest exception) + 1(plainTest false) + 1(sync exception) + 1(sync success) * 7(transition to next iteration) = 10 + // trailing sync: 1(exception) + 1(success) = 2 + // 6(while exception) + 4(while success) * 2(trailing sync) = 14 + assertBehavesSameVariations(14, + () -> { + int counter = 0; + while (counter < 3 && plainTest(counter)) { + counter++; + sync(counter); + } + sync(counter + 1); + }, + (callback) -> { + MutableValue counter = new MutableValue<>(0); + beginAsync().thenRun(c -> { + beginAsync().thenRunWhileLoop(() -> counter.get() < 3 && plainTest(counter.get()), c2 -> { + counter.set(counter.get() + 1); + async(counter.get(), c2); + }).finish(c); + }).thenRun(c -> { + async(counter.get() + 1, c); + }).finish(callback); + }); + } + + @Test + void testNestedWhileLoops() { + // inner while: 4 success + 6 exception = 10 + // last inner iteration: 3 < 3 = 1 + // 1(outer plainTest exception) + 1(outer plainTest false) + (inner while) * 1(transition to next iteration) = 12 + // 1(outer plainTest exception) + 1(outer plainTest false) + (inner while) * 12(transition to next iteration) = 56 + // 1(outer plainTest exception) + 1(outer plainTest false) + (inner while) * 56(transition to next iteration) = 232 + assertBehavesSameVariations(232, + () -> { + int outer = 0; + while (outer < 3 && plainTest(outer)) { + int inner = 0; + while (inner < 3 && plainTest(inner)) { + sync(outer + inner); + inner++; + } + outer++; + } + }, + (callback) -> { + MutableValue outer = new MutableValue<>(0); + beginAsync().thenRunWhileLoop(() -> outer.get() < 3 && plainTest(outer.get()), c -> { + MutableValue inner = new MutableValue<>(0); + beginAsync().thenRunWhileLoop( + () -> inner.get() < 3 && plainTest(inner.get()), + c2 -> { + beginAsync().thenRun(c3 -> { + async(outer.get() + inner.get(), c3); + }).thenRun(c3 -> { + inner.set(inner.get() + 1); + c3.complete(c3); + }).finish(c2); + } + ).thenRun(c2 -> { + outer.set(outer.get() + 1); + c2.complete(c2); + }).finish(c); + }).finish(callback); + }); + } + + @Test + void testWhileLoopStackConstant() { + int depthWith100 = maxStackDepthForIterations(100); + int depthWith10000 = maxStackDepthForIterations(10_000); + assertEquals(depthWith100, depthWith10000, "Stack depth should be constant regardless of iteration count (trampoline)"); + } + + private int maxStackDepthForIterations(final int iterations) { + MutableValue counter = new MutableValue<>(0); + MutableValue maxDepth = new MutableValue<>(0); + beginAsync().thenRunWhileLoop(() -> counter.get() < iterations, c -> { + maxDepth.set(Math.max(maxDepth.get(), Thread.currentThread().getStackTrace().length)); + counter.set(counter.get() + 1); + c.complete(c); + }).finish((v, t) -> {}); + + assertEquals(iterations, counter.get()); + return maxDepth.get(); + } + @Test void testRetryLoop() { assertBehavesSameVariations(InvocationTracker.DEPTH_LIMIT * 2 + 1, @@ -768,6 +883,65 @@ void testDoWhileLoop() { }); } + @Test + void testNestedDoWhileLoops() { + // inner do-while: 3 success + 5 exception = 8 + // last outer iteration: 3 < 3 = 1 + // 5(inner exception) + 3(inner success) * 1(transition to next iteration) = 8 + // 5(inner exception) + 3(inner success) * (1(outer plainTest exception) + 1(outer plainTest false) + 8(transition to next iteration)) = 35 + // 5(inner exception) + 3(inner success) * (1(outer plainTest exception) + 1(outer plainTest false) + 35(transition to next iteration)) = 116 + assertBehavesSameVariations(116, + () -> { + int outer = 0; + do { + int inner = 0; + do { + sync(outer + inner); + inner++; + } while (inner < 3 && plainTest(inner)); + outer++; + } while (outer < 3 && plainTest(outer)); + }, + (callback) -> { + MutableValue outer = new MutableValue<>(0); + beginAsync().thenRunDoWhileLoop(c -> { + MutableValue inner = new MutableValue<>(0); + beginAsync().thenRunDoWhileLoop(c2 -> { + beginAsync().thenRun(c3 -> { + async(outer.get() + inner.get(), c3); + }).thenRun(c3 -> { + inner.set(inner.get() + 1); + c3.complete(c3); + }).finish(c2); + }, () -> inner.get() < 3 && plainTest(inner.get()) + ).thenRun(c2 -> { + outer.set(outer.get() + 1); + c2.complete(c2); + }).finish(c); + }, () -> outer.get() < 3 && plainTest(outer.get())).finish(callback); + }); + } + + @Test + void testDoWhileLoopStackConstant() { + int depthWith100 = maxDoWhileStackDepthForIterations(100); + int depthWith10000 = maxDoWhileStackDepthForIterations(10_000); + assertEquals(depthWith100, depthWith10000, + "Stack depth should be constant regardless of iteration count"); + } + + private int maxDoWhileStackDepthForIterations(final int iterations) { + MutableValue counter = new MutableValue<>(0); + MutableValue maxDepth = new MutableValue<>(0); + beginAsync().thenRunDoWhileLoop(c -> { + maxDepth.set(Math.max(maxDepth.get(), Thread.currentThread().getStackTrace().length)); + counter.set(counter.get() + 1); + c.complete(c); + }, () -> counter.get() < iterations).finish((v, t) -> {}); + assertEquals(iterations, counter.get()); + return maxDepth.get(); + } + @Test void testFinallyWithPlainInsideTry() { // (in try: normal flow + exception + exception) * (in finally: normal + exception) = 6 diff --git a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTestBase.java b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTestBase.java index 10a58152d9f..73d9d59b4dc 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTestBase.java +++ b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTestBase.java @@ -32,6 +32,7 @@ import java.util.function.Consumer; import java.util.function.Supplier; +import static java.lang.String.format; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -272,14 +273,16 @@ private void assertBehavesSame(final Supplier sync, final Runnable betwee } assertTrue(wasCalledFuture.isDone(), "callback should have been called"); - assertEquals(expectedEvents, listener.getEventStrings(), "steps should have matched"); - assertEquals(expectedValue, actualValue.get()); assertEquals(expectedException == null, actualException.get() == null, - "both or neither should have produced an exception"); + format("both or neither should have produced an exception. Expected exception: %s, actual exception: %s", + expectedException, + actualException)); if (expectedException != null) { assertEquals(expectedException.getMessage(), actualException.get().getMessage()); assertEquals(expectedException.getClass(), actualException.get().getClass()); } + assertEquals(expectedEvents, listener.getEventStrings(), "steps should have matched"); + assertEquals(expectedValue, actualValue.get()); listener.clear(); }