diff --git a/core/spring-boot-test/src/main/java/org/springframework/boot/test/system/OutputCapture.java b/core/spring-boot-test/src/main/java/org/springframework/boot/test/system/OutputCapture.java index d9b8b468c664..2e44cfccf2bd 100644 --- a/core/spring-boot-test/src/main/java/org/springframework/boot/test/system/OutputCapture.java +++ b/core/spring-boot-test/src/main/java/org/springframework/boot/test/system/OutputCapture.java @@ -23,6 +23,7 @@ import java.util.ArrayList; import java.util.Deque; import java.util.List; +import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Predicate; @@ -42,6 +43,7 @@ * @author Phillip Webb * @author Andy Wilkinson * @author Sam Brannen + * @author Daniel Schmidt * @see OutputCaptureExtension * @see OutputCaptureRule */ @@ -51,11 +53,14 @@ class OutputCapture implements CapturedOutput { private @Nullable AnsiOutputState ansiOutputState; - private final AtomicReference out = new AtomicReference<>(null); + private final AtomicLong outVersion = new AtomicLong(); + private final AtomicReference out = new AtomicReference<>(null); - private final AtomicReference err = new AtomicReference<>(null); + private final AtomicLong errVersion = new AtomicLong(); + private final AtomicReference err = new AtomicReference<>(null); - private final AtomicReference all = new AtomicReference<>(null); + private final AtomicLong allVersion = new AtomicLong(); + private final AtomicReference all = new AtomicReference<>(null); /** * Push a new system capture session onto the stack. @@ -108,7 +113,7 @@ public String toString() { */ @Override public String getAll() { - return get(this.all, (type) -> true); + return get(this.all, this.allVersion, (type) -> true); } /** @@ -117,7 +122,7 @@ public String getAll() { */ @Override public String getOut() { - return get(this.out, Type.OUT::equals); + return get(this.out, this.outVersion, Type.OUT::equals); } /** @@ -126,7 +131,7 @@ public String getOut() { */ @Override public String getErr() { - return get(this.err, Type.ERR::equals); + return get(this.err, this.errVersion, Type.ERR::equals); } /** @@ -138,19 +143,24 @@ void reset() { } void clearExisting() { + this.outVersion.incrementAndGet(); this.out.set(null); + this.errVersion.incrementAndGet(); this.err.set(null); + this.allVersion.incrementAndGet(); this.all.set(null); } - private String get(AtomicReference existing, Predicate filter) { + private String get(AtomicReference resultCache, AtomicLong version, Predicate filter) { Assert.state(!this.systemCaptures.isEmpty(), "No system captures found. Please check your output capture registration."); - String result = existing.get(); - if (result == null) { - result = build(filter); - existing.compareAndSet(null, result); + long currentVersion = version.get(); + VersionedCacheResult cached = resultCache.get(); + if (cached != null && cached.version == currentVersion) { + return cached.result; } + String result = build(filter); + resultCache.compareAndSet(null, new VersionedCacheResult(result, currentVersion)); return result; } @@ -162,6 +172,10 @@ String build(Predicate filter) { return builder.toString(); } + private record VersionedCacheResult(String result, long version) { + + } + /** * A capture session that captures {@link System#out System.out} and {@link System#out * System.err}. diff --git a/core/spring-boot-test/src/test/java/org/springframework/boot/test/system/OutputCaptureTests.java b/core/spring-boot-test/src/test/java/org/springframework/boot/test/system/OutputCaptureTests.java index 53d31364b15a..b9fb7e680a79 100644 --- a/core/spring-boot-test/src/test/java/org/springframework/boot/test/system/OutputCaptureTests.java +++ b/core/spring-boot-test/src/test/java/org/springframework/boot/test/system/OutputCaptureTests.java @@ -19,8 +19,13 @@ import java.io.ByteArrayOutputStream; import java.io.PrintStream; import java.util.NoSuchElementException; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import java.util.function.Predicate; +import org.jspecify.annotations.Nullable; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -32,6 +37,7 @@ * Tests for {@link OutputCapture}. * * @author Phillip Webb + * @author Daniel Schmidt */ class OutputCaptureTests { @@ -188,6 +194,45 @@ void getErrUsesCache() { assertThat(this.output.buildCount).isEqualTo(2); } + @Test + void getOutCacheShouldNotReturnStaleDataWhenDataIsLoggedWhileReading() throws Exception { + this.output.push(); + System.out.print("A"); + this.output.waitAfterBuildLatch = new CountDownLatch(1); + + ExecutorService executorService = null; + try { + executorService = Executors.newFixedThreadPool(2); + var readingThreadFuture = executorService.submit(() -> { + // this will release the releaseAfterBuildLatch and block on the waitAfterBuildLatch + assertThat(this.output.getOut()).isEqualTo("A"); + }); + var writingThreadFuture = executorService.submit(() -> { + // wait until we finished building the first result (but did not yet update the cache) + try { + this.output.releaseAfterBuildLatch.await(); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + // print something else and then release the latch, for the other thread to continue + System.out.print("B"); + this.output.waitAfterBuildLatch.countDown(); + }); + readingThreadFuture.get(); + writingThreadFuture.get(); + } + finally { + if (executorService != null) { + executorService.shutdown(); + executorService.awaitTermination(10, TimeUnit.SECONDS); + } + } + + // If not synchronized correctly this will fail, because the second print did not clear the cache and the cache will return stale data. + assertThat(this.output.getOut()).isEqualTo("AB"); + } + private void pushAndPrint() { this.output.push(); System.out.print("A"); @@ -212,10 +257,26 @@ static class TestOutputCapture extends OutputCapture { int buildCount; + @Nullable + CountDownLatch waitAfterBuildLatch = null; + + CountDownLatch releaseAfterBuildLatch = new CountDownLatch(1); + @Override String build(Predicate filter) { this.buildCount++; - return super.build(filter); + var result = super.build(filter); + this.releaseAfterBuildLatch.countDown(); + if (this.waitAfterBuildLatch != null) { + try { + this.waitAfterBuildLatch.await(); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + } + return result; } }