diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/logging/GrpcLoggingServiceTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/logging/GrpcLoggingServiceTest.java index fc5043710232..02e500f3e85a 100644 --- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/logging/GrpcLoggingServiceTest.java +++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/logging/GrpcLoggingServiceTest.java @@ -22,8 +22,6 @@ import java.util.ArrayList; import java.util.Collection; -import java.util.concurrent.BlockingQueue; -import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; @@ -63,39 +61,42 @@ public void testMultipleClientsSuccessfullyProcessed() throws Exception { GrpcLoggingService.forWriter(new CollectionAppendingLogWriter(logs)); try (GrpcFnServer server = GrpcFnServer.allocatePortAndCreateFor(service, InProcessServerFactory.create())) { - - Collection> tasks = new ArrayList<>(); + ExecutorService executorService = Executors.newCachedThreadPool(); + Collection> futures = new ArrayList<>(); + CountDownLatch waitForServerHangup = new CountDownLatch(3); for (int i = 1; i <= 3; ++i) { final int instructionId = i; - tasks.add( - () -> { - CountDownLatch waitForServerHangup = new CountDownLatch(1); - String url = server.getApiServiceDescriptor().getUrl(); - ManagedChannel channel = InProcessChannelBuilder.forName(url).build(); - StreamObserver outboundObserver = - BeamFnLoggingGrpc.newStub(channel) - .logging( - TestStreams.withOnNext(messageDiscarder) - .withOnCompleted(new CountDown(waitForServerHangup)) - .build()); - outboundObserver.onNext(createLogsWithIds(instructionId, -instructionId)); - outboundObserver.onCompleted(); - waitForServerHangup.await(); - return null; - }); + futures.add( + executorService.submit( + () -> { + String url = server.getApiServiceDescriptor().getUrl(); + ManagedChannel channel = InProcessChannelBuilder.forName(url).build(); + StreamObserver outboundObserver = + BeamFnLoggingGrpc.newStub(channel) + .logging( + TestStreams.withOnNext(messageDiscarder) + .withOnCompleted(new CountDown(waitForServerHangup)) + .build()); + outboundObserver.onNext(createLogsWithIds(instructionId, -instructionId)); + outboundObserver.onCompleted(); + })); } - ExecutorService executorService = Executors.newCachedThreadPool(); - executorService.invokeAll(tasks); - assertThat( - logs, - containsInAnyOrder( - createLogWithId(1L), - createLogWithId(2L), - createLogWithId(3L), - createLogWithId(-1L), - createLogWithId(-2L), - createLogWithId(-3L))); + // Make sure all streams were created and issued client operations. + for (Future f : futures) { + f.get(); + } + // Ensure all the streams were completed as expected before closing the server. + waitForServerHangup.await(); } + assertThat( + logs, + containsInAnyOrder( + createLogWithId(1L), + createLogWithId(2L), + createLogWithId(3L), + createLogWithId(-1L), + createLogWithId(-2L), + createLogWithId(-3L))); } @Test @@ -107,32 +108,23 @@ public void testMultipleClientsFailingIsHandledGracefullyByServer() throws Excep GrpcFnServer.allocatePortAndCreateFor(service, InProcessServerFactory.create())) { CountDownLatch waitForTermination = new CountDownLatch(3); - final BlockingQueue> outboundObservers = - new LinkedBlockingQueue<>(); - Collection> tasks = new ArrayList<>(); - for (int i = 1; i <= 3; ++i) { - final int instructionId = i; - tasks.add( - () -> { - ManagedChannel channel = - InProcessChannelBuilder.forName(server.getApiServiceDescriptor().getUrl()) - .build(); - StreamObserver outboundObserver = - BeamFnLoggingGrpc.newStub(channel) - .logging( - TestStreams.withOnNext(messageDiscarder) - .withOnError(new CountDown(waitForTermination)) - .build()); - outboundObserver.onNext(createLogsWithIds(instructionId, -instructionId)); - outboundObservers.add(outboundObserver); - return null; - }); + final Collection> outboundObservers = new ArrayList<>(); + // Create all the streams + for (int instructionId = 1; instructionId <= 3; ++instructionId) { + ManagedChannel channel = + InProcessChannelBuilder.forName(server.getApiServiceDescriptor().getUrl()).build(); + StreamObserver outboundObserver = + BeamFnLoggingGrpc.newStub(channel) + .logging( + TestStreams.withOnNext(messageDiscarder) + .withOnError(new CountDown(waitForTermination)) + .build()); + outboundObserver.onNext(createLogsWithIds(instructionId, -instructionId)); + outboundObservers.add(outboundObserver); } - ExecutorService executorService = Executors.newCachedThreadPool(); - executorService.invokeAll(tasks); - for (int i = 1; i <= 3; ++i) { - outboundObservers.take().onError(new RuntimeException("Client " + i)); + for (StreamObserver outboundObserver : outboundObservers) { + outboundObserver.onError(new RuntimeException("Client")); } waitForTermination.await(); } @@ -142,19 +134,19 @@ public void testMultipleClientsFailingIsHandledGracefullyByServer() throws Excep public void testServerCloseHangsUpClients() throws Exception { LinkedBlockingQueue logs = new LinkedBlockingQueue<>(); ExecutorService executorService = Executors.newCachedThreadPool(); - Collection> futures = new ArrayList<>(); final GrpcLoggingService service = GrpcLoggingService.forWriter(new CollectionAppendingLogWriter(logs)); + CountDownLatch waitForServerHangup = new CountDownLatch(3); try (GrpcFnServer server = GrpcFnServer.allocatePortAndCreateFor(service, InProcessServerFactory.create())) { + Collection> futures = new ArrayList<>(); for (int i = 1; i <= 3; ++i) { final long instructionId = i; futures.add( executorService.submit( () -> { { - CountDownLatch waitForServerHangup = new CountDownLatch(1); ManagedChannel channel = InProcessChannelBuilder.forName(server.getApiServiceDescriptor().getUrl()) .build(); @@ -165,19 +157,21 @@ public void testServerCloseHangsUpClients() throws Exception { .withOnCompleted(new CountDown(waitForServerHangup)) .build()); outboundObserver.onNext(createLogsWithIds(instructionId)); - waitForServerHangup.await(); return null; } })); } + // Ensure all the streams have started and sent their instruction. + for (Future f : futures) { + f.get(); + } // Wait till each client has sent their message showing that they have connected. for (int i = 1; i <= 3; ++i) { logs.take(); } + // Close the server without closing the streams and ensure they observe the hangup. } - for (Future future : futures) { - future.get(); - } + waitForServerHangup.await(); } private BeamFnApi.LogEntry.List createLogsWithIds(long... ids) {