diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java index 3ea19ad2422a..a34cabe1bbcb 100644 --- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java @@ -608,6 +608,7 @@ private static void checkStreamException( private Deque requests = new ConcurrentLinkedDeque<>(); private volatile CountDownLatch freezeLock = new CountDownLatch(0); private final AtomicInteger freezeAfterReturningNumRows = new AtomicInteger(); + private final AtomicInteger freezeAfterNumRequests = new AtomicInteger(-1); private Queue exceptions = new ConcurrentLinkedQueue<>(); private boolean stickyGlobalExceptions = false; private ConcurrentMap statementResults = new ConcurrentHashMap<>(); @@ -813,17 +814,37 @@ public void setIgnoreInlineBeginRequest(boolean ignore) { } public void freeze() { - freezeLock = new CountDownLatch(1); + synchronized (lock) { + freezeLock = new CountDownLatch(1); + } } public void unfreeze() { - freezeLock.countDown(); + synchronized (lock) { + freezeAfterNumRequests.set(-1); + freezeLock.countDown(); + } } public void freezeAfterReturningNumRows(int numRows) { freezeAfterReturningNumRows.set(numRows); } + public void freezeAfter(int numRequests) { + freezeAfterNumRequests.set(numRequests); + } + + private void maybeFreezeAndRecordRequest(AbstractMessage request) { + synchronized (lock) { + if (freezeAfterNumRequests.get() >= 0) { + if (freezeAfterNumRequests.decrementAndGet() == -1) { + freeze(); + } + } + requests.add(request); + } + } + public void setMaxSessionsInOneBatch(int max) { this.maxNumSessionsInOneBatch = max; } @@ -836,7 +857,7 @@ public void setMaxTotalSessions(int max) { public void batchCreateSessions( BatchCreateSessionsRequest request, StreamObserver responseObserver) { - requests.add(request); + maybeFreezeAndRecordRequest(request); Preconditions.checkNotNull(request.getDatabase()); String name = null; try { @@ -898,7 +919,7 @@ public void batchCreateSessions( @Override public void createSession( CreateSessionRequest request, StreamObserver responseObserver) { - requests.add(request); + maybeFreezeAndRecordRequest(request); Preconditions.checkNotNull(request.getDatabase()); Preconditions.checkNotNull(request.getSession()); String name = generateSessionName(request.getDatabase()); @@ -938,7 +959,7 @@ public void createSession( @Override public void getSession(GetSessionRequest request, StreamObserver responseObserver) { - requests.add(request); + maybeFreezeAndRecordRequest(request); Preconditions.checkNotNull(request.getName()); try { getSessionExecutionTime.simulateExecutionTime(exceptions, stickyGlobalExceptions, freezeLock); @@ -983,7 +1004,7 @@ private void setSessionNotFound(String name, StreamObserver responseObser @Override public void listSessions( ListSessionsRequest request, StreamObserver responseObserver) { - requests.add(request); + maybeFreezeAndRecordRequest(request); try { listSessionsExecutionTime.simulateExecutionTime( exceptions, stickyGlobalExceptions, freezeLock); @@ -1006,7 +1027,7 @@ public void listSessions( @Override public void deleteSession(DeleteSessionRequest request, StreamObserver responseObserver) { - requests.add(request); + maybeFreezeAndRecordRequest(request); Preconditions.checkNotNull(request.getName()); try { deleteSessionExecutionTime.simulateExecutionTime( @@ -1035,7 +1056,7 @@ void doDeleteSession(Session session) { @Override public void executeSql(ExecuteSqlRequest request, StreamObserver responseObserver) { - requests.add(request); + maybeFreezeAndRecordRequest(request); Preconditions.checkNotNull(request.getSession()); Session session = getSession(request.getSession()); if (session == null) { @@ -1133,7 +1154,7 @@ private void returnResultSet( @Override public void executeBatchDml( ExecuteBatchDmlRequest request, StreamObserver responseObserver) { - requests.add(request); + maybeFreezeAndRecordRequest(request); Preconditions.checkNotNull(request.getSession()); Session session = getSession(request.getSession()); if (session == null) { @@ -1241,7 +1262,7 @@ public void executeStreamingSql( || !request .getSql() .equals(MultiplexedSessionDatabaseClient.DETERMINE_DIALECT_STATEMENT.getSql())) { - requests.add(request); + maybeFreezeAndRecordRequest(request); } Preconditions.checkNotNull(request.getSession()); Session session = getSession(request.getSession()); @@ -1687,7 +1708,7 @@ private void throwTransactionAborted(ByteString transactionId) { @Override public void read(final ReadRequest request, StreamObserver responseObserver) { - requests.add(request); + maybeFreezeAndRecordRequest(request); Preconditions.checkNotNull(request.getSession()); Session session = getSession(request.getSession()); if (session == null) { @@ -1720,7 +1741,7 @@ public void read(final ReadRequest request, StreamObserver responseOb @Override public void streamingRead( final ReadRequest request, StreamObserver responseObserver) { - requests.add(request); + maybeFreezeAndRecordRequest(request); Preconditions.checkNotNull(request.getSession()); Session session = getSession(request.getSession()); if (session == null) { @@ -1945,7 +1966,7 @@ public void beginTransaction( .getRequestOptions() .getTransactionTag() .equals("multiplexed-rw-background-begin-txn")) { - requests.add(request); + maybeFreezeAndRecordRequest(request); } Preconditions.checkNotNull(request.getSession()); Session session = getSession(request.getSession()); @@ -2080,7 +2101,7 @@ private void ensureMostRecentTransaction(Session session, ByteString transaction @Override public void commit(CommitRequest request, StreamObserver responseObserver) { - requests.add(request); + maybeFreezeAndRecordRequest(request); Preconditions.checkNotNull(request.getSession()); Session session = getSession(request.getSession()); if (session == null) { @@ -2152,7 +2173,7 @@ public void commit(CommitRequest request, StreamObserver respons @Override public void batchWrite( BatchWriteRequest request, StreamObserver responseObserver) { - requests.add(request); + maybeFreezeAndRecordRequest(request); Preconditions.checkNotNull(request.getSession()); Session session = getSession(request.getSession()); if (session == null) { @@ -2181,7 +2202,7 @@ private void commitTransaction(ByteString transactionId) { @Override public void rollback(RollbackRequest request, StreamObserver responseObserver) { - requests.add(request); + maybeFreezeAndRecordRequest(request); Preconditions.checkNotNull(request.getTransactionId()); Session session = getSession(request.getSession()); if (session == null) { @@ -2230,7 +2251,7 @@ public void markCommitRetryOnTransaction(ByteString transactionId) { @Override public void partitionQuery( PartitionQueryRequest request, StreamObserver responseObserver) { - requests.add(request); + maybeFreezeAndRecordRequest(request); try { partitionQueryExecutionTime.simulateExecutionTime( exceptions, stickyGlobalExceptions, freezeLock); @@ -2249,7 +2270,7 @@ public void partitionQuery( @Override public void partitionRead( PartitionReadRequest request, StreamObserver responseObserver) { - requests.add(request); + maybeFreezeAndRecordRequest(request); try { partitionReadExecutionTime.simulateExecutionTime( exceptions, stickyGlobalExceptions, freezeLock); @@ -2418,23 +2439,27 @@ public ServerServiceDefinition getServiceDefinition() { /** Removes all sessions and transactions. Mocked results are not removed. */ @Override public void reset() { - requests = new ConcurrentLinkedDeque<>(); - exceptions = new ConcurrentLinkedQueue<>(); - statementGetCounts = new ConcurrentHashMap<>(); - sessions = new ConcurrentHashMap<>(); - sessionLastUsed = new ConcurrentHashMap<>(); - transactions = new ConcurrentHashMap<>(); - transactionsStarted.clear(); - isPartitionedDmlTransaction = new ConcurrentHashMap<>(); - abortedTransactions = new ConcurrentHashMap<>(); - transactionCounters = new ConcurrentHashMap<>(); - partitionTokens = new ConcurrentHashMap<>(); - transactionLastUsed = new ConcurrentHashMap<>(); - transactionSequenceNo = new ConcurrentHashMap<>(); - - numSessionsCreated.set(0); - stickyGlobalExceptions = false; - freezeLock.countDown(); + synchronized (lock) { + requests = new ConcurrentLinkedDeque<>(); + exceptions = new ConcurrentLinkedQueue<>(); + statementGetCounts = new ConcurrentHashMap<>(); + sessions = new ConcurrentHashMap<>(); + sessionLastUsed = new ConcurrentHashMap<>(); + transactions = new ConcurrentHashMap<>(); + transactionsStarted.clear(); + isPartitionedDmlTransaction = new ConcurrentHashMap<>(); + abortedTransactions = new ConcurrentHashMap<>(); + transactionCounters = new ConcurrentHashMap<>(); + partitionTokens = new ConcurrentHashMap<>(); + transactionLastUsed = new ConcurrentHashMap<>(); + transactionSequenceNo = new ConcurrentHashMap<>(); + + numSessionsCreated.set(0); + freezeAfterNumRequests.set(-1); + freezeAfterReturningNumRows.set(0); + stickyGlobalExceptions = false; + freezeLock.countDown(); + } } public void removeAllExecutionTimes() { diff --git a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MultiplexedSessionDatabaseClientMockServerTest.java b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MultiplexedSessionDatabaseClientMockServerTest.java index 61d06da821c0..3915efcf6095 100644 --- a/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MultiplexedSessionDatabaseClientMockServerTest.java +++ b/java-spanner/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MultiplexedSessionDatabaseClientMockServerTest.java @@ -102,10 +102,11 @@ public void createSpannerInstance() { } @Test - public void testCreateSessionDeadlineExceeded() { + public void testCreateSessionDeadlineExceeded() throws Exception { // Simulate a problem with the CreateSession RPC making it slow. mockSpanner.setCreateSessionExecutionTime( - SimulatedExecutionTime.ofStickyException(Status.DEADLINE_EXCEEDED.asRuntimeException())); + SimulatedExecutionTime.ofException(Status.DEADLINE_EXCEEDED.asRuntimeException())); + mockSpanner.freezeAfter(1); Spanner testSpanner = SpannerOptions.newBuilder() @@ -123,13 +124,16 @@ public void testCreateSessionDeadlineExceeded() { assertEquals(ErrorCode.DEADLINE_EXCEEDED, exception.getErrorCode()); } - // Remove the simulated problem on the mock server. // The next attempt should then succeed. - mockSpanner.removeAllExecutionTimes(); + mockSpanner.unfreeze(); + DatabaseClientImpl clientImpl = (DatabaseClientImpl) client; + assertNotNull(clientImpl.multiplexedSessionDatabaseClient.getCurrentSessionReference()); + try (ResultSet resultSet = client.singleUse().executeQuery(STATEMENT)) { //noinspection StatementWithEmptyBody while (resultSet.next()) {} } + testSpanner.close(); } @Test