Skip to content

Fix race condition in OutputCapture #46685

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.ArrayList;
import java.util.Deque;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Predicate;
Expand All @@ -42,6 +43,7 @@
* @author Phillip Webb
* @author Andy Wilkinson
* @author Sam Brannen
* @author Daniel Schmidt
* @see OutputCaptureExtension
* @see OutputCaptureRule
*/
Expand All @@ -51,11 +53,14 @@ class OutputCapture implements CapturedOutput {

private @Nullable AnsiOutputState ansiOutputState;

private final AtomicReference<String> out = new AtomicReference<>(null);
private final AtomicLong outVersion = new AtomicLong();
private final AtomicReference<VersionedCacheResult> out = new AtomicReference<>(null);

private final AtomicReference<String> err = new AtomicReference<>(null);
private final AtomicLong errVersion = new AtomicLong();
private final AtomicReference<VersionedCacheResult> err = new AtomicReference<>(null);

private final AtomicReference<String> all = new AtomicReference<>(null);
private final AtomicLong allVersion = new AtomicLong();
private final AtomicReference<VersionedCacheResult> all = new AtomicReference<>(null);

/**
* Push a new system capture session onto the stack.
Expand Down Expand Up @@ -108,7 +113,7 @@ public String toString() {
*/
@Override
public String getAll() {
return get(this.all, (type) -> true);
return get(this.all, this.allVersion, (type) -> true);
}

/**
Expand All @@ -117,7 +122,7 @@ public String getAll() {
*/
@Override
public String getOut() {
return get(this.out, Type.OUT::equals);
return get(this.out, this.outVersion, Type.OUT::equals);
}

/**
Expand All @@ -126,7 +131,7 @@ public String getOut() {
*/
@Override
public String getErr() {
return get(this.err, Type.ERR::equals);
return get(this.err, this.errVersion, Type.ERR::equals);
}

/**
Expand All @@ -138,19 +143,24 @@ void reset() {
}

void clearExisting() {
this.outVersion.incrementAndGet();
this.out.set(null);
this.errVersion.incrementAndGet();
this.err.set(null);
this.allVersion.incrementAndGet();
this.all.set(null);
}

private String get(AtomicReference<String> existing, Predicate<Type> filter) {
private String get(AtomicReference<VersionedCacheResult> resultCache, AtomicLong version, Predicate<Type> filter) {
Assert.state(!this.systemCaptures.isEmpty(),
"No system captures found. Please check your output capture registration.");
String result = existing.get();
if (result == null) {
result = build(filter);
existing.compareAndSet(null, result);
long currentVersion = version.get();
VersionedCacheResult cached = resultCache.get();
if (cached != null && cached.version == currentVersion) {
return cached.result;
}
String result = build(filter);
resultCache.compareAndSet(null, new VersionedCacheResult(result, currentVersion));
return result;
}

Expand All @@ -162,6 +172,10 @@ String build(Predicate<Type> filter) {
return builder.toString();
}

private record VersionedCacheResult(String result, long version) {

}

/**
* A capture session that captures {@link System#out System.out} and {@link System#out
* System.err}.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@
import java.io.ByteArrayOutputStream;
import java.io.PrintStream;
import java.util.NoSuchElementException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.function.Predicate;

import org.jspecify.annotations.Nullable;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand All @@ -32,6 +37,7 @@
* Tests for {@link OutputCapture}.
*
* @author Phillip Webb
* @author Daniel Schmidt
*/
class OutputCaptureTests {

Expand Down Expand Up @@ -188,6 +194,45 @@ void getErrUsesCache() {
assertThat(this.output.buildCount).isEqualTo(2);
}

@Test
void getOutCacheShouldNotReturnStaleDataWhenDataIsLoggedWhileReading() throws Exception {
this.output.push();
System.out.print("A");
this.output.waitAfterBuildLatch = new CountDownLatch(1);

ExecutorService executorService = null;
try {
executorService = Executors.newFixedThreadPool(2);
var readingThreadFuture = executorService.submit(() -> {
// this will release the releaseAfterBuildLatch and block on the waitAfterBuildLatch
assertThat(this.output.getOut()).isEqualTo("A");
});
var writingThreadFuture = executorService.submit(() -> {
// wait until we finished building the first result (but did not yet update the cache)
try {
this.output.releaseAfterBuildLatch.await();
}
catch (InterruptedException e) {
throw new RuntimeException(e);
}
// print something else and then release the latch, for the other thread to continue
System.out.print("B");
this.output.waitAfterBuildLatch.countDown();
});
readingThreadFuture.get();
writingThreadFuture.get();
}
finally {
if (executorService != null) {
executorService.shutdown();
executorService.awaitTermination(10, TimeUnit.SECONDS);
}
}

// If not synchronized correctly this will fail, because the second print did not clear the cache and the cache will return stale data.
assertThat(this.output.getOut()).isEqualTo("AB");
}

private void pushAndPrint() {
this.output.push();
System.out.print("A");
Expand All @@ -212,10 +257,26 @@ static class TestOutputCapture extends OutputCapture {

int buildCount;

@Nullable
CountDownLatch waitAfterBuildLatch = null;

CountDownLatch releaseAfterBuildLatch = new CountDownLatch(1);

@Override
String build(Predicate<Type> filter) {
this.buildCount++;
return super.build(filter);
var result = super.build(filter);
this.releaseAfterBuildLatch.countDown();
if (this.waitAfterBuildLatch != null) {
try {
this.waitAfterBuildLatch.await();
}
catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException(e);
}
}
return result;
}

}
Expand Down