2222import com .google .api .gax .grpc .InstantiatingGrpcChannelProvider ;
2323import com .google .protobuf .ByteString ;
2424import com .google .protobuf .Empty ;
25+ import com .google .protobuf .ListValue ;
26+ import com .google .protobuf .TextFormat ;
27+ import com .google .protobuf .Value ;
2528import com .google .spanner .v1 .BeginTransactionRequest ;
2629import com .google .spanner .v1 .CacheUpdate ;
2730import com .google .spanner .v1 .CommitRequest ;
2831import com .google .spanner .v1 .CommitResponse ;
2932import com .google .spanner .v1 .ExecuteSqlRequest ;
3033import com .google .spanner .v1 .Group ;
34+ import com .google .spanner .v1 .Mutation ;
3135import com .google .spanner .v1 .PartialResultSet ;
3236import com .google .spanner .v1 .Range ;
3337import com .google .spanner .v1 .ReadRequest ;
38+ import com .google .spanner .v1 .RecipeList ;
3439import com .google .spanner .v1 .ResultSet ;
3540import com .google .spanner .v1 .ResultSetMetadata ;
3641import com .google .spanner .v1 .RollbackRequest ;
@@ -274,6 +279,124 @@ public void resultSetCacheUpdateRoutesSubsequentRequest() throws Exception {
274279 assertThat (harness .endpointCache .callCountForAddress ("routed:1234" )).isEqualTo (1 );
275280 }
276281
282+ @ Test
283+ public void beginTransactionWithMutationKeyAddsRoutingHint () throws Exception {
284+ TestHarness harness = createHarness ();
285+ seedCache (harness , createMutationRoutingCacheUpdate ());
286+
287+ Mutation mutation = createInsertMutation ("b" );
288+ ClientCall <BeginTransactionRequest , Transaction > beginCall =
289+ harness .channel .newCall (SpannerGrpc .getBeginTransactionMethod (), CallOptions .DEFAULT );
290+ beginCall .start (new CapturingListener <Transaction >(), new Metadata ());
291+ beginCall .sendMessage (
292+ BeginTransactionRequest .newBuilder ().setSession (SESSION ).setMutationKey (mutation ).build ());
293+
294+ @ SuppressWarnings ("unchecked" )
295+ RecordingClientCall <BeginTransactionRequest , Transaction > beginDelegate =
296+ (RecordingClientCall <BeginTransactionRequest , Transaction >)
297+ harness .defaultManagedChannel .latestCall ();
298+
299+ assertThat (beginDelegate .lastMessage ).isNotNull ();
300+ assertThat (beginDelegate .lastMessage .getRoutingHint ().getDatabaseId ()).isEqualTo (7L );
301+ assertThat (beginDelegate .lastMessage .getRoutingHint ().getSchemaGeneration ().toStringUtf8 ())
302+ .isEqualTo ("1" );
303+ assertThat (beginDelegate .lastMessage .getRoutingHint ().getKey ().isEmpty ()).isFalse ();
304+ }
305+
306+ @ Test
307+ public void transactionCacheUpdateEnablesCommitRoutingHint () throws Exception {
308+ TestHarness harness = createHarness ();
309+ ByteString transactionId = ByteString .copyFromUtf8 ("tx-with-cache-update" );
310+
311+ ClientCall <BeginTransactionRequest , Transaction > beginCall =
312+ harness .channel .newCall (SpannerGrpc .getBeginTransactionMethod (), CallOptions .DEFAULT );
313+ beginCall .start (new CapturingListener <Transaction >(), new Metadata ());
314+ beginCall .sendMessage (BeginTransactionRequest .newBuilder ().setSession (SESSION ).build ());
315+
316+ @ SuppressWarnings ("unchecked" )
317+ RecordingClientCall <BeginTransactionRequest , Transaction > beginDelegate =
318+ (RecordingClientCall <BeginTransactionRequest , Transaction >)
319+ harness .defaultManagedChannel .latestCall ();
320+ beginDelegate .emitOnMessage (
321+ Transaction .newBuilder ()
322+ .setId (transactionId )
323+ .setCacheUpdate (createMutationRoutingCacheUpdate ())
324+ .build ());
325+ beginDelegate .emitOnClose (Status .OK , new Metadata ());
326+
327+ ClientCall <CommitRequest , CommitResponse > commitCall =
328+ harness .channel .newCall (SpannerGrpc .getCommitMethod (), CallOptions .DEFAULT );
329+ commitCall .start (new CapturingListener <CommitResponse >(), new Metadata ());
330+ commitCall .sendMessage (
331+ CommitRequest .newBuilder ()
332+ .setSession (SESSION )
333+ .setTransactionId (transactionId )
334+ .addMutations (createInsertMutation ("b" ))
335+ .build ());
336+
337+ @ SuppressWarnings ("unchecked" )
338+ RecordingClientCall <CommitRequest , CommitResponse > commitDelegate =
339+ (RecordingClientCall <CommitRequest , CommitResponse >)
340+ harness .defaultManagedChannel .latestCall ();
341+
342+ assertThat (commitDelegate .lastMessage ).isNotNull ();
343+ assertThat (commitDelegate .lastMessage .getRoutingHint ().getDatabaseId ()).isEqualTo (7L );
344+ assertThat (commitDelegate .lastMessage .getRoutingHint ().getSchemaGeneration ().toStringUtf8 ())
345+ .isEqualTo ("1" );
346+ assertThat (commitDelegate .lastMessage .getRoutingHint ().getKey ().isEmpty ()).isFalse ();
347+ }
348+
349+ @ Test
350+ public void commitResponseCacheUpdateEnablesSubsequentBeginRoutingHint () throws Exception {
351+ TestHarness harness = createHarness ();
352+ ByteString transactionId = ByteString .copyFromUtf8 ("tx-before-commit-cache-update" );
353+
354+ ClientCall <BeginTransactionRequest , Transaction > beginCall =
355+ harness .channel .newCall (SpannerGrpc .getBeginTransactionMethod (), CallOptions .DEFAULT );
356+ beginCall .start (new CapturingListener <Transaction >(), new Metadata ());
357+ beginCall .sendMessage (BeginTransactionRequest .newBuilder ().setSession (SESSION ).build ());
358+
359+ @ SuppressWarnings ("unchecked" )
360+ RecordingClientCall <BeginTransactionRequest , Transaction > beginDelegate =
361+ (RecordingClientCall <BeginTransactionRequest , Transaction >)
362+ harness .defaultManagedChannel .latestCall ();
363+ beginDelegate .emitOnMessage (Transaction .newBuilder ().setId (transactionId ).build ());
364+ beginDelegate .emitOnClose (Status .OK , new Metadata ());
365+
366+ ClientCall <CommitRequest , CommitResponse > commitCall =
367+ harness .channel .newCall (SpannerGrpc .getCommitMethod (), CallOptions .DEFAULT );
368+ commitCall .start (new CapturingListener <CommitResponse >(), new Metadata ());
369+ commitCall .sendMessage (
370+ CommitRequest .newBuilder ().setSession (SESSION ).setTransactionId (transactionId ).build ());
371+
372+ @ SuppressWarnings ("unchecked" )
373+ RecordingClientCall <CommitRequest , CommitResponse > commitDelegate =
374+ (RecordingClientCall <CommitRequest , CommitResponse >)
375+ harness .defaultManagedChannel .latestCall ();
376+ commitDelegate .emitOnMessage (
377+ CommitResponse .newBuilder ().setCacheUpdate (createMutationRoutingCacheUpdate ()).build ());
378+ commitDelegate .emitOnClose (Status .OK , new Metadata ());
379+
380+ Mutation mutation = createInsertMutation ("b" );
381+ ClientCall <BeginTransactionRequest , Transaction > secondBeginCall =
382+ harness .channel .newCall (SpannerGrpc .getBeginTransactionMethod (), CallOptions .DEFAULT );
383+ secondBeginCall .start (new CapturingListener <Transaction >(), new Metadata ());
384+ secondBeginCall .sendMessage (
385+ BeginTransactionRequest .newBuilder ().setSession (SESSION ).setMutationKey (mutation ).build ());
386+
387+ @ SuppressWarnings ("unchecked" )
388+ RecordingClientCall <BeginTransactionRequest , Transaction > routedBeginDelegate =
389+ (RecordingClientCall <BeginTransactionRequest , Transaction >)
390+ harness .defaultManagedChannel .latestCall ();
391+
392+ assertThat (routedBeginDelegate .lastMessage ).isNotNull ();
393+ assertThat (routedBeginDelegate .lastMessage .getRoutingHint ().getDatabaseId ()).isEqualTo (7L );
394+ assertThat (
395+ routedBeginDelegate .lastMessage .getRoutingHint ().getSchemaGeneration ().toStringUtf8 ())
396+ .isEqualTo ("1" );
397+ assertThat (routedBeginDelegate .lastMessage .getRoutingHint ().getKey ().isEmpty ()).isFalse ();
398+ }
399+
277400 @ Test
278401 public void readOnlyTransactionRoutesEachReadIndependently () throws Exception {
279402 TestHarness harness = createHarness ();
@@ -635,6 +758,43 @@ private static CacheUpdate createTwoRangeCacheUpdate() {
635758 .build ();
636759 }
637760
761+ private static CacheUpdate createMutationRoutingCacheUpdate () throws TextFormat .ParseException {
762+ RecipeList keyRecipes =
763+ parseRecipeList (
764+ "schema_generation: \" 1\" \n "
765+ + "recipe {\n "
766+ + " table_name: \" T\" \n "
767+ + " part { tag: 1 }\n "
768+ + " part {\n "
769+ + " order: ASCENDING\n "
770+ + " null_order: NULLS_FIRST\n "
771+ + " type { code: STRING }\n "
772+ + " identifier: \" k\" \n "
773+ + " }\n "
774+ + "}\n " );
775+ return CacheUpdate .newBuilder ()
776+ .setDatabaseId (7L )
777+ .setKeyRecipes (keyRecipes )
778+ .addRange (
779+ Range .newBuilder ()
780+ .setStartKey (bytes ("a" ))
781+ .setLimitKey (bytes ("m" ))
782+ .setGroupUid (1L )
783+ .setSplitId (1L )
784+ .setGeneration (bytes ("1" )))
785+ .addGroup (
786+ Group .newBuilder ()
787+ .setGroupUid (1L )
788+ .setGeneration (bytes ("1" ))
789+ .addTablets (
790+ Tablet .newBuilder ()
791+ .setTabletUid (1L )
792+ .setServerAddress ("server-a:1234" )
793+ .setIncarnation (bytes ("1" ))
794+ .setDistance (0 )))
795+ .build ();
796+ }
797+
638798 private static void seedCache (TestHarness harness , CacheUpdate cacheUpdate ) {
639799 ClientCall <ExecuteSqlRequest , ResultSet > seedCall =
640800 harness .channel .newCall (SpannerGrpc .getExecuteSqlMethod (), CallOptions .DEFAULT );
@@ -652,6 +812,25 @@ private static void seedCache(TestHarness harness, CacheUpdate cacheUpdate) {
652812 seedDelegate .emitOnMessage (ResultSet .newBuilder ().setCacheUpdate (cacheUpdate ).build ());
653813 }
654814
815+ private static Mutation createInsertMutation (String keyValue ) {
816+ return Mutation .newBuilder ()
817+ .setInsert (
818+ Mutation .Write .newBuilder ()
819+ .setTable ("T" )
820+ .addColumns ("k" )
821+ .addValues (
822+ ListValue .newBuilder ()
823+ .addValues (Value .newBuilder ().setStringValue (keyValue ).build ())
824+ .build ()))
825+ .build ();
826+ }
827+
828+ private static RecipeList parseRecipeList (String text ) throws TextFormat .ParseException {
829+ RecipeList .Builder builder = RecipeList .newBuilder ();
830+ TextFormat .merge (text , builder );
831+ return builder .build ();
832+ }
833+
655834 private static TestHarness createHarness () throws IOException {
656835 FakeEndpointCache endpointCache = new FakeEndpointCache (DEFAULT_ADDRESS );
657836 InstantiatingGrpcChannelProvider provider =
@@ -841,6 +1020,7 @@ int callCount() {
8411020 private static final class RecordingClientCall <RequestT , ResponseT >
8421021 extends ClientCall <RequestT , ResponseT > {
8431022 @ Nullable private ClientCall .Listener <ResponseT > listener ;
1023+ @ Nullable private RequestT lastMessage ;
8441024 private boolean cancelCalled ;
8451025 @ Nullable private String cancelMessage ;
8461026 @ Nullable private Throwable cancelCause ;
@@ -864,7 +1044,9 @@ public void cancel(@Nullable String message, @Nullable Throwable cause) {
8641044 public void halfClose () {}
8651045
8661046 @ Override
867- public void sendMessage (RequestT message ) {}
1047+ public void sendMessage (RequestT message ) {
1048+ this .lastMessage = message ;
1049+ }
8681050
8691051 void emitOnMessage (ResponseT response ) {
8701052 if (listener != null ) {
0 commit comments