Skip to content

Commit d418612

Browse files
committed
refactor the ThreadLocal approach
1 parent 1284d59 commit d418612

File tree

2 files changed

+53
-38
lines changed

2 files changed

+53
-38
lines changed

driver-core/src/main/com/mongodb/internal/async/function/AsyncCallbackLoop.java

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public AsyncCallbackLoop(final LoopState state, final AsyncCallbackRunnable body
5151

5252
@Override
5353
public void run(final SingleResultCallback<Void> callback) {
54-
body.run(false, new ReusableLoopCallback(callback));
54+
body.run(false, callback);
5555
}
5656

5757
private static final class Body {
@@ -82,7 +82,7 @@ private Body(final LoopState state, final AsyncCallbackRunnable body) {
8282
*
8383
* <p>If another iteration is needed, it is initiated from the callback passed to
8484
* {@link #body}{@code .}{@link AsyncCallbackRunnable#run(SingleResultCallback) run}
85-
* by invoking {@link #run(boolean, ReusableLoopCallback)}.
85+
* by invoking {@link #run(boolean, SingleResultCallback)}.
8686
* Completing the initiated iteration is {@linkplain SingleResultCallback#onResult(Object, Throwable) invoking} the callback.
8787
* Thus, it is guaranteed that all iterations are executed sequentially with each other
8888
* (that is, completion of one iteration happens-before initiation of the next one)
@@ -109,7 +109,7 @@ private Body(final LoopState state, final AsyncCallbackRunnable body) {
109109
* <li>therefore, we would not have an iteration that is executed synchronously but in a different thread.</li>
110110
* </ul>
111111
*/
112-
boolean run(final boolean trampolining, final ReusableLoopCallback callback) {
112+
boolean run(final boolean trampolining, final SingleResultCallback<Void> afterLoopCallback) {
113113
// The `trampoliningResult` variable must be used only if the initiated iteration is executed synchronously with
114114
// the current method, which must be detected separately.
115115
//
@@ -122,7 +122,7 @@ boolean run(final boolean trampolining, final ReusableLoopCallback callback) {
122122
boolean[] trampoliningResult = {false};
123123
sameThreadDetector.set(SameThreadDetectionStatus.PROBING);
124124
body.run((r, t) -> {
125-
if (callback.onResult(state, r, t)) {
125+
if (completeIfNeeded(afterLoopCallback, r, t)) {
126126
// If we are trampolining, then here we bounce up, trampolining completes and so is the whole loop;
127127
// otherwise, the whole loop simply completes.
128128
return;
@@ -139,9 +139,10 @@ boolean run(final boolean trampolining, final ReusableLoopCallback callback) {
139139
sameThreadDetector.remove();
140140
}
141141
}
142+
// trampolining
142143
boolean anotherIterationNeeded;
143-
do { // trampolining
144-
anotherIterationNeeded = run(true, callback);
144+
do {
145+
anotherIterationNeeded = run(true, afterLoopCallback);
145146
} while (anotherIterationNeeded);
146147
});
147148
try {
@@ -150,39 +151,28 @@ boolean run(final boolean trampolining, final ReusableLoopCallback callback) {
150151
sameThreadDetector.remove();
151152
}
152153
}
153-
}
154-
155-
/**
156-
* This callback is allowed to be {@linkplain #onResult(LoopState, Void, Throwable) completed} more than once.
157-
*/
158-
@NotThreadSafe
159-
private static final class ReusableLoopCallback {
160-
private final SingleResultCallback<Void> wrapped;
161-
162-
ReusableLoopCallback(final SingleResultCallback<Void> callback) {
163-
wrapped = callback;
164-
}
165154

166155
/**
167-
* @return {@code true} iff the {@linkplain ReusableLoopCallback#ReusableLoopCallback(SingleResultCallback) wrapped}
168-
* {@link SingleResultCallback} is {@linkplain SingleResultCallback#onResult(Object, Throwable) completed}.
156+
* @return {@code true} iff the {@code afterLoopCallback} was
157+
* {@linkplain SingleResultCallback#onResult(Object, Throwable) completed}.
169158
*/
170-
public boolean onResult(final LoopState state, @Nullable final Void result, @Nullable final Throwable t) {
159+
private boolean completeIfNeeded(final SingleResultCallback<Void> afterLoopCallback,
160+
@Nullable final Void result, @Nullable final Throwable t) {
171161
if (t != null) {
172-
wrapped.onResult(null, t);
162+
afterLoopCallback.onResult(null, t);
173163
return true;
174164
} else {
175-
boolean continueLooping;
165+
boolean anotherIterationNeeded;
176166
try {
177-
continueLooping = state.advance();
167+
anotherIterationNeeded = state.advance();
178168
} catch (Throwable e) {
179-
wrapped.onResult(null, e);
169+
afterLoopCallback.onResult(null, e);
180170
return true;
181171
}
182-
if (continueLooping) {
172+
if (anotherIterationNeeded) {
183173
return false;
184174
} else {
185-
wrapped.onResult(result, null);
175+
afterLoopCallback.onResult(result, null);
186176
return true;
187177
}
188178
}

driver-core/src/test/unit/com/mongodb/internal/async/VakoTest.java

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import java.util.concurrent.CompletableFuture;
3333
import java.util.concurrent.Executors;
3434
import java.util.concurrent.ScheduledExecutorService;
35+
import java.util.concurrent.ThreadLocalRandom;
3536
import java.util.concurrent.TimeUnit;
3637

3738
import static com.mongodb.internal.async.AsyncRunnable.beginAsync;
@@ -77,28 +78,32 @@ private enum IterationExecutionType {
7778
SYNC_SAME_THREAD,
7879
SYNC_DIFFERENT_THREAD,
7980
ASYNC,
81+
MIXED_SYNC_SAME_AND_ASYNC
8082
}
8183

8284
@ParameterizedTest()
8385
@CsvSource({
84-
"10, 0, SYNC_SAME_THREAD, 0",
85-
// "10, 0, SYNC_DIFFERENT_THREAD, 0",
86-
"10, 0, ASYNC, 4",
87-
"10, 4, ASYNC, 0"
86+
"10, 0, SYNC_SAME_THREAD, 0, true",
87+
// "10, 0, SYNC_DIFFERENT_THREAD, 0, true",
88+
"10, 0, ASYNC, 4, true",
89+
"10, 4, ASYNC, 0, true",
90+
"1_000_000, 0, MIXED_SYNC_SAME_AND_ASYNC, 0, false",
8891
})
8992
void testThenRunDoWhileLoop(
9093
final int counterInitialValue,
9194
final int blockSyncPartOfIterationTotalSeconds,
9295
final IterationExecutionType executionType,
93-
final int delayAsyncExecutionTotalSeconds) throws Exception {
96+
final int delayAsyncExecutionTotalSeconds,
97+
final boolean verbose) throws Exception {
9498
System.err.printf("baselineStackDepth=%d%n%n", Thread.currentThread().getStackTrace().length);
9599
Duration blockSyncPartOfIterationTotalDuration = Duration.ofSeconds(blockSyncPartOfIterationTotalSeconds);
96100
com.mongodb.assertions.Assertions.assertTrue(
97101
executionType.equals(IterationExecutionType.ASYNC) || delayAsyncExecutionTotalSeconds == 0);
98102
Duration delayAsyncExecutionTotalDuration = Duration.ofSeconds(delayAsyncExecutionTotalSeconds);
99103
StartTime start = StartTime.now();
100104
CompletableFuture<Void> join = new CompletableFuture<>();
101-
asyncLoop(new Counter(counterInitialValue), blockSyncPartOfIterationTotalDuration, executionType, delayAsyncExecutionTotalDuration,
105+
asyncLoop(new Counter(counterInitialValue, verbose),
106+
blockSyncPartOfIterationTotalDuration, executionType, delayAsyncExecutionTotalDuration, verbose,
102107
(r, t) -> {
103108
System.err.printf("test callback completed callStackDepth=%s, r=%s, t=%s%n",
104109
Thread.currentThread().getStackTrace().length, r, exceptionToString(t));
@@ -114,25 +119,31 @@ private static void asyncLoop(
114119
final Duration blockSyncPartOfIterationTotalDuration,
115120
final IterationExecutionType executionType,
116121
final Duration delayAsyncExecutionTotalDuration,
122+
final boolean verbose,
117123
final SingleResultCallback<Void> callback) {
118124
beginAsync().thenRunDoWhileLoop(c -> {
119125
sleep(blockSyncPartOfIterationTotalDuration.dividedBy(counter.initial()));
120126
StartTime start = StartTime.now();
121-
asyncPartOfIteration(counter, executionType, delayAsyncExecutionTotalDuration, c);
122-
System.err.printf("\tasyncPartOfIteration returned in %s%n", start.elapsed());
127+
asyncPartOfIteration(counter, executionType, delayAsyncExecutionTotalDuration, verbose, c);
128+
if (verbose) {
129+
System.err.printf("\tasyncPartOfIteration returned in %s%n", start.elapsed());
130+
}
123131
}, () -> !counter.done()).finish(callback);
124132
}
125133

126134
private static void asyncPartOfIteration(
127135
final Counter counter,
128136
final IterationExecutionType executionType,
129137
final Duration delayAsyncExecutionTotalDuration,
138+
final boolean verbose,
130139
final SingleResultCallback<Void> callback) {
131140
Runnable asyncPartOfIteration = () -> {
132141
counter.countDown();
133142
StartTime start = StartTime.now();
134143
callback.complete(callback);
135-
System.err.printf("\tasyncPartOfIteration callback.complete returned in %s%n", start.elapsed());
144+
if (verbose) {
145+
System.err.printf("\tasyncPartOfIteration callback.complete returned in %s%n", start.elapsed());
146+
}
136147
};
137148
switch (executionType) {
138149
case SYNC_SAME_THREAD: {
@@ -150,6 +161,15 @@ private static void asyncPartOfIteration(
150161
delayAsyncExecutionTotalDuration.dividedBy(counter.initial()).toNanos(), TimeUnit.NANOSECONDS);
151162
break;
152163
}
164+
case MIXED_SYNC_SAME_AND_ASYNC: {
165+
if (ThreadLocalRandom.current().nextBoolean()) {
166+
asyncPartOfIteration.run();
167+
} else {
168+
executor.schedule(asyncPartOfIteration,
169+
delayAsyncExecutionTotalDuration.dividedBy(counter.initial()).toNanos(), TimeUnit.NANOSECONDS);
170+
}
171+
break;
172+
}
153173
default: {
154174
com.mongodb.assertions.Assertions.fail(executionType.toString());
155175
}
@@ -159,10 +179,12 @@ private static void asyncPartOfIteration(
159179
private static final class Counter {
160180
private final int initial;
161181
private int current;
182+
private final boolean verbose;
162183

163-
Counter(final int initial) {
184+
Counter(final int initial, final boolean verbose) {
164185
this.initial = initial;
165186
this.current = initial;
187+
this.verbose = verbose;
166188
}
167189

168190
int initial() {
@@ -173,7 +195,10 @@ void countDown() {
173195
com.mongodb.assertions.Assertions.assertTrue(current > 0);
174196
int previous = current;
175197
int decremented = --current;
176-
System.err.printf("counted %d->%d callStackDepth=%d %n", previous, decremented, Thread.currentThread().getStackTrace().length);
198+
if (verbose || decremented % 100_000 == 0) {
199+
System.err.printf("counted %d->%d callStackDepth=%d %n",
200+
previous, decremented, Thread.currentThread().getStackTrace().length);
201+
}
177202
}
178203

179204
boolean done() {

0 commit comments

Comments
 (0)