diff --git a/google/cloud/spanner_v1/streamed.py b/google/cloud/spanner_v1/streamed.py index 39b2151388..c41e65d39f 100644 --- a/google/cloud/spanner_v1/streamed.py +++ b/google/cloud/spanner_v1/streamed.py @@ -53,6 +53,7 @@ def __init__( self._column_info = column_info # Column information self._field_decoders = None self._lazy_decode = lazy_decode # Return protobuf values + self._done = False @property def fields(self): @@ -154,11 +155,16 @@ def _consume_next(self): self._merge_values(values) + if response_pb.last: + self._done = True + def __iter__(self): while True: iter_rows, self._rows[:] = self._rows[:], () while iter_rows: yield iter_rows.pop(0) + if self._done: + return try: self._consume_next() except StopIteration: diff --git a/google/cloud/spanner_v1/testing/mock_spanner.py b/google/cloud/spanner_v1/testing/mock_spanner.py index f8971a6098..e3c2198d68 100644 --- a/google/cloud/spanner_v1/testing/mock_spanner.py +++ b/google/cloud/spanner_v1/testing/mock_spanner.py @@ -35,11 +35,17 @@ class MockSpanner: def __init__(self): self.results = {} + self.execute_streaming_sql_results = {} self.errors = {} def add_result(self, sql: str, result: result_set.ResultSet): self.results[sql.lower().strip()] = result + def add_execute_streaming_sql_results( + self, sql: str, partial_result_sets: list[result_set.PartialResultSet] + ): + self.execute_streaming_sql_results[sql.lower().strip()] = partial_result_sets + def get_result(self, sql: str) -> result_set.ResultSet: result = self.results.get(sql.lower().strip()) if result is None: @@ -55,9 +61,20 @@ def pop_error(self, context): if error: context.abort_with_status(error) - def get_result_as_partial_result_sets( + def get_execute_streaming_sql_results( self, sql: str, started_transaction: transaction.Transaction - ) -> [result_set.PartialResultSet]: + ) -> list[result_set.PartialResultSet]: + if self.execute_streaming_sql_results.get(sql.lower().strip()): + partials = self.execute_streaming_sql_results[sql.lower().strip()] + else: + partials = self.get_result_as_partial_result_sets(sql) + if started_transaction: + partials[0].metadata.transaction = started_transaction + return partials + + def get_result_as_partial_result_sets( + self, sql: str + ) -> list[result_set.PartialResultSet]: result: result_set.ResultSet = self.get_result(sql) partials = [] first = True @@ -70,11 +87,10 @@ def get_result_as_partial_result_sets( partial = result_set.PartialResultSet() if first: partial.metadata = ResultSetMetadata(result.metadata) + first = False partial.values.extend(row) partials.append(partial) partials[len(partials) - 1].stats = result.stats - if started_transaction: - partials[0].metadata.transaction = started_transaction return partials @@ -149,7 +165,7 @@ def ExecuteStreamingSql(self, request, context): self._requests.append(request) self.mock_spanner.pop_error(context) started_transaction = self.__maybe_create_transaction(request) - partials = self.mock_spanner.get_result_as_partial_result_sets( + partials = self.mock_spanner.get_execute_streaming_sql_results( request.sql, started_transaction ) for result in partials: diff --git a/tests/mockserver_tests/mock_server_test_base.py b/tests/mockserver_tests/mock_server_test_base.py index 7b4538d601..1b56ca6aa0 100644 --- a/tests/mockserver_tests/mock_server_test_base.py +++ b/tests/mockserver_tests/mock_server_test_base.py @@ -14,27 +14,34 @@ import unittest -from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode -from google.cloud.spanner_v1.testing.mock_database_admin import DatabaseAdminServicer -from google.cloud.spanner_v1.testing.mock_spanner import ( - start_mock_server, - SpannerServicer, -) -import google.cloud.spanner_v1.types.type as spanner_type -import google.cloud.spanner_v1.types.result_set as result_set +import grpc from google.api_core.client_options import ClientOptions from google.auth.credentials import AnonymousCredentials -from google.cloud.spanner_v1 import Client, TypeCode, FixedSizePool -from google.cloud.spanner_v1.database import Database -from google.cloud.spanner_v1.instance import Instance -import grpc -from google.rpc import code_pb2 -from google.rpc import status_pb2 -from google.rpc.error_details_pb2 import RetryInfo +from google.cloud.spanner_v1 import Type + +from google.cloud.spanner_v1 import StructType +from google.cloud.spanner_v1._helpers import _make_value_pb + +from google.cloud.spanner_v1 import PartialResultSet from google.protobuf.duration_pb2 import Duration +from google.rpc import code_pb2, status_pb2 + +from google.rpc.error_details_pb2 import RetryInfo from grpc_status._common import code_to_grpc_status_code from grpc_status.rpc_status import _Status +import google.cloud.spanner_v1.types.result_set as result_set +import google.cloud.spanner_v1.types.type as spanner_type +from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode +from google.cloud.spanner_v1 import Client, FixedSizePool, ResultSetMetadata, TypeCode +from google.cloud.spanner_v1.database import Database +from google.cloud.spanner_v1.instance import Instance +from google.cloud.spanner_v1.testing.mock_database_admin import DatabaseAdminServicer +from google.cloud.spanner_v1.testing.mock_spanner import ( + SpannerServicer, + start_mock_server, +) + # Creates an aborted status with the smallest possible retry delay. def aborted_status() -> _Status: @@ -57,6 +64,27 @@ def aborted_status() -> _Status: return status +def _make_partial_result_sets( + fields: list[tuple[str, TypeCode]], results: list[dict] +) -> list[result_set.PartialResultSet]: + partial_result_sets = [] + for result in results: + partial_result_set = PartialResultSet() + if len(partial_result_sets) == 0: + # setting the metadata + metadata = ResultSetMetadata(row_type=StructType(fields=[])) + for field in fields: + metadata.row_type.fields.append( + StructType.Field(name=field[0], type_=Type(code=field[1])) + ) + partial_result_set.metadata = metadata + for value in result["values"]: + partial_result_set.values.append(_make_value_pb(value)) + partial_result_set.last = result.get("last") or False + partial_result_sets.append(partial_result_set) + return partial_result_sets + + # Creates an UNAVAILABLE status with the smallest possible retry delay. def unavailable_status() -> _Status: error = status_pb2.Status( @@ -101,6 +129,14 @@ def add_select1_result(): add_single_result("select 1", "c", TypeCode.INT64, [("1",)]) +def add_execute_streaming_sql_results( + sql: str, partial_result_sets: list[result_set.PartialResultSet] +): + MockServerTestBase.spanner_service.mock_spanner.add_execute_streaming_sql_results( + sql, partial_result_sets + ) + + def add_single_result( sql: str, column_name: str, type_code: spanner_type.TypeCode, row ): diff --git a/tests/mockserver_tests/test_basics.py b/tests/mockserver_tests/test_basics.py index 9db84b117f..0dab935a16 100644 --- a/tests/mockserver_tests/test_basics.py +++ b/tests/mockserver_tests/test_basics.py @@ -17,22 +17,24 @@ from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode from google.cloud.spanner_v1 import ( BatchCreateSessionsRequest, - ExecuteSqlRequest, BeginTransactionRequest, - TransactionOptions, ExecuteBatchDmlRequest, + ExecuteSqlRequest, + TransactionOptions, TypeCode, ) -from google.cloud.spanner_v1.transaction import Transaction from google.cloud.spanner_v1.testing.mock_spanner import SpannerServicer +from google.cloud.spanner_v1.transaction import Transaction from tests.mockserver_tests.mock_server_test_base import ( MockServerTestBase, + _make_partial_result_sets, add_select1_result, + add_single_result, add_update_count, add_error, unavailable_status, - add_single_result, + add_execute_streaming_sql_results, ) @@ -176,6 +178,31 @@ def test_last_statement_query(self): self.assertEqual(1, len(requests), msg=requests) self.assertTrue(requests[0].last_statement, requests[0]) + def test_execute_streaming_sql_last_field(self): + partial_result_sets = _make_partial_result_sets( + [("ID", TypeCode.INT64), ("NAME", TypeCode.STRING)], + [ + {"values": ["1", "ABC", "2", "DEF"]}, + {"values": ["3", "GHI"], "last": True}, + ], + ) + + sql = "select * from my_table" + add_execute_streaming_sql_results(sql, partial_result_sets) + count = 1 + with self.database.snapshot() as snapshot: + results = snapshot.execute_sql(sql) + result_list = [] + for row in results: + result_list.append(row) + self.assertEqual(count, row[0]) + count += 1 + self.assertEqual(3, len(result_list)) + requests = self.spanner_service.requests + self.assertEqual(2, len(requests), msg=requests) + self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest)) + self.assertTrue(isinstance(requests[1], ExecuteSqlRequest)) + def _execute_query(transaction: Transaction, sql: str): rows = transaction.execute_sql(sql, last_statement=True) diff --git a/tests/unit/test_streamed.py b/tests/unit/test_streamed.py index e02afbede7..529bb0ef3f 100644 --- a/tests/unit/test_streamed.py +++ b/tests/unit/test_streamed.py @@ -122,12 +122,12 @@ def _make_result_set_stats(query_plan=None, **kw): @staticmethod def _make_partial_result_set( - values, metadata=None, stats=None, chunked_value=False + values, metadata=None, stats=None, chunked_value=False, last=False ): from google.cloud.spanner_v1 import PartialResultSet results = PartialResultSet( - metadata=metadata, stats=stats, chunked_value=chunked_value + metadata=metadata, stats=stats, chunked_value=chunked_value, last=last ) for v in values: results.values.append(v) @@ -162,6 +162,40 @@ def test__merge_chunk_bool(self): with self.assertRaises(Unmergeable): streamed._merge_chunk(chunk) + def test__PartialResultSetWithLastFlag(self): + from google.cloud.spanner_v1 import TypeCode + + fields = [ + self._make_scalar_field("ID", TypeCode.INT64), + self._make_scalar_field("NAME", TypeCode.STRING), + ] + for length in range(4, 6): + metadata = self._make_result_set_metadata(fields) + result_sets = [ + self._make_partial_result_set( + [self._make_value(0), "google_0"], metadata=metadata + ) + ] + for i in range(1, 5): + bares = [i] + values = [ + [self._make_value(bare), "google_" + str(bare)] for bare in bares + ] + result_sets.append( + self._make_partial_result_set( + *values, metadata=metadata, last=(i == length - 1) + ) + ) + + iterator = _MockCancellableIterator(*result_sets) + streamed = self._make_one(iterator) + count = 0 + for row in streamed: + self.assertEqual(row[0], count) + self.assertEqual(row[1], "google_" + str(count)) + count += 1 + self.assertEqual(count, length) + def test__merge_chunk_numeric(self): from google.cloud.spanner_v1 import TypeCode