diff --git a/aws_advanced_python_wrapper/cluster_topology_monitor.py b/aws_advanced_python_wrapper/cluster_topology_monitor.py index b6d5f23a..2e67b961 100644 --- a/aws_advanced_python_wrapper/cluster_topology_monitor.py +++ b/aws_advanced_python_wrapper/cluster_topology_monitor.py @@ -18,14 +18,15 @@ import time from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor -from typing import TYPE_CHECKING, Dict, Optional, Tuple +from typing import TYPE_CHECKING, Dict, Optional from aws_advanced_python_wrapper.host_availability import HostAvailability from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.utils.atomic import AtomicReference -from aws_advanced_python_wrapper.utils.cache_map import CacheMap from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils +from aws_advanced_python_wrapper.utils.storage.storage_service import ( + StorageService, Topology) from aws_advanced_python_wrapper.utils.thread_safe_connection_holder import \ ThreadSafeConnectionHolder from aws_advanced_python_wrapper.utils.utils import LogUtils @@ -46,11 +47,11 @@ class ClusterTopologyMonitor(ABC): @abstractmethod - def force_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Tuple[HostInfo, ...]: + def force_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Topology: pass @abstractmethod - def force_refresh_with_connection(self, connection: Connection, timeout_sec: int) -> Tuple[HostInfo, ...]: + def force_refresh_with_connection(self, connection: Connection, timeout_sec: int) -> Topology: pass @abstractmethod @@ -75,8 +76,6 @@ class ClusterTopologyMonitorImpl(ClusterTopologyMonitor): INITIAL_BACKOFF_MS = 100 MAX_BACKOFF_MS = 10000 - _topology_map: CacheMap[str, Tuple[HostInfo, ...]] = CacheMap() - def __init__(self, plugin_service: PluginService, topology_utils: TopologyUtils, cluster_id: str, initial_host_info: HostInfo, properties: Properties, instance_template: HostInfo, refresh_rate_nano: int, high_refresh_rate_nano: int): @@ -103,7 +102,7 @@ def __init__(self, plugin_service: PluginService, topology_utils: TopologyUtils, self._host_threads_writer_connection: AtomicReference[Optional[Connection]] = AtomicReference(None) self._host_threads_writer_host_info: AtomicReference[Optional[HostInfo]] = AtomicReference(None) self._host_threads_reader_connection: AtomicReference[Optional[Connection]] = AtomicReference(None) - self._host_threads_latest_topology: AtomicReference[Optional[Tuple[HostInfo, ...]]] = AtomicReference(None) + self._host_threads_latest_topology: AtomicReference[Optional[Topology]] = AtomicReference(None) self._is_verified_writer_connection = False self._high_refresh_rate_end_time_nano = 0 @@ -118,7 +117,7 @@ def __init__(self, plugin_service: PluginService, topology_utils: TopologyUtils, self._start_monitoring() - def force_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Tuple[HostInfo, ...]: + def force_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Topology: current_time_nano = time.time_ns() if (self._ignore_new_topology_requests_end_time_nano > 0 and current_time_nano < self._ignore_new_topology_requests_end_time_nano): @@ -134,12 +133,12 @@ def force_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Tuple[H result = self._wait_till_topology_gets_updated(timeout_sec) return result - def force_refresh_with_connection(self, connection: Connection, timeout_sec: int) -> Tuple[HostInfo, ...]: + def force_refresh_with_connection(self, connection: Connection, timeout_sec: int) -> Topology: if self._is_verified_writer_connection: return self._wait_till_topology_gets_updated(timeout_sec) return self._fetch_topology_and_update_cache(connection) - def _wait_till_topology_gets_updated(self, timeout_sec: int) -> Tuple[HostInfo, ...]: + def _wait_till_topology_gets_updated(self, timeout_sec: int) -> Topology: current_hosts = self._get_stored_hosts() self._request_to_update_topology.set() @@ -162,8 +161,8 @@ def _wait_till_topology_gets_updated(self, timeout_sec: int) -> Tuple[HostInfo, "ClusterTopologyMonitorImpl.TopologyNotUpdated", self._cluster_id, timeout_sec * 1000)) - def _get_stored_hosts(self) -> Tuple[HostInfo, ...]: - hosts = ClusterTopologyMonitorImpl._topology_map.get(self._cluster_id) + def _get_stored_hosts(self) -> Topology: + hosts = StorageService.get(Topology, self._cluster_id) if hosts is None: return () return hosts @@ -296,7 +295,7 @@ def _is_in_panic_mode(self) -> bool: def _get_host_monitor(self, host_info: HostInfo, writer_host_info: Optional[HostInfo]): return HostMonitor(self, host_info, writer_host_info) - def _open_any_connection_and_update_topology(self) -> Tuple[HostInfo, ...]: + def _open_any_connection_and_update_topology(self) -> Topology: writer_verified_by_this_thread = False if self._monitoring_connection.get() is None: # Try to connect to the initial host first @@ -409,7 +408,7 @@ def _delay(self, use_high_refresh_rate: bool) -> None: while not self._request_to_update_topology.is_set() and time.time() < end_time and not self._stop.is_set(): time.sleep(0.05) - def _fetch_topology_and_update_cache(self, connection: Optional[Connection]) -> Tuple[HostInfo, ...]: + def _fetch_topology_and_update_cache(self, connection: Optional[Connection]) -> Topology: if connection is None: return () @@ -423,7 +422,7 @@ def _fetch_topology_and_update_cache(self, connection: Optional[Connection]) -> logger.debug("ClusterTopologyMonitorImpl.ErrorFetchingTopology", self._cluster_id, ex) return () - def _fetch_topology_and_update_cache_safe(self) -> Tuple[HostInfo, ...]: + def _fetch_topology_and_update_cache_safe(self) -> Topology: """ Safely fetch topology using ThreadSafeConnectionHolder to prevent race conditions. The lock is held during the entire query operation. @@ -433,16 +432,14 @@ def _fetch_topology_and_update_cache_safe(self) -> Tuple[HostInfo, ...]: ) return result if result is not None else () - def _query_for_topology(self, connection: Connection) -> Tuple[HostInfo, ...]: + def _query_for_topology(self, connection: Connection) -> Topology: hosts = self._topology_utils.query_for_topology(connection, self._plugin_service.driver_dialect) if hosts is not None: return hosts return () - def _update_topology_cache(self, hosts: Tuple[HostInfo, ...]) -> None: - ClusterTopologyMonitorImpl._topology_map.put( - self._cluster_id, hosts, ClusterTopologyMonitorImpl.TOPOLOGY_CACHE_EXPIRATION_NANO) - + def _update_topology_cache(self, hosts: Topology) -> None: + StorageService.set(self._cluster_id, hosts, Topology) # Notify waiting threads self._request_to_update_topology.clear() self._topology_updated.set() diff --git a/aws_advanced_python_wrapper/host_list_provider.py b/aws_advanced_python_wrapper/host_list_provider.py index 53e82ab2..47d2bdea 100644 --- a/aws_advanced_python_wrapper/host_list_provider.py +++ b/aws_advanced_python_wrapper/host_list_provider.py @@ -30,6 +30,8 @@ preserve_transaction_status_with_timeout from aws_advanced_python_wrapper.utils.sliding_expiration_cache_container import \ SlidingExpirationCacheContainer +from aws_advanced_python_wrapper.utils.storage.storage_service import ( + StorageService, Topology) if TYPE_CHECKING: from aws_advanced_python_wrapper.driver_dialect import DriverDialect @@ -59,13 +61,13 @@ class HostListProvider(Protocol): - def refresh(self, connection: Optional[Connection] = None) -> Tuple[HostInfo, ...]: + def refresh(self, connection: Optional[Connection] = None) -> Topology: ... - def force_refresh(self, connection: Optional[Connection] = None) -> Tuple[HostInfo, ...]: + def force_refresh(self, connection: Optional[Connection] = None) -> Topology: ... - def force_monitoring_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Tuple[HostInfo, ...]: + def force_monitoring_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Topology: ... def get_host_role(self, connection: Connection) -> HostRole: @@ -149,7 +151,6 @@ def is_static_host_list_provider(self) -> bool: class RdsHostListProvider(DynamicHostListProvider, HostListProvider): - _topology_cache: CacheMap[str, Tuple[HostInfo, ...]] = CacheMap() # Maps cluster IDs to a boolean representing whether they are a primary cluster ID or not. A primary cluster ID is a # cluster ID that is equivalent to a cluster URL. Topology info is shared between RdsHostListProviders that have # the same cluster ID. @@ -164,9 +165,9 @@ def __init__(self, host_list_provider_service: HostListProviderService, props: P self._topology_utils = topology_utils self._rds_utils: RdsUtils = RdsUtils() - self._hosts: Tuple[HostInfo, ...] = () + self._hosts: Topology = () self._cluster_id: str = str(uuid.uuid4()) - self._initial_hosts: Tuple[HostInfo, ...] = () + self._initial_hosts: Topology = () self._rds_url_type: Optional[RdsUrlType] = None self._is_primary_cluster_id: bool = False @@ -182,7 +183,7 @@ def _initialize(self): if self._is_initialized: return - self._initial_hosts: Tuple[HostInfo, ...] = (self._topology_utils.initial_host_info,) + self._initial_hosts: Topology = (self._topology_utils.initial_host_info,) self._host_list_provider_service.initial_connection_host_info = self._topology_utils.initial_host_info self._rds_url_type: RdsUrlType = self._rds_utils.identify_rds_type(self._topology_utils.initial_host_info.host) @@ -210,7 +211,10 @@ def _initialize(self): self._is_initialized = True def _get_suggested_cluster_id(self, url: str) -> Optional[ClusterIdSuggestion]: - for key, hosts in RdsHostListProvider._topology_cache.get_dict().items(): + topology_cache = StorageService.get_all(Topology) + if topology_cache is None: + return None + for key, hosts in topology_cache.get_dict().items(): is_primary_cluster_id = \ RdsHostListProvider._is_primary_cluster_id_cache.get_with_default( key, False, self._suggested_cluster_id_refresh_ns) @@ -244,7 +248,7 @@ def _get_topology(self, conn: Optional[Connection], force_update: bool = False) self._cluster_id = suggested_primary_cluster_id self._is_primary_cluster_id = True - cached_hosts = RdsHostListProvider._topology_cache.get(self._cluster_id) + cached_hosts = StorageService.get(Topology, self._cluster_id) if not cached_hosts or force_update: if not conn: # Cannot fetch topology without a connection @@ -255,7 +259,7 @@ def _get_topology(self, conn: Optional[Connection], force_update: bool = False) driver_dialect = self._host_list_provider_service.driver_dialect hosts = self.query_for_topology(conn, driver_dialect) if hosts is not None and len(hosts) > 0: - RdsHostListProvider._topology_cache.put(self._cluster_id, hosts, self._refresh_rate_ns) + StorageService.set(self._cluster_id, hosts, Topology) if self._is_primary_cluster_id and cached_hosts is None: # This cluster_id is primary and a new entry was just created in the cache. When this happens, # we check for non-primary cluster IDs associated with the same cluster so that the topology @@ -270,14 +274,18 @@ def _get_topology(self, conn: Optional[Connection], force_update: bool = False) else: return RdsHostListProvider.FetchTopologyResult(self._initial_hosts, False) - def query_for_topology(self, conn, driver_dialect) -> Optional[Tuple[HostInfo, ...]]: + def query_for_topology(self, conn, driver_dialect) -> Optional[Topology]: return self._topology_utils.query_for_topology(conn, driver_dialect) - def _suggest_cluster_id(self, primary_cluster_id_hosts: Tuple[HostInfo, ...]): + def _suggest_cluster_id(self, primary_cluster_id_hosts: Topology): if not primary_cluster_id_hosts: - return + return None - for cluster_id, hosts in RdsHostListProvider._topology_cache.get_dict().items(): + topology_cache = StorageService.get_all(Topology) + if topology_cache is None: + return None + + for cluster_id, hosts in topology_cache.get_dict().items(): is_primary_cluster = RdsHostListProvider._is_primary_cluster_id_cache.get_with_default( cluster_id, False, self._suggested_cluster_id_refresh_ns) suggested_primary_cluster_id = RdsHostListProvider._cluster_ids_to_update.get(cluster_id) @@ -293,8 +301,9 @@ def _suggest_cluster_id(self, primary_cluster_id_hosts: Tuple[HostInfo, ...]): RdsHostListProvider._cluster_ids_to_update.put( cluster_id, self._cluster_id, self._suggested_cluster_id_refresh_ns) break + return None - def refresh(self, connection: Optional[Connection] = None) -> Tuple[HostInfo, ...]: + def refresh(self, connection: Optional[Connection] = None) -> Topology: """ Get topology information for the database cluster. This method executes a database query if there is no information for the cluster in the cache, or if the cached topology is outdated. @@ -311,7 +320,7 @@ def refresh(self, connection: Optional[Connection] = None) -> Tuple[HostInfo, .. self._hosts = topology.hosts return tuple(self._hosts) - def force_refresh(self, connection: Optional[Connection] = None) -> Tuple[HostInfo, ...]: + def force_refresh(self, connection: Optional[Connection] = None) -> Topology: """ Execute a database query to retrieve information for the current cluster topology. Any cached topology information will be ignored. @@ -327,7 +336,7 @@ def force_refresh(self, connection: Optional[Connection] = None) -> Tuple[HostIn self._hosts = topology.hosts return tuple(self._hosts) - def force_monitoring_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Tuple[HostInfo, ...]: + def force_monitoring_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Topology: raise AwsWrapperError( Messages.get_formatted("HostListProvider.ForceMonitoringRefreshUnsupported", "RdsHostListProvider")) @@ -385,7 +394,7 @@ class ClusterIdSuggestion: @dataclass() class FetchTopologyResult: - hosts: Tuple[HostInfo, ...] + hosts: Topology is_cached_data: bool @@ -394,7 +403,7 @@ class ConnectionStringHostListProvider(StaticHostListProvider): def __init__(self, host_list_provider_service: HostListProviderService, props: Properties): self._host_list_provider_service: HostListProviderService = host_list_provider_service self._props: Properties = props - self._hosts: Tuple[HostInfo, ...] = () + self._hosts: Topology = () self._is_initialized: bool = False self._initial_host_info: Optional[HostInfo] = None @@ -412,15 +421,15 @@ def _initialize(self): self._host_list_provider_service.initial_connection_host_info = self._initial_host_info self._is_initialized = True - def refresh(self, connection: Optional[Connection] = None) -> Tuple[HostInfo, ...]: + def refresh(self, connection: Optional[Connection] = None) -> Topology: self._initialize() return tuple(self._hosts) - def force_refresh(self, connection: Optional[Connection] = None) -> Tuple[HostInfo, ...]: + def force_refresh(self, connection: Optional[Connection] = None) -> Topology: self._initialize() return tuple(self._hosts) - def force_monitoring_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Tuple[HostInfo, ...]: + def force_monitoring_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Topology: raise AwsWrapperError( Messages.get_formatted("HostListProvider.ForceMonitoringRefreshUnsupported", "ConnectionStringHostListProvider")) @@ -499,7 +508,7 @@ def query_for_topology( self, conn: Connection, driver_dialect: DriverDialect, - ) -> Optional[Tuple[HostInfo, ...]]: + ) -> Optional[Topology]: """ Query the database for topology information. @@ -513,7 +522,7 @@ def query_for_topology( return x @abstractmethod - def _query_for_topology(self, conn: Connection) -> Optional[Tuple[HostInfo, ...]]: + def _query_for_topology(self, conn: Connection) -> Optional[Topology]: pass def _create_host(self, record: Tuple) -> HostInfo: @@ -628,7 +637,7 @@ class AuroraTopologyUtils(TopologyUtils): _executor_name: ClassVar[str] = "AuroraTopologyUtils" - def _query_for_topology(self, conn: Connection) -> Optional[Tuple[HostInfo, ...]]: + def _query_for_topology(self, conn: Connection) -> Optional[Topology]: """ Query the database for topology information. @@ -643,7 +652,7 @@ def _query_for_topology(self, conn: Connection) -> Optional[Tuple[HostInfo, ...] except ProgrammingError as e: raise AwsWrapperError(Messages.get("RdsHostListProvider.InvalidQuery"), e) from e - def _process_query_results(self, cursor: Cursor) -> Tuple[HostInfo, ...]: + def _process_query_results(self, cursor: Cursor) -> Topology: """ Form a list of hosts from the results of the topology query. :param cursor: The Cursor object containing a reference to the results of the topology query. @@ -692,7 +701,7 @@ def __init__( self._writer_host_query = writer_host_query self._writer_host_column_index = writer_host_column_index - def _query_for_topology(self, conn: Connection) -> Optional[Tuple[HostInfo, ...]]: + def _query_for_topology(self, conn: Connection) -> Optional[Topology]: try: with closing(conn.cursor()) as cursor: cursor.execute(self._writer_host_query) @@ -709,7 +718,7 @@ def _query_for_topology(self, conn: Connection) -> Optional[Tuple[HostInfo, ...] except ProgrammingError as e: raise AwsWrapperError(Messages.get("RdsHostListProvider.InvalidQuery"), e) from e - def _process_multi_az_query_results(self, cursor: Cursor, writer_id: str) -> Tuple[HostInfo, ...]: + def _process_multi_az_query_results(self, cursor: Cursor, writer_id: str) -> Topology: hosts_dict = {} for record in cursor: host: HostInfo = self._create_multi_az_host(record, writer_id) @@ -789,7 +798,7 @@ def _get_monitor(self) -> Optional[ClusterTopologyMonitor]: self._high_refresh_rate_ns ), MonitoringRdsHostListProvider._MONITOR_CLEANUP_NANO) - def query_for_topology(self, connection: Connection, driver_dialect) -> Optional[Tuple[HostInfo, ...]]: + def query_for_topology(self, connection: Connection, driver_dialect) -> Optional[Topology]: monitor = self._get_monitor() if monitor is None: @@ -800,7 +809,7 @@ def query_for_topology(self, connection: Connection, driver_dialect) -> Optional except TimeoutError: return None - def force_monitoring_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Tuple[HostInfo, ...]: + def force_monitoring_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Topology: monitor = self._get_monitor() if monitor is None: diff --git a/aws_advanced_python_wrapper/utils/cache_map.py b/aws_advanced_python_wrapper/utils/cache_map.py index 5c0f0282..6f7e35fe 100644 --- a/aws_advanced_python_wrapper/utils/cache_map.py +++ b/aws_advanced_python_wrapper/utils/cache_map.py @@ -23,6 +23,8 @@ class CacheMap(Generic[K, V]): + _DEFAULT_EXPIRATION_TIME = 300_000_000_000 # 5 minutes + def __init__(self): self._cache: Dict[K, CacheItem[V]] = {} self._cleanup_interval_ns: int = 600_000_000_000 # 10 minutes @@ -62,7 +64,7 @@ def get_with_default(self, key: K, value_if_absent: V, item_expiration_ns: int) return None - def put(self, key: K, item: V, item_expiration_ns: int): + def put(self, key: K, item: V, item_expiration_ns: int = _DEFAULT_EXPIRATION_TIME): with self._lock: self._cache[key] = CacheItem(item, time.perf_counter_ns() + item_expiration_ns) self._cleanup() diff --git a/aws_advanced_python_wrapper/utils/storage/__init__.py b/aws_advanced_python_wrapper/utils/storage/__init__.py new file mode 100644 index 00000000..bd4acb2b --- /dev/null +++ b/aws_advanced_python_wrapper/utils/storage/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/aws_advanced_python_wrapper/utils/storage/storage_service.py b/aws_advanced_python_wrapper/utils/storage/storage_service.py new file mode 100644 index 00000000..a5543a2c --- /dev/null +++ b/aws_advanced_python_wrapper/utils/storage/storage_service.py @@ -0,0 +1,69 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from types import MappingProxyType +from typing import (TYPE_CHECKING, Any, ClassVar, Optional, Tuple, Type, + TypeAlias, TypeVar) + +if TYPE_CHECKING: + from aws_advanced_python_wrapper.hostinfo import HostInfo + +from aws_advanced_python_wrapper.utils.cache_map import CacheMap + +V = TypeVar('V') +Topology: TypeAlias = Tuple["HostInfo", ...] + + +class StorageService: + _storage_map: ClassVar[MappingProxyType] = MappingProxyType({ + Topology: CacheMap() + }) + + @staticmethod + def get(item_class: Type[V], key: Any) -> Optional[V]: + cache = StorageService._storage_map.get(item_class) + if cache is None: + return None + + value = cache.get(key) + # TODO: publish data access event + return value + + @staticmethod + def get_all(item_class: Type[V]) -> Optional[CacheMap[Any, V]]: + cache = StorageService._storage_map.get(item_class) + return cache + + @staticmethod + def set(key: Any, item: V, item_class: Type[V]) -> None: + cache = StorageService._storage_map.get(item_class) + if cache is not None: + cache.put(key, item) + + @staticmethod + def remove(item_class: Type, key: Any) -> None: + cache = StorageService._storage_map.get(item_class) + if cache is not None: + cache.remove(key) + + @staticmethod + def clear(item_class: Type) -> None: + cache = StorageService._storage_map.get(item_class) + if cache is not None: + cache.clear() + + @staticmethod + def clear_all() -> None: + for cache in StorageService._storage_map.values(): + cache.clear() diff --git a/tests/integration/container/conftest.py b/tests/integration/container/conftest.py index 808ea48a..b7b974e7 100644 --- a/tests/integration/container/conftest.py +++ b/tests/integration/container/conftest.py @@ -37,6 +37,8 @@ from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils from aws_advanced_python_wrapper.utils.sliding_expiration_cache_container import \ SlidingExpirationCacheContainer +from aws_advanced_python_wrapper.utils.storage.storage_service import \ + StorageService if TYPE_CHECKING: from .utils.test_driver import TestDriver @@ -141,7 +143,7 @@ def pytest_runtest_setup(item): assert cluster_ip == writer_ip RdsUtils.clear_cache() - RdsHostListProvider._topology_cache.clear() + StorageService.clear_all() RdsHostListProvider._is_primary_cluster_id_cache.clear() RdsHostListProvider._cluster_ids_to_update.clear() PluginServiceImpl._host_availability_expiring_cache.clear() diff --git a/tests/integration/container/test_read_write_splitting.py b/tests/integration/container/test_read_write_splitting.py index d8cee5d8..a5e86443 100644 --- a/tests/integration/container/test_read_write_splitting.py +++ b/tests/integration/container/test_read_write_splitting.py @@ -29,6 +29,8 @@ from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) +from aws_advanced_python_wrapper.utils.storage.storage_service import \ + StorageService from tests.integration.container.utils.conditions import ( disable_on_engines, disable_on_features, enable_on_deployments, enable_on_features, enable_on_num_instances) @@ -78,7 +80,7 @@ def rds_utils(self): @pytest.fixture(autouse=True) def clear_caches(self): - RdsHostListProvider._topology_cache.clear() + StorageService.clear_all() RdsHostListProvider._is_primary_cluster_id_cache.clear() RdsHostListProvider._cluster_ids_to_update.clear() yield diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 1b4435e2..aa2dafa5 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -20,10 +20,12 @@ from aws_advanced_python_wrapper.exception_handling import ExceptionManager from aws_advanced_python_wrapper.host_list_provider import RdsHostListProvider from aws_advanced_python_wrapper.plugin_service import PluginServiceImpl +from aws_advanced_python_wrapper.utils.storage.storage_service import \ + StorageService def pytest_runtest_setup(item): - RdsHostListProvider._topology_cache.clear() + StorageService.clear_all() RdsHostListProvider._is_primary_cluster_id_cache.clear() RdsHostListProvider._cluster_ids_to_update.clear() PluginServiceImpl._host_availability_expiring_cache.clear() diff --git a/tests/unit/test_multi_az_rds_host_list_provider.py b/tests/unit/test_multi_az_rds_host_list_provider.py index d287f29c..30236881 100644 --- a/tests/unit/test_multi_az_rds_host_list_provider.py +++ b/tests/unit/test_multi_az_rds_host_list_provider.py @@ -27,11 +27,13 @@ from aws_advanced_python_wrapper.pep249 import ProgrammingError from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) +from aws_advanced_python_wrapper.utils.storage.storage_service import ( + StorageService, Topology) @pytest.fixture(autouse=True) def clear_caches(): - RdsHostListProvider._topology_cache.clear() + StorageService.clear_all() RdsHostListProvider._is_primary_cluster_id_cache.clear() RdsHostListProvider._cluster_ids_to_update.clear() @@ -100,7 +102,7 @@ def create_provider(mock_provider_service, props): def test_get_topology_caches_topology(mocker, mock_provider_service, mock_conn, props, cache_hosts, refresh_ns): provider = create_provider(mock_provider_service, props) - RdsHostListProvider._topology_cache.put(provider._cluster_id, cache_hosts, refresh_ns) + StorageService.set(provider._cluster_id, cache_hosts, Topology) spy = mocker.spy(provider._topology_utils, "_query_for_topology") result = provider.refresh(mock_conn) @@ -112,7 +114,7 @@ def test_get_topology_caches_topology(mocker, mock_provider_service, mock_conn, def test_get_topology_force_update( mocker, mock_provider_service, mock_conn, cache_hosts, queried_hosts, props, refresh_ns): provider = create_provider(mock_provider_service, props) - RdsHostListProvider._topology_cache.put(provider._cluster_id, cache_hosts, refresh_ns) + StorageService.set(provider._cluster_id, cache_hosts, Topology) spy = mocker.spy(provider._topology_utils, "_query_for_topology") result = provider.force_refresh(mock_conn) @@ -135,7 +137,7 @@ def test_get_topology_timeout(mocker, mock_cursor, mock_provider_service, initia def test_get_topology_invalid_topology( mocker, mock_provider_service, mock_conn, mock_cursor, props, cache_hosts, refresh_ns): provider = create_provider(mock_provider_service, props) - RdsHostListProvider._topology_cache.put(provider._cluster_id, cache_hosts, refresh_ns) + StorageService.set(provider._cluster_id, cache_hosts, Topology) spy = mocker.spy(provider._topology_utils, "_query_for_topology") mock_topology_query( mock_conn, @@ -187,7 +189,7 @@ def test_no_cluster_id_suggestion_for_separate_clusters(mock_provider_service, m actual_hosts_b = provider_b.refresh() assert expected_hosts_b == actual_hosts_b - assert 2 == len(RdsHostListProvider._topology_cache) + assert 2 == len(StorageService.get_all(Topology)) def test_cluster_id_suggestion_for_new_provider_with_cluster_url(mocker, mock_provider_service, mock_conn, mock_cursor): @@ -209,7 +211,7 @@ def test_cluster_id_suggestion_for_new_provider_with_cluster_url(mocker, mock_pr actual_hosts = provider2.refresh() assert expected_hosts == actual_hosts - assert 1 == len(RdsHostListProvider._topology_cache) + assert 1 == len(StorageService.get_all(Topology)) spy.assert_not_called() @@ -234,7 +236,7 @@ def test_cluster_id_suggestion_for_new_provider_with_instance_url( actual_hosts = provider2.refresh() assert expected_hosts == actual_hosts - assert 1 == len(RdsHostListProvider._topology_cache) + assert 1 == len(StorageService.get_all(Topology)) spy.assert_not_called() @@ -260,7 +262,7 @@ def test_cluster_id_suggestion_for_existing_provider(mocker, mock_provider_servi assert provider2._cluster_id != provider1._cluster_id assert provider2._is_primary_cluster_id assert not provider1._is_primary_cluster_id - assert 1 == len(RdsHostListProvider._topology_cache) + assert 1 == len(StorageService.get_all(Topology)) provider2.refresh() assert "my-cluster.cluster-xyz.us-east-2.rds.amazonaws.com:5432" == \ @@ -268,7 +270,7 @@ def test_cluster_id_suggestion_for_existing_provider(mocker, mock_provider_servi spy = mocker.spy(provider1._topology_utils, "_query_for_topology") actual_hosts = provider1.refresh() - assert 2 == len(RdsHostListProvider._topology_cache) + assert 2 == len(StorageService.get_all(Topology)) assert list(expected_hosts).sort(key=lambda h: h.host) == list(actual_hosts).sort(key=lambda h: h.host) assert provider2._cluster_id == provider1._cluster_id assert provider2._is_primary_cluster_id diff --git a/tests/unit/test_rds_host_list_provider.py b/tests/unit/test_rds_host_list_provider.py index 101d5530..85a1ffa8 100644 --- a/tests/unit/test_rds_host_list_provider.py +++ b/tests/unit/test_rds_host_list_provider.py @@ -27,11 +27,13 @@ from aws_advanced_python_wrapper.pep249 import ProgrammingError from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) +from aws_advanced_python_wrapper.utils.storage.storage_service import ( + StorageService, Topology) @pytest.fixture(autouse=True) def clear_caches(): - RdsHostListProvider._topology_cache.clear() + StorageService.clear_all() RdsHostListProvider._is_primary_cluster_id_cache.clear() RdsHostListProvider._cluster_ids_to_update.clear() @@ -94,7 +96,7 @@ def refresh_ns(): def test_get_topology_caches_topology(mocker, mock_provider_service, mock_conn, props, cache_hosts, refresh_ns): topology_utils = AuroraTopologyUtils(AuroraPgDialect(), props) provider = RdsHostListProvider(mock_provider_service, props, topology_utils) - RdsHostListProvider._topology_cache.put(provider._cluster_id, tuple(cache_hosts), refresh_ns) + StorageService.set(provider._cluster_id, cache_hosts, Topology) spy = mocker.spy(topology_utils, "_query_for_topology") result = provider.refresh(mock_conn) @@ -107,7 +109,7 @@ def test_get_topology_force_update( mocker, mock_provider_service, mock_conn, cache_hosts, queried_hosts, props, refresh_ns): topology_utils = AuroraTopologyUtils(AuroraPgDialect(), props) provider = RdsHostListProvider(mock_provider_service, props, topology_utils) - RdsHostListProvider._topology_cache.put(provider._cluster_id, cache_hosts, refresh_ns) + StorageService.set(provider._cluster_id, cache_hosts, Topology) spy = mocker.spy(topology_utils, "_query_for_topology") result = provider.force_refresh(mock_conn) @@ -132,7 +134,7 @@ def test_get_topology_invalid_topology( mocker, mock_provider_service, mock_conn, mock_cursor, props, cache_hosts, refresh_ns): topology_utils = AuroraTopologyUtils(AuroraPgDialect(), props) provider = RdsHostListProvider(mock_provider_service, props, topology_utils) - RdsHostListProvider._topology_cache.put(provider._cluster_id, cache_hosts, refresh_ns) + StorageService.set(provider._cluster_id, cache_hosts, Topology) spy = mocker.spy(topology_utils, "_query_for_topology") mock_topology_query(mock_conn, mock_cursor, [("reader", False)]) # Invalid topology: no writer instance @@ -199,7 +201,7 @@ def test_no_cluster_id_suggestion_for_separate_clusters(mock_provider_service, m actual_hosts_b = provider_b.refresh() assert expected_hosts_b == actual_hosts_b - assert 2 == len(RdsHostListProvider._topology_cache) + assert 2 == len(StorageService.get_all(Topology)) def test_cluster_id_suggestion_for_new_provider_with_cluster_url(mocker, mock_provider_service, mock_conn, mock_cursor): @@ -223,7 +225,7 @@ def test_cluster_id_suggestion_for_new_provider_with_cluster_url(mocker, mock_pr actual_hosts = provider2.refresh() assert expected_hosts == actual_hosts - assert 1 == len(RdsHostListProvider._topology_cache) + assert 1 == len(StorageService.get_all(Topology)) spy.assert_not_called() @@ -250,7 +252,7 @@ def test_cluster_id_suggestion_for_new_provider_with_instance_url( actual_hosts = provider2.refresh() assert expected_hosts == actual_hosts - assert 1 == len(RdsHostListProvider._topology_cache) + assert 1 == len(StorageService.get_all(Topology)) spy.assert_not_called() @@ -278,7 +280,7 @@ def test_cluster_id_suggestion_for_existing_provider(mocker, mock_provider_servi assert provider2._cluster_id != provider1._cluster_id assert provider2._is_primary_cluster_id assert not provider1._is_primary_cluster_id - assert 1 == len(RdsHostListProvider._topology_cache) + assert 1 == len(StorageService.get_all(Topology)) provider2.refresh() assert "my-cluster.cluster-xyz.us-east-2.rds.amazonaws.com" == \ @@ -286,7 +288,7 @@ def test_cluster_id_suggestion_for_existing_provider(mocker, mock_provider_servi spy = mocker.spy(provider1._topology_utils, "_query_for_topology") actual_hosts = provider1.refresh() - assert 2 == len(RdsHostListProvider._topology_cache) + assert 2 == len(StorageService.get_all(Topology)) assert list(expected_hosts).sort(key=lambda h: h.host) == list(actual_hosts).sort(key=lambda h: h.host) assert provider2._cluster_id == provider1._cluster_id assert provider2._is_primary_cluster_id