From d35a0cc7f7c22ac60bcf5fda1e3bcbbba53ec47c Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Thu, 29 May 2025 12:41:52 -0700 Subject: [PATCH 01/41] Update `SessionOptions` to support `GOOGLE_CLOUD_SPANNER_FORCE_DISABLE_MULTIPLEXED_SESSIONS` and add unit tests. Signed-off-by: Taylor Curran --- google/cloud/spanner_v1/session_options.py | 89 +++++++----- tests/_builders.py | 22 +++ tests/unit/test_session_options.py | 160 +++++++++++++++++++++ 3 files changed, 233 insertions(+), 38 deletions(-) create mode 100644 tests/_builders.py create mode 100644 tests/unit/test_session_options.py diff --git a/google/cloud/spanner_v1/session_options.py b/google/cloud/spanner_v1/session_options.py index 12af15f8d1..a3042142cd 100644 --- a/google/cloud/spanner_v1/session_options.py +++ b/google/cloud/spanner_v1/session_options.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import os from enum import Enum -from logging import Logger class TransactionType(Enum): @@ -26,7 +26,7 @@ class TransactionType(Enum): class SessionOptions(object): """Represents the session options for the Cloud Spanner Python client. - We can use ::class::`SessionOptions` to determine whether multiplexed sessions + We can use :class:`SessionOptions` to determine whether multiplexed sessions should be used for a specific transaction type with :meth:`use_multiplexed`. The use of multiplexed session can be disabled for a specific transaction type or for all transaction types with :meth:`disable_multiplexed`. @@ -40,6 +40,9 @@ class SessionOptions(object): ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE = ( "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_FOR_RW" ) + ENV_VAR_FORCE_DISABLE_MULTIPLEXED = ( + "GOOGLE_CLOUD_SPANNER_FORCE_DISABLE_MULTIPLEXED_SESSIONS" + ) def __init__(self): # Internal overrides to disable the use of multiplexed @@ -52,76 +55,86 @@ def __init__(self): def use_multiplexed(self, transaction_type: TransactionType) -> bool: """Returns whether to use multiplexed sessions for the given transaction type. + Multiplexed sessions are enabled for read-only transactions if: - * ENV_VAR_ENABLE_MULTIPLEXED is set to true; and - * multiplexed sessions have not been disabled for read-only transactions. + * ENV_VAR_ENABLE_MULTIPLEXED is set to true; + * ENV_VAR_FORCE_DISABLE_MULTIPLEXED is not set to true; and + * multiplexed sessions have not been disabled for read-only transactions. + Multiplexed sessions are enabled for partitioned transactions if: - * ENV_VAR_ENABLE_MULTIPLEXED is set to true; - * ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED is set to true; and - * multiplexed sessions have not been disabled for partitioned transactions. - Multiplexed sessions are **currently disabled** for read / write. + * ENV_VAR_ENABLE_MULTIPLEXED is set to true; + * ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED is set to true; + * ENV_VAR_FORCE_DISABLE_MULTIPLEXED is not set to true; and + * multiplexed sessions have not been disabled for partitioned transactions. + + Multiplexed sessions are enabled for read/write transactions if: + * ENV_VAR_ENABLE_MULTIPLEXED is set to true; + * ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE is set to true; + * ENV_VAR_FORCE_DISABLE_MULTIPLEXED is not set to true; and + * multiplexed sessions have not been disabled for read/write transactions. + :type transaction_type: :class:`TransactionType` - :param transaction_type: the type of transaction to check whether - multiplexed sessions should be used. + :param transaction_type: the type of transaction """ if transaction_type is TransactionType.READ_ONLY: - return self._is_multiplexed_enabled[transaction_type] and self._getenv( - self.ENV_VAR_ENABLE_MULTIPLEXED + return ( + self._getenv(self.ENV_VAR_ENABLE_MULTIPLEXED) + and not self._getenv(self.ENV_VAR_FORCE_DISABLE_MULTIPLEXED) + and self._is_multiplexed_enabled[transaction_type] ) elif transaction_type is TransactionType.PARTITIONED: return ( - self._is_multiplexed_enabled[transaction_type] - and self._getenv(self.ENV_VAR_ENABLE_MULTIPLEXED) + self._getenv(self.ENV_VAR_ENABLE_MULTIPLEXED) and self._getenv(self.ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED) + and not self._getenv(self.ENV_VAR_FORCE_DISABLE_MULTIPLEXED) + and self._is_multiplexed_enabled[transaction_type] ) elif transaction_type is TransactionType.READ_WRITE: - return False + return ( + self._getenv(self.ENV_VAR_ENABLE_MULTIPLEXED) + and self._getenv(self.ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE) + and not self._getenv(self.ENV_VAR_FORCE_DISABLE_MULTIPLEXED) + and self._is_multiplexed_enabled[transaction_type] + ) raise ValueError(f"Transaction type {transaction_type} is not supported.") def disable_multiplexed( - self, logger: Logger = None, transaction_type: TransactionType = None + self, logger: logging.Logger = None, transaction_type: TransactionType = None ) -> None: """Disables the use of multiplexed sessions for the given transaction type. If no transaction type is specified, disables the use of multiplexed sessions for all transaction types. + :type logger: :class:`Logger` - :param logger: logger to use for logging the disabling the use of multiplexed - sessions. + :param logger: logger for logging disabling the use of multiplexed sessions. + :type transaction_type: :class:`TransactionType` :param transaction_type: (Optional) the type of transaction for which to disable the use of multiplexed sessions. """ - disable_multiplexed_log_msg_fstring = ( - "Disabling multiplexed sessions for {transaction_type_value} transactions" - ) - import logging + if transaction_type and transaction_type not in self._is_multiplexed_enabled: + raise ValueError(f"Transaction type '{transaction_type}' is not supported.") - if logger is None: - logger = logging.getLogger(__name__) + logger = logger or logging.getLogger(__name__) - if transaction_type is None: - logger.warning( - disable_multiplexed_log_msg_fstring.format(transaction_type_value="all") - ) - for transaction_type in TransactionType: - self._is_multiplexed_enabled[transaction_type] = False - return + transaction_types_to_disable = ( + [transaction_type] + if transaction_type is not None + else list(TransactionType) + ) - elif transaction_type in self._is_multiplexed_enabled.keys(): + for transaction_type_to_disable in transaction_types_to_disable: logger.warning( - disable_multiplexed_log_msg_fstring.format( - transaction_type_value=transaction_type.value - ) + f"Disabling multiplexed sessions for {transaction_type_to_disable.value} transactions" ) - self._is_multiplexed_enabled[transaction_type] = False - return + self._is_multiplexed_enabled[transaction_type_to_disable] = False - raise ValueError(f"Transaction type '{transaction_type}' is not supported.") + return @staticmethod def _getenv(name: str) -> bool: diff --git a/tests/_builders.py b/tests/_builders.py new file mode 100644 index 0000000000..f5632786bb --- /dev/null +++ b/tests/_builders.py @@ -0,0 +1,22 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mock import create_autospec + + +def build_logger(): + """Builds and returns a logger for testing.""" + from logging import Logger + + return create_autospec(Logger, instance=True) diff --git a/tests/unit/test_session_options.py b/tests/unit/test_session_options.py new file mode 100644 index 0000000000..393df401f5 --- /dev/null +++ b/tests/unit/test_session_options.py @@ -0,0 +1,160 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from logging import Logger +from os import environ +from unittest import TestCase + +from google.cloud.spanner_v1.session_options import SessionOptions, TransactionType +from tests._builders import build_logger + + +class TestSessionOptions(TestCase): + @classmethod + def setUpClass(cls): + # Save the original environment variables. + cls._original_env = dict(environ) + + @classmethod + def tearDownClass(cls): + # Restore environment variables. + environ.clear() + environ.update(cls._original_env) + + def setUp(self): + self.logger: Logger = build_logger() + + def test_use_multiplexed_for_read_only(self): + session_options = SessionOptions() + transaction_type = TransactionType.READ_ONLY + + environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "false" + self.assertFalse(session_options.use_multiplexed(transaction_type)) + + environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "true" + environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "true" + self.assertFalse(session_options.use_multiplexed(transaction_type)) + + environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "false" + self.assertTrue(session_options.use_multiplexed(transaction_type)) + + session_options.disable_multiplexed(self.logger, transaction_type) + self.assertFalse(session_options.use_multiplexed(transaction_type)) + + self.logger.warning.assert_called_once_with( + "Disabling multiplexed sessions for read-only transactions" + ) + + def test_use_multiplexed_for_partitioned(self): + session_options = SessionOptions() + transaction_type = TransactionType.PARTITIONED + + environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "false" + self.assertFalse(session_options.use_multiplexed(transaction_type)) + + environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "true" + environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED] = "false" + self.assertFalse(session_options.use_multiplexed(transaction_type)) + + environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED] = "true" + environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "true" + self.assertFalse(session_options.use_multiplexed(transaction_type)) + + environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "false" + self.assertTrue(session_options.use_multiplexed(transaction_type)) + + session_options.disable_multiplexed(self.logger, transaction_type) + self.assertFalse(session_options.use_multiplexed(transaction_type)) + + self.logger.warning.assert_called_once_with( + "Disabling multiplexed sessions for partitioned transactions" + ) + + def test_use_multiplexed_for_read_write(self): + session_options = SessionOptions() + transaction_type = TransactionType.READ_WRITE + + environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "false" + self.assertFalse(session_options.use_multiplexed(transaction_type)) + + environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "true" + environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE] = "false" + self.assertFalse(session_options.use_multiplexed(transaction_type)) + + environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE] = "true" + environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "true" + self.assertFalse(session_options.use_multiplexed(transaction_type)) + + environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "false" + self.assertTrue(session_options.use_multiplexed(transaction_type)) + + session_options.disable_multiplexed(self.logger, transaction_type) + self.assertFalse(session_options.use_multiplexed(transaction_type)) + + self.logger.warning.assert_called_once_with( + "Disabling multiplexed sessions for read/write transactions" + ) + + def test_disable_multiplexed_all(self): + session_options = SessionOptions() + + environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "true" + environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED] = "true" + environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE] = "true" + environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "false" + + session_options.disable_multiplexed(self.logger) + + self.assertFalse(session_options.use_multiplexed(TransactionType.READ_ONLY)) + self.assertFalse(session_options.use_multiplexed(TransactionType.PARTITIONED)) + self.assertFalse(session_options.use_multiplexed(TransactionType.READ_WRITE)) + + warning = self.logger.warning + self.assertEqual(warning.call_count, 3) + warning.assert_any_call( + "Disabling multiplexed sessions for read-only transactions" + ) + warning.assert_any_call( + "Disabling multiplexed sessions for partitioned transactions" + ) + warning.assert_any_call( + "Disabling multiplexed sessions for read/write transactions" + ) + + def test_unsupported_transaction_type(self): + session_options = SessionOptions() + unsupported_type = "UNSUPPORTED_TRANSACTION_TYPE" + + with self.assertRaises(ValueError): + session_options.use_multiplexed(unsupported_type) + + with self.assertRaises(ValueError): + session_options.disable_multiplexed(self.logger, unsupported_type) + + def test_env_var_values(self): + session_options = SessionOptions() + + environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "false" + + true_values = ["1", " 1", " 1", "true", "True", "TRUE", " true "] + for value in true_values: + environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = value + self.assertTrue(session_options.use_multiplexed(TransactionType.READ_ONLY)) + + false_values = ["", "0", "false", "False", "FALSE", " false "] + for value in false_values: + environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = value + self.assertFalse(session_options.use_multiplexed(TransactionType.READ_ONLY)) + + del environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] + self.assertFalse(session_options.use_multiplexed(TransactionType.READ_ONLY)) From 34baadfd4dc169c42b584172844676c6432aefa1 Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Thu, 29 May 2025 14:00:16 -0700 Subject: [PATCH 02/41] feat: Multiplexed sessions - Remove handling of `MethodNotImplemented` exception from `DatabaseSessionManager` and add unit tests. Signed-off-by: Taylor Curran --- .../spanner_v1/database_sessions_manager.py | 115 ++++---- tests/_builders.py | 66 +++++ tests/unit/test_database_session_manager.py | 246 ++++++++++++++++++ 3 files changed, 358 insertions(+), 69 deletions(-) create mode 100644 tests/unit/test_database_session_manager.py diff --git a/google/cloud/spanner_v1/database_sessions_manager.py b/google/cloud/spanner_v1/database_sessions_manager.py index d9a0c06f52..7e9e08175d 100644 --- a/google/cloud/spanner_v1/database_sessions_manager.py +++ b/google/cloud/spanner_v1/database_sessions_manager.py @@ -11,38 +11,40 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import datetime -import threading -import time -import weakref - -from google.api_core.exceptions import MethodNotImplemented +from datetime import timedelta +from threading import Event, Lock, Thread +from time import sleep, time +from weakref import ref +from google.cloud.spanner_v1.session import Session +from google.cloud.spanner_v1.session_options import TransactionType from google.cloud.spanner_v1._opentelemetry_tracing import ( get_current_span, add_span_event, ) -from google.cloud.spanner_v1.session import Session -from google.cloud.spanner_v1.session_options import TransactionType class DatabaseSessionsManager(object): """Manages sessions for a Cloud Spanner database. + Sessions can be checked out from the database session manager for a specific transaction type using :meth:`get_session`, and returned to the session manager using :meth:`put_session`. - The sessions returned by the session manager depend on the client's session options (see - :class:`~google.cloud.spanner_v1.session_options.SessionOptions`) and the provided session - pool (see :class:`~google.cloud.spanner_v1.pool.AbstractSessionPool`). + + The sessions returned by the session manager depend on the client's session options + (see :class:`~google.cloud.spanner_v1.session_options.SessionOptions`) and the + provided session pool (see :class:`~google.cloud.spanner_v1.pool.AbstractSessionPool`). + :type database: :class:`~google.cloud.spanner_v1.database.Database` :param database: The database to manage sessions for. + :type pool: :class:`~google.cloud.spanner_v1.pool.AbstractSessionPool` :param pool: The pool to get non-multiplexed sessions from. """ # Intervals for the maintenance thread to check and refresh the multiplexed session. - _MAINTENANCE_THREAD_POLLING_INTERVAL = datetime.timedelta(minutes=10) - _MAINTENANCE_THREAD_REFRESH_INTERVAL = datetime.timedelta(days=7) + _MAINTENANCE_THREAD_POLLING_INTERVAL = timedelta(minutes=10) + _MAINTENANCE_THREAD_REFRESH_INTERVAL = timedelta(days=7) def __init__(self, database, pool): self._database = database @@ -56,21 +58,13 @@ def __init__(self, database, pool): # so that the thread can terminate if the use of multiplexed session has been # disabled for all transactions. self._multiplexed_session = None - self._multiplexed_session_maintenance_thread = None - self._multiplexed_session_lock = threading.Lock() - self._is_multiplexed_sessions_disabled_event = threading.Event() - - @property - def _logger(self): - """The logger used by this database session manager. - - :rtype: :class:`logging.Logger` - :returns: The logger. - """ - return self._database.logger + self._multiplexed_session_thread = None + self._multiplexed_session_lock = Lock() + self._multiplexed_session_disabled_event = Event() def get_session(self, transaction_type: TransactionType) -> Session: """Returns a session for the given transaction type from the database session manager. + :rtype: :class:`~google.cloud.spanner_v1.session.Session` :returns: a session for the given transaction type. """ @@ -78,23 +72,15 @@ def get_session(self, transaction_type: TransactionType) -> Session: session_options = self._database.session_options use_multiplexed = session_options.use_multiplexed(transaction_type) + # TODO multiplexed: enable for read/write transactions if use_multiplexed and transaction_type == TransactionType.READ_WRITE: raise NotImplementedError( f"Multiplexed sessions are not yet supported for {transaction_type} transactions." ) - if use_multiplexed: - try: - session = self._get_multiplexed_session() - - # If multiplexed sessions are not supported, disable - # them for all transactions and return a non-multiplexed session. - except MethodNotImplemented: - self._disable_multiplexed_sessions() - session = self._pool.get() - - else: - session = self._pool.get() + session = ( + self._get_multiplexed_session() if use_multiplexed else self._pool.get() + ) add_span_event( get_current_span(), @@ -106,6 +92,7 @@ def get_session(self, transaction_type: TransactionType) -> Session: def put_session(self, session: Session) -> None: """Returns the session to the database session manager. + :type session: :class:`~google.cloud.spanner_v1.session.Session` :param session: The session to return to the database session manager. """ @@ -124,12 +111,12 @@ def put_session(self, session: Session) -> None: def _get_multiplexed_session(self) -> Session: """Returns a multiplexed session from the database session manager. + If the multiplexed session is not defined, creates a new multiplexed session and starts a maintenance thread to periodically delete and recreate it so that it remains valid. Otherwise, simply returns the current multiplexed session. - :raises MethodNotImplemented: - if multiplexed sessions are not supported. + :rtype: :class:`~google.cloud.spanner_v1.session.Session` :returns: a multiplexed session. """ @@ -138,18 +125,14 @@ def _get_multiplexed_session(self) -> Session: if self._multiplexed_session is None: self._multiplexed_session = self._build_multiplexed_session() - # Build and start a thread to maintain the multiplexed session. - self._multiplexed_session_maintenance_thread = ( - self._build_maintenance_thread() - ) - self._multiplexed_session_maintenance_thread.start() + self._multiplexed_session_thread = self._build_maintenance_thread() + self._multiplexed_session_thread.start() return self._multiplexed_session def _build_multiplexed_session(self) -> Session: """Builds and returns a new multiplexed session for the database session manager. - :raises MethodNotImplemented: - if multiplexed sessions are not supported. + :rtype: :class:`~google.cloud.spanner_v1.session.Session` :returns: a new multiplexed session. """ @@ -159,10 +142,9 @@ def _build_multiplexed_session(self) -> Session: database_role=self._database.database_role, is_multiplexed=True, ) - session.create() - self._logger.info("Created multiplexed session.") + self._database.logger.info("Created multiplexed session.") return session @@ -170,13 +152,14 @@ def _disable_multiplexed_sessions(self) -> None: """Disables multiplexed sessions for all transactions.""" self._multiplexed_session = None - self._is_multiplexed_sessions_disabled_event.set() - self._database.session_options.disable_multiplexed(self._logger) + self._multiplexed_session_disabled_event.set() + self._database.session_options.disable_multiplexed(self._database.logger) - def _build_maintenance_thread(self) -> threading.Thread: + def _build_maintenance_thread(self) -> Thread: """Builds and returns a multiplexed session maintenance thread for the database session manager. This thread will periodically delete and recreate the multiplexed session to ensure that it is always valid. + :rtype: :class:`threading.Thread` :returns: a multiplexed session maintenance thread. """ @@ -184,9 +167,9 @@ def _build_maintenance_thread(self) -> threading.Thread: # Use a weak reference to the database session manager to avoid # creating a circular reference that would prevent the database # session manager from being garbage collected. - session_manager_ref = weakref.ref(self) + session_manager_ref = ref(self) - return threading.Thread( + return Thread( target=self._maintain_multiplexed_session, name=f"maintenance-multiplexed-session-{self._multiplexed_session.name}", args=[session_manager_ref], @@ -196,10 +179,12 @@ def _build_maintenance_thread(self) -> threading.Thread: @staticmethod def _maintain_multiplexed_session(session_manager_ref) -> None: """Maintains the multiplexed session for the database session manager. + This method will delete and recreate the referenced database session manager's multiplexed session to ensure that it is always valid. The method will run until the database session manager is deleted, the multiplexed session is deleted, or building a multiplexed session fails. + :type session_manager_ref: :class:`_weakref.ReferenceType` :param session_manager_ref: A weak reference to the database session manager. """ @@ -215,7 +200,7 @@ def _maintain_multiplexed_session(session_manager_ref) -> None: session_manager._MAINTENANCE_THREAD_REFRESH_INTERVAL.total_seconds() ) - session_created_time = time.time() + session_created_time = time() while True: # Terminate the thread is the database session manager has been deleted. @@ -224,26 +209,18 @@ def _maintain_multiplexed_session(session_manager_ref) -> None: return # Terminate the thread if the use of multiplexed sessions has been disabled. - if session_manager._is_multiplexed_sessions_disabled_event.is_set(): + if session_manager._multiplexed_session_disabled_event.is_set(): return # Wait for until the refresh interval has elapsed. - if time.time() - session_created_time < refresh_interval_seconds: - time.sleep(polling_interval_seconds) + if time() - session_created_time < refresh_interval_seconds: + sleep(polling_interval_seconds) continue with session_manager._multiplexed_session_lock: session_manager._multiplexed_session.delete() + session_manager._multiplexed_session = ( + session_manager._build_multiplexed_session() + ) - try: - session_manager._multiplexed_session = ( - session_manager._build_multiplexed_session() - ) - - # Disable multiplexed sessions for all transactions and terminate - # the thread if building a multiplexed session fails. - except MethodNotImplemented: - session_manager._disable_multiplexed_sessions() - return - - session_created_time = time.time() + session_created_time = time() diff --git a/tests/_builders.py b/tests/_builders.py index f5632786bb..7044bb6ecb 100644 --- a/tests/_builders.py +++ b/tests/_builders.py @@ -14,9 +14,75 @@ from mock import create_autospec +# Default values used to populate required or expected attributes. +# Tests should not depend on them: if a test requires a specific +# identifier or name, it should set it explicitly. +_PROJECT_ID = "default-project-id" +_INSTANCE_ID = "default-instance-id" +_DATABASE_ID = "default-database-id" + def build_logger(): """Builds and returns a logger for testing.""" from logging import Logger return create_autospec(Logger, instance=True) + + +# Client objects +# -------------- + + +def build_client(**kwargs): + """Builds and returns a client for testing using the given arguments. + If a required argument is not provided, a default value will be used.""" + from google.cloud.spanner_v1 import Client + + if "project" not in kwargs: + kwargs["project"] = _PROJECT_ID + + return Client(**kwargs) + + +def build_database(**kwargs): + """Builds and returns a database for testing using the given arguments. + If a required argument is not provided, a default value will be used..""" + from google.cloud.spanner_v1.database import Database + + if "database_id" not in kwargs: + kwargs["database_id"] = _DATABASE_ID + + if "logger" not in kwargs: + kwargs["logger"] = build_logger() + + if "instance" not in kwargs or isinstance(kwargs["instance"], dict): + instance_args = kwargs.pop("instance", {}) + kwargs["instance"] = build_instance(**instance_args) + + database = Database(**kwargs) + database._spanner_api = build_spanner_api() + + return database + + +def build_instance(**kwargs): + """Builds and returns an instance for testing using the given arguments. + If a required argument is not provided, a default value will be used.""" + from google.cloud.spanner_v1.instance import Instance + + if "instance_id" not in kwargs: + kwargs["instance_id"] = _INSTANCE_ID + + if "client" not in kwargs or isinstance(kwargs["client"], dict): + client_args = kwargs.pop("client", {}) + kwargs["client"] = build_client(**client_args) + + return Instance(**kwargs) + + +def build_spanner_api(): + """Builds and returns a mock Spanner Client API for testing using the given arguments. + Commonly used methods are mocked to return default values.""" + from google.cloud.spanner_v1 import SpannerClient + + return create_autospec(SpannerClient, instance=True) diff --git a/tests/unit/test_database_session_manager.py b/tests/unit/test_database_session_manager.py new file mode 100644 index 0000000000..c967dd9705 --- /dev/null +++ b/tests/unit/test_database_session_manager.py @@ -0,0 +1,246 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from datetime import timedelta +from mock import Mock, patch +from os import environ +from threading import Thread +from time import time, sleep +from typing import Callable +from unittest import TestCase + +from google.api_core.exceptions import BadRequest, FailedPrecondition +from google.cloud.spanner_v1.database_sessions_manager import DatabaseSessionsManager +from google.cloud.spanner_v1.session_options import SessionOptions, TransactionType +from tests._builders import build_database + + +# Shorten polling and refresh intervals for testing. +@patch.multiple( + DatabaseSessionsManager, + _MAINTENANCE_THREAD_POLLING_INTERVAL=timedelta(seconds=1), + _MAINTENANCE_THREAD_REFRESH_INTERVAL=timedelta(seconds=2), +) +class TestDatabaseSessionManager(TestCase): + def setUp(self): + self._original_env = dict(environ) + self._sessions_manager = self._build_sessions_manager() + + def tearDown(self): + self._cleanup_database_sessions_manager() + environ.clear() + environ.update(self._original_env) + + def test_read_only_pooled(self): + self._disable_multiplexed_sessions() + manager = self._sessions_manager + + # Get session from pool. + session = manager.get_session(TransactionType.READ_ONLY) + self.assertFalse(session.is_multiplexed) + manager._pool.get.assert_called_once() + + # Return session to pool. + manager.put_session(session) + manager._pool.put.assert_called_once_with(session) + + def test_read_only_multiplexed(self): + self._enable_multiplexed_sessions() + manager = self._sessions_manager + + # Session is created. + session_1 = manager.get_session(TransactionType.READ_ONLY) + self.assertTrue(session_1.is_multiplexed) + manager.put_session(session_1) + + # Session is re-used. + session_2 = manager.get_session(TransactionType.READ_ONLY) + self.assertEqual(session_1, session_2) + manager.put_session(session_2) + + # Verify that pool was not used. + manager._pool.get.assert_not_called() + manager._pool.put.assert_not_called() + + # Verify logger calls. + info = manager._database.logger.info + info.assert_called_once_with("Created multiplexed session.") + + def test_partitioned_pooled(self): + self._disable_multiplexed_sessions() + manager = self._sessions_manager + + # Get session from pool. + session = manager.get_session(TransactionType.PARTITIONED) + self.assertFalse(session.is_multiplexed) + manager._pool.get.assert_called_once() + + # Return session to pool. + manager.put_session(session) + manager._pool.put.assert_called_once_with(session) + + def test_partitioned_multiplexed(self): + self._enable_multiplexed_sessions() + manager = self._sessions_manager + + # Session is created. + session_1 = manager.get_session(TransactionType.PARTITIONED) + self.assertTrue(session_1.is_multiplexed) + manager.put_session(session_1) + + # Session is re-used. + session_2 = manager.get_session(TransactionType.PARTITIONED) + self.assertEqual(session_1, session_2) + manager.put_session(session_2) + + # Verify that pool was not used. + pool = manager._pool + pool.get.assert_not_called() + pool.put.assert_not_called() + + # Verify logger calls. + info = manager._database.logger.info + info.assert_called_once_with("Created multiplexed session.") + + def test_read_write_pooled(self): + self._disable_multiplexed_sessions() + manager = self._sessions_manager + + # Get session from pool. + session = manager.get_session(TransactionType.READ_WRITE) + self.assertFalse(session.is_multiplexed) + manager._pool.get.assert_called_once() + + # Return session to pool. + manager.put_session(session) + manager._pool.put.assert_called_once_with(session) + + # TODO multiplexed: implement support for read/write transactions. + def test_read_write_multiplexed(self): + self._enable_multiplexed_sessions() + + with self.assertRaises(NotImplementedError): + self._sessions_manager.get_session(TransactionType.READ_WRITE) + + def test_multiplexed_maintenance(self, *_): + self._enable_multiplexed_sessions() + manager = self._sessions_manager + + # Maintenance thread is started. + session_1 = manager.get_session(TransactionType.READ_ONLY) + self.assertTrue(session_1.is_multiplexed) + self.assertTrue(manager._multiplexed_session_thread.is_alive()) + + # Wait for maintenance thread to execute. + self._assert_true_with_timeout( + lambda: manager._database.spanner_api.create_session.call_count > 1 + ) + + # Verify that maintenance thread created new multiplexed session. + session_2 = manager.get_session(TransactionType.READ_ONLY) + self.assertTrue(session_2.is_multiplexed) + self.assertNotEqual(session_1, session_2) + + def test_multiplexed_maintenance_terminates_disabled(self): + self._enable_multiplexed_sessions() + manager = self._sessions_manager + + # Maintenance thread is started. + session_1 = manager.get_session(TransactionType.READ_ONLY) + self.assertTrue(session_1.is_multiplexed) + + manager._multiplexed_session_disabled_event.set() + + thread = manager._multiplexed_session_thread + self._assert_thread_terminated(thread) + + def test_exception_bad_request(self): + manager = self._sessions_manager + api = manager._database.spanner_api + api.create_session.side_effect = BadRequest("") + + # Verify that BadRequest is not caught. + with self.assertRaises(BadRequest): + manager.get_session(TransactionType.READ_ONLY) + + def test_exception_failed_precondition(self): + manager = self._sessions_manager + api = manager._database.spanner_api + api.create_session.side_effect = FailedPrecondition("") + + # Verify that FailedPrecondition is not caught. + with self.assertRaises(FailedPrecondition): + manager.get_session(TransactionType.READ_ONLY) + + def _cleanup_database_sessions_manager(self) -> None: + """Cleans up the database session manager after testing.""" + + # If the maintenance thread is still alive, disable multiplexed sessions and + # wait for the thread to terminate. We need to do this to ensure that the + # thread is properly cleaned up and does not interfere with other tests. + sessions_manager = self._sessions_manager + thread = sessions_manager._multiplexed_session_thread + + if thread and thread.is_alive(): + sessions_manager._multiplexed_session_disabled_event.set() + self._assert_thread_terminated(thread) + + def _assert_true_with_timeout(self, condition: Callable) -> None: + """Asserts that the given condition is met within a timeout period.""" + + sleep_seconds = 0.1 + timeout_seconds = 10 + + start_time = time() + while not condition() and time() - start_time < timeout_seconds: + sleep(sleep_seconds) + + self.assertTrue(condition()) + + def _assert_thread_terminated(self, thread: Thread) -> None: + """Asserts that the given thread is terminated.""" + + def _is_thread_terminated(): + return not thread.is_alive() + + self._assert_true_with_timeout(_is_thread_terminated) + + @staticmethod + def _build_sessions_manager() -> DatabaseSessionsManager: + """Builds and returns a new database session manager for testing. + + :rtype: :class:`~google.cloud.spanner_v1.database_sessions_manager.DatabaseSessionsManager` + :returns: a new database session manager. + """ + database = build_database() + sessions_manager = database._sessions_manager + + # Mock the session pool. + pool = sessions_manager._pool + pool.get = Mock(wraps=pool.get) + pool.put = Mock(wraps=pool.put) + + return sessions_manager + + @staticmethod + def _disable_multiplexed_sessions() -> None: + """Sets environment variables to disable multiplexed sessions for all transactions types.""" + environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "true" + + @staticmethod + def _enable_multiplexed_sessions() -> None: + """Sets environment variables to enable multiplexed sessions for all transaction types.""" + environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "true" + environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED] = "true" + environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE] = "true" + environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "false" From 998f23ff65aa160240e3830db6b0ccf77322b88b Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Fri, 30 May 2025 10:07:59 -0700 Subject: [PATCH 03/41] feat: Multiplexed sessions - Update `Connection` to use multiplexed sessions, add unit tests. Signed-off-by: Taylor Curran --- google/cloud/spanner_dbapi/connection.py | 17 ++++- tests/_builders.py | 73 +++++++++++++-------- tests/unit/spanner_dbapi/test_connection.py | 54 ++++++++------- 3 files changed, 91 insertions(+), 53 deletions(-) diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index 6a21769f13..ef0db6f784 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -28,6 +28,7 @@ from google.cloud.spanner_dbapi.transaction_helper import TransactionRetryHelper from google.cloud.spanner_dbapi.cursor import Cursor from google.cloud.spanner_v1 import RequestOptions, TransactionOptions +from google.cloud.spanner_v1.session_options import TransactionType from google.cloud.spanner_v1.snapshot import Snapshot from google.cloud.spanner_dbapi.exceptions import ( @@ -356,8 +357,16 @@ def _session_checkout(self): """ if self.database is None: raise ValueError("Database needs to be passed for this operation") + if not self._session: - self._session = self.database._pool.get() + transaction_type = ( + TransactionType.READ_ONLY + if self.read_only + else TransactionType.READ_WRITE + ) + self._session = self.database._sessions_manager.get_session( + transaction_type + ) return self._session @@ -368,9 +377,11 @@ def _release_session(self): """ if self._session is None: return + if self.database is None: raise ValueError("Database needs to be passed for this operation") - self.database._pool.put(self._session) + + self.database._sessions_manager.put_session(self._session) self._session = None def transaction_checkout(self): @@ -432,7 +443,7 @@ def close(self): self._transaction.rollback() if self._own_pool and self.database: - self.database._pool.clear() + self.database._sessions_manager._pool.clear() self.is_closed = True diff --git a/tests/_builders.py b/tests/_builders.py index 7044bb6ecb..cc9e1ddebf 100644 --- a/tests/_builders.py +++ b/tests/_builders.py @@ -11,8 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from logging import Logger from mock import create_autospec +from typing import Mapping + +from google.cloud.spanner_dbapi import Connection +from google.cloud.spanner_v1 import SpannerClient +from google.cloud.spanner_v1.client import Client +from google.cloud.spanner_v1.database import Database +from google.cloud.spanner_v1.instance import Instance +from google.cloud.spanner_v1.session import Session # Default values used to populate required or expected attributes. # Tests should not depend on them: if a test requires a specific @@ -22,21 +30,9 @@ _DATABASE_ID = "default-database-id" -def build_logger(): - """Builds and returns a logger for testing.""" - from logging import Logger - - return create_autospec(Logger, instance=True) - - -# Client objects -# -------------- - - -def build_client(**kwargs): +def build_client(**kwargs: Mapping) -> Client: """Builds and returns a client for testing using the given arguments. If a required argument is not provided, a default value will be used.""" - from google.cloud.spanner_v1 import Client if "project" not in kwargs: kwargs["project"] = _PROJECT_ID @@ -44,10 +40,22 @@ def build_client(**kwargs): return Client(**kwargs) -def build_database(**kwargs): +def build_connection(**kwargs: Mapping) -> Connection: + """Builds and returns a connection for testing using the given arguments. + If a required argument is not provided, a default value will be used.""" + + if "instance" not in kwargs: + kwargs["instance"] = build_instance() + + if "database" not in kwargs: + kwargs["database"] = build_database(instance=kwargs["instance"]) + + return Connection(**kwargs) + + +def build_database(**kwargs: Mapping) -> Database: """Builds and returns a database for testing using the given arguments. - If a required argument is not provided, a default value will be used..""" - from google.cloud.spanner_v1.database import Database + If a required argument is not provided, a default value will be used.""" if "database_id" not in kwargs: kwargs["database_id"] = _DATABASE_ID @@ -55,9 +63,8 @@ def build_database(**kwargs): if "logger" not in kwargs: kwargs["logger"] = build_logger() - if "instance" not in kwargs or isinstance(kwargs["instance"], dict): - instance_args = kwargs.pop("instance", {}) - kwargs["instance"] = build_instance(**instance_args) + if "instance" not in kwargs: + kwargs["instance"] = build_instance() database = Database(**kwargs) database._spanner_api = build_spanner_api() @@ -65,24 +72,36 @@ def build_database(**kwargs): return database -def build_instance(**kwargs): +def build_instance(**kwargs: Mapping) -> Instance: """Builds and returns an instance for testing using the given arguments. If a required argument is not provided, a default value will be used.""" - from google.cloud.spanner_v1.instance import Instance if "instance_id" not in kwargs: kwargs["instance_id"] = _INSTANCE_ID - if "client" not in kwargs or isinstance(kwargs["client"], dict): - client_args = kwargs.pop("client", {}) - kwargs["client"] = build_client(**client_args) + if "client" not in kwargs: + kwargs["client"] = build_client() return Instance(**kwargs) -def build_spanner_api(): +def build_logger() -> Logger: + """Builds and returns a logger for testing.""" + return create_autospec(Logger, instance=True) + + +def build_session(**kwargs: Mapping) -> Session: + """Builds and returns a session for testing using the given arguments. + If a required argument is not provided, a default value will be used.""" + + if "database" not in kwargs: + kwargs["database"] = build_database() + + return Session(**kwargs) + + +def build_spanner_api() -> SpannerClient: """Builds and returns a mock Spanner Client API for testing using the given arguments. Commonly used methods are mocked to return default values.""" - from google.cloud.spanner_v1 import SpannerClient return create_autospec(SpannerClient, instance=True) diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 04434195db..4a9be916ce 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -37,6 +37,8 @@ ClientSideStatementType, AutocommitDmlMode, ) +from google.cloud.spanner_v1.session_options import TransactionType +from tests._builders import build_connection, build_session PROJECT = "test-project" INSTANCE = "test-instance" @@ -151,25 +153,31 @@ def test_read_only_connection(self): connection.read_only = False self.assertFalse(connection.read_only) - @staticmethod - def _make_pool(): - from google.cloud.spanner_v1.pool import AbstractSessionPool + def test__session_checkout_read_only(self): + connection = build_connection(read_only=True) + database = connection._database + sessions_manager = database._sessions_manager - return mock.create_autospec(AbstractSessionPool) + expected_session = build_session(database=database) + sessions_manager.get_session = mock.MagicMock(return_value=expected_session) - @mock.patch("google.cloud.spanner_v1.database.Database") - def test__session_checkout(self, mock_database): - pool = self._make_pool() - mock_database._pool = pool - connection = Connection(INSTANCE, mock_database) + actual_session = connection._session_checkout() + + self.assertEqual(actual_session, expected_session) + sessions_manager.get_session.assert_called_once_with(TransactionType.READ_ONLY) + + def test__session_checkout_read_write(self): + connection = build_connection(read_only=False) + database = connection._database + sessions_manager = database._sessions_manager + + expected_session = build_session(database=database) + sessions_manager.get_session = mock.MagicMock(return_value=expected_session) - connection._session_checkout() - pool.get.assert_called_once_with() - self.assertEqual(connection._session, pool.get.return_value) + actual_session = connection._session_checkout() - connection._session = "db_session" - connection._session_checkout() - self.assertEqual(connection._session, "db_session") + self.assertEqual(actual_session, expected_session) + sessions_manager.get_session.assert_called_once_with(TransactionType.READ_WRITE) def test_session_checkout_database_error(self): connection = Connection(INSTANCE) @@ -177,16 +185,16 @@ def test_session_checkout_database_error(self): with pytest.raises(ValueError): connection._session_checkout() - @mock.patch("google.cloud.spanner_v1.database.Database") - def test__release_session(self, mock_database): - pool = self._make_pool() - mock_database._pool = pool - connection = Connection(INSTANCE, mock_database) - connection._session = "session" + def test__release_session(self): + connection = build_connection() + sessions_manager = connection._database._sessions_manager + + session = connection._session = build_session(database=connection._database) + put_session = sessions_manager.put_session = mock.MagicMock() connection._release_session() - pool.put.assert_called_once_with("session") - self.assertIsNone(connection._session) + + put_session.assert_called_once_with(session) def test_release_session_database_error(self): connection = Connection(INSTANCE) From ec19f2d02fcf263fec479fb3f2e41d1001ae5ad6 Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Fri, 30 May 2025 10:20:19 -0700 Subject: [PATCH 04/41] cleanup: Rename `beforeNextRetry` to `before_next_retry`. Signed-off-by: Taylor Curran --- google/cloud/spanner_v1/_helpers.py | 6 +++--- google/cloud/spanner_v1/transaction.py | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/google/cloud/spanner_v1/_helpers.py b/google/cloud/spanner_v1/_helpers.py index 7b86a5653f..00a69d462b 100644 --- a/google/cloud/spanner_v1/_helpers.py +++ b/google/cloud/spanner_v1/_helpers.py @@ -535,7 +535,7 @@ def _retry( retry_count=5, delay=2, allowed_exceptions=None, - beforeNextRetry=None, + before_next_retry=None, ): """ Retry a function with a specified number of retries, delay between retries, and list of allowed exceptions. @@ -552,8 +552,8 @@ def _retry( """ retries = 0 while retries <= retry_count: - if retries > 0 and beforeNextRetry: - beforeNextRetry(retries, delay) + if retries > 0 and before_next_retry: + before_next_retry(retries, delay) try: return func() diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index 795e158f6a..396a61ada8 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -199,17 +199,17 @@ def wrapped_method(*args, **kwargs): ) return method(*args, **kwargs) - def beforeNextRetry(nthRetry, delayInSeconds): + def before_next_retry(nth_retry, delay_in_seconds): add_span_event( span, "Transaction Begin Attempt Failed. Retrying", - {"attempt": nthRetry, "sleep_seconds": delayInSeconds}, + {"attempt": nth_retry, "sleep_seconds": delay_in_seconds}, ) response = _retry( wrapped_method, allowed_exceptions={InternalServerError: _check_rst_stream_error}, - beforeNextRetry=beforeNextRetry, + before_next_retry=before_next_retry, ) self._transaction_id = response.id return self._transaction_id @@ -348,17 +348,17 @@ def wrapped_method(*args, **kwargs): ) return method(*args, **kwargs) - def beforeNextRetry(nthRetry, delayInSeconds): + def before_next_retry(nth_retry, delay_in_seconds): add_span_event( span, "Transaction Commit Attempt Failed. Retrying", - {"attempt": nthRetry, "sleep_seconds": delayInSeconds}, + {"attempt": nth_retry, "sleep_seconds": delay_in_seconds}, ) response = _retry( wrapped_method, allowed_exceptions={InternalServerError: _check_rst_stream_error}, - beforeNextRetry=beforeNextRetry, + before_next_retry=before_next_retry, ) add_span_event(span, "Commit Done") From 25d0943389ec199a4ba2b131dce4a9389a8e8655 Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Fri, 30 May 2025 10:24:58 -0700 Subject: [PATCH 05/41] cleanup: Fix a few unrelated typos. Signed-off-by: Taylor Curran --- google/cloud/spanner_v1/database.py | 2 +- tests/unit/spanner_dbapi/test_connect.py | 2 +- tests/unit/test_metrics.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 1273e016da..dda1e13c6c 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -1253,7 +1253,7 @@ def observability_options(self): return opts @property - def sessions_manager(self): + def sessions_manager(self) -> DatabaseSessionsManager: """Returns the database sessions manager. :rtype: :class:`~google.cloud.spanner_v1.database_sessions_manager.DatabaseSessionsManager` diff --git a/tests/unit/spanner_dbapi/test_connect.py b/tests/unit/spanner_dbapi/test_connect.py index 34d3d942ad..5e748eaf66 100644 --- a/tests/unit/spanner_dbapi/test_connect.py +++ b/tests/unit/spanner_dbapi/test_connect.py @@ -69,7 +69,7 @@ def test_w_implicit(self, mock_client): instance.database.assert_called_once_with( DATABASE, pool=None, database_role=None ) - # Datbase constructs its own pool + # Database constructs its own pool self.assertIsNotNone(connection.database._pool) self.assertTrue(connection.instance._client.route_to_leader_enabled) diff --git a/tests/unit/test_metrics.py b/tests/unit/test_metrics.py index 59fe6d2f61..5e37e7cfe2 100644 --- a/tests/unit/test_metrics.py +++ b/tests/unit/test_metrics.py @@ -27,7 +27,7 @@ from opentelemetry import metrics pytest.importorskip("opentelemetry") -# Skip if semconv attributes are not present, as tracing wont' be enabled either +# Skip if semconv attributes are not present, as tracing won't be enabled either # pytest.importorskip("opentelemetry.semconv.attributes.otel_attributes") From fca6f06f5bda4ea827340db5e493949aa70ca4b2 Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Fri, 30 May 2025 14:33:00 -0700 Subject: [PATCH 06/41] feat: Multiplexed sessions - Add ingest of precommit tokens to `_SnapshotBase` and update attributes and tests. Signed-off-by: Taylor Curran --- google/cloud/spanner_v1/snapshot.py | 368 +++++++++++++++---------- google/cloud/spanner_v1/streamed.py | 3 + google/cloud/spanner_v1/transaction.py | 272 ++++++++++-------- tests/unit/test_snapshot.py | 11 +- tests/unit/test_spanner.py | 10 +- tests/unit/test_transaction.py | 26 +- 6 files changed, 396 insertions(+), 294 deletions(-) diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index b8131db18a..311381c160 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -14,11 +14,17 @@ """Model a set of read-only queries to a database as a snapshot.""" -from datetime import datetime import functools import threading +from typing import List, Union + from google.protobuf.struct_pb2 import Struct -from google.cloud.spanner_v1 import ExecuteSqlRequest +from google.cloud.spanner_v1 import ( + ExecuteSqlRequest, + PartialResultSet, + ResultSet, + Transaction, +) from google.cloud.spanner_v1 import ReadRequest from google.cloud.spanner_v1 import TransactionOptions from google.cloud.spanner_v1 import TransactionSelector @@ -45,6 +51,7 @@ from google.cloud.spanner_v1 import RequestOptions from google.cloud.spanner_v1.metrics.metrics_capture import MetricsCapture +from google.cloud.spanner_v1.types import MultiplexedSessionPrecommitToken _STREAM_RESUMPTION_INTERNAL_ERROR_MESSAGES = ( "RST_STREAM", @@ -80,8 +87,8 @@ def _restart_on_unavailable( if both transaction_selector and transaction are passed, then transaction is given priority. """ - resume_token = b"" - item_buffer = [] + resume_token: bytes = b"" + item_buffer: List[PartialResultSet] = [] if transaction is not None: transaction_selector = transaction._make_txn_selector() @@ -97,6 +104,7 @@ def _restart_on_unavailable( while True: try: + # Get results iterator. if iterator is None: with trace_call( trace_name, @@ -114,20 +122,20 @@ def _restart_on_unavailable( span, ), ) + + # Add items from iterator to buffer. + item: PartialResultSet for item in iterator: item_buffer.append(item) - # Setting the transaction id because the transaction begin was inlined for first rpc. - if ( - transaction is not None - and transaction._transaction_id is None - and item.metadata is not None - and item.metadata.transaction is not None - and item.metadata.transaction.id is not None - ): - transaction._transaction_id = item.metadata.transaction.id + + # Update the transaction from the response. + if transaction is not None: + transaction._update_for_result_set_pb(item) + if item.resume_token: resume_token = item.resume_token break + except ServiceUnavailable: del item_buffer[:] with trace_call( @@ -152,6 +160,7 @@ def _restart_on_unavailable( ), ) continue + except InternalServerError as exc: resumable_error = any( resumable_message in exc.message @@ -198,15 +207,34 @@ class _SnapshotBase(_SessionWrapper): Allows reuse of API request methods with different transaction selector. :type session: :class:`~google.cloud.spanner_v1.session.Session` - :param session: the session used to perform the commit + :param session: the session used to perform transaction operations. """ - _multi_use = False _read_only: bool = True - _transaction_id = None - _read_request_count = 0 - _execute_sql_count = 0 - _lock = threading.Lock() + _multi_use: bool = False + + def __init__(self, session): + super().__init__(session) + + # Counts for execute SQL requests and total read requests (including + # execute SQL requests). Used to provide sequence numbers for + # :class:`google.cloud.spanner_v1.types.ExecuteSqlRequest` and to + # verify that single-use transactions are not used more than once, + # respectively. + self._execute_sql_request_count: int = 0 + self._read_request_count: int = 0 + + # Identifier for the transaction. + self._transaction_id: bytes = None + + # Precommit tokens are returned for transactions with + # multiplexed sessions. The precommit token with the + # highest sequence number is included in the commit request. + self._precommit_token: MultiplexedSessionPrecommitToken = None + + # Operations within a transaction can be performed using multiple + # threads, so we need to use a lock when updating the transaction. + self._lock: threading.Lock = threading.Lock() def _make_txn_selector(self): """Helper for :meth:`read` / :meth:`execute_sql`. @@ -317,14 +345,18 @@ def read( for reuse of single-use snapshots, or if a transaction ID is already pending for multiple-use snapshots. """ + + # TODO multiplexed - cleanup if self._read_request_count > 0: if not self._multi_use: raise ValueError("Cannot re-use single-use snapshot.") if self._transaction_id is None and self._read_only: raise ValueError("Transaction ID pending.") - database = self._session._database + session = self._session + database = session._database api = database.spanner_api + metadata = _metadata_with_prefix(database.name) if not self._read_only and database._route_to_leader_enabled: metadata.append( @@ -347,8 +379,8 @@ def read( elif self.transaction_tag is not None: request_options.transaction_tag = self.transaction_tag - request = ReadRequest( - session=self._session.name, + read_request = ReadRequest( + session=session.name, table=table, columns=columns, key_set=keyset._to_pb(), @@ -360,67 +392,22 @@ def read( directed_read_options=directed_read_options, ) - restart = functools.partial( + streaming_read_method = functools.partial( api.streaming_read, - request=request, + request=read_request, metadata=metadata, retry=retry, timeout=timeout, ) - trace_attributes = {"table_id": table, "columns": columns} - observability_options = getattr(database, "observability_options", None) - - if self._transaction_id is None: - # lock is added to handle the inline begin for first rpc - with self._lock: - iterator = _restart_on_unavailable( - restart, - request, - metadata, - f"CloudSpanner.{type(self).__name__}.read", - self._session, - trace_attributes, - transaction=self, - observability_options=observability_options, - request_id_manager=self._session._database, - ) - self._read_request_count += 1 - if self._multi_use: - return StreamedResultSet( - iterator, - source=self, - column_info=column_info, - lazy_decode=lazy_decode, - ) - else: - return StreamedResultSet( - iterator, column_info=column_info, lazy_decode=lazy_decode - ) - else: - iterator = _restart_on_unavailable( - restart, - request, - metadata, - f"CloudSpanner.{type(self).__name__}.read", - self._session, - trace_attributes, - transaction=self, - observability_options=observability_options, - request_id_manager=self._session._database, - ) - - self._read_request_count += 1 - self._session._last_use_time = datetime.now() - - if self._multi_use: - return StreamedResultSet( - iterator, source=self, column_info=column_info, lazy_decode=lazy_decode - ) - else: - return StreamedResultSet( - iterator, column_info=column_info, lazy_decode=lazy_decode - ) + return self._get_streamed_result_set( + method=streaming_read_method, + request=read_request, + metadata=metadata, + trace_attributes={"table_id": table, "columns": columns}, + column_info=column_info, + lazy_decode=lazy_decode, + ) def execute_sql( self, @@ -539,6 +526,8 @@ def execute_sql( for reuse of single-use snapshots, or if a transaction ID is already pending for multiple-use snapshots. """ + + # TODO multiplexed - cleanup if self._read_request_count > 0: if not self._multi_use: raise ValueError("Cannot re-use single-use snapshot.") @@ -552,15 +541,16 @@ def execute_sql( else: params_pb = {} - database = self._session._database + session = self._session + database = session._database + api = database.spanner_api + metadata = _metadata_with_prefix(database.name) if not self._read_only and database._route_to_leader_enabled: metadata.append( _metadata_with_leader_aware_routing(database._route_to_leader_enabled) ) - api = database.spanner_api - # Query-level options have higher precedence than client-level and # environment-level options default_query_options = database._instance._client._query_options @@ -581,14 +571,14 @@ def execute_sql( elif self.transaction_tag is not None: request_options.transaction_tag = self.transaction_tag - request = ExecuteSqlRequest( - session=self._session.name, + execute_sql_request = ExecuteSqlRequest( + session=session.name, sql=sql, params=params_pb, param_types=param_types, query_mode=query_mode, partition_token=partition, - seqno=self._execute_sql_count, + seqno=self._execute_sql_request_count, query_options=query_options, request_options=request_options, last_statement=last_statement, @@ -596,74 +586,79 @@ def execute_sql( directed_read_options=directed_read_options, ) - def wrapped_restart(*args, **kwargs): - restart = functools.partial( - api.execute_streaming_sql, - request=request, - metadata=kwargs.get("metadata", metadata), - retry=retry, - timeout=timeout, - ) - return restart(*args, **kwargs) - - trace_attributes = {"db.statement": sql} - observability_options = getattr(database, "observability_options", None) + execute_streaming_sql_method = functools.partial( + api.execute_streaming_sql, + request=execute_sql_request, + metadata=metadata, + retry=retry, + timeout=timeout, + ) - if self._transaction_id is None: - # lock is added to handle the inline begin for first rpc - with self._lock: - return self._get_streamed_result_set( - wrapped_restart, - request, - metadata, - trace_attributes, - column_info, - observability_options, - lazy_decode=lazy_decode, - ) - else: - return self._get_streamed_result_set( - wrapped_restart, - request, - metadata, - trace_attributes, - column_info, - observability_options, - lazy_decode=lazy_decode, - ) + return self._get_streamed_result_set( + method=execute_streaming_sql_method, + request=execute_sql_request, + metadata=metadata, + trace_attributes={"db.statement": sql}, + column_info=column_info, + lazy_decode=lazy_decode, + ) def _get_streamed_result_set( self, - restart, + method, request, metadata, trace_attributes, column_info, - observability_options=None, - lazy_decode=False, + lazy_decode, ): + """Returns the streamed result set for a read or execute SQL request with the given arguments.""" + + session = self._session + database = session._database + + is_execute_sql_request = isinstance(request, ExecuteSqlRequest) + + trace_method_name = "execute_sql" if is_execute_sql_request else "read" + trace_name = f"CloudSpanner.{type(self).__name__}.{trace_method_name}" + + # If this request begins the transaction, we need to lock + # the transaction until the transaction ID is updated. + is_inline_begin = False + + if self._transaction_id is None: + is_inline_begin = True + self._lock.acquire() + iterator = _restart_on_unavailable( - restart, - request, - metadata, - f"CloudSpanner.{type(self).__name__}.execute_sql", - self._session, - trace_attributes, + method=method, + request=request, + session=session, + metadata=metadata, + trace_name=trace_name, + attributes=trace_attributes, transaction=self, - observability_options=observability_options, - request_id_manager=self._session._database, + observability_options=getattr(database, "observability_options", None), + request_id_manager=database, ) + + if is_inline_begin: + self._lock.release() + + if is_execute_sql_request: + self._execute_sql_request_count += 1 self._read_request_count += 1 - self._execute_sql_count += 1 + + streamed_result_set_args = { + "response_iterator": iterator, + "column_info": column_info, + "lazy_decode": lazy_decode, + } if self._multi_use: - return StreamedResultSet( - iterator, source=self, column_info=column_info, lazy_decode=lazy_decode - ) - else: - return StreamedResultSet( - iterator, column_info=column_info, lazy_decode=lazy_decode - ) + streamed_result_set_args["source"] = self + + return StreamedResultSet(**streamed_result_set_args) def partition_read( self, @@ -716,14 +711,18 @@ def partition_read( for single-use snapshots, or if a transaction ID is already associated with the snapshot. """ + + # TODO multiplexed - cleanup if not self._multi_use: raise ValueError("Cannot use single-use snapshot.") if self._transaction_id is None: raise ValueError("Transaction not started.") - database = self._session._database + session = self._session + database = session._database api = database.spanner_api + metadata = _metadata_with_prefix(database.name) if database._route_to_leader_enabled: metadata.append( @@ -733,8 +732,9 @@ def partition_read( partition_options = PartitionOptions( partition_size_bytes=partition_size_bytes, max_partitions=max_partitions ) - request = PartitionReadRequest( - session=self._session.name, + + partition_read_request = PartitionReadRequest( + session=session.name, table=table, columns=columns, key_set=keyset._to_pb(), @@ -750,7 +750,7 @@ def partition_read( with trace_call( f"CloudSpanner.{type(self).__name__}.partition_read", - self._session, + session, extra_attributes=trace_attributes, observability_options=getattr(database, "observability_options", None), metadata=metadata, @@ -765,14 +765,14 @@ def attempt_tracking_method(): metadata, span, ) - method = functools.partial( + partition_read_method = functools.partial( api.partition_read, - request=request, + request=partition_read_request, metadata=all_metadata, retry=retry, timeout=timeout, ) - return method() + return partition_read_method() response = _retry( attempt_tracking_method, @@ -830,6 +830,8 @@ def partition_query( for single-use snapshots, or if a transaction ID is already associated with the snapshot. """ + + # TODO multiplexed - cleanup if not self._multi_use: raise ValueError("Cannot use single-use snapshot.") @@ -843,8 +845,10 @@ def partition_query( else: params_pb = Struct() - database = self._session._database + session = self._session + database = session._database api = database.spanner_api + metadata = _metadata_with_prefix(database.name) if database._route_to_leader_enabled: metadata.append( @@ -854,8 +858,9 @@ def partition_query( partition_options = PartitionOptions( partition_size_bytes=partition_size_bytes, max_partitions=max_partitions ) - request = PartitionQueryRequest( - session=self._session.name, + + partition_query_request = PartitionQueryRequest( + session=session.name, sql=sql, transaction=transaction, params=params_pb, @@ -866,7 +871,7 @@ def partition_query( trace_attributes = {"db.statement": sql} with trace_call( f"CloudSpanner.{type(self).__name__}.partition_query", - self._session, + session, trace_attributes, observability_options=getattr(database, "observability_options", None), metadata=metadata, @@ -881,14 +886,14 @@ def attempt_tracking_method(): metadata, span, ) - method = functools.partial( + partition_query_method = functools.partial( api.partition_query, - request=request, + request=partition_query_request, metadata=all_metadata, retry=retry, timeout=timeout, ) - return method() + return partition_query_method() response = _retry( attempt_tracking_method, @@ -897,6 +902,55 @@ def attempt_tracking_method(): return [partition.partition_token for partition in response.partitions] + def _update_for_result_set_pb( + self, result_set_pb: Union[ResultSet, PartialResultSet] + ) -> None: + """Updates the snapshot for the given result set. + + :type result_set_pb: :class:`~google.cloud.spanner_v1.ResultSet` or + :class:`~google.cloud.spanner_v1.PartialResultSet` + :param result_set_pb: The result set to update the snapshot with. + """ + + if result_set_pb.metadata and result_set_pb.metadata.transaction: + self._update_for_transaction_pb(result_set_pb.metadata.transaction) + + if result_set_pb.precommit_token: + self._update_for_precommit_token_pb(result_set_pb.precommit_token) + + def _update_for_transaction_pb(self, transaction_pb: Transaction) -> None: + """Updates the snapshot for the given transaction. + + :type transaction_pb: :class:`~google.cloud.spanner_v1.Transaction` + :param transaction_pb: The transaction to update the snapshot with. + """ + + # The transaction ID should only be updated when the transaction is + # begun: either explicitly with a begin transaction request, or implicitly + # with read, execute SQL, batch update, or execute update requests. The + # caller is responsible for locking until the transaction ID is updated. + if self._transaction_id is None and transaction_pb.id: + self._transaction_id = transaction_pb.id + + if transaction_pb.precommit_token: + self._update_for_precommit_token_pb(transaction_pb.precommit_token) + + def _update_for_precommit_token_pb( + self, precommit_token_pb: MultiplexedSessionPrecommitToken + ) -> None: + """Updates the snapshot for the given multiplexed session precommit token. + :type precommit_token_pb: :class:`~google.cloud.spanner_v1.MultiplexedSessionPrecommitToken` + :param precommit_token_pb: The multiplexed session precommit token to update the snapshot with. + """ + + # Because multiple threads can be used to perform operations within a + # transaction, we need to use a lock when updating the precommit token. + with self._lock: + if self._precommit_token is None or ( + precommit_token_pb.seq_num > self._precommit_token.seq_num + ): + self._precommit_token = precommit_token_pb + class Snapshot(_SnapshotBase): """Allow a set of reads / SQL statements with shared staleness. @@ -966,6 +1020,7 @@ def __init__( self._multi_use = multi_use self._transaction_id = transaction_id + # TODO multiplexed - refactor to base class def _make_txn_selector(self): """Helper for :meth:`read`.""" if self._transaction_id is not None: @@ -998,6 +1053,7 @@ def _make_txn_selector(self): else: return TransactionSelector(single_use=options) + # TODO multiplexed - move to base class def begin(self): """Begin a read-only transaction on the database. @@ -1055,3 +1111,15 @@ def attempt_tracking_method(): self._transaction_id = response.id self._transaction_read_timestamp = response.read_timestamp return self._transaction_id + + def _update_for_transaction_pb(self, transaction_pb: Transaction) -> None: + """Updates the snapshot for the given transaction. + + :type transaction_pb: :class:`~google.cloud.spanner_v1.Transaction` + :param transaction_pb: The transaction to update the snapshot with. + """ + + super(Snapshot, self)._update_for_transaction_pb(transaction_pb) + + if transaction_pb.read_timestamp is not None: + self._transaction_read_timestamp = transaction_pb.read_timestamp diff --git a/google/cloud/spanner_v1/streamed.py b/google/cloud/spanner_v1/streamed.py index 5de843e103..a4e30ae2fc 100644 --- a/google/cloud/spanner_v1/streamed.py +++ b/google/cloud/spanner_v1/streamed.py @@ -50,7 +50,10 @@ def __init__( self._stats = None # Until set from last PRS self._current_row = [] # Accumulated values for incomplete row self._pending_chunk = None # Incomplete value + + # TODO multiplexed - remove self._source = source # Source snapshot + self._column_info = column_info # Column information self._field_decoders = None self._lazy_decode = lazy_decode # Return protobuf values diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index 396a61ada8..6e67cf0299 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -14,7 +14,6 @@ """Spanner read-write transaction support.""" import functools -import threading from google.protobuf.struct_pb2 import Struct from typing import Optional @@ -27,7 +26,12 @@ _check_rst_stream_error, _merge_Transaction_Options, ) -from google.cloud.spanner_v1 import CommitRequest +from google.cloud.spanner_v1 import ( + CommitRequest, + CommitResponse, + ResultSet, + ExecuteBatchDmlResponse, +) from google.cloud.spanner_v1 import ExecuteBatchDmlRequest from google.cloud.spanner_v1 import ExecuteSqlRequest from google.cloud.spanner_v1 import TransactionSelector @@ -57,19 +61,21 @@ class Transaction(_SnapshotBase, _BatchBase): """Timestamp at which the transaction was successfully committed.""" rolled_back = False commit_stats = None - _multi_use = True - _execute_sql_count = 0 - _lock = threading.Lock() - _read_only = False exclude_txn_from_change_streams = False isolation_level = TransactionOptions.IsolationLevel.ISOLATION_LEVEL_UNSPECIFIED + # Override defaults from _SnapshotBase. + _multi_use: bool = True + _read_only: bool = False + def __init__(self, session): + # TODO multiplexed - remove if session._transaction is not None: raise ValueError("Session has existing transaction.") super(Transaction, self).__init__(session) + # TODO multiplexed - remove def _check_state(self): """Helper for :meth:`commit` et al. @@ -83,6 +89,7 @@ def _check_state(self): if self.rolled_back: raise ValueError("Transaction is already rolled back") + # TODO multiplexed - refactor to base class def _make_txn_selector(self): """Helper for :meth:`read`. @@ -113,9 +120,7 @@ def _execute_request( request, metadata, trace_name=None, - session=None, attributes=None, - observability_options=None, ): """Helper method to execute request after fetching transaction selector. @@ -125,13 +130,18 @@ def _execute_request( :type request: proto :param request: request proto to call the method with """ + + session = self._session transaction = self._make_txn_selector() request.transaction = transaction + with trace_call( trace_name, session, attributes, - observability_options=observability_options, + observability_options=getattr( + session._database, "observability_options", None + ), metadata=metadata, ), MetricsCapture(): method = functools.partial(method, request=request) @@ -142,6 +152,7 @@ def _execute_request( return response + # TODO multiplexed - move to base class def begin(self): """Begin a transaction on the database. @@ -214,13 +225,17 @@ def before_next_retry(nth_retry, delay_in_seconds): self._transaction_id = response.id return self._transaction_id - def rollback(self): + def rollback(self) -> None: """Roll back a transaction on the database.""" + + # TODO multiplexed - cleanup self._check_state() if self._transaction_id is not None: - database = self._session._database + session = self._session + database = session._database api = database.spanner_api + metadata = _metadata_with_prefix(database.name) if database._route_to_leader_enabled: metadata.append( @@ -232,7 +247,7 @@ def rollback(self): observability_options = getattr(database, "observability_options", None) with trace_call( f"CloudSpanner.{type(self).__name__}.rollback", - self._session, + session, observability_options=observability_options, metadata=metadata, ) as span, MetricsCapture(): @@ -241,7 +256,7 @@ def rollback(self): def wrapped_method(*args, **kwargs): attempt.increment() - method = functools.partial( + rollback_method = functools.partial( api.rollback, session=self._session.name, transaction_id=self._transaction_id, @@ -252,7 +267,7 @@ def wrapped_method(*args, **kwargs): span, ), ) - return method(*args, **kwargs) + return rollback_method(*args, **kwargs) _retry( wrapped_method, @@ -260,6 +275,8 @@ def wrapped_method(*args, **kwargs): ) self.rolled_back = True + + # TODO multiplexed - remove del self._session._transaction def commit( @@ -288,28 +305,36 @@ def commit( :returns: timestamp of the committed changes. :raises ValueError: if there are no mutations to commit. """ - database = self._session._database - trace_attributes = {"num_mutations": len(self._mutations)} - observability_options = getattr(database, "observability_options", None) + + mutations = self._mutations + num_mutations = len(mutations) + + session = self._session + database = session._database api = database.spanner_api + metadata = _metadata_with_prefix(database.name) if database._route_to_leader_enabled: metadata.append( _metadata_with_leader_aware_routing(database._route_to_leader_enabled) ) + with trace_call( - f"CloudSpanner.{type(self).__name__}.commit", - self._session, - trace_attributes, - observability_options, + name=f"CloudSpanner.{type(self).__name__}.commit", + session=session, + extra_attributes={"num_mutations": num_mutations}, + observability_options=getattr(database, "observability_options", None), metadata=metadata, ) as span, MetricsCapture(): + # TODO multiplexed - cleanup self._check_state() - if self._transaction_id is None and len(self._mutations) > 0: - self.begin() - elif self._transaction_id is None and len(self._mutations) == 0: + if self._transaction_id is None and len(self._mutations) == 0: raise ValueError("Transaction is not begun") + # TODO multiplexed - begin transaction + if self._transaction_id is None and num_mutations > 0: + self.begin() + if request_options is None: request_options = RequestOptions() elif type(request_options) is dict: @@ -320,9 +345,9 @@ def commit( # Request tags are not supported for commit requests. request_options.request_tag = None - request = CommitRequest( - session=self._session.name, - mutations=self._mutations, + commit_request = CommitRequest( + session=session.name, + mutations=mutations, transaction_id=self._transaction_id, return_commit_stats=return_commit_stats, max_commit_delay=max_commit_delay, @@ -336,9 +361,9 @@ def commit( def wrapped_method(*args, **kwargs): attempt.increment() - method = functools.partial( + commit_method = functools.partial( api.commit, - request=request, + request=commit_request, metadata=database.metadata_with_request_id( nth_request, attempt.value, @@ -346,7 +371,7 @@ def wrapped_method(*args, **kwargs): span, ), ) - return method(*args, **kwargs) + return commit_method(*args, **kwargs) def before_next_retry(nth_retry, delay_in_seconds): add_span_event( @@ -355,18 +380,23 @@ def before_next_retry(nth_retry, delay_in_seconds): {"attempt": nth_retry, "sleep_seconds": delay_in_seconds}, ) - response = _retry( + response_pb: CommitResponse = _retry( wrapped_method, allowed_exceptions={InternalServerError: _check_rst_stream_error}, before_next_retry=before_next_retry, ) + # TODO multiplexed - retry commit if precommit token. + add_span_event(span, "Commit Done") - self.committed = response.commit_timestamp + self.committed = response_pb.commit_timestamp if return_commit_stats: - self.commit_stats = response.commit_stats + self.commit_stats = response_pb.commit_stats + + # TODO multiplexed - remove del self._session._transaction + return self.committed @staticmethod @@ -463,27 +493,28 @@ def execute_update( :rtype: int :returns: Count of rows affected by the DML statement. """ + + session = self._session + database = session._database + api = database.spanner_api + params_pb = self._make_params_pb(params, param_types) - database = self._session._database + metadata = _metadata_with_prefix(database.name) if database._route_to_leader_enabled: metadata.append( _metadata_with_leader_aware_routing(database._route_to_leader_enabled) ) - api = database.spanner_api - seqno, self._execute_sql_count = ( - self._execute_sql_count, - self._execute_sql_count + 1, + seqno, self._execute_sql_request_count = ( + self._execute_sql_request_count, + self._execute_sql_request_count + 1, ) # Query-level options have higher precedence than client-level and # environment-level options default_query_options = database._instance._client._query_options query_options = _merge_query_options(default_query_options, query_options) - observability_options = getattr( - database._instance._client, "observability_options", None - ) if request_options is None: request_options = RequestOptions() @@ -493,8 +524,17 @@ def execute_update( trace_attributes = {"db.statement": dml} - request = ExecuteSqlRequest( - session=self._session.name, + # If this request begins the transaction, we need to lock + # the transaction until the transaction ID is updated. + is_inline_begin = False + + if self._transaction_id is None: + is_inline_begin = True + self._lock.acquire() + + execute_sql_request = ExecuteSqlRequest( + session=session.name, + transaction=self._make_txn_selector(), sql=dml, params=params_pb, param_types=param_types, @@ -510,49 +550,31 @@ def execute_update( def wrapped_method(*args, **kwargs): attempt.increment() - method = functools.partial( + execute_sql_method = functools.partial( api.execute_sql, - request=request, + request=execute_sql_request, metadata=database.metadata_with_request_id( nth_request, attempt.value, metadata ), retry=retry, timeout=timeout, ) - return method(*args, **kwargs) + return execute_sql_method(*args, **kwargs) - if self._transaction_id is None: - # lock is added to handle the inline begin for first rpc - with self._lock: - response = self._execute_request( - wrapped_method, - request, - metadata, - f"CloudSpanner.{type(self).__name__}.execute_update", - self._session, - trace_attributes, - observability_options=observability_options, - ) - # Setting the transaction id because the transaction begin was inlined for first rpc. - if ( - self._transaction_id is None - and response is not None - and response.metadata is not None - and response.metadata.transaction is not None - ): - self._transaction_id = response.metadata.transaction.id - else: - response = self._execute_request( - wrapped_method, - request, - metadata, - f"CloudSpanner.{type(self).__name__}.execute_update", - self._session, - trace_attributes, - observability_options=observability_options, - ) + result_set_pb: ResultSet = self._execute_request( + wrapped_method, + execute_sql_request, + metadata, + f"CloudSpanner.{type(self).__name__}.execute_update", + trace_attributes, + ) + + self._update_for_result_set_pb(result_set_pb) - return response.stats.row_count_exact + if is_inline_begin: + self._lock.release() + + return result_set_pb.stats.row_count_exact def batch_update( self, @@ -610,6 +632,11 @@ def batch_update( statement triggering the error will not have an entry in the list, nor will any statements following that one. """ + + session = self._session + database = session._database + api = database.spanner_api + parsed = [] for statement in statements: if isinstance(statement, str): @@ -623,18 +650,15 @@ def batch_update( ) ) - database = self._session._database metadata = _metadata_with_prefix(database.name) if database._route_to_leader_enabled: metadata.append( _metadata_with_leader_aware_routing(database._route_to_leader_enabled) ) - api = database.spanner_api - observability_options = getattr(database, "observability_options", None) - seqno, self._execute_sql_count = ( - self._execute_sql_count, - self._execute_sql_count + 1, + seqno, self._execute_sql_request_count = ( + self._execute_sql_request_count, + self._execute_sql_request_count + 1, ) if request_options is None: @@ -647,8 +671,18 @@ def batch_update( # Get just the queries from the DML statement batch "db.statement": ";".join([statement.sql for statement in parsed]) } - request = ExecuteBatchDmlRequest( - session=self._session.name, + + # If this request begins the transaction, we need to lock + # the transaction until the transaction ID is updated. + is_inline_begin = False + + if self._transaction_id is None: + is_inline_begin = True + self._lock.acquire() + + execute_batch_dml_request = ExecuteBatchDmlRequest( + session=session.name, + transaction=self._make_txn_selector(), statements=parsed, seqno=seqno, request_options=request_options, @@ -660,54 +694,50 @@ def batch_update( def wrapped_method(*args, **kwargs): attempt.increment() - method = functools.partial( + execute_batch_dml_method = functools.partial( api.execute_batch_dml, - request=request, + request=execute_batch_dml_request, metadata=database.metadata_with_request_id( nth_request, attempt.value, metadata ), retry=retry, timeout=timeout, ) - return method(*args, **kwargs) + return execute_batch_dml_method(*args, **kwargs) - if self._transaction_id is None: - # lock is added to handle the inline begin for first rpc - with self._lock: - response = self._execute_request( - wrapped_method, - request, - metadata, - "CloudSpanner.DMLTransaction", - self._session, - trace_attributes, - observability_options=observability_options, - ) - # Setting the transaction id because the transaction begin was inlined for first rpc. - for result_set in response.result_sets: - if ( - self._transaction_id is None - and result_set.metadata is not None - and result_set.metadata.transaction is not None - ): - self._transaction_id = result_set.metadata.transaction.id - break - else: - response = self._execute_request( - wrapped_method, - request, - metadata, - "CloudSpanner.DMLTransaction", - self._session, - trace_attributes, - observability_options=observability_options, - ) + response_pb: ExecuteBatchDmlResponse = self._execute_request( + wrapped_method, + execute_batch_dml_request, + metadata, + "CloudSpanner.DMLTransaction", + trace_attributes, + ) + + self._update_for_execute_batch_dml_response_pb(response_pb) + + if is_inline_begin: + self._lock.release() row_counts = [ - result_set.stats.row_count_exact for result_set in response.result_sets + result_set.stats.row_count_exact for result_set in response_pb.result_sets ] - return response.status, row_counts + return response_pb.status, row_counts + + def _update_for_execute_batch_dml_response_pb( + self, response_pb: ExecuteBatchDmlResponse + ) -> None: + """Update the transaction for the given execute batch DML response. + + :type response_pb: :class:`~google.cloud.spanner_v1.types.ExecuteBatchDmlResponse` + :param response_pb: The execute batch DML response to update the transaction with. + """ + if response_pb.precommit_token: + self._update_for_precommit_token_pb(response_pb.precommit_token) + + # Only the first result set contains the result set metadata. + if len(response_pb.result_sets) > 0: + self._update_for_result_set_pb(response_pb.result_sets[0]) def __enter__(self): """Begin ``with`` block.""" diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index bb0db5db0f..0a4e789fd7 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -148,7 +148,8 @@ def _make_item(self, value, resume_token=b"", metadata=None): value=value, resume_token=resume_token, metadata=metadata, - spec=["value", "resume_token", "metadata"], + precommit_token=None, + spec=["value", "resume_token", "metadata", "precommit_token"], ) def test_iteration_w_empty_raw(self): @@ -666,7 +667,7 @@ def test_ctor(self): session = _Session() base = self._make_one(session) self.assertIs(base._session, session) - self.assertEqual(base._execute_sql_count, 0) + self.assertEqual(base._execute_sql_request_count, 0) self.assertNoSpans() @@ -953,7 +954,7 @@ def test_execute_sql_other_error(self): with self.assertRaises(RuntimeError): list(derived.execute_sql(SQL_QUERY)) - self.assertEqual(derived._execute_sql_count, 1) + self.assertEqual(derived._execute_sql_request_count, 1) req_id = f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1" self.assertSpanAttributes( @@ -1027,7 +1028,7 @@ def _execute_sql_helper( derived = self._makeDerived(session) derived._multi_use = multi_use derived._read_request_count = count - derived._execute_sql_count = sql_count + derived._execute_sql_request_count = sql_count if not first: derived._transaction_id = TXN_ID @@ -1120,7 +1121,7 @@ def _execute_sql_helper( retry=retry, ) - self.assertEqual(derived._execute_sql_count, sql_count + 1) + self.assertEqual(derived._execute_sql_request_count, sql_count + 1) self.assertSpanAttributes( "CloudSpanner._Derived.execute_sql", diff --git a/tests/unit/test_spanner.py b/tests/unit/test_spanner.py index 4acd7d3798..1d772228b9 100644 --- a/tests/unit/test_spanner.py +++ b/tests/unit/test_spanner.py @@ -152,7 +152,7 @@ def _execute_update_helper( transaction.transaction_tag = self.TRANSACTION_TAG transaction.exclude_txn_from_change_streams = exclude_txn_from_change_streams transaction.isolation_level = isolation_level - transaction._execute_sql_count = count + transaction._execute_sql_request_count = count row_count = transaction.execute_update( DML_QUERY_WITH_PARAM, @@ -246,7 +246,7 @@ def _execute_sql_helper( result_sets[i].values.extend(VALUE_PBS[i]) iterator = _MockIterator(*result_sets) api.execute_streaming_sql.return_value = iterator - transaction._execute_sql_count = sql_count + transaction._execute_sql_request_count = sql_count transaction._read_request_count = count result_set = transaction.execute_sql( @@ -267,7 +267,7 @@ def _execute_sql_helper( self.assertEqual(list(result_set), VALUES) self.assertEqual(result_set.metadata, metadata_pb) self.assertEqual(result_set.stats, stats_pb) - self.assertEqual(transaction._execute_sql_count, sql_count + 1) + self.assertEqual(transaction._execute_sql_request_count, sql_count + 1) def _execute_sql_expected_request( self, @@ -464,7 +464,7 @@ def _batch_update_helper( api.execute_batch_dml.return_value = response transaction.transaction_tag = self.TRANSACTION_TAG - transaction._execute_sql_count = count + transaction._execute_sql_request_count = count status, row_counts = transaction.batch_update( dml_statements, request_options=RequestOptions() @@ -472,7 +472,7 @@ def _batch_update_helper( self.assertEqual(status, expected_status) self.assertEqual(row_counts, expected_row_counts) - self.assertEqual(transaction._execute_sql_count, count + 1) + self.assertEqual(transaction._execute_sql_request_count, count + 1) def _batch_update_expected_request(self, begin=True, count=0): if begin is True: diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index e477ef27c6..71f8d956a8 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -104,7 +104,7 @@ def test_ctor_defaults(self): self.assertIsNone(transaction.committed) self.assertFalse(transaction.rolled_back) self.assertTrue(transaction._multi_use) - self.assertEqual(transaction._execute_sql_count, 0) + self.assertEqual(transaction._execute_sql_request_count, 0) def test__check_state_already_committed(self): session = _Session() @@ -349,7 +349,7 @@ def test_commit_not_begun(self): span_list = self.get_finished_spans() got_span_names = [span.name for span in span_list] want_span_names = ["CloudSpanner.Transaction.commit"] - assert got_span_names == want_span_names + self.assertEqual(got_span_names, want_span_names) got_span_events_statuses = self.finished_spans_events_statuses() want_span_events_statuses = [ @@ -363,7 +363,7 @@ def test_commit_not_begun(self): }, ) ] - assert got_span_events_statuses == want_span_events_statuses + self.assertEqual(got_span_events_statuses, want_span_events_statuses) def test_commit_already_committed(self): database = _Database() @@ -381,7 +381,7 @@ def test_commit_already_committed(self): span_list = self.get_finished_spans() got_span_names = [span.name for span in span_list] want_span_names = ["CloudSpanner.Transaction.commit"] - assert got_span_names == want_span_names + self.assertEqual(got_span_names, want_span_names) got_span_events_statuses = self.finished_spans_events_statuses() want_span_events_statuses = [ @@ -395,7 +395,7 @@ def test_commit_already_committed(self): }, ) ] - assert got_span_events_statuses == want_span_events_statuses + self.assertEqual(got_span_events_statuses, want_span_events_statuses) def test_commit_already_rolled_back(self): database = _Database() @@ -413,7 +413,7 @@ def test_commit_already_rolled_back(self): span_list = self.get_finished_spans() got_span_names = [span.name for span in span_list] want_span_names = ["CloudSpanner.Transaction.commit"] - assert got_span_names == want_span_names + self.assertEqual(got_span_names, want_span_names) got_span_events_statuses = self.finished_spans_events_statuses() want_span_events_statuses = [ @@ -427,7 +427,7 @@ def test_commit_already_rolled_back(self): }, ) ] - assert got_span_events_statuses == want_span_events_statuses + self.assertEqual(got_span_events_statuses, want_span_events_statuses) def test_commit_w_other_error(self): database = _Database() @@ -652,7 +652,7 @@ def _execute_update_helper( transaction = self._make_one(session) transaction._transaction_id = self.TRANSACTION_ID transaction.transaction_tag = self.TRANSACTION_TAG - transaction._execute_sql_count = count + transaction._execute_sql_request_count = count if request_options is None: request_options = RequestOptions() @@ -710,7 +710,7 @@ def _execute_update_helper( ], ) - self.assertEqual(transaction._execute_sql_count, count + 1) + self.assertEqual(transaction._execute_sql_request_count, count + 1) want_span_attributes = dict(TestTransaction.BASE_ATTRIBUTES) want_span_attributes["db.statement"] = DML_QUERY_WITH_PARAM self.assertSpanAttributes( @@ -773,7 +773,7 @@ def test_execute_update_error(self): with self.assertRaises(RuntimeError): transaction.execute_update(DML_QUERY) - self.assertEqual(transaction._execute_sql_count, 1) + self.assertEqual(transaction._execute_sql_request_count, 1) def test_execute_update_w_query_options(self): from google.cloud.spanner_v1 import ExecuteSqlRequest @@ -853,7 +853,7 @@ def _batch_update_helper( transaction = self._make_one(session) transaction._transaction_id = self.TRANSACTION_ID transaction.transaction_tag = self.TRANSACTION_TAG - transaction._execute_sql_count = count + transaction._execute_sql_request_count = count if request_options is None: request_options = RequestOptions() @@ -909,7 +909,7 @@ def _batch_update_helper( timeout=timeout, ) - self.assertEqual(transaction._execute_sql_count, count + 1) + self.assertEqual(transaction._execute_sql_request_count, count + 1) def test_batch_update_wo_errors(self): self._batch_update_helper( @@ -978,7 +978,7 @@ def test_batch_update_error(self): with self.assertRaises(RuntimeError): transaction.batch_update(dml_statements) - self.assertEqual(transaction._execute_sql_count, 1) + self.assertEqual(transaction._execute_sql_request_count, 1) def test_batch_update_w_timeout_param(self): self._batch_update_helper(timeout=2.0) From 56001b915c0620373bf134e52e584c74838b5e81 Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Fri, 30 May 2025 15:58:27 -0700 Subject: [PATCH 07/41] feat: Multiplexed sessions - Deprecate `StreamedResultSet._source` (redundant as transaction ID is set via `_restart_on_unavailable`) Signed-off-by: Taylor Curran --- google/cloud/spanner_v1/streamed.py | 12 ++---------- tests/unit/test_snapshot.py | 10 ---------- tests/unit/test_spanner.py | 2 -- tests/unit/test_streamed.py | 3 --- 4 files changed, 2 insertions(+), 25 deletions(-) diff --git a/google/cloud/spanner_v1/streamed.py b/google/cloud/spanner_v1/streamed.py index a4e30ae2fc..39b2151388 100644 --- a/google/cloud/spanner_v1/streamed.py +++ b/google/cloud/spanner_v1/streamed.py @@ -34,7 +34,7 @@ class StreamedResultSet(object): instances. :type source: :class:`~google.cloud.spanner_v1.snapshot.Snapshot` - :param source: Snapshot from which the result set was fetched. + :param source: Deprecated. Snapshot from which the result set was fetched. """ def __init__( @@ -50,10 +50,6 @@ def __init__( self._stats = None # Until set from last PRS self._current_row = [] # Accumulated values for incomplete row self._pending_chunk = None # Incomplete value - - # TODO multiplexed - remove - self._source = source # Source snapshot - self._column_info = column_info # Column information self._field_decoders = None self._lazy_decode = lazy_decode # Return protobuf values @@ -144,11 +140,7 @@ def _consume_next(self): response_pb = PartialResultSet.pb(response) if self._metadata is None: # first response - metadata = self._metadata = response_pb.metadata - - source = self._source - if source is not None and source._transaction_id is None: - source._transaction_id = metadata.transaction.id + self._metadata = response_pb.metadata if response_pb.HasField("stats"): # last response self._stats = response.stats diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index 0a4e789fd7..27fbf6841d 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -796,11 +796,6 @@ def _read_helper( self.assertEqual(derived._read_request_count, count + 1) - if multi_use: - self.assertIs(result_set._source, derived) - else: - self.assertIsNone(result_set._source) - self.assertEqual(list(result_set), VALUES) self.assertEqual(result_set.metadata, metadata_pb) self.assertEqual(result_set.stats, stats_pb) @@ -1052,11 +1047,6 @@ def _execute_sql_helper( self.assertEqual(derived._read_request_count, count + 1) - if multi_use: - self.assertIs(result_set._source, derived) - else: - self.assertIsNone(result_set._source) - self.assertEqual(list(result_set), VALUES) self.assertEqual(result_set.metadata, metadata_pb) self.assertEqual(result_set.stats, stats_pb) diff --git a/tests/unit/test_spanner.py b/tests/unit/test_spanner.py index 1d772228b9..eedf49d3ff 100644 --- a/tests/unit/test_spanner.py +++ b/tests/unit/test_spanner.py @@ -381,8 +381,6 @@ def _read_helper( self.assertEqual(transaction._read_request_count, count + 1) - self.assertIs(result_set._source, transaction) - self.assertEqual(list(result_set), VALUES) self.assertEqual(result_set.metadata, metadata_pb) self.assertEqual(result_set.stats, stats_pb) diff --git a/tests/unit/test_streamed.py b/tests/unit/test_streamed.py index 83aa25a9d1..e02afbede7 100644 --- a/tests/unit/test_streamed.py +++ b/tests/unit/test_streamed.py @@ -31,7 +31,6 @@ def test_ctor_defaults(self): iterator = _MockCancellableIterator() streamed = self._make_one(iterator) self.assertIs(streamed._response_iterator, iterator) - self.assertIsNone(streamed._source) self.assertEqual(list(streamed), []) self.assertIsNone(streamed.metadata) self.assertIsNone(streamed.stats) @@ -41,7 +40,6 @@ def test_ctor_w_source(self): source = object() streamed = self._make_one(iterator, source=source) self.assertIs(streamed._response_iterator, iterator) - self.assertIs(streamed._source, source) self.assertEqual(list(streamed), []) self.assertIsNone(streamed.metadata) self.assertIsNone(streamed.stats) @@ -807,7 +805,6 @@ def test_consume_next_first_set_partial(self): self.assertEqual(list(streamed), []) self.assertEqual(streamed._current_row, BARE) self.assertEqual(streamed.metadata, metadata) - self.assertEqual(source._transaction_id, TXN_ID) def test_consume_next_first_set_partial_existing_txn_id(self): from google.cloud.spanner_v1 import TypeCode From b4eadcac3272c55a928e844f48fce1bd1eb35cea Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Fri, 30 May 2025 16:32:27 -0700 Subject: [PATCH 08/41] feat: Multiplexed sessions - Move `_session_options` from `Database` to `Client` so that multiplexed are disabled for _all_ databases. Signed-off-by: Taylor Curran --- google/cloud/spanner_v1/client.py | 3 + google/cloud/spanner_v1/database.py | 67 +++++++++---------- .../spanner_v1/database_sessions_manager.py | 6 +- tests/unit/test_database.py | 2 + 4 files changed, 42 insertions(+), 36 deletions(-) diff --git a/google/cloud/spanner_v1/client.py b/google/cloud/spanner_v1/client.py index e0e8c44058..10db8c136e 100644 --- a/google/cloud/spanner_v1/client.py +++ b/google/cloud/spanner_v1/client.py @@ -60,6 +60,7 @@ from google.cloud.spanner_v1.metrics.metrics_exporter import ( CloudMonitoringMetricsExporter, ) +from google.cloud.spanner_v1.session_options import SessionOptions try: from opentelemetry import metrics @@ -269,6 +270,8 @@ def __init__( self._nth_client_id = Client.NTH_CLIENT.increment() self._nth_request = AtomicCounter(0) + self._session_options = SessionOptions() + @property def _next_nth_request(self): return self._nth_request.increment() diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index dda1e13c6c..a289df8e2f 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -62,7 +62,7 @@ from google.cloud.spanner_v1.pool import BurstyPool from google.cloud.spanner_v1.pool import SessionCheckout from google.cloud.spanner_v1.session import Session -from google.cloud.spanner_v1.session_options import SessionOptions +from google.cloud.spanner_v1.session_options import TransactionType from google.cloud.spanner_v1.database_sessions_manager import DatabaseSessionsManager from google.cloud.spanner_v1.snapshot import _restart_on_unavailable from google.cloud.spanner_v1.snapshot import Snapshot @@ -202,7 +202,6 @@ def __init__( self._pool = pool pool.bind(self) - self.session_options = SessionOptions() self._sessions_manager = DatabaseSessionsManager(self, pool) @classmethod @@ -764,11 +763,9 @@ def execute_pdml(): "CloudSpanner.Database.execute_partitioned_pdml", observability_options=self.observability_options, ) as span, MetricsCapture(): - from google.cloud.spanner_v1.session_options import TransactionType + transaction_type = TransactionType.PARTITIONED + session = self._sessions_manager.get_session(transaction_type) - session = self._sessions_manager.get_session( - TransactionType.PARTITIONED - ) try: add_span_event(span, "Starting BeginTransaction") txn = api.begin_transaction( @@ -1296,8 +1293,10 @@ def __init__( isolation_level=TransactionOptions.IsolationLevel.ISOLATION_LEVEL_UNSPECIFIED, **kw, ): - self._database = database - self._session = self._batch = None + self._database: Database = database + self._session: Session = None + self._batch: Batch = None + if request_options is None: self._request_options = RequestOptions() elif type(request_options) is dict: @@ -1311,16 +1310,19 @@ def __init__( def __enter__(self): """Begin ``with`` block.""" - from google.cloud.spanner_v1.session_options import TransactionType + transaction_type = TransactionType.READ_WRITE + self._session = self._database.sessions_manager.get_session(transaction_type) - current_span = get_current_span() - session = self._session = self._database.sessions_manager.get_session( - TransactionType.READ_WRITE + add_span_event( + span=get_current_span(), + event_name="Using session", + event_attributes={"id": self._session.session_id}, ) - add_span_event(current_span, "Using session", {"id": session.session_id}) - batch = self._batch = Batch(session) + + batch = self._batch = Batch(session=self._session) if self._request_options.transaction_tag: batch.transaction_tag = self._request_options.transaction_tag + return batch def __exit__(self, exc_type, exc_val, exc_tb): @@ -1364,17 +1366,15 @@ class MutationGroupsCheckout(object): """ def __init__(self, database): - self._database = database - self._session = None + self._database: Database = database + self._session: Session = None def __enter__(self): """Begin ``with`` block.""" - from google.cloud.spanner_v1.session_options import TransactionType + transaction_type = TransactionType.READ_WRITE + self._session = self._database.sessions_manager.get_session(transaction_type) - session = self._session = self._database.sessions_manager.get_session( - TransactionType.READ_WRITE - ) - return MutationGroups(session) + return MutationGroups(session=self._session) def __exit__(self, exc_type, exc_val, exc_tb): """End ``with`` block.""" @@ -1406,18 +1406,16 @@ class SnapshotCheckout(object): """ def __init__(self, database, **kw): - self._database = database - self._session = None - self._kw = kw + self._database: Database = database + self._session: Session = None + self._kw: dict = kw def __enter__(self): """Begin ``with`` block.""" - from google.cloud.spanner_v1.session_options import TransactionType + transaction_type = TransactionType.READ_ONLY + self._session = self._database.sessions_manager.get_session(transaction_type) - session = self._session = self._database.sessions_manager.get_session( - TransactionType.READ_ONLY - ) - return Snapshot(session, **self._kw) + return Snapshot(session=self._session, **self._kw) def __exit__(self, exc_type, exc_val, exc_tb): """End ``with`` block.""" @@ -1507,14 +1505,15 @@ def _get_session(self): all partitions have been processed. """ if self._session is None: - from google.cloud.spanner_v1.session_options import TransactionType - # Use sessions manager for partition operations - session = self._session = self._database.sessions_manager.get_session( - TransactionType.PARTITIONED + transaction_type = TransactionType.PARTITIONED + self._session = self._database.sessions_manager.get_session( + transaction_type ) + if self._session_id is not None: - session._session_id = self._session_id + self._session._session_id = self._session_id + return self._session def _get_snapshot(self): diff --git a/google/cloud/spanner_v1/database_sessions_manager.py b/google/cloud/spanner_v1/database_sessions_manager.py index 7e9e08175d..46862b6250 100644 --- a/google/cloud/spanner_v1/database_sessions_manager.py +++ b/google/cloud/spanner_v1/database_sessions_manager.py @@ -69,7 +69,7 @@ def get_session(self, transaction_type: TransactionType) -> Session: :returns: a session for the given transaction type. """ - session_options = self._database.session_options + session_options = self._database._instance._client._session_options use_multiplexed = session_options.use_multiplexed(transaction_type) # TODO multiplexed: enable for read/write transactions @@ -153,7 +153,9 @@ def _disable_multiplexed_sessions(self) -> None: self._multiplexed_session = None self._multiplexed_session_disabled_event.set() - self._database.session_options.disable_multiplexed(self._database.logger) + + session_options = self._database._instance._client._session_options + session_options.disable_multiplexed(self._database.logger) def _build_maintenance_thread(self) -> Thread: """Builds and returns a multiplexed session maintenance thread for diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index aee1c83f62..296c2d11dd 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -35,6 +35,7 @@ _metadata_with_request_id, ) from google.cloud.spanner_v1.request_id_header import REQ_RAND_PROCESS_ID +from google.cloud.spanner_v1.session_options import SessionOptions DML_WO_PARAM = """ DELETE FROM citizens @@ -3517,6 +3518,7 @@ def __init__( self.observability_options = observability_options self._nth_client_id = _Client.NTH_CLIENT.increment() self._nth_request = AtomicCounter() + self._session_options = SessionOptions() @property def _next_nth_request(self): From 68e9b675bec7f4b1d4bf8b93362ce9adfa8fd8e2 Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Fri, 30 May 2025 16:56:46 -0700 Subject: [PATCH 09/41] feat: Multiplexed sessions - Deprecate `SessionCheckout` and update `Database.run_in_transaction` to not use it. Signed-off-by: Taylor Curran --- google/cloud/spanner_v1/database.py | 12 ++++++++---- google/cloud/spanner_v1/pool.py | 8 +++++++- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index a289df8e2f..8d5a576045 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -60,7 +60,6 @@ from google.cloud.spanner_v1.keyset import KeySet from google.cloud.spanner_v1.merged_result_set import MergedResultSet from google.cloud.spanner_v1.pool import BurstyPool -from google.cloud.spanner_v1.pool import SessionCheckout from google.cloud.spanner_v1.session import Session from google.cloud.spanner_v1.session_options import TransactionType from google.cloud.spanner_v1.database_sessions_manager import DatabaseSessionsManager @@ -999,15 +998,20 @@ def run_in_transaction(self, func, *args, **kw): # is running. if getattr(self._local, "transaction_running", False): raise RuntimeError("Spanner does not support nested transactions.") + self._local.transaction_running = True # Check out a session and run the function in a transaction; once - # done, flip the sanity check bit back. + # done, flip the sanity check bit back and return the session. + transaction_type = TransactionType.READ_WRITE + session = self._sessions_manager.get_session(transaction_type) + try: - with SessionCheckout(self._pool) as session: - return session.run_in_transaction(func, *args, **kw) + return session.run_in_transaction(func, *args, **kw) + finally: self._local.transaction_running = False + self._sessions_manager.put_session(session) def restore(self, source): """Restore from a backup to this database. diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index 1c82f66ed0..2f21b46d25 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -137,6 +137,9 @@ def _new_session(self): def session(self, **kwargs): """Check out a session from the pool. + Deprecated. Sessions should be checked out using context + managers, rather than directly from the pool. + :param kwargs: (optional) keyword arguments, passed through to the returned checkout. @@ -792,6 +795,9 @@ def begin_pending_transactions(self): class SessionCheckout(object): """Context manager: hold session checked out from a pool. + Deprecated. Sessions should be checked out using context + managers, rather than directly from the pool. + :type pool: concrete subclass of :class:`~google.cloud.spanner_v1.pool.AbstractSessionPool` :param pool: Pool from which to check out a session. @@ -799,7 +805,7 @@ class SessionCheckout(object): :param kwargs: extra keyword arguments to be passed to :meth:`pool.get`. """ - _session = None # Not checked out until '__enter__'. + _session = None def __init__(self, pool, **kwargs): self._pool = pool From 6ca0d3f2b9ca37e6c6f4eaee27273bd1f6d51c5e Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Fri, 30 May 2025 17:06:23 -0700 Subject: [PATCH 10/41] feat: Multiplexed sessions - Deprecate `Database.session()` and minor cleanup. Signed-off-by: Taylor Curran --- google/cloud/spanner_v1/database.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 8d5a576045..f84b82d9a8 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -796,8 +796,9 @@ def execute_pdml(): iterator = _restart_on_unavailable( method=method, - trace_name="CloudSpanner.ExecuteStreamingSql", request=request, + trace_name="CloudSpanner.ExecuteStreamingSql", + session=session, metadata=metadata, transaction_selector=txn_selector, observability_options=self.observability_options, @@ -828,6 +829,9 @@ def _nth_client_id(self): def session(self, labels=None, database_role=None): """Factory to create a session for this database. + Deprecated. Sessions should be checked out using context + managers, rather than retrieved directly from the database. + :type labels: dict (str -> str) or None :param labels: (Optional) user-assigned labels for the session. @@ -1314,7 +1318,10 @@ def __init__( def __enter__(self): """Begin ``with`` block.""" - transaction_type = TransactionType.READ_WRITE + + # Batch transactions are performed as blind writes, + # which are treated as read-only transactions. + transaction_type = TransactionType.READ_ONLY self._session = self._database.sessions_manager.get_session(transaction_type) add_span_event( From 9057a64b094bd128ae25056192603eaa86ba650e Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Fri, 30 May 2025 18:06:29 -0700 Subject: [PATCH 11/41] feat: Multiplexed sessions - Update `BatchSnapshot` to use database session manager. Signed-off-by: Taylor Curran --- google/cloud/spanner_v1/database.py | 35 ++++++++++++++++++-------- tests/unit/test_database.py | 38 +++++++++++++---------------- 2 files changed, 42 insertions(+), 31 deletions(-) diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index f84b82d9a8..822531a435 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -1481,11 +1481,15 @@ def from_dict(cls, database, mapping): :rtype: :class:`BatchSnapshot` """ + instance = cls(database) - session = instance._session = database.session() - session._session_id = mapping["session_id"] + + session = instance._session = Session(database=database) + instance._session_id = session._session_id = mapping["session_id"] + snapshot = instance._snapshot = session.snapshot() - snapshot._transaction_id = mapping["transaction_id"] + instance._transaction_id = snapshot._transaction_id = mapping["transaction_id"] + return instance def to_dict(self): @@ -1516,19 +1520,28 @@ def _get_session(self): all partitions have been processed. """ if self._session is None: - # Use sessions manager for partition operations - transaction_type = TransactionType.PARTITIONED - self._session = self._database.sessions_manager.get_session( - transaction_type - ) + database = self._database + + # If the session ID is not specified, check out a new session for + # partitioned transactions from the database session manager; otherwise, + # the session has already been checked out, so just create a session to + # represent it. + if self._session_id is None: + transaction_type = TransactionType.PARTITIONED + session = database.sessions_manager.get_session(transaction_type) + self._session_id = session.session_id + + else: + session = Session(database=database) + session._session_id = self._session_id - if self._session_id is not None: - self._session._session_id = self._session_id + self._session = session return self._session def _get_snapshot(self): """Create snapshot if needed.""" + if self._snapshot is None: self._snapshot = self._get_session().snapshot( read_timestamp=self._read_timestamp, @@ -1536,8 +1549,10 @@ def _get_snapshot(self): multi_use=True, transaction_id=self._transaction_id, ) + if self._transaction_id is None: self._snapshot.begin() + return self._snapshot def get_batch_transaction_id(self): diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 296c2d11dd..b74c5cef2f 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -35,7 +35,9 @@ _metadata_with_request_id, ) from google.cloud.spanner_v1.request_id_header import REQ_RAND_PROCESS_ID -from google.cloud.spanner_v1.session_options import SessionOptions +from google.cloud.spanner_v1.session import Session +from google.cloud.spanner_v1.session_options import SessionOptions, TransactionType +from tests._builders import build_spanner_api DML_WO_PARAM = """ DELETE FROM citizens @@ -1509,8 +1511,6 @@ def test_execute_partitioned_dml_w_exclude_txn_from_change_streams(self): ) def test_session_factory_defaults(self): - from google.cloud.spanner_v1.session import Session - client = _Client() instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() @@ -1524,8 +1524,6 @@ def test_session_factory_defaults(self): self.assertEqual(session.labels, {}) def test_session_factory_w_labels(self): - from google.cloud.spanner_v1.session import Session - client = _Client() instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() @@ -2475,8 +2473,6 @@ def _make_database(**kwargs): @staticmethod def _make_session(**kwargs): - from google.cloud.spanner_v1.session import Session - return mock.create_autospec(Session, instance=True, **kwargs) @staticmethod @@ -2533,20 +2529,22 @@ def test_ctor_w_exact_staleness(self): def test_from_dict(self): klass = self._get_target_class() database = self._make_database() - session = database.session.return_value = self._make_session() - snapshot = session.snapshot.return_value = self._make_snapshot() - api_repr = { - "session_id": self.SESSION_ID, - "transaction_id": self.TRANSACTION_ID, - } + api = database.spanner_api = build_spanner_api() + + batch_txn = klass.from_dict( + database, + { + "session_id": self.SESSION_ID, + "transaction_id": self.TRANSACTION_ID, + }, + ) - batch_txn = klass.from_dict(database, api_repr) self.assertIs(batch_txn._database, database) - self.assertIs(batch_txn._session, session) - self.assertEqual(session._session_id, self.SESSION_ID) - self.assertEqual(snapshot._transaction_id, self.TRANSACTION_ID) - snapshot.begin.assert_not_called() - self.assertIs(batch_txn._snapshot, snapshot) + self.assertEqual(batch_txn._session._session_id, self.SESSION_ID) + self.assertEqual(batch_txn._snapshot._transaction_id, self.TRANSACTION_ID) + + api.create_session.assert_not_called() + api.begin_transaction.assert_not_called() def test_to_dict(self): database = self._make_database() @@ -2574,8 +2572,6 @@ def test__get_session_new(self): batch_txn = self._make_one(database) self.assertIs(batch_txn._get_session(), session) # Verify that sessions_manager.get_session was called with PARTITIONED transaction type - from google.cloud.spanner_v1.session_options import TransactionType - database.sessions_manager.get_session.assert_called_once_with( TransactionType.PARTITIONED ) From c9dd818f54d05be2b1766f8a12ca1dd6f73abffc Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Mon, 2 Jun 2025 10:36:37 -0700 Subject: [PATCH 12/41] feat: Multiplexed sessions - Move `Batch` and `Transaction` attributes from class attributes to instance attributes. Signed-off-by: Taylor Curran --- google/cloud/spanner_v1/batch.py | 43 +++++++++++++++++--------- google/cloud/spanner_v1/transaction.py | 13 ++++---- 2 files changed, 35 insertions(+), 21 deletions(-) diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index 2194cb9c0d..6056018160 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -14,8 +14,10 @@ """Context manager for Cloud Spanner batched writes.""" import functools +from datetime import datetime +from typing import List -from google.cloud.spanner_v1 import CommitRequest +from google.cloud.spanner_v1 import CommitRequest, CommitResponse from google.cloud.spanner_v1 import Mutation from google.cloud.spanner_v1 import TransactionOptions from google.cloud.spanner_v1 import BatchWriteRequest @@ -47,13 +49,17 @@ class _BatchBase(_SessionWrapper): :param session: the session used to perform the commit """ - transaction_tag = None - _read_only = False - def __init__(self, session): super(_BatchBase, self).__init__(session) - self._mutations = [] + self._mutations: List[Mutation] = [] + self.transaction_tag: str = None + + self.committed: datetime = None + """Timestamp at which the batch was successfully committed.""" + self.commit_stats: CommitResponse.CommitStats = None + + # TODO multiplexed - cleanup def _check_state(self): """Helper for :meth:`commit` et al. @@ -148,10 +154,7 @@ def delete(self, table, keyset): class Batch(_BatchBase): """Accumulate mutations for transmission during :meth:`commit`.""" - committed = None - commit_stats = None - """Timestamp at which the batch was successfully committed.""" - + # TODO multiplexed - cleanup def _check_state(self): """Helper for :meth:`commit` et al. @@ -163,6 +166,7 @@ def _check_state(self): if self.committed is not None: raise ValueError("Batch already committed") + # TODO multiplexed - cleanup kwargs def commit( self, return_commit_stats=False, @@ -205,7 +209,10 @@ def commit( :rtype: datetime :returns: timestamp of the committed changes. """ + + # TODO multiplexed - cleanup self._check_state() + database = self._session._database api = database.spanner_api metadata = _metadata_with_prefix(database.name) @@ -282,6 +289,8 @@ def wrapped_method(*args, **kwargs): def __enter__(self): """Begin ``with`` block.""" + + # TODO multiplexed - cleanup self._check_state() return self @@ -317,11 +326,10 @@ class MutationGroups(_SessionWrapper): :param session: the session used to perform the commit """ - committed = None - def __init__(self, session): super(MutationGroups, self).__init__(session) - self._mutation_groups = [] + self._mutation_groups: List[MutationGroup] = [] + self.committed: bool = False def _check_state(self): """Checks if the object's state is valid for making API requests. @@ -329,7 +337,7 @@ def _check_state(self): :raises: :exc:`ValueError` if the object's state is invalid for making API requests. """ - if self.committed is not None: + if self.committed: raise ValueError("MutationGroups already committed") def group(self): @@ -358,10 +366,14 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals :rtype: :class:`Iterable[google.cloud.spanner_v1.types.BatchWriteResponse]` :returns: a sequence of responses for each batch. """ + + # TODO multiplexed - cleanup self._check_state() - database = self._session._database + session = self._session + database = session._database api = database.spanner_api + metadata = _metadata_with_prefix(database.name) if database._route_to_leader_enabled: metadata.append( @@ -374,7 +386,7 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals request_options = RequestOptions(request_options) request = BatchWriteRequest( - session=self._session.name, + session=session.name, mutation_groups=self._mutation_groups, request_options=request_options, exclude_txn_from_change_streams=exclude_txn_from_change_streams, @@ -409,6 +421,7 @@ def wrapped_method(*args, **kwargs): InternalServerError: _check_rst_stream_error, }, ) + self.committed = True return response diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index 6e67cf0299..63e970750d 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -57,12 +57,10 @@ class Transaction(_SnapshotBase, _BatchBase): :raises ValueError: if session has an existing transaction """ - committed = None - """Timestamp at which the transaction was successfully committed.""" - rolled_back = False - commit_stats = None - exclude_txn_from_change_streams = False - isolation_level = TransactionOptions.IsolationLevel.ISOLATION_LEVEL_UNSPECIFIED + exclude_txn_from_change_streams: bool = False + isolation_level: TransactionOptions.IsolationLevel = ( + TransactionOptions.IsolationLevel.ISOLATION_LEVEL_UNSPECIFIED + ) # Override defaults from _SnapshotBase. _multi_use: bool = True @@ -74,6 +72,7 @@ def __init__(self, session): raise ValueError("Session has existing transaction.") super(Transaction, self).__init__(session) + self.rolled_back: bool = False # TODO multiplexed - remove def _check_state(self): @@ -97,6 +96,8 @@ def _make_txn_selector(self): :class:`~.transaction_pb2.TransactionSelector` :returns: a selector configured for read-write transaction semantics. """ + + # TODO multiplexed - remove self._check_state() if self._transaction_id is None: From 599939a979934e511989f1192ab2b3e1188eff5a Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Mon, 2 Jun 2025 15:53:48 -0700 Subject: [PATCH 13/41] feat: Multiplexed sessions - Update pools so they don't use deprecated `database.session()` Signed-off-by: Taylor Curran --- google/cloud/spanner_v1/pool.py | 15 ++++---- tests/unit/test_pool.py | 67 ++++++++++++++++----------------- 2 files changed, 40 insertions(+), 42 deletions(-) diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index 2f21b46d25..0257cf1211 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -20,7 +20,8 @@ from google.cloud.exceptions import NotFound from google.cloud.spanner_v1 import BatchCreateSessionsRequest -from google.cloud.spanner_v1 import Session +from google.cloud.spanner_v1 import Session as SessionProto +from google.cloud.spanner_v1.session import Session from google.cloud.spanner_v1._helpers import ( _metadata_with_prefix, _metadata_with_leader_aware_routing, @@ -130,9 +131,9 @@ def _new_session(self): :rtype: :class:`~google.cloud.spanner_v1.session.Session` :returns: new session instance. """ - return self._database.session( - labels=self.labels, database_role=self.database_role - ) + + role = self.database_role or self._database.database_role + return Session(database=self._database, labels=self.labels, database_role=role) def session(self, **kwargs): """Check out a session from the pool. @@ -240,7 +241,7 @@ def bind(self, database): request = BatchCreateSessionsRequest( database=database.name, session_count=requested_session_count, - session_template=Session(creator_role=self.database_role), + session_template=SessionProto(creator_role=self.database_role), ) observability_options = getattr(self._database, "observability_options", None) @@ -322,7 +323,7 @@ def get(self, timeout=None): "Session is not valid, recreating it", span_event_attributes, ) - session = self._database.session() + session = self._new_session() session.create() # Replacing with the updated session.id. span_event_attributes["session.id"] = session._session_id @@ -540,7 +541,7 @@ def bind(self, database): request = BatchCreateSessionsRequest( database=database.name, session_count=self.size, - session_template=Session(creator_role=self.database_role), + session_template=SessionProto(creator_role=self.database_role), ) span_event_attributes = {"kind": type(self).__name__} diff --git a/tests/unit/test_pool.py b/tests/unit/test_pool.py index 7c643bc0ea..409f4b043b 100644 --- a/tests/unit/test_pool.py +++ b/tests/unit/test_pool.py @@ -26,6 +26,7 @@ from google.cloud.spanner_v1.request_id_header import REQ_RAND_PROCESS_ID from google.cloud.spanner_v1._opentelemetry_tracing import trace_call +from tests._builders import build_database from tests._helpers import ( OpenTelemetryBase, LIB_VERSION, @@ -94,38 +95,35 @@ def test_clear_abstract(self): def test__new_session_wo_labels(self): pool = self._make_one() - database = pool._database = _make_database("name") - session = _make_session() - database.session.return_value = session + database = pool._database = build_database() new_session = pool._new_session() - self.assertIs(new_session, session) - database.session.assert_called_once_with(labels={}, database_role=None) + self.assertEqual(new_session._database, database) + self.assertEqual(new_session.labels, {}) + self.assertIsNone(new_session.database_role) def test__new_session_w_labels(self): labels = {"foo": "bar"} pool = self._make_one(labels=labels) - database = pool._database = _make_database("name") - session = _make_session() - database.session.return_value = session + database = pool._database = build_database() new_session = pool._new_session() - self.assertIs(new_session, session) - database.session.assert_called_once_with(labels=labels, database_role=None) + self.assertEqual(new_session._database, database) + self.assertEqual(new_session.labels, labels) + self.assertIsNone(new_session.database_role) def test__new_session_w_database_role(self): database_role = "dummy-role" pool = self._make_one(database_role=database_role) - database = pool._database = _make_database("name") - session = _make_session() - database.session.return_value = session + database = pool._database = build_database() new_session = pool._new_session() - self.assertIs(new_session, session) - database.session.assert_called_once_with(labels={}, database_role=database_role) + self.assertEqual(new_session._database, database) + self.assertEqual(new_session.labels, {}) + self.assertEqual(new_session.database_role, database_role) def test_session_wo_kwargs(self): from google.cloud.spanner_v1.pool import SessionCheckout @@ -215,7 +213,7 @@ def test_get_active(self): pool = self._make_one(size=4) database = _Database("name") SESSIONS = sorted([_Session(database) for i in range(0, 4)]) - database._sessions.extend(SESSIONS) + pool._new_session = mock.Mock(side_effect=SESSIONS) pool.bind(database) # check if sessions returned in LIFO order @@ -232,7 +230,7 @@ def test_get_non_expired(self): SESSIONS = sorted( [_Session(database, last_use_time=last_use_time) for i in range(0, 4)] ) - database._sessions.extend(SESSIONS) + pool._new_session = mock.Mock(side_effect=SESSIONS) pool.bind(database) # check if sessions returned in LIFO order @@ -339,8 +337,7 @@ def test_spans_pool_bind(self): # you have an empty pool. pool = self._make_one(size=1) database = _Database("name") - SESSIONS = [] - database._sessions.extend(SESSIONS) + pool._new_session = mock.Mock(side_effect=Exception("test")) fauxSession = mock.Mock() setattr(fauxSession, "_database", database) try: @@ -386,8 +383,8 @@ def test_spans_pool_bind(self): ( "exception", { - "exception.type": "IndexError", - "exception.message": "pop from empty list", + "exception.type": "Exception", + "exception.message": "test", "exception.stacktrace": "EPHEMERAL", "exception.escaped": "False", }, @@ -397,8 +394,8 @@ def test_spans_pool_bind(self): ( "exception", { - "exception.type": "IndexError", - "exception.message": "pop from empty list", + "exception.type": "Exception", + "exception.message": "test", "exception.stacktrace": "EPHEMERAL", "exception.escaped": "False", }, @@ -412,7 +409,7 @@ def test_get_expired(self): last_use_time = datetime.utcnow() - timedelta(minutes=65) SESSIONS = [_Session(database, last_use_time=last_use_time)] * 5 SESSIONS[0]._exists = False - database._sessions.extend(SESSIONS) + pool._new_session = mock.Mock(side_effect=SESSIONS) pool.bind(database) session = pool.get() @@ -475,7 +472,7 @@ def test_clear(self): pool = self._make_one() database = _Database("name") SESSIONS = [_Session(database)] * 10 - database._sessions.extend(SESSIONS) + pool._new_session = mock.Mock(side_effect=SESSIONS) pool.bind(database) self.assertTrue(pool._sessions.full()) @@ -539,7 +536,7 @@ def test_ctor_explicit_w_database_role_in_db(self): def test_get_empty(self): pool = self._make_one() database = _Database("name") - database._sessions.append(_Session(database)) + pool._new_session = mock.Mock(return_value=_Session(database)) pool.bind(database) session = pool.get() @@ -559,7 +556,7 @@ def test_spans_get_empty_pool(self): pool = self._make_one() database = _Database("name") session1 = _Session(database) - database._sessions.append(session1) + pool._new_session = mock.Mock(return_value=session1) pool.bind(database) with trace_call("pool.Get", session1): @@ -630,7 +627,7 @@ def test_get_non_empty_session_expired(self): database = _Database("name") previous = _Session(database, exists=False) newborn = _Session(database) - database._sessions.append(newborn) + pool._new_session = mock.Mock(return_value=newborn) pool.bind(database) pool.put(previous) @@ -811,7 +808,7 @@ def test_get_hit_no_ping(self): pool = self._make_one(size=4) database = _Database("name") SESSIONS = [_Session(database)] * 4 - database._sessions.extend(SESSIONS) + pool._new_session = mock.Mock(side_effect=SESSIONS) pool.bind(database) self.reset() @@ -830,7 +827,7 @@ def test_get_hit_w_ping(self): pool = self._make_one(size=4) database = _Database("name") SESSIONS = [_Session(database)] * 4 - database._sessions.extend(SESSIONS) + pool._new_session = mock.Mock(side_effect=SESSIONS) sessions_created = datetime.datetime.utcnow() - datetime.timedelta(seconds=4000) @@ -855,7 +852,7 @@ def test_get_hit_w_ping_expired(self): database = _Database("name") SESSIONS = [_Session(database)] * 5 SESSIONS[0]._exists = False - database._sessions.extend(SESSIONS) + pool._new_session = mock.Mock(side_effect=SESSIONS) sessions_created = datetime.datetime.utcnow() - datetime.timedelta(seconds=4000) @@ -974,7 +971,7 @@ def test_clear(self): pool = self._make_one() database = _Database("name") SESSIONS = [_Session(database)] * 10 - database._sessions.extend(SESSIONS) + pool._new_session = mock.Mock(side_effect=SESSIONS) pool.bind(database) self.reset() self.assertTrue(pool._sessions.full()) @@ -1016,7 +1013,7 @@ def test_ping_oldest_stale_but_exists(self): pool = self._make_one(size=1) database = _Database("name") SESSIONS = [_Session(database)] * 1 - database._sessions.extend(SESSIONS) + pool._new_session = mock.Mock(side_effect=SESSIONS) pool.bind(database) later = datetime.datetime.utcnow() + datetime.timedelta(seconds=4000) @@ -1034,7 +1031,7 @@ def test_ping_oldest_stale_and_not_exists(self): database = _Database("name") SESSIONS = [_Session(database)] * 2 SESSIONS[0]._exists = False - database._sessions.extend(SESSIONS) + pool._new_session = mock.Mock(side_effect=SESSIONS) pool.bind(database) self.reset() @@ -1055,7 +1052,7 @@ def test_spans_get_and_leave_empty_pool(self): pool = self._make_one() database = _Database("name") session1 = _Session(database) - database._sessions.append(session1) + pool._new_session = mock.Mock(side_effect=[session1, Exception]) try: pool.bind(database) except Exception: From 2065e52efe28de1485f353b19db8ae2c615a16ee Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Mon, 2 Jun 2025 15:57:48 -0700 Subject: [PATCH 14/41] feat: Multiplexed sessions - Update session to remove class attributes, add TODOs, and make `Session._transaction` default to None. Plus add some `Optional` typing hints. Signed-off-by: Taylor Curran --- google/cloud/spanner_v1/batch.py | 8 +-- google/cloud/spanner_v1/session.py | 72 ++++++++++++++++---------- google/cloud/spanner_v1/snapshot.py | 6 +-- google/cloud/spanner_v1/transaction.py | 6 +-- 4 files changed, 54 insertions(+), 38 deletions(-) diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index 6056018160..ea738f944f 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -15,7 +15,7 @@ """Context manager for Cloud Spanner batched writes.""" import functools from datetime import datetime -from typing import List +from typing import List, Optional from google.cloud.spanner_v1 import CommitRequest, CommitResponse from google.cloud.spanner_v1 import Mutation @@ -53,11 +53,11 @@ def __init__(self, session): super(_BatchBase, self).__init__(session) self._mutations: List[Mutation] = [] - self.transaction_tag: str = None + self.transaction_tag: Optional[str] = None - self.committed: datetime = None + self.committed: Optional[datetime] = None """Timestamp at which the batch was successfully committed.""" - self.commit_stats: CommitResponse.CommitStats = None + self.commit_stats: Optional[CommitResponse.CommitStats] = None # TODO multiplexed - cleanup def _check_state(self): diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index 78db192f30..51900344ab 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -17,6 +17,7 @@ from functools import total_ordering import time from datetime import datetime +from typing import MutableMapping, Optional from google.api_core.exceptions import Aborted from google.api_core.exceptions import GoogleAPICallError @@ -69,17 +70,20 @@ class Session(object): :param is_multiplexed: (Optional) whether this session is a multiplexed session. """ - _session_id = None - _transaction = None - def __init__(self, database, labels=None, database_role=None, is_multiplexed=False): self._database = database + self._session_id: Optional[str] = None + + # TODO multiplexed - remove + self._transaction: Optional[Transaction] = None + if labels is None: labels = {} - self._labels = labels - self._database_role = database_role - self._is_multiplexed = is_multiplexed - self._last_use_time = datetime.utcnow() + + self._labels: MutableMapping[str, str] = labels + self._database_role: Optional[str] = database_role + self._is_multiplexed: bool = is_multiplexed + self._last_use_time: datetime = datetime.utcnow() def __lt__(self, other): return self._session_id < other._session_id @@ -100,7 +104,7 @@ def is_multiplexed(self): @property def last_use_time(self): - """ "Approximate last use time of this session + """Approximate last use time of this session :rtype: datetime :returns: the approximate last use time of this session""" @@ -157,27 +161,28 @@ def create(self): if self._session_id is not None: raise ValueError("Session ID already set by back-end") - api = self._database.spanner_api - metadata = _metadata_with_prefix(self._database.name) - if self._database._route_to_leader_enabled: + + database = self._database + api = database.spanner_api + + metadata = _metadata_with_prefix(database.name) + if database._route_to_leader_enabled: metadata.append( - _metadata_with_leader_aware_routing( - self._database._route_to_leader_enabled - ) + _metadata_with_leader_aware_routing(database._route_to_leader_enabled) ) - request = CreateSessionRequest(database=self._database.name) - if self._database.database_role is not None: - request.session.creator_role = self._database.database_role + create_session_request = CreateSessionRequest(database=database.name) + if database.database_role is not None: + create_session_request.session.creator_role = database.database_role if self._labels: - request.session.labels = self._labels + create_session_request.session.labels = self._labels # Set the multiplexed field for multiplexed sessions if self._is_multiplexed: - request.session.multiplexed = True + create_session_request.session.multiplexed = True - observability_options = getattr(self._database, "observability_options", None) + observability_options = getattr(database, "observability_options", None) span_name = ( "CloudSpanner.CreateMultiplexedSession" if self._is_multiplexed @@ -191,9 +196,9 @@ def create(self): metadata=metadata, ) as span, MetricsCapture(): session_pb = api.create_session( - request=request, - metadata=self._database.metadata_with_request_id( - self._database._next_nth_request, + request=create_session_request, + metadata=database.metadata_with_request_id( + database._next_nth_request, 1, metadata, span, @@ -462,6 +467,7 @@ def batch(self): return Batch(self) + # TODO multiplexed - remove def transaction(self): """Create a transaction to perform a set of reads with shared staleness. @@ -474,7 +480,7 @@ def transaction(self): if self._transaction is not None: self._transaction.rolled_back = True - del self._transaction + self._transaction = None txn = self._transaction = Transaction(self) return txn @@ -531,6 +537,7 @@ def run_in_transaction(self, func, *args, **kw): observability_options=observability_options, ) as span, MetricsCapture(): while True: + # TODO multiplexed - remove if self._transaction is None: txn = self.transaction() txn.transaction_tag = transaction_tag @@ -552,8 +559,11 @@ def run_in_transaction(self, func, *args, **kw): return_value = func(txn, *args, **kw) + # TODO multiplexed: store previous transaction ID. except Aborted as exc: - del self._transaction + # TODO multiplexed - remove + self._transaction = None + if span: delay_seconds = _get_retry_delay( exc.errors[0], @@ -573,7 +583,9 @@ def run_in_transaction(self, func, *args, **kw): ) continue except GoogleAPICallError: - del self._transaction + # TODO multiplexed - remove + self._transaction = None + add_span_event( span, "User operation failed due to GoogleAPICallError, not retrying", @@ -596,7 +608,9 @@ def run_in_transaction(self, func, *args, **kw): max_commit_delay=max_commit_delay, ) except Aborted as exc: - del self._transaction + # TODO multiplexed - remove + self._transaction = None + if span: delay_seconds = _get_retry_delay( exc.errors[0], @@ -615,7 +629,9 @@ def run_in_transaction(self, func, *args, **kw): exc, deadline, attempts, default_retry_delay=default_retry_delay ) except GoogleAPICallError: - del self._transaction + # TODO multiplexed - remove + self._transaction = None + add_span_event( span, "Transaction.commit failed due to GoogleAPICallError, not retrying", diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 311381c160..6d3178c420 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -16,7 +16,7 @@ import functools import threading -from typing import List, Union +from typing import List, Union, Optional from google.protobuf.struct_pb2 import Struct from google.cloud.spanner_v1 import ( @@ -225,12 +225,12 @@ def __init__(self, session): self._read_request_count: int = 0 # Identifier for the transaction. - self._transaction_id: bytes = None + self._transaction_id: Optional[bytes] = None # Precommit tokens are returned for transactions with # multiplexed sessions. The precommit token with the # highest sequence number is included in the commit request. - self._precommit_token: MultiplexedSessionPrecommitToken = None + self._precommit_token: Optional[MultiplexedSessionPrecommitToken] = None # Operations within a transaction can be performed using multiple # threads, so we need to use a lock when updating the transaction. diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index 63e970750d..b8d3b5fe8e 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -259,7 +259,7 @@ def wrapped_method(*args, **kwargs): attempt.increment() rollback_method = functools.partial( api.rollback, - session=self._session.name, + session=session.name, transaction_id=self._transaction_id, metadata=database.metadata_with_request_id( nth_request, @@ -278,7 +278,7 @@ def wrapped_method(*args, **kwargs): self.rolled_back = True # TODO multiplexed - remove - del self._session._transaction + self._session._transaction = None def commit( self, return_commit_stats=False, request_options=None, max_commit_delay=None @@ -396,7 +396,7 @@ def before_next_retry(nth_retry, delay_in_seconds): self.commit_stats = response_pb.commit_stats # TODO multiplexed - remove - del self._session._transaction + self._session._transaction = None return self.committed From 7b925b33c9935c0441d84453ef0453a01c0912a5 Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Mon, 2 Jun 2025 22:47:27 -0700 Subject: [PATCH 15/41] feat: Multiplexed sessions - Move begin transaction logic from `Snapshot` to `_SnapshotBase` and update unit tests. Signed-off-by: Taylor Curran --- google/cloud/spanner_v1/snapshot.py | 144 ++++++++++++-------- tests/_builders.py | 14 +- tests/unit/spanner_dbapi/test_connection.py | 6 +- tests/unit/test_snapshot.py | 19 ++- 4 files changed, 114 insertions(+), 69 deletions(-) diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 6d3178c420..3bab95c20d 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -24,6 +24,8 @@ PartialResultSet, ResultSet, Transaction, + Mutation, + BeginTransactionRequest, ) from google.cloud.spanner_v1 import ReadRequest from google.cloud.spanner_v1 import TransactionOptions @@ -32,7 +34,7 @@ from google.cloud.spanner_v1 import PartitionQueryRequest from google.cloud.spanner_v1 import PartitionReadRequest -from google.api_core.exceptions import InternalServerError +from google.api_core.exceptions import InternalServerError, Aborted from google.api_core.exceptions import ServiceUnavailable from google.api_core.exceptions import InvalidArgument from google.api_core import gapic_v1 @@ -247,6 +249,16 @@ def _make_txn_selector(self): """ raise NotImplementedError + def begin(self) -> bytes: + """Begins a transaction on the database. + + :rtype: bytes + :returns: identifier for the transaction. + + :raises ValueError: if the transaction has already begun. + """ + return self._begin_transaction() + def read( self, table, @@ -902,6 +914,77 @@ def attempt_tracking_method(): return [partition.partition_token for partition in response.partitions] + def _begin_transaction(self, mutation: Mutation = None) -> bytes: + """Begins a transaction on the database. + + :type mutation: :class:`~google.cloud.spanner_v1.mutation.Mutation` + :param mutation: (Optional) Mutation to include in the begin transaction + request. Required for mutation-only transactions with multiplexed sessions. + + :rtype: bytes + :returns: identifier for the transaction. + + :raises ValueError: if the transaction has already begun or is single-use. + """ + + if self._transaction_id is not None: + raise ValueError("Transaction has already begun.") + if not self._multi_use: + raise ValueError("Cannot begin a single-use transaction.") + if self._read_request_count > 0: + raise ValueError("Read-only transaction already pending") + + session = self._session + database = session._database + api = database.spanner_api + + metadata = _metadata_with_prefix(database.name) + if not self._read_only and database._route_to_leader_enabled: + metadata.append( + (_metadata_with_leader_aware_routing(database._route_to_leader_enabled)) + ) + + with trace_call( + name=f"CloudSpanner.{type(self).__name__}.begin", + session=session, + observability_options=getattr(database, "observability_options", None), + metadata=metadata, + ) as span, MetricsCapture(): + nth_request = getattr(database, "_next_nth_request", 0) + attempt = AtomicCounter() + + def attempt_tracking_method(): + all_metadata = database.metadata_with_request_id( + nth_request, + attempt.increment(), + metadata, + span, + ) + begin_transaction_request = BeginTransactionRequest( + session=session.name, + options=self._make_txn_selector().begin, + mutation_key=mutation, + ) + begin_transaction_method = functools.partial( + api.begin_transaction, + request=begin_transaction_request, + metadata=all_metadata, + ) + return begin_transaction_method() + + # An aborted transaction may be raised by a mutations-only + # transaction with a multiplexed session. + transaction_pb: Transaction = _retry( + attempt_tracking_method, + allowed_exceptions={ + InternalServerError: _check_rst_stream_error, + Aborted: None, + }, + ) + + self._update_for_transaction_pb(transaction_pb) + return self._transaction_id + def _update_for_result_set_pb( self, result_set_pb: Union[ResultSet, PartialResultSet] ) -> None: @@ -1053,65 +1136,6 @@ def _make_txn_selector(self): else: return TransactionSelector(single_use=options) - # TODO multiplexed - move to base class - def begin(self): - """Begin a read-only transaction on the database. - - :rtype: bytes - :returns: the ID for the newly-begun transaction. - - :raises ValueError: - if the transaction is already begun, committed, or rolled back. - """ - if not self._multi_use: - raise ValueError("Cannot call 'begin' on single-use snapshots") - - if self._transaction_id is not None: - raise ValueError("Read-only transaction already begun") - - if self._read_request_count > 0: - raise ValueError("Read-only transaction already pending") - - database = self._session._database - api = database.spanner_api - metadata = _metadata_with_prefix(database.name) - if not self._read_only and database._route_to_leader_enabled: - metadata.append( - (_metadata_with_leader_aware_routing(database._route_to_leader_enabled)) - ) - txn_selector = self._make_txn_selector() - with trace_call( - f"CloudSpanner.{type(self).__name__}.begin", - self._session, - observability_options=getattr(database, "observability_options", None), - metadata=metadata, - ) as span, MetricsCapture(): - nth_request = getattr(database, "_next_nth_request", 0) - attempt = AtomicCounter() - - def attempt_tracking_method(): - all_metadata = database.metadata_with_request_id( - nth_request, - attempt.increment(), - metadata, - span, - ) - method = functools.partial( - api.begin_transaction, - session=self._session.name, - options=txn_selector.begin, - metadata=all_metadata, - ) - return method() - - response = _retry( - attempt_tracking_method, - allowed_exceptions={InternalServerError: _check_rst_stream_error}, - ) - self._transaction_id = response.id - self._transaction_read_timestamp = response.read_timestamp - return self._transaction_id - def _update_for_transaction_pb(self, transaction_pb: Transaction) -> None: """Updates the snapshot for the given transaction. diff --git a/tests/_builders.py b/tests/_builders.py index cc9e1ddebf..8944e378bd 100644 --- a/tests/_builders.py +++ b/tests/_builders.py @@ -21,6 +21,7 @@ from google.cloud.spanner_v1.database import Database from google.cloud.spanner_v1.instance import Instance from google.cloud.spanner_v1.session import Session +from google.cloud.spanner_v1.types import Session as SessionPB # Default values used to populate required or expected attributes. # Tests should not depend on them: if a test requires a specific @@ -28,6 +29,12 @@ _PROJECT_ID = "default-project-id" _INSTANCE_ID = "default-instance-id" _DATABASE_ID = "default-database-id" +_SESSION_ID = "default-session-id" + +_PROJECT_NAME = "projects/" + _PROJECT_ID +_INSTANCE_NAME = _PROJECT_NAME + "/instances/" + _INSTANCE_ID +_DATABASE_NAME = _INSTANCE_NAME + "/databases/" + _DATABASE_ID +_SESSION_NAME = _DATABASE_NAME + "/sessions/" + _SESSION_ID def build_client(**kwargs: Mapping) -> Client: @@ -104,4 +111,9 @@ def build_spanner_api() -> SpannerClient: """Builds and returns a mock Spanner Client API for testing using the given arguments. Commonly used methods are mocked to return default values.""" - return create_autospec(SpannerClient, instance=True) + api = create_autospec(SpannerClient, instance=True) + + # Mock API calls with default return values. + api.create_session.return_value = SessionPB(name=_SESSION_NAME) + + return api diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 4a9be916ce..d2501be20e 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -221,12 +221,12 @@ def test_transaction_checkout(self): self.assertIsNone(connection.transaction_checkout()) def test_snapshot_checkout(self): - connection = Connection(INSTANCE, DATABASE, read_only=True) + connection = build_connection(read_only=True) connection.autocommit = False - session_checkout = mock.MagicMock(autospec=True) + session_checkout = mock.Mock(wraps=connection._session_checkout) + release_session = mock.Mock(wraps=connection._release_session) connection._session_checkout = session_checkout - release_session = mock.MagicMock() connection._release_session = release_session snapshot = connection.snapshot_checkout() diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index 27fbf6841d..ae8f1b72da 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -16,7 +16,11 @@ from google.api_core import gapic_v1 import mock -from google.cloud.spanner_v1 import RequestOptions, DirectedReadOptions +from google.cloud.spanner_v1 import ( + RequestOptions, + DirectedReadOptions, + BeginTransactionRequest, +) from tests._helpers import ( OpenTelemetryBase, LIB_VERSION, @@ -1885,8 +1889,11 @@ def test_begin_ok_exact_staleness(self): req_id = f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1" api.begin_transaction.assert_called_once_with( - session=session.name, - options=expected_txn_options, + request=BeginTransactionRequest( + session=session.name, + options=expected_txn_options, + mutation_key=None, + ), metadata=[ ("google-cloud-resource-prefix", database.name), ( @@ -1928,8 +1935,10 @@ def test_begin_ok_exact_strong(self): req_id = f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1" api.begin_transaction.assert_called_once_with( - session=session.name, - options=expected_txn_options, + request=BeginTransactionRequest( + session=session.name, + options=expected_txn_options, + ), metadata=[ ("google-cloud-resource-prefix", database.name), ( From 9246dd237493b8cc3750dd9c15db74583cb8c059 Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Tue, 3 Jun 2025 10:58:24 -0700 Subject: [PATCH 16/41] feat: Multiplexed sessions - Remove begin transaction logic from `Transaction`, move to base class, update tests. Signed-off-by: Taylor Curran --- google/cloud/spanner_v1/batch.py | 3 +- google/cloud/spanner_v1/snapshot.py | 30 +++-- google/cloud/spanner_v1/transaction.py | 94 ++++---------- tests/_builders.py | 94 +++++++++++++- tests/unit/test_session.py | 96 +++++++++------ tests/unit/test_transaction.py | 162 +++++++++++++------------ 6 files changed, 273 insertions(+), 206 deletions(-) diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index ea738f944f..0aa66c4e39 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -14,7 +14,6 @@ """Context manager for Cloud Spanner batched writes.""" import functools -from datetime import datetime from typing import List, Optional from google.cloud.spanner_v1 import CommitRequest, CommitResponse @@ -55,7 +54,7 @@ def __init__(self, session): self._mutations: List[Mutation] = [] self.transaction_tag: Optional[str] = None - self.committed: Optional[datetime] = None + self.committed = None """Timestamp at which the batch was successfully committed.""" self.commit_stats: Optional[CommitResponse.CommitStats] = None diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 3bab95c20d..6094cd6d9c 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -48,7 +48,7 @@ _SessionWrapper, AtomicCounter, ) -from google.cloud.spanner_v1._opentelemetry_tracing import trace_call +from google.cloud.spanner_v1._opentelemetry_tracing import trace_call, add_span_event from google.cloud.spanner_v1.streamed import StreamedResultSet from google.cloud.spanner_v1 import RequestOptions @@ -953,13 +953,7 @@ def _begin_transaction(self, mutation: Mutation = None) -> bytes: nth_request = getattr(database, "_next_nth_request", 0) attempt = AtomicCounter() - def attempt_tracking_method(): - all_metadata = database.metadata_with_request_id( - nth_request, - attempt.increment(), - metadata, - span, - ) + def wrapped_method(): begin_transaction_request = BeginTransactionRequest( session=session.name, options=self._make_txn_selector().begin, @@ -968,14 +962,30 @@ def attempt_tracking_method(): begin_transaction_method = functools.partial( api.begin_transaction, request=begin_transaction_request, - metadata=all_metadata, + metadata=database.metadata_with_request_id( + nth_request, + attempt.increment(), + metadata, + span, + ), ) return begin_transaction_method() + def before_next_retry(nth_retry, delay_in_seconds): + add_span_event( + span=span, + event_name="Transaction Begin Attempt Failed. Retrying", + event_attributes={ + "attempt": nth_retry, + "sleep_seconds": delay_in_seconds, + }, + ) + # An aborted transaction may be raised by a mutations-only # transaction with a multiplexed session. transaction_pb: Transaction = _retry( - attempt_tracking_method, + wrapped_method, + before_next_retry=before_next_retry, allowed_exceptions={ InternalServerError: _check_rst_stream_error, Aborted: None, diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index b8d3b5fe8e..4bca7f588c 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -31,6 +31,7 @@ CommitResponse, ResultSet, ExecuteBatchDmlResponse, + Mutation, ) from google.cloud.spanner_v1 import ExecuteBatchDmlRequest from google.cloud.spanner_v1 import ExecuteSqlRequest @@ -153,79 +154,6 @@ def _execute_request( return response - # TODO multiplexed - move to base class - def begin(self): - """Begin a transaction on the database. - - :rtype: bytes - :returns: the ID for the newly-begun transaction. - :raises ValueError: - if the transaction is already begun, committed, or rolled back. - """ - if self._transaction_id is not None: - raise ValueError("Transaction already begun") - - if self.committed is not None: - raise ValueError("Transaction already committed") - - if self.rolled_back: - raise ValueError("Transaction is already rolled back") - - database = self._session._database - api = database.spanner_api - metadata = _metadata_with_prefix(database.name) - if database._route_to_leader_enabled: - metadata.append( - _metadata_with_leader_aware_routing(database._route_to_leader_enabled) - ) - txn_options = TransactionOptions( - read_write=TransactionOptions.ReadWrite(), - exclude_txn_from_change_streams=self.exclude_txn_from_change_streams, - isolation_level=self.isolation_level, - ) - txn_options = _merge_Transaction_Options( - database.default_transaction_options.default_read_write_transaction_options, - txn_options, - ) - observability_options = getattr(database, "observability_options", None) - with trace_call( - f"CloudSpanner.{type(self).__name__}.begin", - self._session, - observability_options=observability_options, - metadata=metadata, - ) as span, MetricsCapture(): - attempt = AtomicCounter(0) - nth_request = database._next_nth_request - - def wrapped_method(*args, **kwargs): - method = functools.partial( - api.begin_transaction, - session=self._session.name, - options=txn_options, - metadata=database.metadata_with_request_id( - nth_request, - attempt.increment(), - metadata, - span, - ), - ) - return method(*args, **kwargs) - - def before_next_retry(nth_retry, delay_in_seconds): - add_span_event( - span, - "Transaction Begin Attempt Failed. Retrying", - {"attempt": nth_retry, "sleep_seconds": delay_in_seconds}, - ) - - response = _retry( - wrapped_method, - allowed_exceptions={InternalServerError: _check_rst_stream_error}, - before_next_retry=before_next_retry, - ) - self._transaction_id = response.id - return self._transaction_id - def rollback(self) -> None: """Roll back a transaction on the database.""" @@ -725,6 +653,26 @@ def wrapped_method(*args, **kwargs): return response_pb.status, row_counts + def _begin_transaction(self, mutation: Mutation = None) -> bytes: + """Begins a transaction on the database. + + :type mutation: :class:`~google.cloud.spanner_v1.mutation.Mutation` + :param mutation: (Optional) Mutation to include in the begin transaction + request. Required for mutation-only transactions with multiplexed sessions. + + :rtype: bytes + :returns: identifier for the transaction. + + :raises ValueError: if the transaction has already begun or is single-use. + """ + + if self.committed is not None: + raise ValueError("Transaction is already committed") + if self.rolled_back: + raise ValueError("Transaction is already rolled back") + + return super(Transaction, self)._begin_transaction() + def _update_for_execute_batch_dml_response_pb( self, response_pb: ExecuteBatchDmlResponse ) -> None: diff --git a/tests/_builders.py b/tests/_builders.py index 8944e378bd..50816efed3 100644 --- a/tests/_builders.py +++ b/tests/_builders.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from datetime import datetime from logging import Logger from mock import create_autospec from typing import Mapping @@ -21,7 +22,14 @@ from google.cloud.spanner_v1.database import Database from google.cloud.spanner_v1.instance import Instance from google.cloud.spanner_v1.session import Session +from google.cloud.spanner_v1.transaction import Transaction + from google.cloud.spanner_v1.types import Session as SessionPB +from google.cloud.spanner_v1.types import Transaction as TransactionPB +from google.cloud.spanner_v1.types import CommitResponse as CommitResponsePB + +from google.cloud._helpers import _datetime_to_pb_timestamp +from tests._helpers import HAS_OPENTELEMETRY_INSTALLED, get_test_ot_exporter # Default values used to populate required or expected attributes. # Tests should not depend on them: if a test requires a specific @@ -30,12 +38,52 @@ _INSTANCE_ID = "default-instance-id" _DATABASE_ID = "default-database-id" _SESSION_ID = "default-session-id" +_TRANSACTION_ID = b"default-transaction-id" _PROJECT_NAME = "projects/" + _PROJECT_ID _INSTANCE_NAME = _PROJECT_NAME + "/instances/" + _INSTANCE_ID _DATABASE_NAME = _INSTANCE_NAME + "/databases/" + _DATABASE_ID _SESSION_NAME = _DATABASE_NAME + "/sessions/" + _SESSION_ID +_TIMESTAMP = _datetime_to_pb_timestamp(datetime.now()) + +# Protocol buffers +# ---------------- + + +def _build_commit_response_pb(**kwargs) -> CommitResponsePB: + """Builds and returns a commit response protocol buffer for testing using the given arguments. + If an expected argument is not provided, a default value will be used.""" + + if "commit_timestamp" not in kwargs: + kwargs["commit_timestamp"] = _TIMESTAMP + + return CommitResponsePB(**kwargs) + + +def build_session_pb(**kwargs) -> SessionPB: + """Builds and returns a session protocol buffer for testing using the given arguments. + If an expected argument is not provided, a default value will be used.""" + + if "name" not in kwargs: + kwargs["name"] = _SESSION_NAME + + return SessionPB(**kwargs) + + +def build_transaction_pb(**kwargs) -> TransactionPB: + """Builds and returns a transaction protocol buffer for testing using the given arguments.. + If an expected argument is not provided, a default value will be used.""" + + if "id" not in kwargs: + kwargs["id"] = _TRANSACTION_ID + + return TransactionPB(**kwargs) + + +# Client classes +# -------------- + def build_client(**kwargs: Mapping) -> Client: """Builds and returns a client for testing using the given arguments. @@ -92,11 +140,6 @@ def build_instance(**kwargs: Mapping) -> Instance: return Instance(**kwargs) -def build_logger() -> Logger: - """Builds and returns a logger for testing.""" - return create_autospec(Logger, instance=True) - - def build_session(**kwargs: Mapping) -> Session: """Builds and returns a session for testing using the given arguments. If a required argument is not provided, a default value will be used.""" @@ -107,6 +150,30 @@ def build_session(**kwargs: Mapping) -> Session: return Session(**kwargs) +def build_transaction(session=None) -> Transaction: + """Builds and returns a transaction for testing using the given arguments. + If a required argument is not provided, a default value will be used.""" + + session = session or build_session() + + # Ensure session exists. + if session.session_id is None: + session.create() + + _clear_spans() + return session.transaction() + + +# Other classes +# ------------- + + +def build_logger() -> Logger: + """Builds and returns a logger for testing.""" + + return create_autospec(Logger, instance=True) + + def build_spanner_api() -> SpannerClient: """Builds and returns a mock Spanner Client API for testing using the given arguments. Commonly used methods are mocked to return default values.""" @@ -114,6 +181,21 @@ def build_spanner_api() -> SpannerClient: api = create_autospec(SpannerClient, instance=True) # Mock API calls with default return values. - api.create_session.return_value = SessionPB(name=_SESSION_NAME) + api.begin_transaction.return_value = build_transaction_pb() + api.commit.return_value = _build_commit_response_pb() + api.create_session.return_value = build_session_pb() return api + + +# Helper functions +# ---------------- + + +def _clear_spans() -> None: + """Clears the spans collected by the OpenTelemetry exporter. + This ensures that spans generated while building test objects + do not interfere with the tests.""" + + if HAS_OPENTELEMETRY_INSTALLED: + get_test_ot_exporter().clear() diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 010d59e198..cba26d51e6 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -28,6 +28,7 @@ Session as SessionRequestProto, ExecuteSqlRequest, TypeCode, + BeginTransactionRequest, ) from google.cloud._helpers import UTC, _datetime_to_pb_timestamp from google.cloud.spanner_v1._helpers import _delay_until_retry @@ -1089,8 +1090,9 @@ def unit_of_work(txn, *args, **kw): expected_options = TransactionOptions(read_write=TransactionOptions.ReadWrite()) gax_api.begin_transaction.assert_called_once_with( - session=self.SESSION_NAME, - options=expected_options, + request=BeginTransactionRequest( + session=self.SESSION_NAME, options=expected_options + ), metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), @@ -1217,8 +1219,9 @@ def unit_of_work(txn, *args, **kw): gax_api.begin_transaction.call_args_list, [ mock.call( - session=self.SESSION_NAME, - options=expected_options, + request=BeginTransactionRequest( + session=self.SESSION_NAME, options=expected_options + ), metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), @@ -1229,8 +1232,9 @@ def unit_of_work(txn, *args, **kw): ], ), mock.call( - session=self.SESSION_NAME, - options=expected_options, + request=BeginTransactionRequest( + session=self.SESSION_NAME, options=expected_options + ), metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), @@ -1331,8 +1335,9 @@ def unit_of_work(txn, *args, **kw): gax_api.begin_transaction.call_args_list, [ mock.call( - session=self.SESSION_NAME, - options=expected_options, + request=BeginTransactionRequest( + session=self.SESSION_NAME, options=expected_options + ), metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), @@ -1343,8 +1348,9 @@ def unit_of_work(txn, *args, **kw): ], ), mock.call( - session=self.SESSION_NAME, - options=expected_options, + request=BeginTransactionRequest( + session=self.SESSION_NAME, options=expected_options + ), metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), @@ -1444,8 +1450,9 @@ def unit_of_work(txn, *args, **kw): # First call was aborted before commit operation, therefore no begin rpc was made during first attempt. gax_api.begin_transaction.assert_called_once_with( - session=self.SESSION_NAME, - options=expected_options, + request=BeginTransactionRequest( + session=self.SESSION_NAME, options=expected_options + ), metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), @@ -1528,8 +1535,9 @@ def _time(_results=[1, 1.5]): expected_options = TransactionOptions(read_write=TransactionOptions.ReadWrite()) gax_api.begin_transaction.assert_called_once_with( - session=self.SESSION_NAME, - options=expected_options, + request=BeginTransactionRequest( + session=self.SESSION_NAME, options=expected_options + ), metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), @@ -1608,8 +1616,9 @@ def _time(_results=[1, 2, 4, 8]): gax_api.begin_transaction.call_args_list, [ mock.call( - session=self.SESSION_NAME, - options=expected_options, + request=BeginTransactionRequest( + session=self.SESSION_NAME, options=expected_options + ), metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), @@ -1620,8 +1629,9 @@ def _time(_results=[1, 2, 4, 8]): ], ), mock.call( - session=self.SESSION_NAME, - options=expected_options, + request=BeginTransactionRequest( + session=self.SESSION_NAME, options=expected_options + ), metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), @@ -1632,8 +1642,9 @@ def _time(_results=[1, 2, 4, 8]): ], ), mock.call( - session=self.SESSION_NAME, - options=expected_options, + request=BeginTransactionRequest( + session=self.SESSION_NAME, options=expected_options + ), metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), @@ -1731,8 +1742,9 @@ def unit_of_work(txn, *args, **kw): expected_options = TransactionOptions(read_write=TransactionOptions.ReadWrite()) gax_api.begin_transaction.assert_called_once_with( - session=self.SESSION_NAME, - options=expected_options, + request=BeginTransactionRequest( + session=self.SESSION_NAME, options=expected_options + ), metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), @@ -1801,8 +1813,9 @@ def unit_of_work(txn, *args, **kw): expected_options = TransactionOptions(read_write=TransactionOptions.ReadWrite()) gax_api.begin_transaction.assert_called_once_with( - session=self.SESSION_NAME, - options=expected_options, + request=BeginTransactionRequest( + session=self.SESSION_NAME, options=expected_options + ), metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), @@ -1875,8 +1888,9 @@ def unit_of_work(txn, *args, **kw): expected_options = TransactionOptions(read_write=TransactionOptions.ReadWrite()) gax_api.begin_transaction.assert_called_once_with( - session=self.SESSION_NAME, - options=expected_options, + request=BeginTransactionRequest( + session=self.SESSION_NAME, options=expected_options + ), metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), @@ -1948,8 +1962,9 @@ def unit_of_work(txn, *args, **kw): exclude_txn_from_change_streams=True, ) gax_api.begin_transaction.assert_called_once_with( - session=self.SESSION_NAME, - options=expected_options, + request=BeginTransactionRequest( + session=self.SESSION_NAME, options=expected_options + ), metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), @@ -2042,8 +2057,9 @@ def unit_of_work(txn, *args, **kw): gax_api.begin_transaction.call_args_list, [ mock.call( - session=self.SESSION_NAME, - options=expected_options, + request=BeginTransactionRequest( + session=self.SESSION_NAME, options=expected_options + ), metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), @@ -2054,8 +2070,9 @@ def unit_of_work(txn, *args, **kw): ], ), mock.call( - session=self.SESSION_NAME, - options=expected_options, + request=BeginTransactionRequest( + session=self.SESSION_NAME, options=expected_options + ), metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), @@ -2125,8 +2142,9 @@ def unit_of_work(txn, *args, **kw): isolation_level=TransactionOptions.IsolationLevel.SERIALIZABLE, ) gax_api.begin_transaction.assert_called_once_with( - session=self.SESSION_NAME, - options=expected_options, + request=BeginTransactionRequest( + session=self.SESSION_NAME, options=expected_options + ), metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), @@ -2163,8 +2181,9 @@ def unit_of_work(txn, *args, **kw): isolation_level=TransactionOptions.IsolationLevel.SERIALIZABLE, ) gax_api.begin_transaction.assert_called_once_with( - session=self.SESSION_NAME, - options=expected_options, + request=BeginTransactionRequest( + session=self.SESSION_NAME, options=expected_options + ), metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), @@ -2205,8 +2224,9 @@ def unit_of_work(txn, *args, **kw): isolation_level=TransactionOptions.IsolationLevel.REPEATABLE_READ, ) gax_api.begin_transaction.assert_called_once_with( - session=self.SESSION_NAME, - options=expected_options, + request=BeginTransactionRequest( + session=self.SESSION_NAME, options=expected_options + ), metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index 71f8d956a8..85849c08a6 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -11,11 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from typing import Mapping import mock -from google.cloud.spanner_v1 import RequestOptions +from google.cloud.spanner_v1 import ( + RequestOptions, + BeginTransactionRequest, + TransactionOptions, + CommitRequest, +) from google.cloud.spanner_v1 import DefaultTransactionOptions from google.cloud.spanner_v1 import Type from google.cloud.spanner_v1 import TypeCode @@ -25,7 +30,9 @@ AtomicCounter, _metadata_with_request_id, ) +from google.cloud.spanner_v1.database import Database from google.cloud.spanner_v1.request_id_header import REQ_RAND_PROCESS_ID +from tests._builders import build_transaction, build_transaction_pb from tests._helpers import ( HAS_OPENTELEMETRY_INSTALLED, @@ -64,17 +71,6 @@ class TestTransaction(OpenTelemetryBase): TRANSACTION_ID = b"DEADBEEF" TRANSACTION_TAG = "transaction-tag" - BASE_ATTRIBUTES = { - "db.type": "spanner", - "db.url": "spanner.googleapis.com", - "db.instance": "testing", - "net.host.name": "spanner.googleapis.com", - "gcp.client.service": "spanner", - "gcp.client.version": LIB_VERSION, - "gcp.client.repo": "googleapis/python-spanner", - } - enrich_with_otel_scope(BASE_ATTRIBUTES) - def _getTargetClass(self): from google.cloud.spanner_v1.transaction import Transaction @@ -176,47 +172,42 @@ def test_begin_w_other_error(self): self.assertSpanAttributes( "CloudSpanner.Transaction.begin", status=StatusCode.ERROR, - attributes=dict( - TestTransaction.BASE_ATTRIBUTES, x_goog_spanner_request_id=req_id + attributes=self._build_span_attributes( + database, x_goog_spanner_request_id=req_id ), ) def test_begin_ok(self): - from google.cloud.spanner_v1 import Transaction as TransactionPB + transaction = build_transaction() + session = transaction._session + database = session._database - transaction_pb = TransactionPB(id=self.TRANSACTION_ID) - database = _Database() - api = database.spanner_api = _FauxSpannerAPI( - _begin_transaction_response=transaction_pb - ) - session = _Session(database) - transaction = self._make_one(session) + begin_transaction = database.spanner_api.begin_transaction + begin_transaction.return_value = build_transaction_pb(id=self.TRANSACTION_ID) - txn_id = transaction.begin() + transaction_id = transaction.begin() - self.assertEqual(txn_id, self.TRANSACTION_ID) + self.assertEqual(transaction_id, self.TRANSACTION_ID) self.assertEqual(transaction._transaction_id, self.TRANSACTION_ID) - session_id, txn_options, metadata = api._begun - self.assertEqual(session_id, session.name) - self.assertTrue(type(txn_options).pb(txn_options).HasField("read_write")) - req_id = f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1" - self.assertEqual( - metadata, - [ + request_id = self._build_request_id(database) + + begin_transaction.assert_called_once_with( + request=BeginTransactionRequest( + session=session.name, + options=TransactionOptions(read_write=TransactionOptions.ReadWrite()), + ), + metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), - ( - "x-goog-spanner-request-id", - req_id, - ), + ("x-goog-spanner-request-id", request_id), ], ) self.assertSpanAttributes( "CloudSpanner.Transaction.begin", - attributes=dict( - TestTransaction.BASE_ATTRIBUTES, x_goog_spanner_request_id=req_id + attributes=self._build_span_attributes( + database, x_goog_spanner_request_id=request_id ), ) @@ -291,8 +282,8 @@ def test_rollback_w_other_error(self): self.assertSpanAttributes( "CloudSpanner.Transaction.rollback", status=StatusCode.ERROR, - attributes=dict( - TestTransaction.BASE_ATTRIBUTES, x_goog_spanner_request_id=req_id + attributes=self._build_span_attributes( + database, x_goog_spanner_request_id=req_id ), ) @@ -330,8 +321,8 @@ def test_rollback_ok(self): self.assertSpanAttributes( "CloudSpanner.Transaction.rollback", - attributes=dict( - TestTransaction.BASE_ATTRIBUTES, x_goog_spanner_request_id=req_id + attributes=self._build_span_attributes( + database, x_goog_spanner_request_id=req_id ), ) @@ -447,10 +438,10 @@ def test_commit_w_other_error(self): self.assertSpanAttributes( "CloudSpanner.Transaction.commit", status=StatusCode.ERROR, - attributes=dict( - TestTransaction.BASE_ATTRIBUTES, - num_mutations=1, + attributes=self._build_span_attributes( + database, x_goog_spanner_request_id=req_id, + num_mutations=1, ), ) @@ -537,8 +528,8 @@ def _commit_helper( self.assertSpanAttributes( "CloudSpanner.Transaction.commit", - attributes=dict( - TestTransaction.BASE_ATTRIBUTES, + attributes=self._build_span_attributes( + database, num_mutations=len(transaction._mutations), x_goog_spanner_request_id=req_id, ), @@ -711,12 +702,11 @@ def _execute_update_helper( ) self.assertEqual(transaction._execute_sql_request_count, count + 1) - want_span_attributes = dict(TestTransaction.BASE_ATTRIBUTES) - want_span_attributes["db.statement"] = DML_QUERY_WITH_PARAM self.assertSpanAttributes( "CloudSpanner.Transaction.execute_update", - status=StatusCode.OK, - attributes=want_span_attributes, + attributes=self._build_span_attributes( + database, **{"db.statement": DML_QUERY_WITH_PARAM} + ), ) def test_execute_update_new_transaction(self): @@ -990,39 +980,27 @@ def test_batch_update_w_timeout_and_retry_params(self): self._batch_update_helper(retry=gapic_v1.method.DEFAULT, timeout=2.0) def test_context_mgr_success(self): - import datetime - from google.cloud.spanner_v1 import CommitResponse - from google.cloud.spanner_v1 import Transaction as TransactionPB - from google.cloud._helpers import UTC - - transaction_pb = TransactionPB(id=self.TRANSACTION_ID) - now = datetime.datetime.utcnow().replace(tzinfo=UTC) - response = CommitResponse(commit_timestamp=now) - database = _Database() - api = database.spanner_api = _FauxSpannerAPI( - _begin_transaction_response=transaction_pb, _commit_response=response - ) - session = _Session(database) - transaction = self._make_one(session) + transaction = build_transaction() + session = transaction._session + database = session._database + commit = database.spanner_api.commit with transaction: transaction.insert(TABLE_NAME, COLUMNS, VALUES) - self.assertEqual(transaction.committed, now) + self.assertEqual(transaction.committed, commit.return_value.commit_timestamp) - session_id, mutations, txn_id, _, _, metadata = api._committed - self.assertEqual(session_id, self.SESSION_NAME) - self.assertEqual(txn_id, self.TRANSACTION_ID) - self.assertEqual(mutations, transaction._mutations) - self.assertEqual( - metadata, - [ + commit.assert_called_once_with( + request=CommitRequest( + session=session.name, + transaction_id=transaction._transaction_id, + request_options=RequestOptions(), + mutations=transaction._mutations, + ), + metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), - ( - "x-goog-spanner-request-id", - f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.2.1", - ), + ("x-goog-spanner-request-id", self._build_request_id(database)), ], ) @@ -1051,6 +1029,36 @@ def test_context_mgr_failure(self): self.assertEqual(len(transaction._mutations), 1) self.assertEqual(api._committed, None) + @staticmethod + def _build_span_attributes( + database: Database, **extra_attributes + ) -> Mapping[str, str]: + """Builds the attributes for spans using the given database and extra attributes.""" + + attributes = enrich_with_otel_scope( + { + "db.type": "spanner", + "db.url": "spanner.googleapis.com", + "db.instance": database.name, + "net.host.name": "spanner.googleapis.com", + "gcp.client.service": "spanner", + "gcp.client.version": LIB_VERSION, + "gcp.client.repo": "googleapis/python-spanner", + } + ) + + if extra_attributes: + attributes.update(extra_attributes) + + return attributes + + @staticmethod + def _build_request_id(database: Database, attempt: int = 1) -> str: + """Builds a request ID for an Spanner Client API request with the given database and attempt number.""" + + client = database._instance._client + return f"1.{REQ_RAND_PROCESS_ID}.{client._nth_client_id}.{database._channel_id}.{client._nth_request.value}.{attempt}" + class _Client(object): NTH_CLIENT = AtomicCounter() From f1b3fdbdf05c161ebef2a57d0a2cdd3e789f2a6f Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Tue, 3 Jun 2025 11:30:15 -0700 Subject: [PATCH 17/41] feat: Multiplexed sessions - Add logic for beginning mutations-only transactions. Signed-off-by: Taylor Curran --- google/cloud/spanner_v1/transaction.py | 47 ++++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 3 deletions(-) diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index 4bca7f588c..a016a35394 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -260,9 +260,8 @@ def commit( if self._transaction_id is None and len(self._mutations) == 0: raise ValueError("Transaction is not begun") - # TODO multiplexed - begin transaction if self._transaction_id is None and num_mutations > 0: - self.begin() + self._begin_mutations_only_transaction() if request_options is None: request_options = RequestOptions() @@ -671,7 +670,49 @@ def _begin_transaction(self, mutation: Mutation = None) -> bytes: if self.rolled_back: raise ValueError("Transaction is already rolled back") - return super(Transaction, self)._begin_transaction() + return super(Transaction, self)._begin_transaction(mutation=mutation) + + def _begin_mutations_only_transaction(self) -> None: + """Begins a mutations-only transaction on the database.""" + + mutation = self._get_mutation_for_begin_mutations_only_transaction() + self._begin_transaction(mutation=mutation) + + def _get_mutation_for_begin_mutations_only_transaction(self) -> Optional[Mutation]: + """Returns a mutation to use for beginning a mutations-only transaction. + Returns None if a mutation does not need to be included. + + :rtype: :class:`~google.cloud.spanner_v1.types.Mutation` + :returns: A mutation to use for beginning a mutations-only transaction. + """ + + # A mutation only needs to be included + # for transaction with multiplexed sessions. + if not self._session.is_multiplexed: + return None + + mutations: list[Mutation] = self._mutations + + # If there are multiple mutations, select the mutation as follows: + # 1. Choose a delete, update, or replace mutation instead + # of an insert mutation (since inserts could involve an auto- + # generated column and the client doesn't have that information). + # 2. If there are no delete, update, or replace mutations, choose + # the insert mutation that includes the largest number of values. + + insert_mutation: Mutation = None + max_insert_values: int = -1 + + for mut in mutations: + if mut.insert: + num_values = len(mut.insert.values) + if num_values > max_insert_values: + insert_mutation = mut + max_insert_values = num_values + else: + return mut + + return insert_mutation def _update_for_execute_batch_dml_response_pb( self, response_pb: ExecuteBatchDmlResponse From 98c477d5dae0d811f4602cf5660c87e61235695b Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Tue, 3 Jun 2025 12:04:52 -0700 Subject: [PATCH 18/41] feat: Multiplexed sessions - Cleanup and improve consistency of state checks, add `raises` documentation. Signed-off-by: Taylor Curran --- google/cloud/spanner_v1/batch.py | 34 +++------------ google/cloud/spanner_v1/snapshot.py | 44 +++++++------------ google/cloud/spanner_v1/transaction.py | 58 +++++++++++++------------- tests/unit/test_batch.py | 6 --- tests/unit/test_transaction.py | 28 ++----------- 5 files changed, 53 insertions(+), 117 deletions(-) diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index 0aa66c4e39..c84b4f91e8 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -58,17 +58,6 @@ def __init__(self, session): """Timestamp at which the batch was successfully committed.""" self.commit_stats: Optional[CommitResponse.CommitStats] = None - # TODO multiplexed - cleanup - def _check_state(self): - """Helper for :meth:`commit` et al. - - Subclasses must override - - :raises: :exc:`ValueError` if the object's state is invalid for making - API requests. - """ - raise NotImplementedError - def insert(self, table, columns, values): """Insert one or more new table rows. @@ -153,18 +142,6 @@ def delete(self, table, keyset): class Batch(_BatchBase): """Accumulate mutations for transmission during :meth:`commit`.""" - # TODO multiplexed - cleanup - def _check_state(self): - """Helper for :meth:`commit` et al. - - Subclasses must override - - :raises: :exc:`ValueError` if the object's state is invalid for making - API requests. - """ - if self.committed is not None: - raise ValueError("Batch already committed") - # TODO multiplexed - cleanup kwargs def commit( self, @@ -207,10 +184,12 @@ def commit( :rtype: datetime :returns: timestamp of the committed changes. + + :raises: ValueError: if the transaction is not ready to commit. """ - # TODO multiplexed - cleanup - self._check_state() + if self.committed is not None: + raise ValueError("Transaction already committed.") database = self._session._database api = database.spanner_api @@ -288,9 +267,8 @@ def wrapped_method(*args, **kwargs): def __enter__(self): """Begin ``with`` block.""" - - # TODO multiplexed - cleanup - self._check_state() + if self.committed is not None: + raise ValueError("Transaction already committed") return self diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 6094cd6d9c..fa613bc572 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -353,17 +353,15 @@ def read( :rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. - :raises ValueError: - for reuse of single-use snapshots, or if a transaction ID is - already pending for multiple-use snapshots. + :raises ValueError: if the Transaction already used to execute a + read request, but is not a multi-use transaction or has not begun. """ - # TODO multiplexed - cleanup if self._read_request_count > 0: if not self._multi_use: raise ValueError("Cannot re-use single-use snapshot.") - if self._transaction_id is None and self._read_only: - raise ValueError("Transaction ID pending.") + if self._transaction_id is None: + raise ValueError("Transaction has not begun.") session = self._session database = session._database @@ -534,17 +532,15 @@ def execute_sql( objects. ``iterator.decode_column(row, column_index)`` decodes one specific column in the given row. - :raises ValueError: - for reuse of single-use snapshots, or if a transaction ID is - already pending for multiple-use snapshots. + :raises ValueError: if the Transaction already used to execute a + read request, but is not a multi-use transaction or has not begun. """ - # TODO multiplexed - cleanup if self._read_request_count > 0: if not self._multi_use: raise ValueError("Cannot re-use single-use snapshot.") - if self._transaction_id is None and self._read_only: - raise ValueError("Transaction ID pending.") + if self._transaction_id is None: + raise ValueError("Transaction has not begun.") if params is not None: params_pb = Struct( @@ -719,17 +715,13 @@ def partition_read( :rtype: iterable of bytes :returns: a sequence of partition tokens - :raises ValueError: - for single-use snapshots, or if a transaction ID is - already associated with the snapshot. + :raises ValueError: if the transaction has not begun or is single-use. """ - # TODO multiplexed - cleanup - if not self._multi_use: - raise ValueError("Cannot use single-use snapshot.") - if self._transaction_id is None: - raise ValueError("Transaction not started.") + raise ValueError("Transaction has not begun.") + if not self._multi_use: + raise ValueError("Cannot partition a single-use transaction.") session = self._session database = session._database @@ -838,17 +830,13 @@ def partition_query( :rtype: iterable of bytes :returns: a sequence of partition tokens - :raises ValueError: - for single-use snapshots, or if a transaction ID is - already associated with the snapshot. + :raises ValueError: if the transaction has not begun or is single-use. """ - # TODO multiplexed - cleanup - if not self._multi_use: - raise ValueError("Cannot use single-use snapshot.") - if self._transaction_id is None: - raise ValueError("Transaction not started.") + raise ValueError("Transaction has not begun.") + if not self._multi_use: + raise ValueError("Cannot partition a single-use transaction.") if params is not None: params_pb = Struct( diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index a016a35394..01df3cfe5b 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -75,32 +75,13 @@ def __init__(self, session): super(Transaction, self).__init__(session) self.rolled_back: bool = False - # TODO multiplexed - remove - def _check_state(self): - """Helper for :meth:`commit` et al. - - :raises: :exc:`ValueError` if the object's state is invalid for making - API requests. - """ - - if self.committed is not None: - raise ValueError("Transaction is already committed") - - if self.rolled_back: - raise ValueError("Transaction is already rolled back") - - # TODO multiplexed - refactor to base class def _make_txn_selector(self): """Helper for :meth:`read`. - :rtype: - :class:`~.transaction_pb2.TransactionSelector` + :rtype: :class:`~.transaction_pb2.TransactionSelector` :returns: a selector configured for read-write transaction semantics. """ - # TODO multiplexed - remove - self._check_state() - if self._transaction_id is None: txn_options = TransactionOptions( read_write=TransactionOptions.ReadWrite(), @@ -131,8 +112,15 @@ def _execute_request( :type request: proto :param request: request proto to call the method with + + :raises: ValueError: if the transaction is not ready to update. """ + if self.committed is not None: + raise ValueError("Transaction already committed.") + if self.rolled_back: + raise ValueError("Transaction already rolled back.") + session = self._session transaction = self._make_txn_selector() request.transaction = transaction @@ -155,10 +143,15 @@ def _execute_request( return response def rollback(self) -> None: - """Roll back a transaction on the database.""" + """Roll back a transaction on the database. - # TODO multiplexed - cleanup - self._check_state() + :raises: ValueError: if the transaction is not ready to roll back. + """ + + if self.committed is not None: + raise ValueError("Transaction already committed.") + if self.rolled_back: + raise ValueError("Transaction already rolled back.") if self._transaction_id is not None: session = self._session @@ -232,7 +225,8 @@ def commit( :rtype: datetime :returns: timestamp of the committed changes. - :raises ValueError: if there are no mutations to commit. + + :raises: ValueError: if the transaction is not ready to commit. """ mutations = self._mutations @@ -255,13 +249,17 @@ def commit( observability_options=getattr(database, "observability_options", None), metadata=metadata, ) as span, MetricsCapture(): - # TODO multiplexed - cleanup - self._check_state() - if self._transaction_id is None and len(self._mutations) == 0: - raise ValueError("Transaction is not begun") - if self._transaction_id is None and num_mutations > 0: - self._begin_mutations_only_transaction() + if self.committed is not None: + raise ValueError("Transaction already committed.") + if self.rolled_back: + raise ValueError("Transaction already rolled back.") + + if self._transaction_id is None: + if num_mutations > 0: + self._begin_mutations_only_transaction() + else: + raise ValueError("Transaction has not begun.") if request_options is None: request_options = RequestOptions() diff --git a/tests/unit/test_batch.py b/tests/unit/test_batch.py index cb3dc7e2cd..2056581d6f 100644 --- a/tests/unit/test_batch.py +++ b/tests/unit/test_batch.py @@ -94,12 +94,6 @@ def test_ctor(self): self.assertIs(base._session, session) self.assertEqual(len(base._mutations), 0) - def test__check_state_virtual(self): - session = _Session() - base = self._make_one(session) - with self.assertRaises(NotImplementedError): - base._check_state() - def test_insert(self): session = _Session() base = self._make_one(session) diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index 85849c08a6..201cf5de6c 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -102,28 +102,6 @@ def test_ctor_defaults(self): self.assertTrue(transaction._multi_use) self.assertEqual(transaction._execute_sql_request_count, 0) - def test__check_state_already_committed(self): - session = _Session() - transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID - transaction.committed = object() - with self.assertRaises(ValueError): - transaction._check_state() - - def test__check_state_already_rolled_back(self): - session = _Session() - transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID - transaction.rolled_back = True - with self.assertRaises(ValueError): - transaction._check_state() - - def test__check_state_ok(self): - session = _Session() - transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID - transaction._check_state() # does not raise - def test__make_txn_selector(self): session = _Session() transaction = self._make_one(session) @@ -348,7 +326,7 @@ def test_commit_not_begun(self): "exception", { "exception.type": "ValueError", - "exception.message": "Transaction is not begun", + "exception.message": "Transaction has not begun.", "exception.stacktrace": "EPHEMERAL", "exception.escaped": "False", }, @@ -380,7 +358,7 @@ def test_commit_already_committed(self): "exception", { "exception.type": "ValueError", - "exception.message": "Transaction is already committed", + "exception.message": "Transaction already committed.", "exception.stacktrace": "EPHEMERAL", "exception.escaped": "False", }, @@ -412,7 +390,7 @@ def test_commit_already_rolled_back(self): "exception", { "exception.type": "ValueError", - "exception.message": "Transaction is already rolled back", + "exception.message": "Transaction already rolled back.", "exception.stacktrace": "EPHEMERAL", "exception.escaped": "False", }, From 052f3e1100498cd77e2f1a3af79dcae950daa571 Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Tue, 3 Jun 2025 12:20:50 -0700 Subject: [PATCH 19/41] feat: Multiplexed sessions - Cleanup documentation for `Batch.commit`, some minor cleanup. Signed-off-by: Taylor Curran --- google/cloud/spanner_v1/batch.py | 101 ++++++++++++------------- google/cloud/spanner_v1/transaction.py | 1 - 2 files changed, 49 insertions(+), 53 deletions(-) diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index c84b4f91e8..0856b90d5f 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -150,7 +150,8 @@ def commit( max_commit_delay=None, exclude_txn_from_change_streams=False, isolation_level=TransactionOptions.IsolationLevel.ISOLATION_LEVEL_UNSPECIFIED, - **kwargs, + timeout_secs=DEFAULT_RETRY_TIMEOUT_SECS, + default_retry_delay=None, ): """Commit mutations to the database. @@ -182,6 +183,12 @@ def commit( :param isolation_level: (Optional) Sets isolation level for the transaction. + :type timeout_secs: int + :param timeout_secs: (Optional) The maximum time in seconds to wait for the commit to complete. + + :type default_retry_delay: int + :param timeout_secs: (Optional) The default time in seconds to wait before re-trying the commit.. + :rtype: datetime :returns: timestamp of the committed changes. @@ -191,8 +198,11 @@ def commit( if self.committed is not None: raise ValueError("Transaction already committed.") - database = self._session._database + mutations = self._mutations + session = self._session + database = session._database api = database.spanner_api + metadata = _metadata_with_prefix(database.name) if database._route_to_leader_enabled: metadata.append( @@ -208,7 +218,6 @@ def commit( database.default_transaction_options.default_read_write_transaction_options, txn_options, ) - trace_attributes = {"num_mutations": len(self._mutations)} if request_options is None: request_options = RequestOptions() @@ -219,27 +228,26 @@ def commit( # Request tags are not supported for commit requests. request_options.request_tag = None - request = CommitRequest( - session=self._session.name, - mutations=self._mutations, - single_use_transaction=txn_options, - return_commit_stats=return_commit_stats, - max_commit_delay=max_commit_delay, - request_options=request_options, - ) - observability_options = getattr(database, "observability_options", None) with trace_call( - f"CloudSpanner.{type(self).__name__}.commit", - self._session, - trace_attributes, - observability_options=observability_options, + name=f"CloudSpanner.{type(self).__name__}.commit", + session=session, + extra_attributes={"num_mutations": len(mutations)}, + observability_options=getattr(database, "observability_options", None), metadata=metadata, ) as span, MetricsCapture(): - def wrapped_method(*args, **kwargs): - method = functools.partial( + def wrapped_method(): + commit_request = CommitRequest( + session=session.name, + mutations=mutations, + single_use_transaction=txn_options, + return_commit_stats=return_commit_stats, + max_commit_delay=max_commit_delay, + request_options=request_options, + ) + commit_method = functools.partial( api.commit, - request=request, + request=commit_request, metadata=database.metadata_with_request_id( # This code is retried due to ABORTED, hence nth_request # should be increased. attempt can only be increased if @@ -250,19 +258,17 @@ def wrapped_method(*args, **kwargs): span, ), ) - return method(*args, **kwargs) + return commit_method() - deadline = time.time() + kwargs.get( - "timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS - ) - default_retry_delay = kwargs.get("default_retry_delay", None) response = _retry_on_aborted_exception( wrapped_method, - deadline=deadline, + deadline=time.time() + timeout_secs, default_retry_delay=default_retry_delay, ) + self.committed = response.commit_timestamp self.commit_stats = response.commit_stats + return self.committed def __enter__(self): @@ -308,15 +314,6 @@ def __init__(self, session): self._mutation_groups: List[MutationGroup] = [] self.committed: bool = False - def _check_state(self): - """Checks if the object's state is valid for making API requests. - - :raises: :exc:`ValueError` if the object's state is invalid for making - API requests. - """ - if self.committed: - raise ValueError("MutationGroups already committed") - def group(self): """Returns a new `MutationGroup` to which mutations can be added.""" mutation_group = BatchWriteRequest.MutationGroup() @@ -344,9 +341,10 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals :returns: a sequence of responses for each batch. """ - # TODO multiplexed - cleanup - self._check_state() + if self.committed: + raise ValueError("MutationGroups already committed") + mutation_groups = self._mutation_groups session = self._session database = session._database api = database.spanner_api @@ -356,33 +354,32 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals metadata.append( _metadata_with_leader_aware_routing(database._route_to_leader_enabled) ) - trace_attributes = {"num_mutation_groups": len(self._mutation_groups)} + if request_options is None: request_options = RequestOptions() elif type(request_options) is dict: request_options = RequestOptions(request_options) - request = BatchWriteRequest( - session=session.name, - mutation_groups=self._mutation_groups, - request_options=request_options, - exclude_txn_from_change_streams=exclude_txn_from_change_streams, - ) - observability_options = getattr(database, "observability_options", None) with trace_call( - "CloudSpanner.batch_write", - self._session, - trace_attributes, - observability_options=observability_options, + name="CloudSpanner.batch_write", + session=session, + extra_attributes={"num_mutation_groups": len(mutation_groups)}, + observability_options=getattr(database, "observability_options", None), metadata=metadata, ) as span, MetricsCapture(): attempt = AtomicCounter(0) nth_request = getattr(database, "_next_nth_request", 0) - def wrapped_method(*args, **kwargs): - method = functools.partial( + def wrapped_method(): + batch_write_request = BatchWriteRequest( + session=session.name, + mutation_groups=mutation_groups, + request_options=request_options, + exclude_txn_from_change_streams=exclude_txn_from_change_streams, + ) + batch_write_method = functools.partial( api.batch_write, - request=request, + request=batch_write_request, metadata=database.metadata_with_request_id( nth_request, attempt.increment(), @@ -390,7 +387,7 @@ def wrapped_method(*args, **kwargs): span, ), ) - return method(*args, **kwargs) + return batch_write_method() response = _retry( wrapped_method, diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index 01df3cfe5b..ceed7dec1b 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -249,7 +249,6 @@ def commit( observability_options=getattr(database, "observability_options", None), metadata=metadata, ) as span, MetricsCapture(): - if self.committed is not None: raise ValueError("Transaction already committed.") if self.rolled_back: From 2b9f212cae57d83b24cea90d6017ffc8a5917d01 Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Tue, 3 Jun 2025 12:46:18 -0700 Subject: [PATCH 20/41] feat: Multiplexed sessions - Add logic for retrying commits if precommit token returned. Signed-off-by: Taylor Curran --- google/cloud/spanner_v1/batch.py | 1 - google/cloud/spanner_v1/transaction.py | 52 ++++++++++++++++++-------- tests/unit/test_session.py | 19 ++++------ 3 files changed, 43 insertions(+), 29 deletions(-) diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index 0856b90d5f..ab58bdec7a 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -142,7 +142,6 @@ def delete(self, table, keyset): class Batch(_BatchBase): """Accumulate mutations for transmission during :meth:`commit`.""" - # TODO multiplexed - cleanup kwargs def commit( self, return_commit_stats=False, diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index ceed7dec1b..8dfb0281e4 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -270,14 +270,13 @@ def commit( # Request tags are not supported for commit requests. request_options.request_tag = None - commit_request = CommitRequest( - session=session.name, - mutations=mutations, - transaction_id=self._transaction_id, - return_commit_stats=return_commit_stats, - max_commit_delay=max_commit_delay, - request_options=request_options, - ) + common_commit_request_args = { + "session": session.name, + "transaction_id": self._transaction_id, + "return_commit_stats": return_commit_stats, + "max_commit_delay": max_commit_delay, + "request_options": request_options, + } add_span_event(span, "Starting Commit") @@ -288,7 +287,11 @@ def wrapped_method(*args, **kwargs): attempt.increment() commit_method = functools.partial( api.commit, - request=commit_request, + request=CommitRequest( + mutations=mutations, + precommit_token=self._precommit_token, + **common_commit_request_args, + ), metadata=database.metadata_with_request_id( nth_request, attempt.value, @@ -298,26 +301,43 @@ def wrapped_method(*args, **kwargs): ) return commit_method(*args, **kwargs) + commit_retry_event_name = "Transaction Commit Attempt Failed. Retrying" + def before_next_retry(nth_retry, delay_in_seconds): add_span_event( - span, - "Transaction Commit Attempt Failed. Retrying", - {"attempt": nth_retry, "sleep_seconds": delay_in_seconds}, + span=span, + event_name=commit_retry_event_name, + event_attributes={ + "attempt": nth_retry, + "sleep_seconds": delay_in_seconds, + }, ) - response_pb: CommitResponse = _retry( + commit_response_pb: CommitResponse = _retry( wrapped_method, allowed_exceptions={InternalServerError: _check_rst_stream_error}, before_next_retry=before_next_retry, ) - # TODO multiplexed - retry commit if precommit token. + # If the response contains a precommit token, the transaction did not + # successfully commit, and must be retried with the new precommit token. + # The mutations should not be included in the new request, and no further + # retries or exception handling should be performed. + if commit_response_pb.precommit_token: + add_span_event(span, commit_retry_event_name) + commit_response_pb = api.commit( + request=CommitRequest( + precommit_token=commit_response_pb.precommit_token, + **common_commit_request_args, + ), + metadata=metadata, + ) add_span_event(span, "Commit Done") - self.committed = response_pb.commit_timestamp + self.committed = commit_response_pb.commit_timestamp if return_commit_stats: - self.commit_stats = response_pb.commit_stats + self.commit_stats = commit_response_pb.commit_stats # TODO multiplexed - remove self._session._transaction = None diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index cba26d51e6..1052d21dcd 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -33,6 +33,7 @@ from google.cloud._helpers import UTC, _datetime_to_pb_timestamp from google.cloud.spanner_v1._helpers import _delay_until_retry from google.cloud.spanner_v1.transaction import Transaction +from tests._builders import build_spanner_api from tests._helpers import ( OpenTelemetryBase, LIB_VERSION, @@ -2119,10 +2120,8 @@ def unit_of_work(txn, *args, **kw): ) def test_run_in_transaction_w_isolation_level_at_request(self): - gax_api = self._make_spanner_api() - gax_api.begin_transaction.return_value = TransactionPB(id=b"FACEDACE") database = self._make_database() - database.spanner_api = gax_api + api = database.spanner_api = build_spanner_api() session = self._make_one(database) session._session_id = self.SESSION_ID @@ -2141,7 +2140,7 @@ def unit_of_work(txn, *args, **kw): read_write=TransactionOptions.ReadWrite(), isolation_level=TransactionOptions.IsolationLevel.SERIALIZABLE, ) - gax_api.begin_transaction.assert_called_once_with( + api.begin_transaction.assert_called_once_with( request=BeginTransactionRequest( session=self.SESSION_NAME, options=expected_options ), @@ -2156,14 +2155,12 @@ def unit_of_work(txn, *args, **kw): ) def test_run_in_transaction_w_isolation_level_at_client(self): - gax_api = self._make_spanner_api() - gax_api.begin_transaction.return_value = TransactionPB(id=b"FACEDACE") database = self._make_database( default_transaction_options=DefaultTransactionOptions( isolation_level="SERIALIZABLE" ) ) - database.spanner_api = gax_api + api = database.spanner_api = build_spanner_api() session = self._make_one(database) session._session_id = self.SESSION_ID @@ -2180,7 +2177,7 @@ def unit_of_work(txn, *args, **kw): read_write=TransactionOptions.ReadWrite(), isolation_level=TransactionOptions.IsolationLevel.SERIALIZABLE, ) - gax_api.begin_transaction.assert_called_once_with( + api.begin_transaction.assert_called_once_with( request=BeginTransactionRequest( session=self.SESSION_NAME, options=expected_options ), @@ -2195,14 +2192,12 @@ def unit_of_work(txn, *args, **kw): ) def test_run_in_transaction_w_isolation_level_at_request_overrides_client(self): - gax_api = self._make_spanner_api() - gax_api.begin_transaction.return_value = TransactionPB(id=b"FACEDACE") database = self._make_database( default_transaction_options=DefaultTransactionOptions( isolation_level="SERIALIZABLE" ) ) - database.spanner_api = gax_api + api = database.spanner_api = build_spanner_api() session = self._make_one(database) session._session_id = self.SESSION_ID @@ -2223,7 +2218,7 @@ def unit_of_work(txn, *args, **kw): read_write=TransactionOptions.ReadWrite(), isolation_level=TransactionOptions.IsolationLevel.REPEATABLE_READ, ) - gax_api.begin_transaction.assert_called_once_with( + api.begin_transaction.assert_called_once_with( request=BeginTransactionRequest( session=self.SESSION_NAME, options=expected_options ), From a77cc2b628eb6a7b23fbf0a65c9a3f81f422186e Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Tue, 3 Jun 2025 18:20:29 -0700 Subject: [PATCH 21/41] feat: Multiplexed sessions - Remove `GOOGLE_CLOUD_SPANNER_FORCE_DISABLE_MULTIPLEXED_SESSIONS` and update tests. Signed-off-by: Taylor Curran --- google/cloud/spanner_v1/session_options.py | 15 +++------------ tests/unit/test_session_options.py | 15 --------------- 2 files changed, 3 insertions(+), 27 deletions(-) diff --git a/google/cloud/spanner_v1/session_options.py b/google/cloud/spanner_v1/session_options.py index a3042142cd..35939dc469 100644 --- a/google/cloud/spanner_v1/session_options.py +++ b/google/cloud/spanner_v1/session_options.py @@ -40,9 +40,6 @@ class SessionOptions(object): ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE = ( "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_FOR_RW" ) - ENV_VAR_FORCE_DISABLE_MULTIPLEXED = ( - "GOOGLE_CLOUD_SPANNER_FORCE_DISABLE_MULTIPLEXED_SESSIONS" - ) def __init__(self): # Internal overrides to disable the use of multiplexed @@ -57,20 +54,17 @@ def use_multiplexed(self, transaction_type: TransactionType) -> bool: """Returns whether to use multiplexed sessions for the given transaction type. Multiplexed sessions are enabled for read-only transactions if: - * ENV_VAR_ENABLE_MULTIPLEXED is set to true; - * ENV_VAR_FORCE_DISABLE_MULTIPLEXED is not set to true; and + * ENV_VAR_ENABLE_MULTIPLEXED is set to true; and * multiplexed sessions have not been disabled for read-only transactions. Multiplexed sessions are enabled for partitioned transactions if: * ENV_VAR_ENABLE_MULTIPLEXED is set to true; - * ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED is set to true; - * ENV_VAR_FORCE_DISABLE_MULTIPLEXED is not set to true; and + * ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED is set to true; and * multiplexed sessions have not been disabled for partitioned transactions. Multiplexed sessions are enabled for read/write transactions if: * ENV_VAR_ENABLE_MULTIPLEXED is set to true; - * ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE is set to true; - * ENV_VAR_FORCE_DISABLE_MULTIPLEXED is not set to true; and + * ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE is set to true; and * multiplexed sessions have not been disabled for read/write transactions. :type transaction_type: :class:`TransactionType` @@ -80,7 +74,6 @@ def use_multiplexed(self, transaction_type: TransactionType) -> bool: if transaction_type is TransactionType.READ_ONLY: return ( self._getenv(self.ENV_VAR_ENABLE_MULTIPLEXED) - and not self._getenv(self.ENV_VAR_FORCE_DISABLE_MULTIPLEXED) and self._is_multiplexed_enabled[transaction_type] ) @@ -88,7 +81,6 @@ def use_multiplexed(self, transaction_type: TransactionType) -> bool: return ( self._getenv(self.ENV_VAR_ENABLE_MULTIPLEXED) and self._getenv(self.ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED) - and not self._getenv(self.ENV_VAR_FORCE_DISABLE_MULTIPLEXED) and self._is_multiplexed_enabled[transaction_type] ) @@ -96,7 +88,6 @@ def use_multiplexed(self, transaction_type: TransactionType) -> bool: return ( self._getenv(self.ENV_VAR_ENABLE_MULTIPLEXED) and self._getenv(self.ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE) - and not self._getenv(self.ENV_VAR_FORCE_DISABLE_MULTIPLEXED) and self._is_multiplexed_enabled[transaction_type] ) diff --git a/tests/unit/test_session_options.py b/tests/unit/test_session_options.py index 393df401f5..18291eae34 100644 --- a/tests/unit/test_session_options.py +++ b/tests/unit/test_session_options.py @@ -42,10 +42,6 @@ def test_use_multiplexed_for_read_only(self): self.assertFalse(session_options.use_multiplexed(transaction_type)) environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "true" - environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "true" - self.assertFalse(session_options.use_multiplexed(transaction_type)) - - environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "false" self.assertTrue(session_options.use_multiplexed(transaction_type)) session_options.disable_multiplexed(self.logger, transaction_type) @@ -67,10 +63,6 @@ def test_use_multiplexed_for_partitioned(self): self.assertFalse(session_options.use_multiplexed(transaction_type)) environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED] = "true" - environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "true" - self.assertFalse(session_options.use_multiplexed(transaction_type)) - - environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "false" self.assertTrue(session_options.use_multiplexed(transaction_type)) session_options.disable_multiplexed(self.logger, transaction_type) @@ -92,10 +84,6 @@ def test_use_multiplexed_for_read_write(self): self.assertFalse(session_options.use_multiplexed(transaction_type)) environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE] = "true" - environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "true" - self.assertFalse(session_options.use_multiplexed(transaction_type)) - - environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "false" self.assertTrue(session_options.use_multiplexed(transaction_type)) session_options.disable_multiplexed(self.logger, transaction_type) @@ -111,7 +99,6 @@ def test_disable_multiplexed_all(self): environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "true" environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED] = "true" environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE] = "true" - environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "false" session_options.disable_multiplexed(self.logger) @@ -144,8 +131,6 @@ def test_unsupported_transaction_type(self): def test_env_var_values(self): session_options = SessionOptions() - environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "false" - true_values = ["1", " 1", " 1", "true", "True", "TRUE", " true "] for value in true_values: environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = value From 5615f2c1de3dbbdf0e013328c58d0ade2c84a372 Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Tue, 3 Jun 2025 18:31:20 -0700 Subject: [PATCH 22/41] feat: Multiplexed sessions - Cleanup `TestDatabaseSessionManager` so that it doesn't depend on environment variable values. Signed-off-by: Taylor Curran --- tests/unit/test_database_session_manager.py | 121 +++++++++----------- 1 file changed, 55 insertions(+), 66 deletions(-) diff --git a/tests/unit/test_database_session_manager.py b/tests/unit/test_database_session_manager.py index c967dd9705..3fa1252837 100644 --- a/tests/unit/test_database_session_manager.py +++ b/tests/unit/test_database_session_manager.py @@ -13,7 +13,6 @@ # limitations under the License. from datetime import timedelta from mock import Mock, patch -from os import environ from threading import Thread from time import time, sleep from typing import Callable @@ -21,7 +20,7 @@ from google.api_core.exceptions import BadRequest, FailedPrecondition from google.cloud.spanner_v1.database_sessions_manager import DatabaseSessionsManager -from google.cloud.spanner_v1.session_options import SessionOptions, TransactionType +from google.cloud.spanner_v1.session_options import TransactionType from tests._builders import build_database @@ -33,30 +32,46 @@ ) class TestDatabaseSessionManager(TestCase): def setUp(self): - self._original_env = dict(environ) - self._sessions_manager = self._build_sessions_manager() + # Build session manager. + database = build_database() + self._manager = database._sessions_manager + + # Mock the session pool. + pool = self._manager._pool + pool.get = Mock(wraps=pool.get) + pool.put = Mock(wraps=pool.put) def tearDown(self): - self._cleanup_database_sessions_manager() - environ.clear() - environ.update(self._original_env) + # If the maintenance thread is still alive, disable multiplexed sessions and + # wait for the thread to terminate. We need to do this to ensure that the + # thread is properly cleaned up and does not interfere with other tests. + sessions_manager = self._manager + thread = sessions_manager._multiplexed_session_thread + + if thread and thread.is_alive(): + sessions_manager._multiplexed_session_disabled_event.set() + self._assert_thread_terminated(thread) def test_read_only_pooled(self): + manager = self._manager + pool = manager._pool + self._disable_multiplexed_sessions() - manager = self._sessions_manager # Get session from pool. session = manager.get_session(TransactionType.READ_ONLY) self.assertFalse(session.is_multiplexed) - manager._pool.get.assert_called_once() + pool.get.assert_called_once() # Return session to pool. manager.put_session(session) - manager._pool.put.assert_called_once_with(session) + pool.put.assert_called_once_with(session) def test_read_only_multiplexed(self): + manager = self._manager + pool = manager._pool + self._enable_multiplexed_sessions() - manager = self._sessions_manager # Session is created. session_1 = manager.get_session(TransactionType.READ_ONLY) @@ -69,29 +84,33 @@ def test_read_only_multiplexed(self): manager.put_session(session_2) # Verify that pool was not used. - manager._pool.get.assert_not_called() - manager._pool.put.assert_not_called() + pool.get.assert_not_called() + pool.put.assert_not_called() # Verify logger calls. info = manager._database.logger.info info.assert_called_once_with("Created multiplexed session.") def test_partitioned_pooled(self): + manager = self._manager + pool = manager._pool + self._disable_multiplexed_sessions() - manager = self._sessions_manager # Get session from pool. session = manager.get_session(TransactionType.PARTITIONED) self.assertFalse(session.is_multiplexed) - manager._pool.get.assert_called_once() + pool.get.assert_called_once() # Return session to pool. manager.put_session(session) - manager._pool.put.assert_called_once_with(session) + pool.put.assert_called_once_with(session) def test_partitioned_multiplexed(self): + manager = self._manager + pool = manager._pool + self._enable_multiplexed_sessions() - manager = self._sessions_manager # Session is created. session_1 = manager.get_session(TransactionType.PARTITIONED) @@ -104,7 +123,6 @@ def test_partitioned_multiplexed(self): manager.put_session(session_2) # Verify that pool was not used. - pool = manager._pool pool.get.assert_not_called() pool.put.assert_not_called() @@ -113,28 +131,30 @@ def test_partitioned_multiplexed(self): info.assert_called_once_with("Created multiplexed session.") def test_read_write_pooled(self): + manager = self._manager + pool = manager._pool + self._disable_multiplexed_sessions() - manager = self._sessions_manager # Get session from pool. session = manager.get_session(TransactionType.READ_WRITE) self.assertFalse(session.is_multiplexed) - manager._pool.get.assert_called_once() + pool.get.assert_called_once() # Return session to pool. manager.put_session(session) - manager._pool.put.assert_called_once_with(session) + pool.put.assert_called_once_with(session) # TODO multiplexed: implement support for read/write transactions. def test_read_write_multiplexed(self): self._enable_multiplexed_sessions() with self.assertRaises(NotImplementedError): - self._sessions_manager.get_session(TransactionType.READ_WRITE) + self._manager.get_session(TransactionType.READ_WRITE) - def test_multiplexed_maintenance(self, *_): + def test_multiplexed_maintenance(self): + manager = self._manager self._enable_multiplexed_sessions() - manager = self._sessions_manager # Maintenance thread is started. session_1 = manager.get_session(TransactionType.READ_ONLY) @@ -152,8 +172,8 @@ def test_multiplexed_maintenance(self, *_): self.assertNotEqual(session_1, session_2) def test_multiplexed_maintenance_terminates_disabled(self): + manager = self._manager self._enable_multiplexed_sessions() - manager = self._sessions_manager # Maintenance thread is started. session_1 = manager.get_session(TransactionType.READ_ONLY) @@ -165,7 +185,7 @@ def test_multiplexed_maintenance_terminates_disabled(self): self._assert_thread_terminated(thread) def test_exception_bad_request(self): - manager = self._sessions_manager + manager = self._manager api = manager._database.spanner_api api.create_session.side_effect = BadRequest("") @@ -174,7 +194,7 @@ def test_exception_bad_request(self): manager.get_session(TransactionType.READ_ONLY) def test_exception_failed_precondition(self): - manager = self._sessions_manager + manager = self._manager api = manager._database.spanner_api api.create_session.side_effect = FailedPrecondition("") @@ -182,19 +202,6 @@ def test_exception_failed_precondition(self): with self.assertRaises(FailedPrecondition): manager.get_session(TransactionType.READ_ONLY) - def _cleanup_database_sessions_manager(self) -> None: - """Cleans up the database session manager after testing.""" - - # If the maintenance thread is still alive, disable multiplexed sessions and - # wait for the thread to terminate. We need to do this to ensure that the - # thread is properly cleaned up and does not interfere with other tests. - sessions_manager = self._sessions_manager - thread = sessions_manager._multiplexed_session_thread - - if thread and thread.is_alive(): - sessions_manager._multiplexed_session_disabled_event.set() - self._assert_thread_terminated(thread) - def _assert_true_with_timeout(self, condition: Callable) -> None: """Asserts that the given condition is met within a timeout period.""" @@ -215,32 +222,14 @@ def _is_thread_terminated(): self._assert_true_with_timeout(_is_thread_terminated) - @staticmethod - def _build_sessions_manager() -> DatabaseSessionsManager: - """Builds and returns a new database session manager for testing. - - :rtype: :class:`~google.cloud.spanner_v1.database_sessions_manager.DatabaseSessionsManager` - :returns: a new database session manager. - """ - database = build_database() - sessions_manager = database._sessions_manager - - # Mock the session pool. - pool = sessions_manager._pool - pool.get = Mock(wraps=pool.get) - pool.put = Mock(wraps=pool.put) - - return sessions_manager - - @staticmethod - def _disable_multiplexed_sessions() -> None: + def _disable_multiplexed_sessions(self) -> None: """Sets environment variables to disable multiplexed sessions for all transactions types.""" - environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "true" - @staticmethod - def _enable_multiplexed_sessions() -> None: + options = self._manager._database._instance._client._session_options + options.use_multiplexed = Mock(return_value=False) + + def _enable_multiplexed_sessions(self) -> None: """Sets environment variables to enable multiplexed sessions for all transaction types.""" - environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "true" - environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED] = "true" - environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE] = "true" - environ[SessionOptions.ENV_VAR_FORCE_DISABLE_MULTIPLEXED] = "false" + + options = self._manager._database._instance._client._session_options + options.use_multiplexed = Mock(return_value=True) From 00059f942e6dbdc317bdfe7d0cfae8c16b7afdec Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Tue, 3 Jun 2025 19:10:27 -0700 Subject: [PATCH 23/41] feat: Multiplexed sessions - Add type hints for `SessionOptions` and `DatabaseSessionManager`. Signed-off-by: Taylor Curran --- google/cloud/spanner_v1/database_sessions_manager.py | 9 +++++---- google/cloud/spanner_v1/session_options.py | 3 ++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/google/cloud/spanner_v1/database_sessions_manager.py b/google/cloud/spanner_v1/database_sessions_manager.py index 46862b6250..44ca8502c0 100644 --- a/google/cloud/spanner_v1/database_sessions_manager.py +++ b/google/cloud/spanner_v1/database_sessions_manager.py @@ -14,6 +14,7 @@ from datetime import timedelta from threading import Event, Lock, Thread from time import sleep, time +from typing import Optional from weakref import ref from google.cloud.spanner_v1.session import Session @@ -57,10 +58,10 @@ def __init__(self, database, pool): # the multiplexed session to avoid any race conditions. We also create an event # so that the thread can terminate if the use of multiplexed session has been # disabled for all transactions. - self._multiplexed_session = None - self._multiplexed_session_thread = None - self._multiplexed_session_lock = Lock() - self._multiplexed_session_disabled_event = Event() + self._multiplexed_session: Optional[Session] = None + self._multiplexed_session_thread: Optional[Thread] = None + self._multiplexed_session_lock: Lock = Lock() + self._multiplexed_session_disabled_event: Event = Event() def get_session(self, transaction_type: TransactionType) -> Session: """Returns a session for the given transaction type from the database session manager. diff --git a/google/cloud/spanner_v1/session_options.py b/google/cloud/spanner_v1/session_options.py index 35939dc469..7e68b235a3 100644 --- a/google/cloud/spanner_v1/session_options.py +++ b/google/cloud/spanner_v1/session_options.py @@ -14,6 +14,7 @@ import logging import os from enum import Enum +from typing import Mapping class TransactionType(Enum): @@ -44,7 +45,7 @@ class SessionOptions(object): def __init__(self): # Internal overrides to disable the use of multiplexed # sessions in case of runtime errors. - self._is_multiplexed_enabled = { + self._is_multiplexed_enabled: Mapping[TransactionType, str] = { TransactionType.READ_ONLY: True, TransactionType.PARTITIONED: True, TransactionType.READ_WRITE: True, From 083d6bc326c492127d514beedbc1cc3ae8b2e9c3 Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Tue, 3 Jun 2025 19:10:59 -0700 Subject: [PATCH 24/41] feat: Multiplexed sessions - Fix `test_observability_options` Signed-off-by: Taylor Curran --- tests/system/test_observability_options.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/system/test_observability_options.py b/tests/system/test_observability_options.py index c3eabffe12..7f818c8f13 100644 --- a/tests/system/test_observability_options.py +++ b/tests/system/test_observability_options.py @@ -13,7 +13,9 @@ # limitations under the License. import pytest +from mock import PropertyMock, patch +from google.cloud.spanner_v1.session import Session from . import _helpers from google.cloud.spanner_v1 import Client from google.api_core.exceptions import Aborted @@ -210,9 +212,15 @@ def create_db_trace_exporter(): not HAS_OTEL_INSTALLED, reason="Tracing requires OpenTelemetry", ) -def test_transaction_abort_then_retry_spans(): +@patch.object(Session, "session_id", new_callable=PropertyMock) +@patch.object(Session, "is_multiplexed", new_callable=PropertyMock) +def test_transaction_abort_then_retry_spans(mock_session_multiplexed, mock_session_id): from opentelemetry.trace.status import StatusCode + # Mock session properties for testing. + mock_session_multiplexed.return_value = session_multiplexed = False + mock_session_id.return_value = session_id = "session-id" + db, trace_exporter = create_db_trace_exporter() counters = dict(aborted=0) @@ -239,6 +247,8 @@ def select_in_txn(txn): ("Waiting for a session to become available", {"kind": "BurstyPool"}), ("No sessions available in pool. Creating session", {"kind": "BurstyPool"}), ("Creating Session", {}), + ("Using session", {"id": session_id, "multiplexed": session_multiplexed}), + ("Returning session", {"id": session_id, "multiplexed": session_multiplexed}), ( "Transaction was aborted in user operation, retrying", {"delay_seconds": "EPHEMERAL", "cause": "EPHEMERAL", "attempt": 1}, From 6e33b1d1f0444c54c4883f563dd985ea1a4fda10 Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Tue, 3 Jun 2025 19:11:17 -0700 Subject: [PATCH 25/41] feat: Multiplexed sessions - Update `_builders` to use mock scoped credentials. Signed-off-by: Taylor Curran --- tests/_builders.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/_builders.py b/tests/_builders.py index 50816efed3..70b92706c2 100644 --- a/tests/_builders.py +++ b/tests/_builders.py @@ -16,6 +16,7 @@ from mock import create_autospec from typing import Mapping +from google.auth.credentials import Credentials, Scoped from google.cloud.spanner_dbapi import Connection from google.cloud.spanner_v1 import SpannerClient from google.cloud.spanner_v1.client import Client @@ -92,6 +93,9 @@ def build_client(**kwargs: Mapping) -> Client: if "project" not in kwargs: kwargs["project"] = _PROJECT_ID + if "credentials" not in kwargs: + kwargs["credentials"] = build_scoped_credentials() + return Client(**kwargs) @@ -174,6 +178,15 @@ def build_logger() -> Logger: return create_autospec(Logger, instance=True) +def build_scoped_credentials() -> Credentials: + """Builds and returns a mock scoped credentials for testing.""" + + class _ScopedCredentials(Credentials, Scoped): + pass + + return create_autospec(spec=_ScopedCredentials, instance=True) + + def build_spanner_api() -> SpannerClient: """Builds and returns a mock Spanner Client API for testing using the given arguments. Commonly used methods are mocked to return default values.""" From 65042abdec0bf92fd27c059645c7fcbd3c458738 Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Tue, 3 Jun 2025 19:48:17 -0700 Subject: [PATCH 26/41] feat: Multiplexed sessions - Add helpers for mock scoped credentials for testing. Signed-off-by: Taylor Curran --- tests/unit/spanner_dbapi/test_connect.py | 14 +---- tests/unit/spanner_dbapi/test_connection.py | 9 --- tests/unit/test_client.py | 62 +++++++++------------ tests/unit/test_database.py | 11 ---- 4 files changed, 29 insertions(+), 67 deletions(-) diff --git a/tests/unit/spanner_dbapi/test_connect.py b/tests/unit/spanner_dbapi/test_connect.py index 5e748eaf66..7f4fb4c7f3 100644 --- a/tests/unit/spanner_dbapi/test_connect.py +++ b/tests/unit/spanner_dbapi/test_connect.py @@ -17,24 +17,16 @@ import unittest from unittest import mock -import google from google.auth.credentials import AnonymousCredentials +from tests._builders import build_scoped_credentials + INSTANCE = "test-instance" DATABASE = "test-database" PROJECT = "test-project" USER_AGENT = "user-agent" -def _make_credentials(): - class _CredentialsWithScopes( - google.auth.credentials.Credentials, google.auth.credentials.Scoped - ): - pass - - return mock.Mock(spec=_CredentialsWithScopes) - - @mock.patch("google.cloud.spanner_v1.Client") class Test_connect(unittest.TestCase): def test_w_implicit(self, mock_client): @@ -79,7 +71,7 @@ def test_w_explicit(self, mock_client): from google.cloud.spanner_dbapi import Connection from google.cloud.spanner_dbapi.version import PY_VERSION - credentials = _make_credentials() + credentials = build_scoped_credentials() pool = mock.create_autospec(AbstractSessionPool) client = mock_client.return_value instance = client.instance.return_value diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index d2501be20e..dbef230417 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -46,15 +46,6 @@ USER_AGENT = "user-agent" -def _make_credentials(): - from google.auth import credentials - - class _CredentialsWithScopes(credentials.Credentials, credentials.Scoped): - pass - - return mock.Mock(spec=_CredentialsWithScopes) - - class TestConnection(unittest.TestCase): def setUp(self): self._under_test = self._make_connection() diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 6084224a84..dd6e6a6b8d 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -19,17 +19,7 @@ from google.auth.credentials import AnonymousCredentials from google.cloud.spanner_v1 import DirectedReadOptions, DefaultTransactionOptions - - -def _make_credentials(): - import google.auth.credentials - - class _CredentialsWithScopes( - google.auth.credentials.Credentials, google.auth.credentials.Scoped - ): - pass - - return mock.Mock(spec=_CredentialsWithScopes) +from tests._builders import build_scoped_credentials class TestClient(unittest.TestCase): @@ -148,7 +138,7 @@ def test_constructor_emulator_host_warning(self, mock_warn, mock_em): from google.auth.credentials import AnonymousCredentials expected_scopes = None - creds = _make_credentials() + creds = build_scoped_credentials() mock_em.return_value = "http://emulator.host.com" with mock.patch("google.cloud.spanner_v1.client.AnonymousCredentials") as patch: expected_creds = patch.return_value = AnonymousCredentials() @@ -159,7 +149,7 @@ def test_constructor_default_scopes(self): from google.cloud.spanner_v1 import client as MUT expected_scopes = (MUT.SPANNER_ADMIN_SCOPE,) - creds = _make_credentials() + creds = build_scoped_credentials() self._constructor_test_helper(expected_scopes, creds) def test_constructor_custom_client_info(self): @@ -167,7 +157,7 @@ def test_constructor_custom_client_info(self): client_info = mock.Mock() expected_scopes = (MUT.SPANNER_ADMIN_SCOPE,) - creds = _make_credentials() + creds = build_scoped_credentials() self._constructor_test_helper(expected_scopes, creds, client_info=client_info) # Disable metrics to avoid google.auth.default calls from Metric Exporter @@ -175,7 +165,7 @@ def test_constructor_custom_client_info(self): def test_constructor_implicit_credentials(self): from google.cloud.spanner_v1 import client as MUT - creds = _make_credentials() + creds = build_scoped_credentials() patch = mock.patch("google.auth.default", return_value=(creds, None)) with patch as default: @@ -186,7 +176,7 @@ def test_constructor_implicit_credentials(self): default.assert_called_once_with(scopes=(MUT.SPANNER_ADMIN_SCOPE,)) def test_constructor_credentials_wo_create_scoped(self): - creds = _make_credentials() + creds = build_scoped_credentials() expected_scopes = None self._constructor_test_helper(expected_scopes, creds) @@ -195,7 +185,7 @@ def test_constructor_custom_client_options_obj(self): from google.cloud.spanner_v1 import client as MUT expected_scopes = (MUT.SPANNER_ADMIN_SCOPE,) - creds = _make_credentials() + creds = build_scoped_credentials() self._constructor_test_helper( expected_scopes, creds, @@ -206,7 +196,7 @@ def test_constructor_custom_client_options_dict(self): from google.cloud.spanner_v1 import client as MUT expected_scopes = (MUT.SPANNER_ADMIN_SCOPE,) - creds = _make_credentials() + creds = build_scoped_credentials() self._constructor_test_helper( expected_scopes, creds, client_options={"api_endpoint": "endpoint"} ) @@ -216,7 +206,7 @@ def test_constructor_custom_query_options_client_config(self): from google.cloud.spanner_v1 import client as MUT expected_scopes = (MUT.SPANNER_ADMIN_SCOPE,) - creds = _make_credentials() + creds = build_scoped_credentials() query_options = expected_query_options = ExecuteSqlRequest.QueryOptions( optimizer_version="1", optimizer_statistics_package="auto_20191128_14_47_22UTC", @@ -237,7 +227,7 @@ def test_constructor_custom_query_options_env_config(self, mock_ver, mock_stats) from google.cloud.spanner_v1 import client as MUT expected_scopes = (MUT.SPANNER_ADMIN_SCOPE,) - creds = _make_credentials() + creds = build_scoped_credentials() mock_ver.return_value = "2" mock_stats.return_value = "auto_20191128_14_47_22UTC" query_options = ExecuteSqlRequest.QueryOptions( @@ -259,7 +249,7 @@ def test_constructor_w_directed_read_options(self): from google.cloud.spanner_v1 import client as MUT expected_scopes = (MUT.SPANNER_ADMIN_SCOPE,) - creds = _make_credentials() + creds = build_scoped_credentials() self._constructor_test_helper( expected_scopes, creds, directed_read_options=self.DIRECTED_READ_OPTIONS ) @@ -268,7 +258,7 @@ def test_constructor_route_to_leader_disbled(self): from google.cloud.spanner_v1 import client as MUT expected_scopes = (MUT.SPANNER_ADMIN_SCOPE,) - creds = _make_credentials() + creds = build_scoped_credentials() self._constructor_test_helper( expected_scopes, creds, route_to_leader_enabled=False ) @@ -277,7 +267,7 @@ def test_constructor_w_default_transaction_options(self): from google.cloud.spanner_v1 import client as MUT expected_scopes = (MUT.SPANNER_ADMIN_SCOPE,) - creds = _make_credentials() + creds = build_scoped_credentials() self._constructor_test_helper( expected_scopes, creds, @@ -291,7 +281,7 @@ def test_instance_admin_api(self, mock_em): mock_em.return_value = None - credentials = _make_credentials() + credentials = build_scoped_credentials() client_info = mock.Mock() client_options = ClientOptions(quota_project_id="QUOTA-PROJECT") client = self._make_one( @@ -325,7 +315,7 @@ def test_instance_admin_api_emulator_env(self, mock_em): from google.api_core.client_options import ClientOptions mock_em.return_value = "emulator.host" - credentials = _make_credentials() + credentials = build_scoped_credentials() client_info = mock.Mock() client_options = ClientOptions(api_endpoint="endpoint") client = self._make_one( @@ -391,7 +381,7 @@ def test_database_admin_api(self, mock_em): from google.api_core.client_options import ClientOptions mock_em.return_value = None - credentials = _make_credentials() + credentials = build_scoped_credentials() client_info = mock.Mock() client_options = ClientOptions(quota_project_id="QUOTA-PROJECT") client = self._make_one( @@ -425,7 +415,7 @@ def test_database_admin_api_emulator_env(self, mock_em): from google.api_core.client_options import ClientOptions mock_em.return_value = "host:port" - credentials = _make_credentials() + credentials = build_scoped_credentials() client_info = mock.Mock() client_options = ClientOptions(api_endpoint="endpoint") client = self._make_one( @@ -486,7 +476,7 @@ def test_database_admin_api_emulator_code(self): self.assertNotIn("credentials", called_kw) def test_copy(self): - credentials = _make_credentials() + credentials = build_scoped_credentials() # Make sure it "already" is scoped. credentials.requires_scopes = False @@ -497,12 +487,12 @@ def test_copy(self): self.assertEqual(new_client.project, client.project) def test_credentials_property(self): - credentials = _make_credentials() + credentials = build_scoped_credentials() client = self._make_one(project=self.PROJECT, credentials=credentials) self.assertIs(client.credentials, credentials.with_scopes.return_value) def test_project_name_property(self): - credentials = _make_credentials() + credentials = build_scoped_credentials() client = self._make_one(project=self.PROJECT, credentials=credentials) project_name = "projects/" + self.PROJECT self.assertEqual(client.project_name, project_name) @@ -516,7 +506,7 @@ def test_list_instance_configs(self): from google.cloud.spanner_admin_instance_v1 import ListInstanceConfigsResponse api = InstanceAdminClient(credentials=AnonymousCredentials()) - credentials = _make_credentials() + credentials = build_scoped_credentials() client = self._make_one(project=self.PROJECT, credentials=credentials) client._instance_admin_api = api @@ -562,7 +552,7 @@ def test_list_instance_configs_w_options(self): from google.cloud.spanner_admin_instance_v1 import ListInstanceConfigsRequest from google.cloud.spanner_admin_instance_v1 import ListInstanceConfigsResponse - credentials = _make_credentials() + credentials = build_scoped_credentials() api = InstanceAdminClient(credentials=credentials) client = self._make_one(project=self.PROJECT, credentials=credentials) client._instance_admin_api = api @@ -597,7 +587,7 @@ def test_instance_factory_defaults(self): from google.cloud.spanner_v1.instance import DEFAULT_NODE_COUNT from google.cloud.spanner_v1.instance import Instance - credentials = _make_credentials() + credentials = build_scoped_credentials() client = self._make_one(project=self.PROJECT, credentials=credentials) instance = client.instance(self.INSTANCE_ID) @@ -613,7 +603,7 @@ def test_instance_factory_defaults(self): def test_instance_factory_explicit(self): from google.cloud.spanner_v1.instance import Instance - credentials = _make_credentials() + credentials = build_scoped_credentials() client = self._make_one(project=self.PROJECT, credentials=credentials) instance = client.instance( @@ -638,7 +628,7 @@ def test_list_instances(self): from google.cloud.spanner_admin_instance_v1 import ListInstancesRequest from google.cloud.spanner_admin_instance_v1 import ListInstancesResponse - credentials = _make_credentials() + credentials = build_scoped_credentials() api = InstanceAdminClient(credentials=credentials) client = self._make_one(project=self.PROJECT, credentials=credentials) client._instance_admin_api = api @@ -686,7 +676,7 @@ def test_list_instances_w_options(self): from google.cloud.spanner_admin_instance_v1 import ListInstancesRequest from google.cloud.spanner_admin_instance_v1 import ListInstancesResponse - credentials = _make_credentials() + credentials = build_scoped_credentials() api = InstanceAdminClient(credentials=credentials) client = self._make_one(project=self.PROJECT, credentials=credentials) client._instance_admin_api = api diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index b74c5cef2f..7c5b9691fd 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -63,17 +63,6 @@ } -def _make_credentials(): # pragma: NO COVER - import google.auth.credentials - - class _CredentialsWithScopes( - google.auth.credentials.Credentials, google.auth.credentials.Scoped - ): - pass - - return mock.Mock(spec=_CredentialsWithScopes) - - class _BaseTest(unittest.TestCase): PROJECT_ID = "project-id" PARENT = "projects/" + PROJECT_ID From 9df088d1ca47bd9fc52be2a3c08f1d6e75ca6e4e Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Tue, 3 Jun 2025 21:45:57 -0700 Subject: [PATCH 27/41] feat: Multiplexed sessions - Fix failing `test_batch_insert_then_read`. Signed-off-by: Taylor Curran --- tests/system/test_session_api.py | 39 ++++++++++++++++++++++++-------- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/tests/system/test_session_api.py b/tests/system/test_session_api.py index 26b389090f..d160349083 100644 --- a/tests/system/test_session_api.py +++ b/tests/system/test_session_api.py @@ -458,16 +458,32 @@ def test_batch_insert_then_read(sessions_database, ot_exporter): os.getenv("GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS", "").lower() == "true" ) - assert_span_attributes( - ot_exporter, - "CloudSpanner.GetSession", - attributes=_make_attributes( - db_name, - session_found=True, - x_goog_spanner_request_id=f"1.{REQ_RAND_PROCESS_ID}.{db._nth_client_id}.{db._channel_id}.{nth_req0 + 0}.1", - ), - span=span_list[0], - ) + # [A] Verify batch checkout spans + # ------------------------------- + + request_id_1 = f"1.{REQ_RAND_PROCESS_ID}.{db._nth_client_id}.{db._channel_id}.{nth_req0 + 0}.1" + + if multiplexed_enabled: + assert_span_attributes( + ot_exporter, + "CloudSpanner.CreateMultiplexedSession", + attributes=_make_attributes( + db_name, x_goog_spanner_request_id=request_id_1 + ), + span=span_list[0], + ) + else: + assert_span_attributes( + ot_exporter, + "CloudSpanner.GetSession", + attributes=_make_attributes( + db_name, + session_found=True, + x_goog_spanner_request_id=request_id_1, + ), + span=span_list[0], + ) + assert_span_attributes( ot_exporter, "CloudSpanner.Batch.commit", @@ -479,6 +495,9 @@ def test_batch_insert_then_read(sessions_database, ot_exporter): span=span_list[1], ) + # [B] Verify snapshot checkout spans + # ---------------------------------- + if len(span_list) == 4: if multiplexed_enabled: expected_snapshot_span_name = "CloudSpanner.CreateMultiplexedSession" From 607df64846cd327d211d28396d4cadb08dd635b1 Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Tue, 3 Jun 2025 22:10:58 -0700 Subject: [PATCH 28/41] feat: Multiplexed sessions - Fix failing `test_transaction_read_and_insert_then_rollback`. Signed-off-by: Taylor Curran --- tests/system/test_session_api.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/system/test_session_api.py b/tests/system/test_session_api.py index d160349083..f11cdc99e6 100644 --- a/tests/system/test_session_api.py +++ b/tests/system/test_session_api.py @@ -701,12 +701,11 @@ def transaction_work(transaction): if multiplexed_enabled: # With multiplexed sessions enabled: - # - Batch operations still use regular sessions (GetSession) + # - Batch operations use multiplexed sessions (GetSession) # - run_in_transaction uses regular sessions (GetSession) - # - Snapshot (read-only) can use multiplexed sessions (CreateMultiplexedSession) + # - Snapshot (read-only) re-use existing multiplexed sessions # Note: Session creation span may not appear if session is reused from pool expected_span_names = [ - "CloudSpanner.GetSession", # Batch operation "CloudSpanner.Batch.commit", # Batch commit "CloudSpanner.GetSession", # Transaction session "CloudSpanner.Transaction.read", # First read From 0b6f5dfeb79c4a9d1baabce87783739b233a1244 Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Wed, 4 Jun 2025 08:02:43 -0700 Subject: [PATCH 29/41] feat: Multiplexed sessions - Add test helper for multiplexed env vars. Signed-off-by: Taylor Curran --- tests/_helpers.py | 30 ++++++++++++++++++++++ tests/system/test_observability_options.py | 29 ++++++++------------- tests/system/test_session_api.py | 23 ++++------------- tests/unit/test_database.py | 11 +++----- 4 files changed, 49 insertions(+), 44 deletions(-) diff --git a/tests/_helpers.py b/tests/_helpers.py index 667f9f8be1..2f5eed98de 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -1,7 +1,10 @@ import unittest +from os import getenv + import mock from google.cloud.spanner_v1 import gapic_version +from google.cloud.spanner_v1.session_options import TransactionType LIB_VERSION = gapic_version.__version__ @@ -31,6 +34,33 @@ _TEST_OT_EXPORTER = None _TEST_OT_PROVIDER_INITIALIZED = False +# Environment variables for enabling multiplexed sessions +"GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS" +ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED = ( + "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_PARTITIONED_OPS" +) +ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE = ( + "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_FOR_RW" +) + + +def is_multiplexed_enabled(transaction_type: TransactionType) -> bool: + """Returns whether multiplexed sessions are enabled for the given transaction type.""" + + env_var = "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS" + env_var_partitioned = "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_PARTITIONED_OPS" + env_var_read_write = "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_FOR_RW" + + def _getenv(val: str) -> bool: + return getenv(val, "false").lower() == "true" + + if transaction_type is TransactionType.READ_ONLY: + return _getenv(env_var) + elif transaction_type is TransactionType.PARTITIONED: + return _getenv(env_var) and _getenv(env_var_partitioned) + else: + return _getenv(env_var) and _getenv(env_var_read_write) + def get_test_ot_exporter(): global _TEST_OT_EXPORTER diff --git a/tests/system/test_observability_options.py b/tests/system/test_observability_options.py index 7f818c8f13..d25f5e73d7 100644 --- a/tests/system/test_observability_options.py +++ b/tests/system/test_observability_options.py @@ -16,12 +16,15 @@ from mock import PropertyMock, patch from google.cloud.spanner_v1.session import Session +from google.cloud.spanner_v1.session_options import TransactionType from . import _helpers from google.cloud.spanner_v1 import Client from google.api_core.exceptions import Aborted from google.auth.credentials import AnonymousCredentials from google.rpc import code_pb2 +from .._helpers import is_multiplexed_enabled + HAS_OTEL_INSTALLED = False try: @@ -113,11 +116,7 @@ def test_propagation(enable_extended_tracing): gotNames = [span.name for span in from_inject_spans] # Check if multiplexed sessions are enabled - import os - - multiplexed_enabled = ( - os.getenv("GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS", "").lower() == "true" - ) + multiplexed_enabled = is_multiplexed_enabled(TransactionType.READ_ONLY) # Determine expected session span name based on multiplexed sessions expected_session_span_name = ( @@ -213,13 +212,11 @@ def create_db_trace_exporter(): reason="Tracing requires OpenTelemetry", ) @patch.object(Session, "session_id", new_callable=PropertyMock) -@patch.object(Session, "is_multiplexed", new_callable=PropertyMock) -def test_transaction_abort_then_retry_spans(mock_session_multiplexed, mock_session_id): +def test_transaction_abort_then_retry_spans(mock_session_id): from opentelemetry.trace.status import StatusCode - # Mock session properties for testing. - mock_session_multiplexed.return_value = session_multiplexed = False mock_session_id.return_value = session_id = "session-id" + multiplexed = is_multiplexed_enabled(TransactionType.READ_WRITE) db, trace_exporter = create_db_trace_exporter() @@ -247,8 +244,8 @@ def select_in_txn(txn): ("Waiting for a session to become available", {"kind": "BurstyPool"}), ("No sessions available in pool. Creating session", {"kind": "BurstyPool"}), ("Creating Session", {}), - ("Using session", {"id": session_id, "multiplexed": session_multiplexed}), - ("Returning session", {"id": session_id, "multiplexed": session_multiplexed}), + ("Using session", {"id": session_id, "multiplexed": multiplexed}), + ("Returning session", {"id": session_id, "multiplexed": multiplexed}), ( "Transaction was aborted in user operation, retrying", {"delay_seconds": "EPHEMERAL", "cause": "EPHEMERAL", "attempt": 1}, @@ -417,7 +414,6 @@ def tx_update(txn): reason="Tracing requires OpenTelemetry", ) def test_database_partitioned_error(): - import os from opentelemetry.trace.status import StatusCode db, trace_exporter = create_db_trace_exporter() @@ -428,12 +424,9 @@ def test_database_partitioned_error(): pass got_statuses, got_events = finished_spans_statuses(trace_exporter) + multiplexed_enabled = is_multiplexed_enabled(TransactionType.PARTITIONED) - multiplexed_partitioned_enabled = ( - os.getenv("GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_PARTITIONED_OPS") == "true" - ) - - if multiplexed_partitioned_enabled: + if multiplexed_enabled: expected_event_names = [ "Creating Session", "Using session", @@ -496,7 +489,7 @@ def test_database_partitioned_error(): expected_session_span_name = ( "CloudSpanner.CreateMultiplexedSession" - if multiplexed_partitioned_enabled + if multiplexed_enabled else "CloudSpanner.CreateSession" ) want_statuses = [ diff --git a/tests/system/test_session_api.py b/tests/system/test_session_api.py index f11cdc99e6..83df30a6b8 100644 --- a/tests/system/test_session_api.py +++ b/tests/system/test_session_api.py @@ -29,6 +29,7 @@ from google.cloud.spanner_admin_database_v1 import DatabaseDialect from google.cloud._helpers import UTC from google.cloud.spanner_v1.data_types import JsonObject +from google.cloud.spanner_v1.session_options import TransactionType from .testdata import singer_pb2 from tests import _helpers as ot_helpers from . import _helpers @@ -37,7 +38,7 @@ REQ_RAND_PROCESS_ID, parse_request_id, ) - +from .._helpers import is_multiplexed_enabled SOME_DATE = datetime.date(2011, 1, 17) SOME_TIME = datetime.datetime(1989, 1, 17, 17, 59, 12, 345612) @@ -430,8 +431,6 @@ def test_session_crud(sessions_database): def test_batch_insert_then_read(sessions_database, ot_exporter): - import os - db_name = sessions_database.name sd = _sample_data @@ -453,10 +452,7 @@ def test_batch_insert_then_read(sessions_database, ot_exporter): nth_req0 = sampling_req_id[-2] db = sessions_database - - multiplexed_enabled = ( - os.getenv("GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS", "").lower() == "true" - ) + multiplexed_enabled = is_multiplexed_enabled(TransactionType.READ_ONLY) # [A] Verify batch checkout spans # ------------------------------- @@ -690,12 +686,7 @@ def transaction_work(transaction): assert rows == [] if ot_exporter is not None: - import os - - multiplexed_enabled = ( - os.getenv("GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS", "").lower() == "true" - ) - + multiplexed_enabled = is_multiplexed_enabled(TransactionType.READ_ONLY) span_list = ot_exporter.get_finished_spans() got_span_names = [span.name for span in span_list] @@ -3332,17 +3323,13 @@ def test_interval_array_cast(transaction): def test_session_id_and_multiplexed_flag_behavior(sessions_database, ot_exporter): - import os - sd = _sample_data with sessions_database.batch() as batch: batch.delete(sd.TABLE, sd.ALL) batch.insert(sd.TABLE, sd.COLUMNS, sd.ROW_DATA) - multiplexed_enabled = ( - os.getenv("GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS", "").lower() == "true" - ) + multiplexed_enabled = is_multiplexed_enabled(TransactionType.READ_ONLY) snapshot1_session_id = None snapshot2_session_id = None diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 7c5b9691fd..258e9913f0 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -38,6 +38,7 @@ from google.cloud.spanner_v1.session import Session from google.cloud.spanner_v1.session_options import SessionOptions, TransactionType from tests._builders import build_spanner_api +from tests._helpers import is_multiplexed_enabled DML_WO_PARAM = """ DELETE FROM citizens @@ -1527,7 +1528,6 @@ def test_session_factory_w_labels(self): self.assertEqual(session.labels, labels) def test_snapshot_defaults(self): - import os from google.cloud.spanner_v1.database import SnapshotCheckout from google.cloud.spanner_v1.snapshot import Snapshot @@ -1539,9 +1539,7 @@ def test_snapshot_defaults(self): database = self._make_one(self.DATABASE_ID, instance, pool=pool) # Check if multiplexed sessions are enabled for read operations - multiplexed_enabled = ( - os.getenv("GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS") == "true" - ) + multiplexed_enabled = is_multiplexed_enabled(TransactionType.READ_ONLY) if multiplexed_enabled: # When multiplexed sessions are enabled, configure the sessions manager @@ -1575,7 +1573,6 @@ def test_snapshot_defaults(self): def test_snapshot_w_read_timestamp_and_multi_use(self): import datetime - import os from google.cloud._helpers import UTC from google.cloud.spanner_v1.database import SnapshotCheckout from google.cloud.spanner_v1.snapshot import Snapshot @@ -1589,9 +1586,7 @@ def test_snapshot_w_read_timestamp_and_multi_use(self): database = self._make_one(self.DATABASE_ID, instance, pool=pool) # Check if multiplexed sessions are enabled for read operations - multiplexed_enabled = ( - os.getenv("GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS") == "true" - ) + multiplexed_enabled = is_multiplexed_enabled(TransactionType.READ_ONLY) if multiplexed_enabled: # When multiplexed sessions are enabled, configure the sessions manager From 36c9775804f1a8c18a66824cf9983b433e6d4bed Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Wed, 4 Jun 2025 09:39:16 -0700 Subject: [PATCH 30/41] feat: Multiplexed sessions - Add unit tests for begin transaction base class, simplify `_SnapshotBase` tests, remove redundant tests. Signed-off-by: Taylor Curran --- tests/_builders.py | 31 +- tests/unit/test_snapshot.py | 630 ++++++++++++++++----------------- tests/unit/test_transaction.py | 90 +---- 3 files changed, 333 insertions(+), 418 deletions(-) diff --git a/tests/_builders.py b/tests/_builders.py index 70b92706c2..b934bc91b3 100644 --- a/tests/_builders.py +++ b/tests/_builders.py @@ -25,9 +25,12 @@ from google.cloud.spanner_v1.session import Session from google.cloud.spanner_v1.transaction import Transaction -from google.cloud.spanner_v1.types import Session as SessionPB -from google.cloud.spanner_v1.types import Transaction as TransactionPB -from google.cloud.spanner_v1.types import CommitResponse as CommitResponsePB +from google.cloud.spanner_v1.types import ( + CommitResponse as CommitResponsePB, + MultiplexedSessionPrecommitToken as PrecommitTokenPB, + Session as SessionPB, + Transaction as TransactionPB, +) from google.cloud._helpers import _datetime_to_pb_timestamp from tests._helpers import HAS_OPENTELEMETRY_INSTALLED, get_test_ot_exporter @@ -39,20 +42,22 @@ _INSTANCE_ID = "default-instance-id" _DATABASE_ID = "default-database-id" _SESSION_ID = "default-session-id" -_TRANSACTION_ID = b"default-transaction-id" _PROJECT_NAME = "projects/" + _PROJECT_ID _INSTANCE_NAME = _PROJECT_NAME + "/instances/" + _INSTANCE_ID _DATABASE_NAME = _INSTANCE_NAME + "/databases/" + _DATABASE_ID _SESSION_NAME = _DATABASE_NAME + "/sessions/" + _SESSION_ID +_TRANSACTION_ID = b"default-transaction-id" +_PRECOMMIT_TOKEN = b"default-precommit-token" +_SEQUENCE_NUMBER = -1 _TIMESTAMP = _datetime_to_pb_timestamp(datetime.now()) # Protocol buffers # ---------------- -def _build_commit_response_pb(**kwargs) -> CommitResponsePB: +def build_commit_response_pb(**kwargs) -> CommitResponsePB: """Builds and returns a commit response protocol buffer for testing using the given arguments. If an expected argument is not provided, a default value will be used.""" @@ -62,6 +67,20 @@ def _build_commit_response_pb(**kwargs) -> CommitResponsePB: return CommitResponsePB(**kwargs) +def build_precommit_token_pb(**kwargs) -> PrecommitTokenPB: + """Builds and returns a multiplexed session precommit token protocol buffer for + testing using the given arguments. If an expected argument is not provided, a + default value will be used.""" + + if "precommit_token" not in kwargs: + kwargs["precommit_token"] = _PRECOMMIT_TOKEN + + if "seq_num" not in kwargs: + kwargs["seq_num"] = _SEQUENCE_NUMBER + + return PrecommitTokenPB(**kwargs) + + def build_session_pb(**kwargs) -> SessionPB: """Builds and returns a session protocol buffer for testing using the given arguments. If an expected argument is not provided, a default value will be used.""" @@ -195,7 +214,7 @@ def build_spanner_api() -> SpannerClient: # Mock API calls with default return values. api.begin_transaction.return_value = build_transaction_pb() - api.commit.return_value = _build_commit_response_pb() + api.commit.return_value = build_commit_response_pb() api.create_session.return_value = build_session_pb() return api diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index ae8f1b72da..579294218c 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -11,15 +11,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from typing import Mapping from google.api_core import gapic_v1 import mock +from google.api_core.exceptions import InternalServerError, Aborted +from google.cloud.spanner_admin_database_v1 import Database from google.cloud.spanner_v1 import ( RequestOptions, DirectedReadOptions, BeginTransactionRequest, + TransactionSelector, +) +from google.cloud.spanner_v1.snapshot import _SnapshotBase +from tests._builders import ( + build_precommit_token_pb, + build_spanner_api, + build_session, + build_transaction_pb, ) from tests._helpers import ( OpenTelemetryBase, @@ -83,6 +93,16 @@ }, } +TRANSACTION_ID = b"transaction-id" + +PRECOMMIT_TOKEN_1 = build_precommit_token_pb(precommit_token=b"1", seq_num=1) +PRECOMMIT_TOKEN_2 = build_precommit_token_pb(precommit_token=b"2", seq_num=2) + +# Common errors for testing. +INTERNAL_SERVER_ERROR_UNEXPECTED_EOS = InternalServerError( + "Received unexpected EOS on DATA frame from server" +) + def _makeTimestamp(): import datetime @@ -119,7 +139,7 @@ def _make_txn_selector(self): return _Derived(session) - def _make_spanner_api(self): + def build_spanner_api(self): from google.cloud.spanner_v1 import SpannerClient return mock.create_autospec(SpannerClient, instance=True) @@ -161,7 +181,7 @@ def test_iteration_w_empty_raw(self): request = mock.Mock(test="test", spec=["test", "resume_token"]) restart = mock.Mock(spec=[], return_value=raw) database = _Database() - database.spanner_api = self._make_spanner_api() + database.spanner_api = build_spanner_api() session = _Session(database) derived = self._makeDerived(session) resumable = self._call_fut(derived, restart, request, session=session) @@ -183,7 +203,7 @@ def test_iteration_w_non_empty_raw(self): request = mock.Mock(test="test", spec=["test", "resume_token"]) restart = mock.Mock(spec=[], return_value=raw) database = _Database() - database.spanner_api = self._make_spanner_api() + database.spanner_api = build_spanner_api() session = _Session(database) derived = self._makeDerived(session) resumable = self._call_fut(derived, restart, request, session=session) @@ -210,7 +230,7 @@ def test_iteration_w_raw_w_resume_tken(self): request = mock.Mock(test="test", spec=["test", "resume_token"]) restart = mock.Mock(spec=[], return_value=raw) database = _Database() - database.spanner_api = self._make_spanner_api() + database.spanner_api = build_spanner_api() session = _Session(database) derived = self._makeDerived(session) resumable = self._call_fut(derived, restart, request, session=session) @@ -239,7 +259,7 @@ def test_iteration_w_raw_raising_unavailable_no_token(self): request = mock.Mock(test="test", spec=["test", "resume_token"]) restart = mock.Mock(spec=[], side_effect=[before, after]) database = _Database() - database.spanner_api = self._make_spanner_api() + database.spanner_api = build_spanner_api() session = _Session(database) derived = self._makeDerived(session) resumable = self._call_fut(derived, restart, request, session=session) @@ -249,8 +269,6 @@ def test_iteration_w_raw_raising_unavailable_no_token(self): self.assertNoSpans() def test_iteration_w_raw_raising_retryable_internal_error_no_token(self): - from google.api_core.exceptions import InternalServerError - ITEMS = ( self._make_item(0), self._make_item(1, resume_token=RESUME_TOKEN), @@ -258,15 +276,13 @@ def test_iteration_w_raw_raising_retryable_internal_error_no_token(self): ) before = _MockIterator( fail_after=True, - error=InternalServerError( - "Received unexpected EOS on DATA frame from server" - ), + error=INTERNAL_SERVER_ERROR_UNEXPECTED_EOS, ) after = _MockIterator(*ITEMS) request = mock.Mock(test="test", spec=["test", "resume_token"]) restart = mock.Mock(spec=[], side_effect=[before, after]) database = _Database() - database.spanner_api = self._make_spanner_api() + database.spanner_api = build_spanner_api() session = _Session(database) derived = self._makeDerived(session) resumable = self._call_fut(derived, restart, request, session=session) @@ -288,7 +304,7 @@ def test_iteration_w_raw_raising_non_retryable_internal_error_no_token(self): request = mock.Mock(spec=["resume_token"]) restart = mock.Mock(spec=[], side_effect=[before, after]) database = _Database() - database.spanner_api = self._make_spanner_api() + database.spanner_api = build_spanner_api() session = _Session(database) derived = self._makeDerived(session) resumable = self._call_fut(derived, restart, request, session=session) @@ -318,7 +334,7 @@ def test_iteration_w_raw_raising_unavailable(self): request = mock.Mock(test="test", spec=["test", "resume_token"]) restart = mock.Mock(spec=[], side_effect=[before, after]) database = _Database() - database.spanner_api = self._make_spanner_api() + database.spanner_api = build_spanner_api() session = _Session(database) derived = self._makeDerived(session) resumable = self._call_fut(derived, restart, request, session=session) @@ -328,23 +344,19 @@ def test_iteration_w_raw_raising_unavailable(self): self.assertNoSpans() def test_iteration_w_raw_raising_retryable_internal_error(self): - from google.api_core.exceptions import InternalServerError - FIRST = (self._make_item(0), self._make_item(1, resume_token=RESUME_TOKEN)) SECOND = (self._make_item(2),) # discarded after 503 LAST = (self._make_item(3),) before = _MockIterator( *(FIRST + SECOND), fail_after=True, - error=InternalServerError( - "Received unexpected EOS on DATA frame from server" - ), + error=INTERNAL_SERVER_ERROR_UNEXPECTED_EOS, ) after = _MockIterator(*LAST) request = mock.Mock(test="test", spec=["test", "resume_token"]) restart = mock.Mock(spec=[], side_effect=[before, after]) database = _Database() - database.spanner_api = self._make_spanner_api() + database.spanner_api = build_spanner_api() session = _Session(database) derived = self._makeDerived(session) resumable = self._call_fut(derived, restart, request, session=session) @@ -366,7 +378,7 @@ def test_iteration_w_raw_raising_non_retryable_internal_error(self): request = mock.Mock(test="test", spec=["test", "resume_token"]) restart = mock.Mock(spec=[], side_effect=[before, after]) database = _Database() - database.spanner_api = self._make_spanner_api() + database.spanner_api = build_spanner_api() session = _Session(database) derived = self._makeDerived(session) resumable = self._call_fut(derived, restart, request, session=session) @@ -395,7 +407,7 @@ def test_iteration_w_raw_raising_unavailable_after_token(self): request = mock.Mock(test="test", spec=["test", "resume_token"]) restart = mock.Mock(spec=[], side_effect=[before, after]) database = _Database() - database.spanner_api = self._make_spanner_api() + database.spanner_api = build_spanner_api() session = _Session(database) derived = self._makeDerived(session) resumable = self._call_fut(derived, restart, request, session=session) @@ -417,7 +429,7 @@ def test_iteration_w_raw_w_multiuse(self): request = ReadRequest(transaction=None) restart = mock.Mock(spec=[], return_value=before) database = _Database() - database.spanner_api = self._make_spanner_api() + database.spanner_api = build_spanner_api() session = _Session(database) derived = self._makeDerived(session) derived._multi_use = True @@ -448,7 +460,7 @@ def test_iteration_w_raw_raising_unavailable_w_multiuse(self): request = ReadRequest(transaction=None) restart = mock.Mock(spec=[], side_effect=[before, after]) database = _Database() - database.spanner_api = self._make_spanner_api() + database.spanner_api = build_spanner_api() session = _Session(database) derived = self._makeDerived(session) derived._multi_use = True @@ -486,7 +498,7 @@ def test_iteration_w_raw_raising_unavailable_after_token_w_multiuse(self): request = ReadRequest(transaction=None) restart = mock.Mock(spec=[], side_effect=[before, after]) database = _Database() - database.spanner_api = self._make_spanner_api() + database.spanner_api = build_spanner_api() session = _Session(database) derived = self._makeDerived(session) derived._multi_use = True @@ -509,22 +521,18 @@ def test_iteration_w_raw_raising_unavailable_after_token_w_multiuse(self): self.assertNoSpans() def test_iteration_w_raw_raising_retryable_internal_error_after_token(self): - from google.api_core.exceptions import InternalServerError - FIRST = (self._make_item(0), self._make_item(1, resume_token=RESUME_TOKEN)) SECOND = (self._make_item(2), self._make_item(3)) before = _MockIterator( *FIRST, fail_after=True, - error=InternalServerError( - "Received unexpected EOS on DATA frame from server" - ), + error=INTERNAL_SERVER_ERROR_UNEXPECTED_EOS, ) after = _MockIterator(*SECOND) request = mock.Mock(test="test", spec=["test", "resume_token"]) restart = mock.Mock(spec=[], side_effect=[before, after]) database = _Database() - database.spanner_api = self._make_spanner_api() + database.spanner_api = build_spanner_api() session = _Session(database) derived = self._makeDerived(session) resumable = self._call_fut(derived, restart, request, session=session) @@ -545,7 +553,7 @@ def test_iteration_w_raw_raising_non_retryable_internal_error_after_token(self): request = mock.Mock(test="test", spec=["test", "resume_token"]) restart = mock.Mock(spec=[], side_effect=[before, after]) database = _Database() - database.spanner_api = self._make_spanner_api() + database.spanner_api = build_spanner_api() session = _Session(database) derived = self._makeDerived(session) resumable = self._call_fut(derived, restart, request, session=session) @@ -569,7 +577,7 @@ def test_iteration_w_span_creation(self): request = mock.Mock(test="test", spec=["test", "resume_token"]) restart = mock.Mock(spec=[], return_value=raw) database = _Database() - database.spanner_api = self._make_spanner_api() + database.spanner_api = build_spanner_api() session = _Session(database) derived = self._makeDerived(session) resumable = self._call_fut( @@ -599,7 +607,7 @@ def test_iteration_w_multiple_span_creation(self): restart = mock.Mock(spec=[], side_effect=[before, after]) name = "TestSpan" database = _Database() - database.spanner_api = self._make_spanner_api() + database.spanner_api = build_spanner_api() session = _Session(database) derived = self._makeDerived(session) resumable = self._call_fut( @@ -624,52 +632,34 @@ def test_iteration_w_multiple_span_creation(self): class Test_SnapshotBase(OpenTelemetryBase): - PROJECT_ID = "project-id" - INSTANCE_ID = "instance-id" - INSTANCE_NAME = "projects/" + PROJECT_ID + "/instances/" + INSTANCE_ID - DATABASE_ID = "database-id" - DATABASE_NAME = INSTANCE_NAME + "/databases/" + DATABASE_ID - SESSION_ID = "session-id" - SESSION_NAME = DATABASE_NAME + "/sessions/" + SESSION_ID + class _Derived(_SnapshotBase): + """A minimally-implemented _SnapshotBase-derived class for testing""" - def _getTargetClass(self): - from google.cloud.spanner_v1.snapshot import _SnapshotBase + # Use a simplified implementation of _make_txn_selector + # that always returns the same transaction selector. + TRANSACTION_SELECTOR = TransactionSelector() - return _SnapshotBase + def _make_txn_selector(self) -> TransactionSelector: + return self.TRANSACTION_SELECTOR - def _make_one(self, session): - return self._getTargetClass()(session) + @staticmethod + def _build_derived(session=None, multi_use=False, read_only=True): + """Builds and returns an instance of a minimally-implemented + _SnapshotBase-derived class for testing.""" - def _makeDerived(self, session): - class _Derived(self._getTargetClass()): - _transaction_id = None - _multi_use = False + session = session or build_session() + if session.session_id is None: + session.create() - def _make_txn_selector(self): - from google.cloud.spanner_v1 import ( - TransactionOptions, - TransactionSelector, - ) - - if self._transaction_id: - return TransactionSelector(id=self._transaction_id) - options = TransactionOptions( - read_only=TransactionOptions.ReadOnly(strong=True) - ) - if self._multi_use: - return TransactionSelector(begin=options) - return TransactionSelector(single_use=options) - - return _Derived(session) - - def _make_spanner_api(self): - from google.cloud.spanner_v1 import SpannerClient + derived = Test_SnapshotBase._Derived(session=session) + derived._multi_use = multi_use + derived._read_only = read_only - return mock.create_autospec(SpannerClient, instance=True) + return derived def test_ctor(self): session = _Session() - base = self._make_one(session) + base = _SnapshotBase(session) self.assertIs(base._session, session) self.assertEqual(base._execute_sql_request_count, 0) @@ -677,19 +667,175 @@ def test_ctor(self): def test__make_txn_selector_virtual(self): session = _Session() - base = self._make_one(session) + base = _SnapshotBase(session) with self.assertRaises(NotImplementedError): base._make_txn_selector() + def test_begin_error_not_multi_use(self): + derived = self._build_derived(multi_use=False) + + self.reset() + with self.assertRaises(ValueError): + derived.begin() + + self.assertNoSpans() + + def test_begin_error_already_begun(self): + derived = self._build_derived(multi_use=True) + derived.begin() + + self.reset() + with self.assertRaises(ValueError): + derived.begin() + + self.assertNoSpans() + + def test_begin_error_other(self): + derived = self._build_derived(multi_use=True) + + database = derived._session._database + begin_transaction = database.spanner_api.begin_transaction + begin_transaction.side_effect = RuntimeError() + + self.reset() + with self.assertRaises(RuntimeError): + derived.begin() + + if not HAS_OPENTELEMETRY_INSTALLED: + return + + self.assertSpanAttributes( + name="CloudSpanner._Derived.begin", + status=StatusCode.ERROR, + attributes=_build_span_attributes(database), + ) + + def test_begin_read_write(self): + derived = self._build_derived(multi_use=True, read_only=False) + + begin_transaction = derived._session._database.spanner_api.begin_transaction + begin_transaction.return_value = build_transaction_pb() + + self._execute_begin(derived) + + def test_begin_read_only(self): + derived = self._build_derived(multi_use=True, read_only=True) + + begin_transaction = derived._session._database.spanner_api.begin_transaction + begin_transaction.return_value = build_transaction_pb() + + self._execute_begin(derived) + + def test_begin_precommit_token(self): + derived = self._build_derived(multi_use=True) + + begin_transaction = derived._session._database.spanner_api.begin_transaction + begin_transaction.return_value = build_transaction_pb( + precommit_token=PRECOMMIT_TOKEN_1 + ) + + self._execute_begin(derived) + + def test_begin_retry_for_internal_server_error(self): + derived = self._build_derived(multi_use=True) + + begin_transaction = derived._session._database.spanner_api.begin_transaction + begin_transaction.side_effect = [ + INTERNAL_SERVER_ERROR_UNEXPECTED_EOS, + build_transaction_pb(), + ] + + self._execute_begin(derived, attempts=2) + + expected_statuses = [ + ( + "Transaction Begin Attempt Failed. Retrying", + {"attempt": 1, "sleep_seconds": 4}, + ) + ] + actual_statuses = self.finished_spans_events_statuses() + self.assertEqual(expected_statuses, actual_statuses) + + def test_begin_retry_for_aborted(self): + derived = self._build_derived(multi_use=True) + + begin_transaction = derived._session._database.spanner_api.begin_transaction + begin_transaction.side_effect = [ + Aborted("test"), + build_transaction_pb(), + ] + + self._execute_begin(derived, attempts=2) + + expected_statuses = [ + ( + "Transaction Begin Attempt Failed. Retrying", + {"attempt": 1, "sleep_seconds": 4}, + ) + ] + actual_statuses = self.finished_spans_events_statuses() + self.assertEqual(expected_statuses, actual_statuses) + + def _execute_begin(self, derived: _Derived, attempts: int = 1): + """Helper for testing _SnapshotBase.begin(). Executes method and verifies + transaction state, begin transaction API call, and span attributes and events. + """ + + session = derived._session + database = session._database + + # Clear spans. + self.reset() + + transaction_id = derived.begin() + + # Verify transaction state. + begin_transaction = database.spanner_api.begin_transaction + expected_transaction_id = begin_transaction.return_value.id or None + expected_precommit_token = ( + begin_transaction.return_value.precommit_token or None + ) + + self.assertEqual(transaction_id, expected_transaction_id) + self.assertEqual(derived._transaction_id, expected_transaction_id) + self.assertEqual(derived._precommit_token, expected_precommit_token) + + # Verify begin transaction API call. + self.assertEqual(begin_transaction.call_count, attempts) + + expected_metadata = [ + ("google-cloud-resource-prefix", database.name), + ("x-goog-spanner-request-id", _build_request_id(database, attempts)), + ] + if not derived._read_only and database._route_to_leader_enabled: + expected_metadata.insert(-1, ("x-goog-spanner-route-to-leader", "true")) + + database.spanner_api.begin_transaction.assert_called_with( + request=BeginTransactionRequest( + session=session.name, options=self._Derived.TRANSACTION_SELECTOR.begin + ), + metadata=expected_metadata, + ) + + if not HAS_OPENTELEMETRY_INSTALLED: + return + + # Verify span attributes. + expected_span_name = "CloudSpanner._Derived.begin" + self.assertSpanAttributes( + name=expected_span_name, + attributes=_build_span_attributes(database, attempts), + ) + def test_read_other_error(self): from google.cloud.spanner_v1.keyset import KeySet keyset = KeySet(all_=True) database = _Database() - database.spanner_api = self._make_spanner_api() + database.spanner_api = build_spanner_api() database.spanner_api.streaming_read.side_effect = RuntimeError() session = _Session(database) - derived = self._makeDerived(session) + derived = self._build_derived(session) with self.assertRaises(RuntimeError): list(derived.read(TABLE_NAME, COLUMNS, keyset)) @@ -706,7 +852,7 @@ def test_read_other_error(self): ), ) - def _read_helper( + def _execute_read( self, multi_use, first=True, @@ -718,16 +864,16 @@ def _read_helper( directed_read_options=None, directed_read_options_at_client_level=None, ): + """Helper for testing _SnapshotBase.read(). Executes method and verifies + transaction state, begin transaction API call, and span attributes and events. + """ + from google.protobuf.struct_pb2 import Struct from google.cloud.spanner_v1 import ( PartialResultSet, ResultSetMetadata, ResultSetStats, ) - from google.cloud.spanner_v1 import ( - TransactionSelector, - TransactionOptions, - ) from google.cloud.spanner_v1 import ReadRequest from google.cloud.spanner_v1 import Type, StructType from google.cloud.spanner_v1 import TypeCode @@ -759,10 +905,10 @@ def _read_helper( database = _Database( directed_read_options=directed_read_options_at_client_level ) - api = database.spanner_api = self._make_spanner_api() + api = database.spanner_api = build_spanner_api() api.streaming_read.return_value = _MockIterator(*result_sets) session = _Session(database) - derived = self._makeDerived(session) + derived = self._build_derived(session) derived._multi_use = multi_use derived._read_request_count = count if not first: @@ -804,18 +950,6 @@ def _read_helper( self.assertEqual(result_set.metadata, metadata_pb) self.assertEqual(result_set.stats, stats_pb) - txn_options = TransactionOptions( - read_only=TransactionOptions.ReadOnly(strong=True) - ) - - if multi_use: - if first: - expected_transaction = TransactionSelector(begin=txn_options) - else: - expected_transaction = TransactionSelector(id=TXN_ID) - else: - expected_transaction = TransactionSelector(single_use=txn_options) - if partition is not None: expected_limit = 0 else: @@ -832,11 +966,11 @@ def _read_helper( ) expected_request = ReadRequest( - session=self.SESSION_NAME, + session=session.name, table=TABLE_NAME, columns=COLUMNS, key_set=keyset._to_pb(), - transaction=expected_transaction, + transaction=self._Derived.TRANSACTION_SELECTOR, index=INDEX, limit=expected_limit, partition_token=partition, @@ -868,76 +1002,76 @@ def _read_helper( ) def test_read_wo_multi_use(self): - self._read_helper(multi_use=False) + self._execute_read(multi_use=False) def test_read_w_request_tag_success(self): request_options = RequestOptions( request_tag="tag-1", ) - self._read_helper(multi_use=False, request_options=request_options) + self._execute_read(multi_use=False, request_options=request_options) def test_read_w_transaction_tag_success(self): request_options = RequestOptions( transaction_tag="tag-1-1", ) - self._read_helper(multi_use=False, request_options=request_options) + self._execute_read(multi_use=False, request_options=request_options) def test_read_w_request_and_transaction_tag_success(self): request_options = RequestOptions( request_tag="tag-1", transaction_tag="tag-1-1", ) - self._read_helper(multi_use=False, request_options=request_options) + self._execute_read(multi_use=False, request_options=request_options) def test_read_w_request_and_transaction_tag_dictionary_success(self): request_options = {"request_tag": "tag-1", "transaction_tag": "tag-1-1"} - self._read_helper(multi_use=False, request_options=request_options) + self._execute_read(multi_use=False, request_options=request_options) def test_read_w_incorrect_tag_dictionary_error(self): request_options = {"incorrect_tag": "tag-1-1"} with self.assertRaises(ValueError): - self._read_helper(multi_use=False, request_options=request_options) + self._execute_read(multi_use=False, request_options=request_options) def test_read_wo_multi_use_w_read_request_count_gt_0(self): with self.assertRaises(ValueError): - self._read_helper(multi_use=False, count=1) + self._execute_read(multi_use=False, count=1) def test_read_w_multi_use_wo_first(self): - self._read_helper(multi_use=True, first=False) + self._execute_read(multi_use=True, first=False) def test_read_w_multi_use_wo_first_w_count_gt_0(self): - self._read_helper(multi_use=True, first=False, count=1) + self._execute_read(multi_use=True, first=False, count=1) def test_read_w_multi_use_w_first_w_partition(self): PARTITION = b"FADEABED" - self._read_helper(multi_use=True, first=True, partition=PARTITION) + self._execute_read(multi_use=True, first=True, partition=PARTITION) def test_read_w_multi_use_w_first_w_count_gt_0(self): with self.assertRaises(ValueError): - self._read_helper(multi_use=True, first=True, count=1) + self._execute_read(multi_use=True, first=True, count=1) def test_read_w_timeout_param(self): - self._read_helper(multi_use=True, first=False, timeout=2.0) + self._execute_read(multi_use=True, first=False, timeout=2.0) def test_read_w_retry_param(self): - self._read_helper(multi_use=True, first=False, retry=Retry(deadline=60)) + self._execute_read(multi_use=True, first=False, retry=Retry(deadline=60)) def test_read_w_timeout_and_retry_params(self): - self._read_helper( + self._execute_read( multi_use=True, first=False, retry=Retry(deadline=60), timeout=2.0 ) def test_read_w_directed_read_options(self): - self._read_helper(multi_use=False, directed_read_options=DIRECTED_READ_OPTIONS) + self._execute_read(multi_use=False, directed_read_options=DIRECTED_READ_OPTIONS) def test_read_w_directed_read_options_at_client_level(self): - self._read_helper( + self._execute_read( multi_use=False, directed_read_options_at_client_level=DIRECTED_READ_OPTIONS_FOR_CLIENT, ) def test_read_w_directed_read_options_override(self): - self._read_helper( + self._execute_read( multi_use=False, directed_read_options=DIRECTED_READ_OPTIONS, directed_read_options_at_client_level=DIRECTED_READ_OPTIONS_FOR_CLIENT, @@ -945,10 +1079,10 @@ def test_read_w_directed_read_options_override(self): def test_execute_sql_other_error(self): database = _Database() - database.spanner_api = self._make_spanner_api() + database.spanner_api = build_spanner_api() database.spanner_api.execute_streaming_sql.side_effect = RuntimeError() session = _Session(database) - derived = self._makeDerived(session) + derived = self._build_derived(session) with self.assertRaises(RuntimeError): list(derived.execute_sql(SQL_QUERY)) @@ -979,16 +1113,16 @@ def _execute_sql_helper( directed_read_options=None, directed_read_options_at_client_level=None, ): + """Helper for testing _SnapshotBase.execute_sql(). Executes method and verifies + transaction state, begin transaction API call, and span attributes and events. + """ + from google.protobuf.struct_pb2 import Struct from google.cloud.spanner_v1 import ( PartialResultSet, ResultSetMetadata, ResultSetStats, ) - from google.cloud.spanner_v1 import ( - TransactionSelector, - TransactionOptions, - ) from google.cloud.spanner_v1 import ExecuteSqlRequest from google.cloud.spanner_v1 import Type, StructType from google.cloud.spanner_v1 import TypeCode @@ -1021,11 +1155,10 @@ def _execute_sql_helper( database = _Database( directed_read_options=directed_read_options_at_client_level ) - api = database.spanner_api = self._make_spanner_api() + api = database.spanner_api = build_spanner_api() api.execute_streaming_sql.return_value = iterator session = _Session(database) - derived = self._makeDerived(session) - derived._multi_use = multi_use + derived = self._build_derived(session, multi_use=multi_use) derived._read_request_count = count derived._execute_sql_request_count = sql_count if not first: @@ -1055,18 +1188,6 @@ def _execute_sql_helper( self.assertEqual(result_set.metadata, metadata_pb) self.assertEqual(result_set.stats, stats_pb) - txn_options = TransactionOptions( - read_only=TransactionOptions.ReadOnly(strong=True) - ) - - if multi_use: - if first: - expected_transaction = TransactionSelector(begin=txn_options) - else: - expected_transaction = TransactionSelector(id=TXN_ID) - else: - expected_transaction = TransactionSelector(single_use=txn_options) - expected_params = Struct( fields={key: _make_value_pb(value) for (key, value) in PARAMS.items()} ) @@ -1089,9 +1210,9 @@ def _execute_sql_helper( ) expected_request = ExecuteSqlRequest( - session=self.SESSION_NAME, + session=session.name, sql=SQL_QUERY_WITH_PARAM, - transaction=expected_transaction, + transaction=self._Derived.TRANSACTION_SELECTOR, params=expected_params, param_types=PARAM_TYPES, query_mode=MODE, @@ -1233,7 +1354,6 @@ def _partition_read_helper( from google.cloud.spanner_v1 import PartitionReadRequest from google.cloud.spanner_v1 import PartitionResponse from google.cloud.spanner_v1 import Transaction - from google.cloud.spanner_v1 import TransactionSelector keyset = KeySet(all_=True) new_txn_id = b"ABECAB91" @@ -1247,10 +1367,10 @@ def _partition_read_helper( transaction=Transaction(id=new_txn_id), ) database = _Database() - api = database.spanner_api = self._make_spanner_api() + api = database.spanner_api = build_spanner_api() api.partition_read.return_value = response session = _Session(database) - derived = self._makeDerived(session) + derived = self._build_derived(session) derived._multi_use = multi_use if w_txn: derived._transaction_id = TXN_ID @@ -1269,18 +1389,16 @@ def _partition_read_helper( self.assertEqual(tokens, [token_1, token_2]) - expected_txn_selector = TransactionSelector(id=TXN_ID) - expected_partition_options = PartitionOptions( partition_size_bytes=size, max_partitions=max_partitions ) expected_request = PartitionReadRequest( - session=self.SESSION_NAME, + session=session.name, table=TABLE_NAME, columns=COLUMNS, key_set=keyset._to_pb(), - transaction=expected_txn_selector, + transaction=self._Derived.TRANSACTION_SELECTOR, index=index, partition_options=expected_partition_options, ) @@ -1326,11 +1444,10 @@ def test_partition_read_other_error(self): keyset = KeySet(all_=True) database = _Database() - database.spanner_api = self._make_spanner_api() + database.spanner_api = build_spanner_api() database.spanner_api.partition_read.side_effect = RuntimeError() session = _Session(database) - derived = self._makeDerived(session) - derived._multi_use = True + derived = self._build_derived(session, multi_use=True) derived._transaction_id = TXN_ID with self.assertRaises(RuntimeError): @@ -1350,14 +1467,13 @@ def test_partition_read_other_error(self): def test_partition_read_w_retry(self): from google.cloud.spanner_v1.keyset import KeySet - from google.api_core.exceptions import InternalServerError from google.cloud.spanner_v1 import Partition from google.cloud.spanner_v1 import PartitionResponse from google.cloud.spanner_v1 import Transaction keyset = KeySet(all_=True) database = _Database() - api = database.spanner_api = self._make_spanner_api() + api = database.spanner_api = build_spanner_api() new_txn_id = b"ABECAB91" token_1 = b"FACE0FFF" token_2 = b"BADE8CAF" @@ -1369,12 +1485,12 @@ def test_partition_read_w_retry(self): transaction=Transaction(id=new_txn_id), ) database.spanner_api.partition_read.side_effect = [ - InternalServerError("Received unexpected EOS on DATA frame from server"), + INTERNAL_SERVER_ERROR_UNEXPECTED_EOS, response, ] session = _Session(database) - derived = self._makeDerived(session) + derived = self._build_derived(session) derived._multi_use = True derived._transaction_id = TXN_ID @@ -1413,13 +1529,16 @@ def _partition_query_helper( retry=gapic_v1.method.DEFAULT, timeout=gapic_v1.method.DEFAULT, ): + """Helper for testing _SnapshotBase.partition_query(). Executes method and verifies + transaction state, begin transaction API call, and span attributes and events. + """ + from google.protobuf.struct_pb2 import Struct from google.cloud.spanner_v1 import Partition from google.cloud.spanner_v1 import PartitionOptions from google.cloud.spanner_v1 import PartitionQueryRequest from google.cloud.spanner_v1 import PartitionResponse from google.cloud.spanner_v1 import Transaction - from google.cloud.spanner_v1 import TransactionSelector from google.cloud.spanner_v1._helpers import _make_value_pb new_txn_id = b"ABECAB91" @@ -1433,11 +1552,10 @@ def _partition_query_helper( transaction=Transaction(id=new_txn_id), ) database = _Database() - api = database.spanner_api = self._make_spanner_api() + api = database.spanner_api = build_spanner_api() api.partition_query.return_value = response session = _Session(database) - derived = self._makeDerived(session) - derived._multi_use = multi_use + derived = self._build_derived(session, multi_use=multi_use) if w_txn: derived._transaction_id = TXN_ID @@ -1459,16 +1577,14 @@ def _partition_query_helper( fields={key: _make_value_pb(value) for (key, value) in PARAMS.items()} ) - expected_txn_selector = TransactionSelector(id=TXN_ID) - expected_partition_options = PartitionOptions( partition_size_bytes=size, max_partitions=max_partitions ) expected_request = PartitionQueryRequest( - session=self.SESSION_NAME, + session=session.name, sql=SQL_QUERY_WITH_PARAM, - transaction=expected_txn_selector, + transaction=self._Derived.TRANSACTION_SELECTOR, params=expected_params, param_types=PARAM_TYPES, partition_options=expected_partition_options, @@ -1502,11 +1618,10 @@ def _partition_query_helper( def test_partition_query_other_error(self): database = _Database() - database.spanner_api = self._make_spanner_api() + database.spanner_api = build_spanner_api() database.spanner_api.partition_query.side_effect = RuntimeError() session = _Session(database) - derived = self._makeDerived(session) - derived._multi_use = True + derived = self._build_derived(session, multi_use=True) derived._transaction_id = TXN_ID with self.assertRaises(RuntimeError): @@ -1570,11 +1685,6 @@ def _getTargetClass(self): def _make_one(self, *args, **kwargs): return self._getTargetClass()(*args, **kwargs) - def _make_spanner_api(self): - from google.cloud.spanner_v1 import SpannerClient - - return mock.create_autospec(SpannerClient, instance=True) - def _makeDuration(self, seconds=1, microseconds=0): import datetime @@ -1795,165 +1905,6 @@ def test__make_txn_selector_w_exact_staleness_w_multi_use(self): type(options).pb(options).read_only.exact_staleness.nanos, 123456000 ) - def test_begin_wo_multi_use(self): - session = _Session() - snapshot = self._make_one(session) - with self.assertRaises(ValueError): - snapshot.begin() - - def test_begin_w_read_request_count_gt_0(self): - session = _Session() - snapshot = self._make_one(session, multi_use=True) - snapshot._read_request_count = 1 - with self.assertRaises(ValueError): - snapshot.begin() - - def test_begin_w_existing_txn_id(self): - session = _Session() - snapshot = self._make_one(session, multi_use=True) - snapshot._transaction_id = TXN_ID - with self.assertRaises(ValueError): - snapshot.begin() - - def test_begin_w_other_error(self): - database = _Database() - database.spanner_api = self._make_spanner_api() - database.spanner_api.begin_transaction.side_effect = RuntimeError() - timestamp = _makeTimestamp() - session = _Session(database) - snapshot = self._make_one(session, read_timestamp=timestamp, multi_use=True) - - with self.assertRaises(RuntimeError): - snapshot.begin() - - if not HAS_OPENTELEMETRY_INSTALLED: - return - - span_list = self.get_finished_spans() - got_span_names = [span.name for span in span_list] - want_span_names = ["CloudSpanner.Snapshot.begin"] - assert got_span_names == want_span_names - - req_id = f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1" - self.assertSpanAttributes( - "CloudSpanner.Snapshot.begin", - status=StatusCode.ERROR, - attributes=dict(BASE_ATTRIBUTES, x_goog_spanner_request_id=req_id), - ) - - def test_begin_w_retry(self): - from google.cloud.spanner_v1 import ( - Transaction as TransactionPB, - ) - from google.api_core.exceptions import InternalServerError - - database = _Database() - api = database.spanner_api = self._make_spanner_api() - database.spanner_api.begin_transaction.side_effect = [ - InternalServerError("Received unexpected EOS on DATA frame from server"), - TransactionPB(id=TXN_ID), - ] - timestamp = _makeTimestamp() - session = _Session(database) - snapshot = self._make_one(session, read_timestamp=timestamp, multi_use=True) - - snapshot.begin() - self.assertEqual(api.begin_transaction.call_count, 2) - - def test_begin_ok_exact_staleness(self): - from google.protobuf.duration_pb2 import Duration - from google.cloud.spanner_v1 import ( - Transaction as TransactionPB, - TransactionOptions, - ) - - transaction_pb = TransactionPB(id=TXN_ID) - database = _Database() - api = database.spanner_api = self._make_spanner_api() - api.begin_transaction.return_value = transaction_pb - duration = self._makeDuration(seconds=SECONDS, microseconds=MICROS) - session = _Session(database) - snapshot = self._make_one(session, exact_staleness=duration, multi_use=True) - - txn_id = snapshot.begin() - - self.assertEqual(txn_id, TXN_ID) - self.assertEqual(snapshot._transaction_id, TXN_ID) - - expected_duration = Duration(seconds=SECONDS, nanos=MICROS * 1000) - expected_txn_options = TransactionOptions( - read_only=TransactionOptions.ReadOnly( - exact_staleness=expected_duration, return_read_timestamp=True - ) - ) - - req_id = f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1" - api.begin_transaction.assert_called_once_with( - request=BeginTransactionRequest( - session=session.name, - options=expected_txn_options, - mutation_key=None, - ), - metadata=[ - ("google-cloud-resource-prefix", database.name), - ( - "x-goog-spanner-request-id", - req_id, - ), - ], - ) - - self.assertSpanAttributes( - "CloudSpanner.Snapshot.begin", - status=StatusCode.OK, - attributes=dict(BASE_ATTRIBUTES, x_goog_spanner_request_id=req_id), - ) - - def test_begin_ok_exact_strong(self): - from google.cloud.spanner_v1 import ( - Transaction as TransactionPB, - TransactionOptions, - ) - - transaction_pb = TransactionPB(id=TXN_ID) - database = _Database() - api = database.spanner_api = self._make_spanner_api() - api.begin_transaction.return_value = transaction_pb - session = _Session(database) - snapshot = self._make_one(session, multi_use=True) - - txn_id = snapshot.begin() - - self.assertEqual(txn_id, TXN_ID) - self.assertEqual(snapshot._transaction_id, TXN_ID) - - expected_txn_options = TransactionOptions( - read_only=TransactionOptions.ReadOnly( - strong=True, return_read_timestamp=True - ) - ) - - req_id = f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1" - api.begin_transaction.assert_called_once_with( - request=BeginTransactionRequest( - session=session.name, - options=expected_txn_options, - ), - metadata=[ - ("google-cloud-resource-prefix", database.name), - ( - "x-goog-spanner-request-id", - req_id, - ), - ], - ) - - self.assertSpanAttributes( - "CloudSpanner.Snapshot.begin", - status=StatusCode.OK, - attributes=dict(BASE_ATTRIBUTES, x_goog_spanner_request_id=req_id), - ) - class _Client(object): NTH_CLIENT = AtomicCounter() @@ -2041,3 +1992,34 @@ def __next__(self): raise next = __next__ + + +def _build_span_attributes( + database: Database, attempt: int = 1, **extra_attributes +) -> Mapping[str, str]: + """Builds the attributes for spans using the given database and extra attributes.""" + + attributes = enrich_with_otel_scope( + { + "db.type": "spanner", + "db.url": "spanner.googleapis.com", + "db.instance": database.name, + "net.host.name": "spanner.googleapis.com", + "gcp.client.service": "spanner", + "gcp.client.version": LIB_VERSION, + "gcp.client.repo": "googleapis/python-spanner", + "x_goog_spanner_request_id": _build_request_id(database, attempt), + } + ) + + if extra_attributes: + attributes.update(extra_attributes) + + return attributes + + +def _build_request_id(database: Database, attempt: int = 1) -> str: + """Builds a request ID for an Spanner Client API request with the given database and attempt number.""" + + client = database._instance._client + return f"1.{REQ_RAND_PROCESS_ID}.{client._nth_client_id}.{database._channel_id}.{client._nth_request.value}.{attempt}" diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index 201cf5de6c..1d28629062 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -15,12 +15,7 @@ import mock -from google.cloud.spanner_v1 import ( - RequestOptions, - BeginTransactionRequest, - TransactionOptions, - CommitRequest, -) +from google.cloud.spanner_v1 import RequestOptions, CommitRequest from google.cloud.spanner_v1 import DefaultTransactionOptions from google.cloud.spanner_v1 import Type from google.cloud.spanner_v1 import TypeCode @@ -32,7 +27,7 @@ ) from google.cloud.spanner_v1.database import Database from google.cloud.spanner_v1.request_id_header import REQ_RAND_PROCESS_ID -from tests._builders import build_transaction, build_transaction_pb +from tests._builders import build_transaction from tests._helpers import ( HAS_OPENTELEMETRY_INSTALLED, @@ -109,15 +104,6 @@ def test__make_txn_selector(self): selector = transaction._make_txn_selector() self.assertEqual(selector.id, self.TRANSACTION_ID) - def test_begin_already_begun(self): - session = _Session() - transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID - with self.assertRaises(ValueError): - transaction.begin() - - self.assertNoSpans() - def test_begin_already_rolled_back(self): session = _Session() transaction = self._make_one(session) @@ -136,78 +122,6 @@ def test_begin_already_committed(self): self.assertNoSpans() - def test_begin_w_other_error(self): - database = _Database() - database.spanner_api = self._make_spanner_api() - database.spanner_api.begin_transaction.side_effect = RuntimeError() - session = _Session(database) - transaction = self._make_one(session) - - with self.assertRaises(RuntimeError): - transaction.begin() - - req_id = f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1" - self.assertSpanAttributes( - "CloudSpanner.Transaction.begin", - status=StatusCode.ERROR, - attributes=self._build_span_attributes( - database, x_goog_spanner_request_id=req_id - ), - ) - - def test_begin_ok(self): - transaction = build_transaction() - session = transaction._session - database = session._database - - begin_transaction = database.spanner_api.begin_transaction - begin_transaction.return_value = build_transaction_pb(id=self.TRANSACTION_ID) - - transaction_id = transaction.begin() - - self.assertEqual(transaction_id, self.TRANSACTION_ID) - self.assertEqual(transaction._transaction_id, self.TRANSACTION_ID) - - request_id = self._build_request_id(database) - - begin_transaction.assert_called_once_with( - request=BeginTransactionRequest( - session=session.name, - options=TransactionOptions(read_write=TransactionOptions.ReadWrite()), - ), - metadata=[ - ("google-cloud-resource-prefix", database.name), - ("x-goog-spanner-route-to-leader", "true"), - ("x-goog-spanner-request-id", request_id), - ], - ) - - self.assertSpanAttributes( - "CloudSpanner.Transaction.begin", - attributes=self._build_span_attributes( - database, x_goog_spanner_request_id=request_id - ), - ) - - def test_begin_w_retry(self): - from google.cloud.spanner_v1 import ( - Transaction as TransactionPB, - ) - from google.api_core.exceptions import InternalServerError - - database = _Database() - api = database.spanner_api = self._make_spanner_api() - database.spanner_api.begin_transaction.side_effect = [ - InternalServerError("Received unexpected EOS on DATA frame from server"), - TransactionPB(id=self.TRANSACTION_ID), - ] - - session = _Session(database) - transaction = self._make_one(session) - transaction.begin() - - self.assertEqual(api.begin_transaction.call_count, 2) - def test_rollback_not_begun(self): database = _Database() api = database.spanner_api = self._make_spanner_api() From 418eddaccefd961499ff45677b2478592894d79d Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Wed, 4 Jun 2025 13:14:04 -0700 Subject: [PATCH 31/41] feat: Multiplexed sessions - Attempt to fix `test_transaction_read_and_insert_then_rollback` and add `build_request_id` helper method, fix `test_snapshot` and `test_transaction` failures. Signed-off-by: Taylor Curran --- google/cloud/spanner_v1/request_id_header.py | 6 +- tests/system/test_session_api.py | 305 ++++++++----------- tests/unit/test_snapshot.py | 27 +- tests/unit/test_transaction.py | 12 +- 4 files changed, 152 insertions(+), 198 deletions(-) diff --git a/google/cloud/spanner_v1/request_id_header.py b/google/cloud/spanner_v1/request_id_header.py index c095bc88e2..b540b725f5 100644 --- a/google/cloud/spanner_v1/request_id_header.py +++ b/google/cloud/spanner_v1/request_id_header.py @@ -39,7 +39,7 @@ def generate_rand_uint64(): def with_request_id( client_id, channel_id, nth_request, attempt, other_metadata=[], span=None ): - req_id = f"{REQ_ID_VERSION}.{REQ_RAND_PROCESS_ID}.{client_id}.{channel_id}.{nth_request}.{attempt}" + req_id = build_request_id(client_id, channel_id, nth_request, attempt) all_metadata = (other_metadata or []).copy() all_metadata.append((REQ_ID_HEADER_KEY, req_id)) @@ -49,6 +49,10 @@ def with_request_id( return all_metadata +def build_request_id(client_id, channel_id, nth_request, attempt): + return f"{REQ_ID_VERSION}.{REQ_RAND_PROCESS_ID}.{client_id}.{channel_id}.{nth_request}.{attempt}" + + def parse_request_id(request_id_str): splits = request_id_str.split(".") version, rand_process_id, client_id, channel_id, nth_request, nth_attempt = list( diff --git a/tests/system/test_session_api.py b/tests/system/test_session_api.py index 83df30a6b8..957ac0bdb5 100644 --- a/tests/system/test_session_api.py +++ b/tests/system/test_session_api.py @@ -15,6 +15,7 @@ import collections import datetime import decimal + import math import struct import threading @@ -28,6 +29,8 @@ from google.cloud import spanner_v1 from google.cloud.spanner_admin_database_v1 import DatabaseDialect from google.cloud._helpers import UTC + +from google.cloud.spanner_v1._helpers import AtomicCounter from google.cloud.spanner_v1.data_types import JsonObject from google.cloud.spanner_v1.session_options import TransactionType from .testdata import singer_pb2 @@ -37,6 +40,7 @@ from google.cloud.spanner_v1.request_id_header import ( REQ_RAND_PROCESS_ID, parse_request_id, + build_request_id, ) from .._helpers import is_multiplexed_enabled @@ -687,210 +691,147 @@ def transaction_work(transaction): if ot_exporter is not None: multiplexed_enabled = is_multiplexed_enabled(TransactionType.READ_ONLY) + span_list = ot_exporter.get_finished_spans() - got_span_names = [span.name for span in span_list] - if multiplexed_enabled: - # With multiplexed sessions enabled: - # - Batch operations use multiplexed sessions (GetSession) - # - run_in_transaction uses regular sessions (GetSession) - # - Snapshot (read-only) re-use existing multiplexed sessions - # Note: Session creation span may not appear if session is reused from pool - expected_span_names = [ - "CloudSpanner.Batch.commit", # Batch commit - "CloudSpanner.GetSession", # Transaction session - "CloudSpanner.Transaction.read", # First read - "CloudSpanner.Transaction.read", # Second read - "CloudSpanner.Transaction.rollback", # Rollback due to exception - "CloudSpanner.Session.run_in_transaction", # Session transaction wrapper - "CloudSpanner.Database.run_in_transaction", # Database transaction wrapper - "CloudSpanner.Snapshot.read", # Snapshot read - ] - # Check if we have a multiplexed session creation span - if "CloudSpanner.CreateMultiplexedSession" in got_span_names: - expected_span_names.insert(-1, "CloudSpanner.CreateMultiplexedSession") - else: - # Without multiplexed sessions, all operations use regular sessions - expected_span_names = [ - "CloudSpanner.GetSession", # Batch operation - "CloudSpanner.Batch.commit", # Batch commit - "CloudSpanner.GetSession", # Transaction session - "CloudSpanner.Transaction.read", # First read - "CloudSpanner.Transaction.read", # Second read - "CloudSpanner.Transaction.rollback", # Rollback due to exception - "CloudSpanner.Session.run_in_transaction", # Session transaction wrapper - "CloudSpanner.Database.run_in_transaction", # Database transaction wrapper - "CloudSpanner.Snapshot.read", # Snapshot read - ] - # Check if we have a session creation span for snapshot - if len(got_span_names) > len(expected_span_names): - expected_span_names.insert(-1, "CloudSpanner.GetSession") + # Determine the first request ID from the spans, + # and use an atomic counter to track it. + first_request_id = span_list[0].attributes["x_goog_spanner_request_id"] + first_request_id = (parse_request_id(first_request_id))[-2] + request_id_counter = AtomicCounter(start_value=first_request_id - 1) + + def _build_request_id(): + return build_request_id( + client_id=sessions_database._nth_client_id, + channel_id=sessions_database._channel_id, + nth_request=request_id_counter.increment(), + attempt=1, + ) - assert got_span_names == expected_span_names + expected_span_properties = [] + + # [A] Batch spans + if not multiplexed_enabled: + expected_span_properties.append( + { + "name": "CloudSpanner.GetSession", + "attributes": _make_attributes( + db_name, + session_found=True, + x_goog_spanner_request_id=_build_request_id(), + ), + } + ) - sampling_req_id = parse_request_id( - span_list[0].attributes["x_goog_spanner_request_id"] + expected_span_properties.append( + { + "name": "CloudSpanner.Batch.commit", + "attributes": _make_attributes( + db_name, + num_mutations=1, + x_goog_spanner_request_id=_build_request_id(), + ), + } ) - nth_req0 = sampling_req_id[-2] - - db = sessions_database - # Span 0: batch operation (always uses GetSession from pool) - assert_span_attributes( - ot_exporter, - "CloudSpanner.GetSession", - attributes=_make_attributes( - db_name, - session_found=True, - x_goog_spanner_request_id=f"1.{REQ_RAND_PROCESS_ID}.{db._nth_client_id}.{db._channel_id}.{nth_req0 + 0}.1", - ), - span=span_list[0], + # [B] Transaction spans + expected_span_properties.append( + { + "name": "CloudSpanner.GetSession", + "attributes": _make_attributes( + db_name, + session_found=True, + x_goog_spanner_request_id=_build_request_id(), + ), + } ) - # Span 1: batch commit - assert_span_attributes( - ot_exporter, - "CloudSpanner.Batch.commit", - attributes=_make_attributes( - db_name, - num_mutations=1, - x_goog_spanner_request_id=f"1.{REQ_RAND_PROCESS_ID}.{db._nth_client_id}.{db._channel_id}.{nth_req0 + 1}.1", - ), - span=span_list[1], + expected_span_properties.append( + { + "name": "CloudSpanner.Transaction.read", + "attributes": _make_attributes( + db_name, + table_id=sd.TABLE, + columns=sd.COLUMNS, + x_goog_spanner_request_id=_build_request_id(), + ), + } ) - # Span 2: GetSession for transaction - assert_span_attributes( - ot_exporter, - "CloudSpanner.GetSession", - attributes=_make_attributes( - db_name, - session_found=True, - x_goog_spanner_request_id=f"1.{REQ_RAND_PROCESS_ID}.{db._nth_client_id}.{db._channel_id}.{nth_req0 + 2}.1", - ), - span=span_list[2], + expected_span_properties.append( + { + "name": "CloudSpanner.Transaction.read", + "attributes": _make_attributes( + db_name, + table_id=sd.TABLE, + columns=sd.COLUMNS, + x_goog_spanner_request_id=_build_request_id(), + ), + } ) - # Span 3: First transaction read - assert_span_attributes( - ot_exporter, - "CloudSpanner.Transaction.read", - attributes=_make_attributes( - db_name, - table_id=sd.TABLE, - columns=sd.COLUMNS, - x_goog_spanner_request_id=f"1.{REQ_RAND_PROCESS_ID}.{db._nth_client_id}.{db._channel_id}.{nth_req0 + 3}.1", - ), - span=span_list[3], + expected_span_properties.append( + { + "name": "CloudSpanner.Transaction.rollback", + "attributes": _make_attributes( + db_name, x_goog_spanner_request_id=_build_request_id() + ), + } ) - # Span 4: Second transaction read - assert_span_attributes( - ot_exporter, - "CloudSpanner.Transaction.read", - attributes=_make_attributes( - db_name, - table_id=sd.TABLE, - columns=sd.COLUMNS, - x_goog_spanner_request_id=f"1.{REQ_RAND_PROCESS_ID}.{db._nth_client_id}.{db._channel_id}.{nth_req0 + 4}.1", - ), - span=span_list[4], + expected_span_properties.append( + { + "name": "CloudSpanner.Session.run_in_transaction", + "status": ot_helpers.StatusCode.ERROR, + "attributes": _make_attributes(db_name), + } ) - # Span 5: Transaction rollback - assert_span_attributes( - ot_exporter, - "CloudSpanner.Transaction.rollback", - attributes=_make_attributes( - db_name, - x_goog_spanner_request_id=f"1.{REQ_RAND_PROCESS_ID}.{db._nth_client_id}.{db._channel_id}.{nth_req0 + 5}.1", - ), - span=span_list[5], + expected_span_properties.append( + { + "name": "CloudSpanner.Database.run_in_transaction", + "status": ot_helpers.StatusCode.ERROR, + "attributes": _make_attributes(db_name), + } ) - # Span 6: Session.run_in_transaction (ERROR status due to intentional exception) - assert_span_attributes( - ot_exporter, - "CloudSpanner.Session.run_in_transaction", - status=ot_helpers.StatusCode.ERROR, - attributes=_make_attributes(db_name), - span=span_list[6], - ) + # [C] Snapshot spans + if not multiplexed_enabled: + expected_span_properties.append( + { + "name": "CloudSpanner.GetSession", + "attributes": _make_attributes( + db_name, + session_found=True, + x_goog_spanner_request_id=_build_request_id(), + ), + } + ) - # Span 7: Database.run_in_transaction (ERROR status due to intentional exception) - assert_span_attributes( - ot_exporter, - "CloudSpanner.Database.run_in_transaction", - status=ot_helpers.StatusCode.ERROR, - attributes=_make_attributes(db_name), - span=span_list[7], + expected_span_properties.append( + { + "name": "CloudSpanner.Snapshot.read", + "attributes": _make_attributes( + db_name, + table_id=sd.TABLE, + columns=sd.COLUMNS, + x_goog_spanner_request_id=_build_request_id(), + ), + } ) - # Check if we have a snapshot session creation span - snapshot_read_span_index = -1 - snapshot_session_span_index = -1 - - for i, span in enumerate(span_list): - if span.name == "CloudSpanner.Snapshot.read": - snapshot_read_span_index = i - break + # Verify spans. + assert len(span_list) == len(expected_span_properties) - # Look for session creation span before the snapshot read - if snapshot_read_span_index > 8: - snapshot_session_span_index = snapshot_read_span_index - 1 - - if ( - multiplexed_enabled - and span_list[snapshot_session_span_index].name - == "CloudSpanner.CreateMultiplexedSession" - ): - expected_snapshot_span_name = "CloudSpanner.CreateMultiplexedSession" - snapshot_session_attributes = _make_attributes( - db_name, - x_goog_spanner_request_id=span_list[ - snapshot_session_span_index - ].attributes["x_goog_spanner_request_id"], - ) - assert_span_attributes( - ot_exporter, - expected_snapshot_span_name, - attributes=snapshot_session_attributes, - span=span_list[snapshot_session_span_index], - ) - elif ( - not multiplexed_enabled - and span_list[snapshot_session_span_index].name - == "CloudSpanner.GetSession" - ): - expected_snapshot_span_name = "CloudSpanner.GetSession" - snapshot_session_attributes = _make_attributes( - db_name, - session_found=True, - x_goog_spanner_request_id=span_list[ - snapshot_session_span_index - ].attributes["x_goog_spanner_request_id"], - ) - assert_span_attributes( - ot_exporter, - expected_snapshot_span_name, - attributes=snapshot_session_attributes, - span=span_list[snapshot_session_span_index], - ) - - # Snapshot read span - assert_span_attributes( - ot_exporter, - "CloudSpanner.Snapshot.read", - attributes=_make_attributes( - db_name, - table_id=sd.TABLE, - columns=sd.COLUMNS, - x_goog_spanner_request_id=span_list[ - snapshot_read_span_index - ].attributes["x_goog_spanner_request_id"], - ), - span=span_list[snapshot_read_span_index], - ) + for i, expected in enumerate(expected_span_properties): + expected = expected_span_properties[i] + assert_span_attributes( + span=span_list[i], + name=expected["name"], + status=expected.get("status", ot_helpers.StatusCode.OK), + attributes=expected["attributes"], + ot_exporter=ot_exporter, + ) @_helpers.retry_maybe_conflict diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index 579294218c..44cc1ab530 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -43,7 +43,10 @@ AtomicCounter, ) from google.cloud.spanner_v1.param_types import INT64 -from google.cloud.spanner_v1.request_id_header import REQ_RAND_PROCESS_ID +from google.cloud.spanner_v1.request_id_header import ( + REQ_RAND_PROCESS_ID, + build_request_id, +) from google.api_core.retry import Retry TABLE_NAME = "citizens" @@ -824,7 +827,7 @@ def _execute_begin(self, derived: _Derived, attempts: int = 1): expected_span_name = "CloudSpanner._Derived.begin" self.assertSpanAttributes( name=expected_span_name, - attributes=_build_span_attributes(database, attempts), + attributes=_build_span_attributes(database, attempt=attempts), ) def test_read_other_error(self): @@ -1994,12 +1997,10 @@ def __next__(self): next = __next__ -def _build_span_attributes( - database: Database, attempt: int = 1, **extra_attributes -) -> Mapping[str, str]: +def _build_span_attributes(database: Database, attempt: int = 1) -> Mapping[str, str]: """Builds the attributes for spans using the given database and extra attributes.""" - attributes = enrich_with_otel_scope( + return enrich_with_otel_scope( { "db.type": "spanner", "db.url": "spanner.googleapis.com", @@ -2012,14 +2013,14 @@ def _build_span_attributes( } ) - if extra_attributes: - attributes.update(extra_attributes) - - return attributes - -def _build_request_id(database: Database, attempt: int = 1) -> str: +def _build_request_id(database: Database, attempt: int) -> str: """Builds a request ID for an Spanner Client API request with the given database and attempt number.""" client = database._instance._client - return f"1.{REQ_RAND_PROCESS_ID}.{client._nth_client_id}.{database._channel_id}.{client._nth_request.value}.{attempt}" + return build_request_id( + client_id=client._nth_client_id, + channel_id=database._channel_id, + nth_request=client._nth_request.value, + attempt=attempt, + ) diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index 1d28629062..99a204f486 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -26,7 +26,10 @@ _metadata_with_request_id, ) from google.cloud.spanner_v1.database import Database -from google.cloud.spanner_v1.request_id_header import REQ_RAND_PROCESS_ID +from google.cloud.spanner_v1.request_id_header import ( + REQ_RAND_PROCESS_ID, + build_request_id, +) from tests._builders import build_transaction from tests._helpers import ( @@ -949,7 +952,12 @@ def _build_request_id(database: Database, attempt: int = 1) -> str: """Builds a request ID for an Spanner Client API request with the given database and attempt number.""" client = database._instance._client - return f"1.{REQ_RAND_PROCESS_ID}.{client._nth_client_id}.{database._channel_id}.{client._nth_request.value}.{attempt}" + return build_request_id( + client_id=client._nth_client_id, + channel_id=database._channel_id, + nth_request=client._nth_request.value, + attempt=attempt, + ) class _Client(object): From da226c1f210a5101cc5882209ef3558a111b942b Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Wed, 4 Jun 2025 15:15:54 -0700 Subject: [PATCH 32/41] feat: Multiplexed sessions - Add test for log when new session created by maintenance thread. Signed-off-by: Taylor Curran --- tests/unit/test_database_session_manager.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/unit/test_database_session_manager.py b/tests/unit/test_database_session_manager.py index 3fa1252837..89dd21012a 100644 --- a/tests/unit/test_database_session_manager.py +++ b/tests/unit/test_database_session_manager.py @@ -171,6 +171,10 @@ def test_multiplexed_maintenance(self): self.assertTrue(session_2.is_multiplexed) self.assertNotEqual(session_1, session_2) + # Verify logger calls. + info = manager._database.logger.info + info.assert_called_with("Created multiplexed session.") + def test_multiplexed_maintenance_terminates_disabled(self): manager = self._manager self._enable_multiplexed_sessions() From c6c130e983f8bd4c57c59d4fe9aa76d0c67ea225 Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Wed, 4 Jun 2025 16:02:00 -0700 Subject: [PATCH 33/41] feat: Multiplexed sessions - Add additional multiplexed unit tests for `_SnapshotBase`. Signed-off-by: Taylor Curran --- tests/unit/test_snapshot.py | 85 +++++++++++++++++++++++++++++++------ 1 file changed, 73 insertions(+), 12 deletions(-) diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index 44cc1ab530..54955f735a 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -96,8 +96,6 @@ }, } -TRANSACTION_ID = b"transaction-id" - PRECOMMIT_TOKEN_1 = build_precommit_token_pb(precommit_token=b"1", seq_num=1) PRECOMMIT_TOKEN_2 = build_precommit_token_pb(precommit_token=b"2", seq_num=2) @@ -866,6 +864,7 @@ def _execute_read( request_options=None, directed_read_options=None, directed_read_options_at_client_level=None, + use_multiplexed=False, ): """Helper for testing _SnapshotBase.read(). Executes method and verifies transaction state, begin transaction API call, and span attributes and events. @@ -891,14 +890,33 @@ def _execute_read( StructType.Field(name="age", type_=Type(code=TypeCode.INT64)), ] ) - metadata_pb = ResultSetMetadata(row_type=struct_type_pb) + + # If the transaction had not already begun, the first result + # set will include metadata with information about the transaction. + transaction_pb = build_transaction_pb(id=TXN_ID) if first else None + metadata_pb = ResultSetMetadata( + row_type=struct_type_pb, + transaction=transaction_pb, + ) + stats_pb = ResultSetStats( query_stats=Struct(fields={"rows_returned": _make_value_pb(2)}) ) - result_sets = [ - PartialResultSet(metadata=metadata_pb), - PartialResultSet(stats=stats_pb), - ] + + # Precommit tokens will be included in the result sets if the transaction is on + # a multiplexed session. Precommit tokens may be returned out of order. + partial_result_set_1_args = {"metadata": metadata_pb} + if use_multiplexed: + partial_result_set_1_args["precommit_token"] = PRECOMMIT_TOKEN_2 + partial_result_set_1 = PartialResultSet(**partial_result_set_1_args) + + partial_result_set_2_args = {"stats": stats_pb} + if use_multiplexed: + partial_result_set_2_args["precommit_token"] = PRECOMMIT_TOKEN_1 + partial_result_set_2 = PartialResultSet(**partial_result_set_2_args) + + result_sets = [partial_result_set_1, partial_result_set_2] + for i in range(len(result_sets)): result_sets[i].values.extend(VALUE_PBS[i]) KEYS = [["bharney@example.com"], ["phred@example.com"]] @@ -908,6 +926,7 @@ def _execute_read( database = _Database( directed_read_options=directed_read_options_at_client_level ) + api = database.spanner_api = build_spanner_api() api.streaming_read.return_value = _MockIterator(*result_sets) session = _Session(database) @@ -1004,6 +1023,12 @@ def _execute_read( ), ) + if first: + self.assertEqual(derived._transaction_id, TXN_ID) + + if use_multiplexed: + self.assertEqual(derived._precommit_token, PRECOMMIT_TOKEN_2) + def test_read_wo_multi_use(self): self._execute_read(multi_use=False) @@ -1039,6 +1064,9 @@ def test_read_wo_multi_use_w_read_request_count_gt_0(self): with self.assertRaises(ValueError): self._execute_read(multi_use=False, count=1) + def test_read_w_multi_use_w_first(self): + self._execute_read(multi_use=True, first=True) + def test_read_w_multi_use_wo_first(self): self._execute_read(multi_use=True, first=False) @@ -1080,6 +1108,9 @@ def test_read_w_directed_read_options_override(self): directed_read_options_at_client_level=DIRECTED_READ_OPTIONS_FOR_CLIENT, ) + def test_read_w_precommit_tokens(self): + self._execute_read(multi_use=True, use_multiplexed=True) + def test_execute_sql_other_error(self): database = _Database() database.spanner_api = build_spanner_api() @@ -1115,6 +1146,7 @@ def _execute_sql_helper( retry=gapic_v1.method.DEFAULT, directed_read_options=None, directed_read_options_at_client_level=None, + use_multiplexed=False, ): """Helper for testing _SnapshotBase.execute_sql(). Executes method and verifies transaction state, begin transaction API call, and span attributes and events. @@ -1144,14 +1176,34 @@ def _execute_sql_helper( StructType.Field(name="age", type_=Type(code=TypeCode.INT64)), ] ) - metadata_pb = ResultSetMetadata(row_type=struct_type_pb) + + # If the transaction has not already begun, the first result set will + # include metadata with information about the newly-begun transaction. + transaction_pb = build_transaction_pb(id=TXN_ID) if first else None + metadata_pb = ResultSetMetadata( + row_type=struct_type_pb, + transaction=transaction_pb, + ) + stats_pb = ResultSetStats( query_stats=Struct(fields={"rows_returned": _make_value_pb(2)}) ) - result_sets = [ - PartialResultSet(metadata=metadata_pb), - PartialResultSet(stats=stats_pb), - ] + + # Precommit tokens will be included in the result sets if the transaction is on + # a multiplexed session. Return the precommit tokens out of order to verify that + # the transaction tracks the one with the highest sequence number. + partial_result_set_1_args = {"metadata": metadata_pb} + if use_multiplexed: + partial_result_set_1_args["precommit_token"] = PRECOMMIT_TOKEN_2 + partial_result_set_1 = PartialResultSet(**partial_result_set_1_args) + + partial_result_set_2_args = {"stats": stats_pb} + if use_multiplexed: + partial_result_set_2_args["precommit_token"] = PRECOMMIT_TOKEN_1 + partial_result_set_2 = PartialResultSet(**partial_result_set_2_args) + + result_sets = [partial_result_set_1, partial_result_set_2] + for i in range(len(result_sets)): result_sets[i].values.extend(VALUE_PBS[i]) iterator = _MockIterator(*result_sets) @@ -1253,6 +1305,12 @@ def _execute_sql_helper( ), ) + if first: + self.assertEqual(derived._transaction_id, TXN_ID) + + if use_multiplexed: + self.assertEqual(derived._precommit_token, PRECOMMIT_TOKEN_2) + def test_execute_sql_wo_multi_use(self): self._execute_sql_helper(multi_use=False) @@ -1341,6 +1399,9 @@ def test_execute_sql_w_directed_read_options_override(self): directed_read_options_at_client_level=DIRECTED_READ_OPTIONS_FOR_CLIENT, ) + def test_execute_sql_w_precommit_tokens(self): + self._execute_sql_helper(multi_use=True, use_multiplexed=True) + def _partition_read_helper( self, multi_use, From 1018f4c10b63cfa715b182be3c45ef1d65b35414 Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Wed, 4 Jun 2025 16:11:49 -0700 Subject: [PATCH 34/41] feat: Multiplexed sessions - Cleanup `Transaction` by extracting some constants for next step. Signed-off-by: Taylor Curran --- tests/unit/test_transaction.py | 90 +++++++++++++++++++--------------- 1 file changed, 51 insertions(+), 39 deletions(-) diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index 99a204f486..1a8d40214f 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -15,7 +15,7 @@ import mock -from google.cloud.spanner_v1 import RequestOptions, CommitRequest +from google.cloud.spanner_v1 import RequestOptions, CommitRequest, Mutation, KeySet from google.cloud.spanner_v1 import DefaultTransactionOptions from google.cloud.spanner_v1 import Type from google.cloud.spanner_v1 import TypeCode @@ -25,12 +25,13 @@ AtomicCounter, _metadata_with_request_id, ) +from google.cloud.spanner_v1.batch import _make_write_pb from google.cloud.spanner_v1.database import Database from google.cloud.spanner_v1.request_id_header import ( REQ_RAND_PROCESS_ID, build_request_id, ) -from tests._builders import build_transaction +from tests._builders import build_transaction, build_precommit_token_pb from tests._helpers import ( HAS_OPENTELEMETRY_INSTALLED, @@ -40,12 +41,16 @@ enrich_with_otel_scope, ) +KEYS = [[0], [1], [2]] +KEYSET = KeySet(keys=KEYS) +KEYSET_PB = KEYSET._to_pb() + TABLE_NAME = "citizens" COLUMNS = ["email", "first_name", "last_name", "age"] -VALUES = [ - ["phred@exammple.com", "Phred", "Phlyntstone", 32], - ["bharney@example.com", "Bharney", "Rhubble", 31], -] +VALUE_1 = ["phred@exammple.com", "Phred", "Phlyntstone", 32] +VALUE_2 = ["bharney@example.com", "Bharney", "Rhubble", 31] +VALUES = [VALUE_1, VALUE_2] + DML_QUERY = """\ INSERT INTO citizens(first_name, last_name, age) VALUES ("Phred", "Phlyntstone", 32) @@ -57,6 +62,17 @@ PARAMS = {"age": 30} PARAM_TYPES = {"age": Type(code=TypeCode.INT64)} +TRANSACTION_ID = b"transaction-id" +TRANSACTION_TAG = "transaction-tag" + +PRECOMMIT_TOKEN_PB_0 = build_precommit_token_pb(precommit_token=b"0", seq_num=0) +PRECOMMIT_TOKEN_PB_1 = build_precommit_token_pb(precommit_token=b"1", seq_num=1) +PRECOMMIT_TOKEN_PB_2 = build_precommit_token_pb(precommit_token=b"2", seq_num=2) + +DELETE_MUTATION = Mutation(delete=Mutation.Delete(table=TABLE_NAME, key_set=KEYSET_PB)) +INSERT_MUTATION = Mutation(insert=_make_write_pb(TABLE_NAME, COLUMNS, VALUES)) +UPDATE_MUTATION = Mutation(update=_make_write_pb(TABLE_NAME, COLUMNS, VALUES)) + class TestTransaction(OpenTelemetryBase): PROJECT_ID = "project-id" @@ -66,8 +82,6 @@ class TestTransaction(OpenTelemetryBase): DATABASE_NAME = INSTANCE_NAME + "/databases/" + DATABASE_ID SESSION_ID = "session-id" SESSION_NAME = DATABASE_NAME + "/sessions/" + SESSION_ID - TRANSACTION_ID = b"DEADBEEF" - TRANSACTION_TAG = "transaction-tag" def _getTargetClass(self): from google.cloud.spanner_v1.transaction import Transaction @@ -103,9 +117,9 @@ def test_ctor_defaults(self): def test__make_txn_selector(self): session = _Session() transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID + transaction._transaction_id = TRANSACTION_ID selector = transaction._make_txn_selector() - self.assertEqual(selector.id, self.TRANSACTION_ID) + self.assertEqual(selector.id, TRANSACTION_ID) def test_begin_already_rolled_back(self): session = _Session() @@ -142,7 +156,7 @@ def test_rollback_not_begun(self): def test_rollback_already_committed(self): session = _Session() transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID + transaction._transaction_id = TRANSACTION_ID transaction.committed = object() with self.assertRaises(ValueError): transaction.rollback() @@ -152,7 +166,7 @@ def test_rollback_already_committed(self): def test_rollback_already_rolled_back(self): session = _Session() transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID + transaction._transaction_id = TRANSACTION_ID transaction.rolled_back = True with self.assertRaises(ValueError): transaction.rollback() @@ -165,7 +179,7 @@ def test_rollback_w_other_error(self): database.spanner_api.rollback.side_effect = RuntimeError("other error") session = _Session(database) transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID + transaction._transaction_id = TRANSACTION_ID transaction.insert(TABLE_NAME, COLUMNS, VALUES) with self.assertRaises(RuntimeError): @@ -190,7 +204,7 @@ def test_rollback_ok(self): api = database.spanner_api = _FauxSpannerAPI(_rollback_response=empty_pb) session = _Session(database) transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID + transaction._transaction_id = TRANSACTION_ID transaction.replace(TABLE_NAME, COLUMNS, VALUES) transaction.rollback() @@ -200,7 +214,7 @@ def test_rollback_ok(self): session_id, txn_id, metadata = api._rolled_back self.assertEqual(session_id, session.name) - self.assertEqual(txn_id, self.TRANSACTION_ID) + self.assertEqual(txn_id, TRANSACTION_ID) req_id = f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1" self.assertEqual( metadata, @@ -256,7 +270,7 @@ def test_commit_already_committed(self): database.spanner_api = self._make_spanner_api() session = _Session(database) transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID + transaction._transaction_id = TRANSACTION_ID transaction.committed = object() with self.assertRaises(ValueError): transaction.commit() @@ -288,7 +302,7 @@ def test_commit_already_rolled_back(self): database.spanner_api = self._make_spanner_api() session = _Session(database) transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID + transaction._transaction_id = TRANSACTION_ID transaction.rolled_back = True with self.assertRaises(ValueError): transaction.commit() @@ -321,7 +335,7 @@ def test_commit_w_other_error(self): database.spanner_api.commit.side_effect = RuntimeError() session = _Session(database) transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID + transaction._transaction_id = TRANSACTION_ID transaction.replace(TABLE_NAME, COLUMNS, VALUES) with self.assertRaises(RuntimeError): @@ -363,8 +377,8 @@ def _commit_helper( api = database.spanner_api = _FauxSpannerAPI(_commit_response=response) session = _Session(database) transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID - transaction.transaction_tag = self.TRANSACTION_TAG + transaction._transaction_id = TRANSACTION_ID + transaction.transaction_tag = TRANSACTION_TAG if mutate: transaction.delete(TABLE_NAME, keyset) @@ -388,21 +402,19 @@ def _commit_helper( ) = api._committed if request_options is None: - expected_request_options = RequestOptions( - transaction_tag=self.TRANSACTION_TAG - ) + expected_request_options = RequestOptions(transaction_tag=TRANSACTION_TAG) elif type(request_options) is dict: expected_request_options = RequestOptions(request_options) - expected_request_options.transaction_tag = self.TRANSACTION_TAG + expected_request_options.transaction_tag = TRANSACTION_TAG expected_request_options.request_tag = None else: expected_request_options = request_options - expected_request_options.transaction_tag = self.TRANSACTION_TAG + expected_request_options.transaction_tag = TRANSACTION_TAG expected_request_options.request_tag = None self.assertEqual(max_commit_delay_in, max_commit_delay) self.assertEqual(session_id, session.name) - self.assertEqual(txn_id, self.TRANSACTION_ID) + self.assertEqual(txn_id, TRANSACTION_ID) self.assertEqual(mutations, transaction._mutations) req_id = f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1" self.assertEqual( @@ -504,7 +516,7 @@ def test_execute_update_other_error(self): database.spanner_api.execute_sql.side_effect = RuntimeError() session = _Session(database) transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID + transaction._transaction_id = TRANSACTION_ID with self.assertRaises(RuntimeError): transaction.execute_update(DML_QUERY) @@ -536,8 +548,8 @@ def _execute_update_helper( api.execute_sql.return_value = ResultSet(stats=stats_pb) session = _Session(database) transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID - transaction.transaction_tag = self.TRANSACTION_TAG + transaction._transaction_id = TRANSACTION_ID + transaction.transaction_tag = TRANSACTION_TAG transaction._execute_sql_request_count = count if request_options is None: @@ -558,7 +570,7 @@ def _execute_update_helper( self.assertEqual(row_count, 1) - expected_transaction = TransactionSelector(id=self.TRANSACTION_ID) + expected_transaction = TransactionSelector(id=TRANSACTION_ID) expected_params = Struct( fields={key: _make_value_pb(value) for (key, value) in PARAMS.items()} ) @@ -569,7 +581,7 @@ def _execute_update_helper( expected_query_options, query_options ) expected_request_options = request_options - expected_request_options.transaction_tag = self.TRANSACTION_TAG + expected_request_options.transaction_tag = TRANSACTION_TAG expected_request = ExecuteSqlRequest( session=self.SESSION_NAME, @@ -653,7 +665,7 @@ def test_execute_update_error(self): database.spanner_api.execute_sql.side_effect = RuntimeError() session = _Session(database) transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID + transaction._transaction_id = TRANSACTION_ID with self.assertRaises(RuntimeError): transaction.execute_update(DML_QUERY) @@ -680,7 +692,7 @@ def test_batch_update_other_error(self): database.spanner_api.execute_batch_dml.side_effect = RuntimeError() session = _Session(database) transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID + transaction._transaction_id = TRANSACTION_ID with self.assertRaises(RuntimeError): transaction.batch_update(statements=[DML_QUERY]) @@ -736,8 +748,8 @@ def _batch_update_helper( api.execute_batch_dml.return_value = response session = _Session(database) transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID - transaction.transaction_tag = self.TRANSACTION_TAG + transaction._transaction_id = TRANSACTION_ID + transaction.transaction_tag = TRANSACTION_TAG transaction._execute_sql_request_count = count if request_options is None: @@ -755,7 +767,7 @@ def _batch_update_helper( self.assertEqual(status, expected_status) self.assertEqual(row_counts, expected_row_counts) - expected_transaction = TransactionSelector(id=self.TRANSACTION_ID) + expected_transaction = TransactionSelector(id=TRANSACTION_ID) expected_insert_params = Struct( fields={ key: _make_value_pb(value) for (key, value) in insert_params.items() @@ -771,7 +783,7 @@ def _batch_update_helper( ExecuteBatchDmlRequest.Statement(sql=delete_dml), ] expected_request_options = request_options - expected_request_options.transaction_tag = self.TRANSACTION_TAG + expected_request_options.transaction_tag = TRANSACTION_TAG expected_request = ExecuteBatchDmlRequest( session=self.SESSION_NAME, @@ -843,7 +855,7 @@ def test_batch_update_error(self): api.execute_batch_dml.side_effect = RuntimeError() session = _Session(database) transaction = self._make_one(session) - transaction._transaction_id = self.TRANSACTION_ID + transaction._transaction_id = TRANSACTION_ID insert_dml = "INSERT INTO table(pkey, desc) VALUES (%pkey, %desc)" insert_params = {"pkey": 12345, "desc": "DESCRIPTION"} @@ -905,7 +917,7 @@ def test_context_mgr_failure(self): empty_pb = Empty() from google.cloud.spanner_v1 import Transaction as TransactionPB - transaction_pb = TransactionPB(id=self.TRANSACTION_ID) + transaction_pb = TransactionPB(id=TRANSACTION_ID) database = _Database() api = database.spanner_api = _FauxSpannerAPI( _begin_transaction_response=transaction_pb, _rollback_response=empty_pb From b761e85998e3fef61eed6fafa86ee6f47594d06b Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Wed, 4 Jun 2025 21:28:53 -0700 Subject: [PATCH 35/41] feat: Multiplexed sessions - Add additional `Transaction` tests for new multiplexed behaviour. Signed-off-by: Taylor Curran --- tests/_builders.py | 17 +- tests/unit/test_transaction.py | 433 +++++++++++++++++++++++++-------- 2 files changed, 329 insertions(+), 121 deletions(-) diff --git a/tests/_builders.py b/tests/_builders.py index b934bc91b3..1521219dea 100644 --- a/tests/_builders.py +++ b/tests/_builders.py @@ -33,7 +33,6 @@ ) from google.cloud._helpers import _datetime_to_pb_timestamp -from tests._helpers import HAS_OPENTELEMETRY_INSTALLED, get_test_ot_exporter # Default values used to populate required or expected attributes. # Tests should not depend on them: if a test requires a specific @@ -181,9 +180,8 @@ def build_transaction(session=None) -> Transaction: # Ensure session exists. if session.session_id is None: - session.create() + session._session_id = _SESSION_ID - _clear_spans() return session.transaction() @@ -218,16 +216,3 @@ def build_spanner_api() -> SpannerClient: api.create_session.return_value = build_session_pb() return api - - -# Helper functions -# ---------------- - - -def _clear_spans() -> None: - """Clears the spans collected by the OpenTelemetry exporter. - This ensures that spans generated while building test objects - do not interfere with the tests.""" - - if HAS_OPENTELEMETRY_INSTALLED: - get_test_ot_exporter().clear() diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index 1a8d40214f..b1269884bb 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -12,10 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Mapping +from datetime import timedelta import mock -from google.cloud.spanner_v1 import RequestOptions, CommitRequest, Mutation, KeySet +from google.cloud.spanner_v1 import ( + RequestOptions, + CommitRequest, + Mutation, + KeySet, + BeginTransactionRequest, + TransactionOptions, + ResultSetMetadata, +) from google.cloud.spanner_v1 import DefaultTransactionOptions from google.cloud.spanner_v1 import Type from google.cloud.spanner_v1 import TypeCode @@ -31,7 +40,13 @@ REQ_RAND_PROCESS_ID, build_request_id, ) -from tests._builders import build_transaction, build_precommit_token_pb +from tests._builders import ( + build_transaction, + build_precommit_token_pb, + build_session, + build_commit_response_pb, + build_transaction_pb, +) from tests._helpers import ( HAS_OPENTELEMETRY_INSTALLED, @@ -356,50 +371,107 @@ def test_commit_w_other_error(self): def _commit_helper( self, - mutate=True, + mutations=None, return_commit_stats=False, request_options=None, max_commit_delay_in=None, + retry_for_precommit_token=None, + is_multiplexed=False, + expected_begin_mutation=None, ): - import datetime + from google.cloud.spanner_v1 import CommitRequest - from google.cloud.spanner_v1 import CommitResponse - from google.cloud.spanner_v1.keyset import KeySet - from google.cloud._helpers import UTC + # [A] Build transaction + # --------------------- + + session = build_session(is_multiplexed=is_multiplexed) + transaction = build_transaction(session=session) + + database = session._database + api = database.spanner_api - now = datetime.datetime.utcnow().replace(tzinfo=UTC) - keys = [[0], [1], [2]] - keyset = KeySet(keys=keys) - response = CommitResponse(commit_timestamp=now) - if return_commit_stats: - response.commit_stats.mutation_count = 4 - database = _Database() - api = database.spanner_api = _FauxSpannerAPI(_commit_response=response) - session = _Session(database) - transaction = self._make_one(session) - transaction._transaction_id = TRANSACTION_ID transaction.transaction_tag = TRANSACTION_TAG - if mutate: - transaction.delete(TABLE_NAME, keyset) + if mutations is not None: + transaction._mutations = mutations + + # [B] Build responses + # ------------------- + + # Mock begin API call. + begin_precommit_token_pb = PRECOMMIT_TOKEN_PB_0 + begin_transaction = api.begin_transaction + begin_transaction.return_value = build_transaction_pb( + id=TRANSACTION_ID, precommit_token=begin_precommit_token_pb + ) + + # Mock commit API call. + retry_precommit_token = PRECOMMIT_TOKEN_PB_1 + commit_response_pb = build_commit_response_pb( + precommit_token=retry_precommit_token if retry_for_precommit_token else None + ) + if return_commit_stats: + commit_response_pb.commit_stats.mutation_count = 4 + + commit = api.commit + commit.return_value = commit_response_pb + + # [C] Begin transaction, add mutations, and execute commit + # -------------------------------------------------------- + + # Transaction must be begun unless it is mutations-only. + if mutations is None: + transaction._transaction_id = TRANSACTION_ID - transaction.commit( + commit_timestamp = transaction.commit( return_commit_stats=return_commit_stats, request_options=request_options, max_commit_delay=max_commit_delay_in, ) - self.assertEqual(transaction.committed, now) + # [D] Verify results + # ------------------ + + # Verify transaction state. + self.assertEqual(transaction.committed, commit_timestamp) self.assertIsNone(session._transaction) - ( - session_id, - mutations, - txn_id, - actual_request_options, - max_commit_delay, - metadata, - ) = api._committed + if return_commit_stats: + self.assertEqual(transaction.commit_stats.mutation_count, 4) + + nth_request_counter = AtomicCounter() + base_metadata = [ + ("google-cloud-resource-prefix", database.name), + ("x-goog-spanner-route-to-leader", "true"), + ] + + # Verify begin API call. + if mutations is not None: + self.assertEqual(transaction._transaction_id, TRANSACTION_ID) + + expected_begin_transaction_request = BeginTransactionRequest( + session=session.name, + options=TransactionOptions(read_write=TransactionOptions.ReadWrite()), + mutation_key=expected_begin_mutation, + ) + + expected_begin_metadata = base_metadata.copy() + expected_begin_metadata.append( + ( + "x-goog-spanner-request-id", + self._build_request_id( + database, nth_request=nth_request_counter.increment() + ), + ) + ) + + begin_transaction.assert_called_once_with( + request=expected_begin_transaction_request, + metadata=expected_begin_metadata, + ) + + # Verify commit API call(s). + self.assertEqual(commit.call_count, 1 if not retry_for_precommit_token else 2) if request_options is None: expected_request_options = RequestOptions(transaction_tag=TRANSACTION_TAG) @@ -412,79 +484,135 @@ def _commit_helper( expected_request_options.transaction_tag = TRANSACTION_TAG expected_request_options.request_tag = None - self.assertEqual(max_commit_delay_in, max_commit_delay) - self.assertEqual(session_id, session.name) - self.assertEqual(txn_id, TRANSACTION_ID) - self.assertEqual(mutations, transaction._mutations) - req_id = f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.1.1" - self.assertEqual( - metadata, - [ - ("google-cloud-resource-prefix", database.name), - ("x-goog-spanner-route-to-leader", "true"), - ( - "x-goog-spanner-request-id", - req_id, - ), - ], - ) - self.assertEqual(actual_request_options, expected_request_options) + common_expected_commit_response_args = { + "session": session.name, + "transaction_id": TRANSACTION_ID, + "return_commit_stats": return_commit_stats, + "max_commit_delay": max_commit_delay_in, + "request_options": expected_request_options, + } - if return_commit_stats: - self.assertEqual(transaction.commit_stats.mutation_count, 4) + expected_commit_request = CommitRequest( + mutations=transaction._mutations, + precommit_token=transaction._precommit_token, + **common_expected_commit_response_args, + ) - self.assertSpanAttributes( - "CloudSpanner.Transaction.commit", - attributes=self._build_span_attributes( - database, - num_mutations=len(transaction._mutations), - x_goog_spanner_request_id=req_id, - ), + expected_commit_metadata = base_metadata.copy() + expected_commit_metadata.append( + ( + "x-goog-spanner-request-id", + self._build_request_id( + database, nth_request=nth_request_counter.increment() + ), + ) + ) + commit.assert_any_call( + request=expected_commit_request, + metadata=expected_commit_metadata, ) + if retry_for_precommit_token: + expected_retry_request = CommitRequest( + precommit_token=retry_precommit_token, + **common_expected_commit_response_args, + ) + expected_retry_metadata = base_metadata.copy() + expected_retry_metadata.append( + ( + "x-goog-spanner-request-id", + self._build_request_id( + database, nth_request=nth_request_counter.increment() + ), + ) + ) + commit.assert_any_call( + request=expected_retry_request, + metadata=base_metadata, + ) + if not HAS_OPENTELEMETRY_INSTALLED: return - span_list = self.get_finished_spans() - got_span_names = [span.name for span in span_list] - want_span_names = ["CloudSpanner.Transaction.commit"] - assert got_span_names == want_span_names + # Verify span names. + expected_names = ["CloudSpanner.Transaction.commit"] + if mutations is not None: + expected_names.append("CloudSpanner.Transaction.begin") - got_span_events_statuses = self.finished_spans_events_statuses() - want_span_events_statuses = [("Starting Commit", {}), ("Commit Done", {})] - assert got_span_events_statuses == want_span_events_statuses + actual_names = [span.name for span in self.get_finished_spans()] + self.assertEqual(actual_names, expected_names) + + # Verify span events statuses. + expected_statuses = [("Starting Commit", {})] + if retry_for_precommit_token: + expected_statuses.append( + ("Transaction Commit Attempt Failed. Retrying", {}) + ) + expected_statuses.append(("Commit Done", {})) + + actual_statuses = self.finished_spans_events_statuses() + self.assertEqual(actual_statuses, expected_statuses) + + def test_commit_mutations_only_not_multiplexed(self): + self._commit_helper(mutations=[DELETE_MUTATION], is_multiplexed=False) + + def test_commit_mutations_only_multiplexed_w_non_insert_mutation(self): + self._commit_helper( + mutations=[DELETE_MUTATION], + is_multiplexed=True, + expected_begin_mutation=DELETE_MUTATION, + ) + + def test_commit_mutations_only_multiplexed_w_insert_mutation(self): + self._commit_helper( + mutations=[INSERT_MUTATION], + is_multiplexed=True, + expected_begin_mutation=INSERT_MUTATION, + ) - def test_commit_no_mutations(self): - self._commit_helper(mutate=False) + def test_commit_mutations_only_multiplexed_w_non_insert_and_insert_mutations(self): + self._commit_helper( + mutations=[INSERT_MUTATION, DELETE_MUTATION], + is_multiplexed=True, + expected_begin_mutation=DELETE_MUTATION, + ) - def test_commit_w_mutations(self): - self._commit_helper(mutate=True) + def test_commit_mutations_only_multiplexed_w_multiple_insert_mutations(self): + insert_1 = Mutation(insert=_make_write_pb(TABLE_NAME, COLUMNS, [VALUE_1])) + insert_2 = Mutation( + insert=_make_write_pb(TABLE_NAME, COLUMNS, [VALUE_1, VALUE_2]) + ) + + self._commit_helper( + mutations=[insert_1, insert_2], + is_multiplexed=True, + expected_begin_mutation=insert_2, + ) + + def test_commit_mutations_only_multiplexed_w_multiple_non_insert_mutations(self): + mutations = [UPDATE_MUTATION, DELETE_MUTATION] + self._commit_helper( + mutations=mutations, + is_multiplexed=True, + expected_begin_mutation=mutations[0], + ) def test_commit_w_return_commit_stats(self): self._commit_helper(return_commit_stats=True) def test_commit_w_max_commit_delay(self): - import datetime - - self._commit_helper(max_commit_delay_in=datetime.timedelta(milliseconds=100)) + self._commit_helper(max_commit_delay_in=timedelta(milliseconds=100)) def test_commit_w_request_tag_success(self): - request_options = RequestOptions( - request_tag="tag-1", - ) + request_options = RequestOptions(request_tag="tag-1") self._commit_helper(request_options=request_options) def test_commit_w_transaction_tag_ignored_success(self): - request_options = RequestOptions( - transaction_tag="tag-1-1", - ) + request_options = RequestOptions(transaction_tag="tag-1-1") self._commit_helper(request_options=request_options) def test_commit_w_request_and_transaction_tag_success(self): - request_options = RequestOptions( - request_tag="tag-1", - transaction_tag="tag-1-1", - ) + request_options = RequestOptions(request_tag="tag-1", transaction_tag="tag-1-1") self._commit_helper(request_options=request_options) def test_commit_w_request_and_transaction_tag_dictionary_success(self): @@ -496,6 +624,22 @@ def test_commit_w_incorrect_tag_dictionary_error(self): with self.assertRaises(ValueError): self._commit_helper(request_options=request_options) + def test_commit_w_retry_for_precommit_token(self): + self._commit_helper(retry_for_precommit_token=True) + + def test_commit_w_retry_for_precommit_token_then_error(self): + transaction = build_transaction() + + commit = transaction._session._database.spanner_api.commit + commit.side_effect = [ + build_commit_response_pb(precommit_token=PRECOMMIT_TOKEN_PB_0), + RuntimeError(), + ] + + transaction.begin() + with self.assertRaises(RuntimeError): + transaction.commit() + def test__make_params_pb_w_params_w_param_types(self): from google.protobuf.struct_pb2 import Struct from google.cloud.spanner_v1._helpers import _make_value_pb @@ -528,6 +672,8 @@ def _execute_update_helper( request_options=None, retry=gapic_v1.method.DEFAULT, timeout=gapic_v1.method.DEFAULT, + begin=True, + use_multiplexed=False, ): from google.protobuf.struct_pb2 import Struct from google.cloud.spanner_v1 import ( @@ -542,16 +688,30 @@ def _execute_update_helper( from google.cloud.spanner_v1 import ExecuteSqlRequest MODE = 2 # PROFILE - stats_pb = ResultSetStats(row_count_exact=1) database = _Database() api = database.spanner_api = self._make_spanner_api() - api.execute_sql.return_value = ResultSet(stats=stats_pb) + + # If the transaction had not already begun, the first result set will include + # metadata with information about the transaction. Precommit tokens will be + # included in the result sets if the transaction is on a multiplexed session. + transaction_pb = None if begin else build_transaction_pb(id=TRANSACTION_ID) + metadata_pb = ResultSetMetadata(transaction=transaction_pb) + precommit_token_pb = PRECOMMIT_TOKEN_PB_0 if use_multiplexed else None + + api.execute_sql.return_value = ResultSet( + stats=ResultSetStats(row_count_exact=1), + metadata=metadata_pb, + precommit_token=precommit_token_pb, + ) + session = _Session(database) transaction = self._make_one(session) - transaction._transaction_id = TRANSACTION_ID transaction.transaction_tag = TRANSACTION_TAG transaction._execute_sql_request_count = count + if begin: + transaction._transaction_id = TRANSACTION_ID + if request_options is None: request_options = RequestOptions() elif type(request_options) is dict: @@ -570,7 +730,14 @@ def _execute_update_helper( self.assertEqual(row_count, 1) - expected_transaction = TransactionSelector(id=TRANSACTION_ID) + expected_transaction = ( + TransactionSelector(id=transaction._transaction_id) + if begin + else TransactionSelector( + begin=TransactionOptions(read_write=TransactionOptions.ReadWrite()) + ) + ) + expected_params = Struct( fields={key: _make_value_pb(value) for (key, value) in PARAMS.items()} ) @@ -608,7 +775,6 @@ def _execute_update_helper( ], ) - self.assertEqual(transaction._execute_sql_request_count, count + 1) self.assertSpanAttributes( "CloudSpanner.Transaction.execute_update", attributes=self._build_span_attributes( @@ -616,6 +782,12 @@ def _execute_update_helper( ), ) + self.assertEqual(transaction._transaction_id, TRANSACTION_ID) + self.assertEqual(transaction._execute_sql_request_count, count + 1) + + if use_multiplexed: + self.assertEqual(transaction._precommit_token, PRECOMMIT_TOKEN_PB_0) + def test_execute_update_new_transaction(self): self._execute_update_helper() @@ -679,6 +851,12 @@ def test_execute_update_w_query_options(self): query_options=ExecuteSqlRequest.QueryOptions(optimizer_version="3") ) + def test_execute_update_wo_begin(self): + self._execute_update_helper(begin=False) + + def test_execute_update_w_precommit_token(self): + self._execute_update_helper(use_multiplexed=True) + def test_execute_update_w_request_options(self): self._execute_update_helper( request_options=RequestOptions( @@ -704,6 +882,8 @@ def _batch_update_helper( request_options=None, retry=gapic_v1.method.DEFAULT, timeout=gapic_v1.method.DEFAULT, + begin=True, + use_multiplexed=False, ): from google.rpc.status_pb2 import Status from google.protobuf.struct_pb2 import Struct @@ -727,31 +907,51 @@ def _batch_update_helper( delete_dml, ] - stats_pbs = [ - ResultSetStats(row_count_exact=1), - ResultSetStats(row_count_exact=2), - ResultSetStats(row_count_exact=3), + # These precommit tokens are intentionally returned with sequence numbers out + # of order to test that the transaction saves the precommit token with the + # highest sequence number. + precommit_tokens = [ + PRECOMMIT_TOKEN_PB_2, + PRECOMMIT_TOKEN_PB_0, + PRECOMMIT_TOKEN_PB_1, ] - if error_after is not None: - stats_pbs = stats_pbs[:error_after] - expected_status = Status(code=400) - else: - expected_status = Status(code=200) - expected_row_counts = [stats.row_count_exact for stats in stats_pbs] - response = ExecuteBatchDmlResponse( - status=expected_status, - result_sets=[ResultSet(stats=stats_pb) for stats_pb in stats_pbs], - ) + expected_status = Status(code=200) if error_after is None else Status(code=400) + + result_sets = [] + for i in range(len(precommit_tokens)): + if error_after is not None and i == error_after: + break + + result_set_args = {"stats": {"row_count_exact": i}} + + # If the transaction had not already begun, the first result + # set will include metadata with information about the transaction. + if not begin and i == 0: + result_set_args["metadata"] = {"transaction": {"id": TRANSACTION_ID}} + + # Precommit tokens will be included in the result + # sets if the transaction is on a multiplexed session. + if use_multiplexed: + result_set_args["precommit_token"] = precommit_tokens[i] + + result_sets.append(ResultSet(**result_set_args)) + database = _Database() api = database.spanner_api = self._make_spanner_api() - api.execute_batch_dml.return_value = response + api.execute_batch_dml.return_value = ExecuteBatchDmlResponse( + status=expected_status, + result_sets=result_sets, + ) + session = _Session(database) transaction = self._make_one(session) - transaction._transaction_id = TRANSACTION_ID transaction.transaction_tag = TRANSACTION_TAG transaction._execute_sql_request_count = count + if begin: + transaction._transaction_id = TRANSACTION_ID + if request_options is None: request_options = RequestOptions() elif type(request_options) is dict: @@ -765,9 +965,18 @@ def _batch_update_helper( ) self.assertEqual(status, expected_status) - self.assertEqual(row_counts, expected_row_counts) + self.assertEqual( + row_counts, [result_set.stats.row_count_exact for result_set in result_sets] + ) + + expected_transaction = ( + TransactionSelector(id=transaction._transaction_id) + if begin + else TransactionSelector( + begin=TransactionOptions(read_write=TransactionOptions.ReadWrite()) + ) + ) - expected_transaction = TransactionSelector(id=TRANSACTION_ID) expected_insert_params = Struct( fields={ key: _make_value_pb(value) for (key, value) in insert_params.items() @@ -807,6 +1016,13 @@ def _batch_update_helper( ) self.assertEqual(transaction._execute_sql_request_count, count + 1) + self.assertEqual(transaction._transaction_id, TRANSACTION_ID) + + if use_multiplexed: + self.assertEqual(transaction._precommit_token, PRECOMMIT_TOKEN_PB_2) + + def test_batch_update_wo_begin(self): + self._batch_update_helper(begin=False) def test_batch_update_wo_errors(self): self._batch_update_helper( @@ -886,6 +1102,9 @@ def test_batch_update_w_retry_param(self): def test_batch_update_w_timeout_and_retry_params(self): self._batch_update_helper(retry=gapic_v1.method.DEFAULT, timeout=2.0) + def test_batch_update_w_precommit_token(self): + self._batch_update_helper(use_multiplexed=True) + def test_context_mgr_success(self): transaction = build_transaction() session = transaction._session @@ -960,14 +1179,18 @@ def _build_span_attributes( return attributes @staticmethod - def _build_request_id(database: Database, attempt: int = 1) -> str: + def _build_request_id( + database: Database, nth_request: int = None, attempt: int = 1 + ) -> str: """Builds a request ID for an Spanner Client API request with the given database and attempt number.""" client = database._instance._client + nth_request = nth_request or client._nth_request.value + return build_request_id( client_id=client._nth_client_id, channel_id=database._channel_id, - nth_request=client._nth_request.value, + nth_request=nth_request, attempt=attempt, ) From 665547bf8fb4f7fba9a4e8f1f5deeb04500b1c0a Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Thu, 5 Jun 2025 07:50:55 -0700 Subject: [PATCH 36/41] feat: Multiplexed sessions - Fix linter Signed-off-by: Taylor Curran --- tests/unit/test_transaction.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index b1269884bb..d9448ef5ba 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -889,7 +889,6 @@ def _batch_update_helper( from google.protobuf.struct_pb2 import Struct from google.cloud.spanner_v1 import param_types from google.cloud.spanner_v1 import ResultSet - from google.cloud.spanner_v1 import ResultSetStats from google.cloud.spanner_v1 import ExecuteBatchDmlRequest from google.cloud.spanner_v1 import ExecuteBatchDmlResponse from google.cloud.spanner_v1 import TransactionSelector From 0bd5fd1939606a5cc709e357036903e0ac7c9a8e Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Thu, 5 Jun 2025 08:53:11 -0700 Subject: [PATCH 37/41] feat: Multiplexed sessions - Remove unnecessary TODO Signed-off-by: Taylor Curran --- google/cloud/spanner_v1/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index 51900344ab..89f610d988 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -467,7 +467,6 @@ def batch(self): return Batch(self) - # TODO multiplexed - remove def transaction(self): """Create a transaction to perform a set of reads with shared staleness. @@ -478,6 +477,7 @@ def transaction(self): if self._session_id is None: raise ValueError("Session has not been created.") + # TODO multiplexed - remove if self._transaction is not None: self._transaction.rolled_back = True self._transaction = None From 4cb1f0560bab765b83883f64c8b515b4fad2039f Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Thu, 5 Jun 2025 09:50:34 -0700 Subject: [PATCH 38/41] feat: Multiplexed sessions - Remove unnecessary constants. Signed-off-by: Taylor Curran --- tests/_helpers.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/tests/_helpers.py b/tests/_helpers.py index 2f5eed98de..89b7750e02 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -34,15 +34,6 @@ _TEST_OT_EXPORTER = None _TEST_OT_PROVIDER_INITIALIZED = False -# Environment variables for enabling multiplexed sessions -"GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS" -ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED = ( - "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_PARTITIONED_OPS" -) -ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE = ( - "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_FOR_RW" -) - def is_multiplexed_enabled(transaction_type: TransactionType) -> bool: """Returns whether multiplexed sessions are enabled for the given transaction type.""" From 17f3c5ff9721cb7b2377b8f0b54ca5e3c1e1af64 Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Thu, 5 Jun 2025 13:06:02 -0700 Subject: [PATCH 39/41] feat: Multiplexed sessions - Remove support for disabling the use of multiplexed sessions due to runtime failures. Signed-off-by: Taylor Curran --- google/cloud/spanner_dbapi/connection.py | 2 +- google/cloud/spanner_v1/client.py | 3 - google/cloud/spanner_v1/database.py | 6 +- .../spanner_v1/database_sessions_manager.py | 128 ++++++++++++---- google/cloud/spanner_v1/session_options.py | 138 ----------------- tests/_helpers.py | 2 +- tests/system/test_observability_options.py | 2 +- tests/system/test_session_api.py | 2 +- tests/unit/spanner_dbapi/test_connection.py | 2 +- tests/unit/test_database.py | 6 +- tests/unit/test_database_session_manager.py | 133 +++++++++++----- tests/unit/test_session_options.py | 145 ------------------ 12 files changed, 199 insertions(+), 370 deletions(-) delete mode 100644 google/cloud/spanner_v1/session_options.py delete mode 100644 tests/unit/test_session_options.py diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index ef0db6f784..1a2b117e4c 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -28,7 +28,7 @@ from google.cloud.spanner_dbapi.transaction_helper import TransactionRetryHelper from google.cloud.spanner_dbapi.cursor import Cursor from google.cloud.spanner_v1 import RequestOptions, TransactionOptions -from google.cloud.spanner_v1.session_options import TransactionType +from google.cloud.spanner_v1.database_sessions_manager import TransactionType from google.cloud.spanner_v1.snapshot import Snapshot from google.cloud.spanner_dbapi.exceptions import ( diff --git a/google/cloud/spanner_v1/client.py b/google/cloud/spanner_v1/client.py index 10db8c136e..e0e8c44058 100644 --- a/google/cloud/spanner_v1/client.py +++ b/google/cloud/spanner_v1/client.py @@ -60,7 +60,6 @@ from google.cloud.spanner_v1.metrics.metrics_exporter import ( CloudMonitoringMetricsExporter, ) -from google.cloud.spanner_v1.session_options import SessionOptions try: from opentelemetry import metrics @@ -270,8 +269,6 @@ def __init__( self._nth_client_id = Client.NTH_CLIENT.increment() self._nth_request = AtomicCounter(0) - self._session_options = SessionOptions() - @property def _next_nth_request(self): return self._nth_request.increment() diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 822531a435..5cff58bcb2 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -61,8 +61,10 @@ from google.cloud.spanner_v1.merged_result_set import MergedResultSet from google.cloud.spanner_v1.pool import BurstyPool from google.cloud.spanner_v1.session import Session -from google.cloud.spanner_v1.session_options import TransactionType -from google.cloud.spanner_v1.database_sessions_manager import DatabaseSessionsManager +from google.cloud.spanner_v1.database_sessions_manager import ( + DatabaseSessionsManager, + TransactionType, +) from google.cloud.spanner_v1.snapshot import _restart_on_unavailable from google.cloud.spanner_v1.snapshot import Snapshot from google.cloud.spanner_v1.streamed import StreamedResultSet diff --git a/google/cloud/spanner_v1/database_sessions_manager.py b/google/cloud/spanner_v1/database_sessions_manager.py index 44ca8502c0..09f93cdcd6 100644 --- a/google/cloud/spanner_v1/database_sessions_manager.py +++ b/google/cloud/spanner_v1/database_sessions_manager.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from enum import Enum +from os import getenv from datetime import timedelta from threading import Event, Lock, Thread from time import sleep, time @@ -18,13 +20,20 @@ from weakref import ref from google.cloud.spanner_v1.session import Session -from google.cloud.spanner_v1.session_options import TransactionType from google.cloud.spanner_v1._opentelemetry_tracing import ( get_current_span, add_span_event, ) +class TransactionType(Enum): + """Transaction types for session options.""" + + READ_ONLY = "read-only" + PARTITIONED = "partitioned" + READ_WRITE = "read/write" + + class DatabaseSessionsManager(object): """Manages sessions for a Cloud Spanner database. @@ -32,9 +41,8 @@ class DatabaseSessionsManager(object): transaction type using :meth:`get_session`, and returned to the session manager using :meth:`put_session`. - The sessions returned by the session manager depend on the client's session options - (see :class:`~google.cloud.spanner_v1.session_options.SessionOptions`) and the - provided session pool (see :class:`~google.cloud.spanner_v1.pool.AbstractSessionPool`). + The sessions returned by the session manager depend on the configured environment variables + and the provided session pool (see :class:`~google.cloud.spanner_v1.pool.AbstractSessionPool`). :type database: :class:`~google.cloud.spanner_v1.database.Database` :param database: The database to manage sessions for. @@ -43,6 +51,13 @@ class DatabaseSessionsManager(object): :param pool: The pool to get non-multiplexed sessions from. """ + # Environment variables for multiplexed sessions + _ENV_VAR_MULTIPLEXED = "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS" + _ENV_VAR_MULTIPLEXED_PARTITIONED = ( + "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_PARTITIONED_OPS" + ) + _ENV_VAR_MULTIPLEXED_READ_WRITE = "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_FOR_RW" + # Intervals for the maintenance thread to check and refresh the multiplexed session. _MAINTENANCE_THREAD_POLLING_INTERVAL = timedelta(minutes=10) _MAINTENANCE_THREAD_REFRESH_INTERVAL = timedelta(days=7) @@ -55,13 +70,14 @@ def __init__(self, database, pool): # database session manager is created, a maintenance thread is initialized to # periodically delete and recreate the multiplexed session so that it remains # valid. Because of this concurrency, we need to use a lock whenever we access - # the multiplexed session to avoid any race conditions. We also create an event - # so that the thread can terminate if the use of multiplexed session has been - # disabled for all transactions. + # the multiplexed session to avoid any race conditions. self._multiplexed_session: Optional[Session] = None self._multiplexed_session_thread: Optional[Thread] = None self._multiplexed_session_lock: Lock = Lock() - self._multiplexed_session_disabled_event: Event = Event() + + # Event to terminate the maintenance thread. + # Only used for testing purposes. + self._multiplexed_session_terminate_event: Event = Event() def get_session(self, transaction_type: TransactionType) -> Session: """Returns a session for the given transaction type from the database session manager. @@ -70,8 +86,7 @@ def get_session(self, transaction_type: TransactionType) -> Session: :returns: a session for the given transaction type. """ - session_options = self._database._instance._client._session_options - use_multiplexed = session_options.use_multiplexed(transaction_type) + use_multiplexed = self._use_multiplexed(transaction_type) # TODO multiplexed: enable for read/write transactions if use_multiplexed and transaction_type == TransactionType.READ_WRITE: @@ -149,15 +164,6 @@ def _build_multiplexed_session(self) -> Session: return session - def _disable_multiplexed_sessions(self) -> None: - """Disables multiplexed sessions for all transactions.""" - - self._multiplexed_session = None - self._multiplexed_session_disabled_event.set() - - session_options = self._database._instance._client._session_options - session_options.disable_multiplexed(self._database.logger) - def _build_maintenance_thread(self) -> Thread: """Builds and returns a multiplexed session maintenance thread for the database session manager. This thread will periodically delete @@ -185,34 +191,33 @@ def _maintain_multiplexed_session(session_manager_ref) -> None: This method will delete and recreate the referenced database session manager's multiplexed session to ensure that it is always valid. The method will run until - the database session manager is deleted, the multiplexed session is deleted, or - building a multiplexed session fails. + the database session manager is deleted or the multiplexed session is deleted. :type session_manager_ref: :class:`_weakref.ReferenceType` :param session_manager_ref: A weak reference to the database session manager. """ - session_manager = session_manager_ref() - if session_manager is None: + manager = session_manager_ref() + if manager is None: return polling_interval_seconds = ( - session_manager._MAINTENANCE_THREAD_POLLING_INTERVAL.total_seconds() + manager._MAINTENANCE_THREAD_POLLING_INTERVAL.total_seconds() ) refresh_interval_seconds = ( - session_manager._MAINTENANCE_THREAD_REFRESH_INTERVAL.total_seconds() + manager._MAINTENANCE_THREAD_REFRESH_INTERVAL.total_seconds() ) session_created_time = time() while True: # Terminate the thread is the database session manager has been deleted. - session_manager = session_manager_ref() - if session_manager is None: + manager = session_manager_ref() + if manager is None: return - # Terminate the thread if the use of multiplexed sessions has been disabled. - if session_manager._multiplexed_session_disabled_event.is_set(): + # Terminate the thread if corresponding event is set. + if manager._multiplexed_session_terminate_event.is_set(): return # Wait for until the refresh interval has elapsed. @@ -220,10 +225,65 @@ def _maintain_multiplexed_session(session_manager_ref) -> None: sleep(polling_interval_seconds) continue - with session_manager._multiplexed_session_lock: - session_manager._multiplexed_session.delete() - session_manager._multiplexed_session = ( - session_manager._build_multiplexed_session() - ) + with manager._multiplexed_session_lock: + manager._multiplexed_session.delete() + manager._multiplexed_session = manager._build_multiplexed_session() session_created_time = time() + + @classmethod + def _use_multiplexed(cls, transaction_type: TransactionType) -> bool: + """Returns whether to use multiplexed sessions for the given transaction type. + + Multiplexed sessions are enabled for read-only transactions if: + * _ENV_VAR_MULTIPLEXED is set to true. + + Multiplexed sessions are enabled for partitioned transactions if: + * _ENV_VAR_MULTIPLEXED is set to true; and + * _ENV_VAR_MULTIPLEXED_PARTITIONED is set to true. + + Multiplexed sessions are enabled for read/write transactions if: + * _ENV_VAR_MULTIPLEXED is set to true; and + * _ENV_VAR_MULTIPLEXED_READ_WRITE is set to true. + + :type transaction_type: :class:`TransactionType` + :param transaction_type: the type of transaction + + :rtype: bool + :returns: True if multiplexed sessions should be used for the given transaction + type, False otherwise. + + :raises ValueError: if the transaction type is not supported. + """ + + if transaction_type is TransactionType.READ_ONLY: + return cls._getenv(cls._ENV_VAR_MULTIPLEXED) + + elif transaction_type is TransactionType.PARTITIONED: + return cls._getenv(cls._ENV_VAR_MULTIPLEXED) and cls._getenv( + cls._ENV_VAR_MULTIPLEXED_PARTITIONED + ) + + elif transaction_type is TransactionType.READ_WRITE: + return cls._getenv(cls._ENV_VAR_MULTIPLEXED) and cls._getenv( + cls._ENV_VAR_MULTIPLEXED_READ_WRITE + ) + + raise ValueError(f"Transaction type {transaction_type} is not supported.") + + @classmethod + def _getenv(cls, env_var_name: str) -> bool: + """Returns the value of the given environment variable as a boolean. + + True values are '1' and 'true' (case-insensitive). + All other values are considered false. + + :type env_var_name: str + :param env_var_name: the name of the boolean environment variable + + :rtype: bool + :returns: True if the environment variable is set to a true value, False otherwise. + """ + + env_var_value = getenv(env_var_name, "").lower().strip() + return env_var_value in ["1", "true"] diff --git a/google/cloud/spanner_v1/session_options.py b/google/cloud/spanner_v1/session_options.py deleted file mode 100644 index 7e68b235a3..0000000000 --- a/google/cloud/spanner_v1/session_options.py +++ /dev/null @@ -1,138 +0,0 @@ -# Copyright 2025 Google LLC All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -import os -from enum import Enum -from typing import Mapping - - -class TransactionType(Enum): - """Transaction types for session options.""" - - READ_ONLY = "read-only" - PARTITIONED = "partitioned" - READ_WRITE = "read/write" - - -class SessionOptions(object): - """Represents the session options for the Cloud Spanner Python client. - We can use :class:`SessionOptions` to determine whether multiplexed sessions - should be used for a specific transaction type with :meth:`use_multiplexed`. The use - of multiplexed session can be disabled for a specific transaction type or for all - transaction types with :meth:`disable_multiplexed`. - """ - - # Environment variables for multiplexed sessions - ENV_VAR_ENABLE_MULTIPLEXED = "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS" - ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED = ( - "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_PARTITIONED_OPS" - ) - ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE = ( - "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_FOR_RW" - ) - - def __init__(self): - # Internal overrides to disable the use of multiplexed - # sessions in case of runtime errors. - self._is_multiplexed_enabled: Mapping[TransactionType, str] = { - TransactionType.READ_ONLY: True, - TransactionType.PARTITIONED: True, - TransactionType.READ_WRITE: True, - } - - def use_multiplexed(self, transaction_type: TransactionType) -> bool: - """Returns whether to use multiplexed sessions for the given transaction type. - - Multiplexed sessions are enabled for read-only transactions if: - * ENV_VAR_ENABLE_MULTIPLEXED is set to true; and - * multiplexed sessions have not been disabled for read-only transactions. - - Multiplexed sessions are enabled for partitioned transactions if: - * ENV_VAR_ENABLE_MULTIPLEXED is set to true; - * ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED is set to true; and - * multiplexed sessions have not been disabled for partitioned transactions. - - Multiplexed sessions are enabled for read/write transactions if: - * ENV_VAR_ENABLE_MULTIPLEXED is set to true; - * ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE is set to true; and - * multiplexed sessions have not been disabled for read/write transactions. - - :type transaction_type: :class:`TransactionType` - :param transaction_type: the type of transaction - """ - - if transaction_type is TransactionType.READ_ONLY: - return ( - self._getenv(self.ENV_VAR_ENABLE_MULTIPLEXED) - and self._is_multiplexed_enabled[transaction_type] - ) - - elif transaction_type is TransactionType.PARTITIONED: - return ( - self._getenv(self.ENV_VAR_ENABLE_MULTIPLEXED) - and self._getenv(self.ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED) - and self._is_multiplexed_enabled[transaction_type] - ) - - elif transaction_type is TransactionType.READ_WRITE: - return ( - self._getenv(self.ENV_VAR_ENABLE_MULTIPLEXED) - and self._getenv(self.ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE) - and self._is_multiplexed_enabled[transaction_type] - ) - - raise ValueError(f"Transaction type {transaction_type} is not supported.") - - def disable_multiplexed( - self, logger: logging.Logger = None, transaction_type: TransactionType = None - ) -> None: - """Disables the use of multiplexed sessions for the given transaction type. - If no transaction type is specified, disables the use of multiplexed sessions - for all transaction types. - - :type logger: :class:`Logger` - :param logger: logger for logging disabling the use of multiplexed sessions. - - :type transaction_type: :class:`TransactionType` - :param transaction_type: (Optional) the type of transaction for which to disable - the use of multiplexed sessions. - """ - - if transaction_type and transaction_type not in self._is_multiplexed_enabled: - raise ValueError(f"Transaction type '{transaction_type}' is not supported.") - - logger = logger or logging.getLogger(__name__) - - transaction_types_to_disable = ( - [transaction_type] - if transaction_type is not None - else list(TransactionType) - ) - - for transaction_type_to_disable in transaction_types_to_disable: - logger.warning( - f"Disabling multiplexed sessions for {transaction_type_to_disable.value} transactions" - ) - self._is_multiplexed_enabled[transaction_type_to_disable] = False - - return - - @staticmethod - def _getenv(name: str) -> bool: - """Returns the value of the given environment variable as a boolean. - True values are '1' and 'true' (case-insensitive); all other values are - considered false. - """ - env_var = os.getenv(name, "").lower().strip() - return env_var in ["1", "true"] diff --git a/tests/_helpers.py b/tests/_helpers.py index 89b7750e02..32feedc514 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -4,7 +4,7 @@ import mock from google.cloud.spanner_v1 import gapic_version -from google.cloud.spanner_v1.session_options import TransactionType +from google.cloud.spanner_v1.database_sessions_manager import TransactionType LIB_VERSION = gapic_version.__version__ diff --git a/tests/system/test_observability_options.py b/tests/system/test_observability_options.py index d25f5e73d7..50a6432d3b 100644 --- a/tests/system/test_observability_options.py +++ b/tests/system/test_observability_options.py @@ -16,7 +16,7 @@ from mock import PropertyMock, patch from google.cloud.spanner_v1.session import Session -from google.cloud.spanner_v1.session_options import TransactionType +from google.cloud.spanner_v1.database_sessions_manager import TransactionType from . import _helpers from google.cloud.spanner_v1 import Client from google.api_core.exceptions import Aborted diff --git a/tests/system/test_session_api.py b/tests/system/test_session_api.py index 957ac0bdb5..1b4a6dc183 100644 --- a/tests/system/test_session_api.py +++ b/tests/system/test_session_api.py @@ -32,7 +32,7 @@ from google.cloud.spanner_v1._helpers import AtomicCounter from google.cloud.spanner_v1.data_types import JsonObject -from google.cloud.spanner_v1.session_options import TransactionType +from google.cloud.spanner_v1.database_sessions_manager import TransactionType from .testdata import singer_pb2 from tests import _helpers as ot_helpers from . import _helpers diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index dbef230417..0bfab5bab9 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -37,7 +37,7 @@ ClientSideStatementType, AutocommitDmlMode, ) -from google.cloud.spanner_v1.session_options import TransactionType +from google.cloud.spanner_v1.database_sessions_manager import TransactionType from tests._builders import build_connection, build_session PROJECT = "test-project" diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 258e9913f0..3668edfe5b 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -21,6 +21,7 @@ Database as DatabasePB, DatabaseDialect, ) + from google.cloud.spanner_v1.param_types import INT64 from google.api_core.retry import Retry from google.protobuf.field_mask_pb2 import FieldMask @@ -36,7 +37,7 @@ ) from google.cloud.spanner_v1.request_id_header import REQ_RAND_PROCESS_ID from google.cloud.spanner_v1.session import Session -from google.cloud.spanner_v1.session_options import SessionOptions, TransactionType +from google.cloud.spanner_v1.database_sessions_manager import TransactionType from tests._builders import build_spanner_api from tests._helpers import is_multiplexed_enabled @@ -1449,8 +1450,6 @@ def _execute_partitioned_dml_helper( # Verify that the correct session type was used based on environment if multiplexed_partitioned_enabled: # Verify that sessions_manager.get_session was called with PARTITIONED transaction type - from google.cloud.spanner_v1.session_options import TransactionType - database._sessions_manager.get_session.assert_called_with( TransactionType.PARTITIONED ) @@ -3498,7 +3497,6 @@ def __init__( self.observability_options = observability_options self._nth_client_id = _Client.NTH_CLIENT.increment() self._nth_request = AtomicCounter() - self._session_options = SessionOptions() @property def _next_nth_request(self): diff --git a/tests/unit/test_database_session_manager.py b/tests/unit/test_database_session_manager.py index 89dd21012a..7626bd0d60 100644 --- a/tests/unit/test_database_session_manager.py +++ b/tests/unit/test_database_session_manager.py @@ -13,14 +13,14 @@ # limitations under the License. from datetime import timedelta from mock import Mock, patch -from threading import Thread +from os import environ from time import time, sleep from typing import Callable from unittest import TestCase from google.api_core.exceptions import BadRequest, FailedPrecondition from google.cloud.spanner_v1.database_sessions_manager import DatabaseSessionsManager -from google.cloud.spanner_v1.session_options import TransactionType +from google.cloud.spanner_v1.database_sessions_manager import TransactionType from tests._builders import build_database @@ -31,6 +31,17 @@ _MAINTENANCE_THREAD_REFRESH_INTERVAL=timedelta(seconds=2), ) class TestDatabaseSessionManager(TestCase): + @classmethod + def setUpClass(cls): + # Save the original environment variables. + cls._original_env = dict(environ) + + @classmethod + def tearDownClass(cls): + # Restore environment variables. + environ.clear() + environ.update(cls._original_env) + def setUp(self): # Build session manager. database = build_database() @@ -42,15 +53,15 @@ def setUp(self): pool.put = Mock(wraps=pool.put) def tearDown(self): - # If the maintenance thread is still alive, disable multiplexed sessions and - # wait for the thread to terminate. We need to do this to ensure that the - # thread is properly cleaned up and does not interfere with other tests. - sessions_manager = self._manager - thread = sessions_manager._multiplexed_session_thread + # If the maintenance thread is still alive, set the event and wait + # for the thread to terminate. We need to do this to ensure that the + # thread does not interfere with other tests. + manager = self._manager + thread = manager._multiplexed_session_thread if thread and thread.is_alive(): - sessions_manager._multiplexed_session_disabled_event.set() - self._assert_thread_terminated(thread) + manager._multiplexed_session_terminate_event.set() + self._assert_true_with_timeout(lambda: not thread.is_alive()) def test_read_only_pooled(self): manager = self._manager @@ -175,25 +186,11 @@ def test_multiplexed_maintenance(self): info = manager._database.logger.info info.assert_called_with("Created multiplexed session.") - def test_multiplexed_maintenance_terminates_disabled(self): - manager = self._manager - self._enable_multiplexed_sessions() - - # Maintenance thread is started. - session_1 = manager.get_session(TransactionType.READ_ONLY) - self.assertTrue(session_1.is_multiplexed) - - manager._multiplexed_session_disabled_event.set() - - thread = manager._multiplexed_session_thread - self._assert_thread_terminated(thread) - def test_exception_bad_request(self): manager = self._manager api = manager._database.spanner_api api.create_session.side_effect = BadRequest("") - # Verify that BadRequest is not caught. with self.assertRaises(BadRequest): manager.get_session(TransactionType.READ_ONLY) @@ -202,12 +199,76 @@ def test_exception_failed_precondition(self): api = manager._database.spanner_api api.create_session.side_effect = FailedPrecondition("") - # Verify that FailedPrecondition is not caught. with self.assertRaises(FailedPrecondition): manager.get_session(TransactionType.READ_ONLY) + def test__use_multiplexed_read_only(self): + transaction_type = TransactionType.READ_ONLY + + environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED] = "false" + self.assertFalse(DatabaseSessionsManager._use_multiplexed(transaction_type)) + + environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED] = "true" + self.assertTrue(DatabaseSessionsManager._use_multiplexed(transaction_type)) + + def test__use_multiplexed_partitioned(self): + transaction_type = TransactionType.PARTITIONED + + environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED] = "false" + self.assertFalse(DatabaseSessionsManager._use_multiplexed(transaction_type)) + + environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED] = "true" + environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED_PARTITIONED] = "false" + self.assertFalse(DatabaseSessionsManager._use_multiplexed(transaction_type)) + + environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED_PARTITIONED] = "true" + self.assertTrue(DatabaseSessionsManager._use_multiplexed(transaction_type)) + + def test__use_multiplexed_read_write(self): + transaction_type = TransactionType.READ_WRITE + + environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED] = "false" + self.assertFalse(DatabaseSessionsManager._use_multiplexed(transaction_type)) + + environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED] = "true" + environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED_READ_WRITE] = "false" + self.assertFalse(DatabaseSessionsManager._use_multiplexed(transaction_type)) + + environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED_READ_WRITE] = "true" + self.assertTrue(DatabaseSessionsManager._use_multiplexed(transaction_type)) + + def test__use_multiplexed_unsupported_transaction_type(self): + unsupported_type = "UNSUPPORTED_TRANSACTION_TYPE" + + with self.assertRaises(ValueError): + DatabaseSessionsManager._use_multiplexed(unsupported_type) + + def test__getenv(self): + true_values = ["1", " 1", " 1", "true", "True", "TRUE", " true "] + for value in true_values: + environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED] = value + self.assertTrue( + DatabaseSessionsManager._use_multiplexed(TransactionType.READ_ONLY) + ) + + false_values = ["", "0", "false", "False", "FALSE", " false "] + for value in false_values: + environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED] = value + self.assertFalse( + DatabaseSessionsManager._use_multiplexed(TransactionType.READ_ONLY) + ) + + del environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED] + self.assertFalse( + DatabaseSessionsManager._use_multiplexed(TransactionType.READ_ONLY) + ) + def _assert_true_with_timeout(self, condition: Callable) -> None: - """Asserts that the given condition is met within a timeout period.""" + """Asserts that the given condition is met within a timeout period. + + :type condition: Callable + :param condition: A callable that returns a boolean indicating whether the condition is met. + """ sleep_seconds = 0.1 timeout_seconds = 10 @@ -218,22 +279,16 @@ def _assert_true_with_timeout(self, condition: Callable) -> None: self.assertTrue(condition()) - def _assert_thread_terminated(self, thread: Thread) -> None: - """Asserts that the given thread is terminated.""" - - def _is_thread_terminated(): - return not thread.is_alive() - - self._assert_true_with_timeout(_is_thread_terminated) - - def _disable_multiplexed_sessions(self) -> None: + @staticmethod + def _disable_multiplexed_sessions() -> None: """Sets environment variables to disable multiplexed sessions for all transactions types.""" - options = self._manager._database._instance._client._session_options - options.use_multiplexed = Mock(return_value=False) + environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED] = "false" - def _enable_multiplexed_sessions(self) -> None: + @staticmethod + def _enable_multiplexed_sessions() -> None: """Sets environment variables to enable multiplexed sessions for all transaction types.""" - options = self._manager._database._instance._client._session_options - options.use_multiplexed = Mock(return_value=True) + environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED] = "true" + environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED_PARTITIONED] = "true" + environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED_READ_WRITE] = "true" diff --git a/tests/unit/test_session_options.py b/tests/unit/test_session_options.py deleted file mode 100644 index 18291eae34..0000000000 --- a/tests/unit/test_session_options.py +++ /dev/null @@ -1,145 +0,0 @@ -# Copyright 2025 Google LLC All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from logging import Logger -from os import environ -from unittest import TestCase - -from google.cloud.spanner_v1.session_options import SessionOptions, TransactionType -from tests._builders import build_logger - - -class TestSessionOptions(TestCase): - @classmethod - def setUpClass(cls): - # Save the original environment variables. - cls._original_env = dict(environ) - - @classmethod - def tearDownClass(cls): - # Restore environment variables. - environ.clear() - environ.update(cls._original_env) - - def setUp(self): - self.logger: Logger = build_logger() - - def test_use_multiplexed_for_read_only(self): - session_options = SessionOptions() - transaction_type = TransactionType.READ_ONLY - - environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "false" - self.assertFalse(session_options.use_multiplexed(transaction_type)) - - environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "true" - self.assertTrue(session_options.use_multiplexed(transaction_type)) - - session_options.disable_multiplexed(self.logger, transaction_type) - self.assertFalse(session_options.use_multiplexed(transaction_type)) - - self.logger.warning.assert_called_once_with( - "Disabling multiplexed sessions for read-only transactions" - ) - - def test_use_multiplexed_for_partitioned(self): - session_options = SessionOptions() - transaction_type = TransactionType.PARTITIONED - - environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "false" - self.assertFalse(session_options.use_multiplexed(transaction_type)) - - environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "true" - environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED] = "false" - self.assertFalse(session_options.use_multiplexed(transaction_type)) - - environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED] = "true" - self.assertTrue(session_options.use_multiplexed(transaction_type)) - - session_options.disable_multiplexed(self.logger, transaction_type) - self.assertFalse(session_options.use_multiplexed(transaction_type)) - - self.logger.warning.assert_called_once_with( - "Disabling multiplexed sessions for partitioned transactions" - ) - - def test_use_multiplexed_for_read_write(self): - session_options = SessionOptions() - transaction_type = TransactionType.READ_WRITE - - environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "false" - self.assertFalse(session_options.use_multiplexed(transaction_type)) - - environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "true" - environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE] = "false" - self.assertFalse(session_options.use_multiplexed(transaction_type)) - - environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE] = "true" - self.assertTrue(session_options.use_multiplexed(transaction_type)) - - session_options.disable_multiplexed(self.logger, transaction_type) - self.assertFalse(session_options.use_multiplexed(transaction_type)) - - self.logger.warning.assert_called_once_with( - "Disabling multiplexed sessions for read/write transactions" - ) - - def test_disable_multiplexed_all(self): - session_options = SessionOptions() - - environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = "true" - environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_PARTITIONED] = "true" - environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED_FOR_READ_WRITE] = "true" - - session_options.disable_multiplexed(self.logger) - - self.assertFalse(session_options.use_multiplexed(TransactionType.READ_ONLY)) - self.assertFalse(session_options.use_multiplexed(TransactionType.PARTITIONED)) - self.assertFalse(session_options.use_multiplexed(TransactionType.READ_WRITE)) - - warning = self.logger.warning - self.assertEqual(warning.call_count, 3) - warning.assert_any_call( - "Disabling multiplexed sessions for read-only transactions" - ) - warning.assert_any_call( - "Disabling multiplexed sessions for partitioned transactions" - ) - warning.assert_any_call( - "Disabling multiplexed sessions for read/write transactions" - ) - - def test_unsupported_transaction_type(self): - session_options = SessionOptions() - unsupported_type = "UNSUPPORTED_TRANSACTION_TYPE" - - with self.assertRaises(ValueError): - session_options.use_multiplexed(unsupported_type) - - with self.assertRaises(ValueError): - session_options.disable_multiplexed(self.logger, unsupported_type) - - def test_env_var_values(self): - session_options = SessionOptions() - - true_values = ["1", " 1", " 1", "true", "True", "TRUE", " true "] - for value in true_values: - environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = value - self.assertTrue(session_options.use_multiplexed(TransactionType.READ_ONLY)) - - false_values = ["", "0", "false", "False", "FALSE", " false "] - for value in false_values: - environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] = value - self.assertFalse(session_options.use_multiplexed(TransactionType.READ_ONLY)) - - del environ[SessionOptions.ENV_VAR_ENABLE_MULTIPLEXED] - self.assertFalse(session_options.use_multiplexed(TransactionType.READ_ONLY)) From 9c3db74b102336bffdecaf7ece18ffb5f4d4d001 Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Mon, 9 Jun 2025 06:57:58 -0700 Subject: [PATCH 40/41] feat: Multiplexed sessions - Make deprecation comments a bit more clear. Signed-off-by: Taylor Curran --- google/cloud/spanner_v1/database.py | 5 +++-- google/cloud/spanner_v1/pool.py | 10 ++++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 5cff58bcb2..0bbc794920 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -831,8 +831,9 @@ def _nth_client_id(self): def session(self, labels=None, database_role=None): """Factory to create a session for this database. - Deprecated. Sessions should be checked out using context - managers, rather than retrieved directly from the database. + Deprecated. Sessions should be checked out indirectly using context + managers or :meth:`~google.cloud.spanner_v1.database.Database.run_in_transaction`, + rather than built directly from the database. :type labels: dict (str -> str) or None :param labels: (Optional) user-assigned labels for the session. diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index 0257cf1211..a75c13cb7a 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -138,8 +138,9 @@ def _new_session(self): def session(self, **kwargs): """Check out a session from the pool. - Deprecated. Sessions should be checked out using context - managers, rather than directly from the pool. + Deprecated. Sessions should be checked out indirectly using context + managers or :meth:`~google.cloud.spanner_v1.database.Database.run_in_transaction`, + rather than checked out directly from the pool. :param kwargs: (optional) keyword arguments, passed through to the returned checkout. @@ -796,8 +797,9 @@ def begin_pending_transactions(self): class SessionCheckout(object): """Context manager: hold session checked out from a pool. - Deprecated. Sessions should be checked out using context - managers, rather than directly from the pool. + Deprecated. Sessions should be checked out indirectly using context + managers or :meth:`~google.cloud.spanner_v1.database.Database.run_in_transaction`, + rather than checked out directly from the pool. :type pool: concrete subclass of :class:`~google.cloud.spanner_v1.pool.AbstractSessionPool` From 17a2f4701a0e5aba7965f284b1c4a2e044ee2f71 Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Mon, 9 Jun 2025 07:07:43 -0700 Subject: [PATCH 41/41] feat: Multiplexed sessions - Add some more type hints. Signed-off-by: Taylor Curran --- google/cloud/spanner_v1/database.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 0bbc794920..e8ddc48c60 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -16,6 +16,7 @@ import copy import functools +from typing import Optional import grpc import logging @@ -1305,8 +1306,8 @@ def __init__( **kw, ): self._database: Database = database - self._session: Session = None - self._batch: Batch = None + self._session: Optional[Session] = None + self._batch: Optional[Batch] = None if request_options is None: self._request_options = RequestOptions() @@ -1381,7 +1382,7 @@ class MutationGroupsCheckout(object): def __init__(self, database): self._database: Database = database - self._session: Session = None + self._session: Optional[Session] = None def __enter__(self): """Begin ``with`` block.""" @@ -1421,7 +1422,7 @@ class SnapshotCheckout(object): def __init__(self, database, **kw): self._database: Database = database - self._session: Session = None + self._session: Optional[Session] = None self._kw: dict = kw def __enter__(self): @@ -1464,11 +1465,14 @@ def __init__( session_id=None, transaction_id=None, ): - self._database = database - self._session_id = session_id - self._session = None - self._snapshot = None - self._transaction_id = transaction_id + self._database: Database = database + + self._session_id: Optional[str] = session_id + self._transaction_id: Optional[bytes] = transaction_id + + self._session: Optional[Session] = None + self._snapshot: Optional[Snapshot] = None + self._read_timestamp = read_timestamp self._exact_staleness = exact_staleness