Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 17 additions & 20 deletions aws_advanced_python_wrapper/cluster_topology_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
import time
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, Dict, Optional, Tuple
from typing import TYPE_CHECKING, Dict, Optional

from aws_advanced_python_wrapper.host_availability import HostAvailability
from aws_advanced_python_wrapper.hostinfo import HostInfo
from aws_advanced_python_wrapper.utils.atomic import AtomicReference
from aws_advanced_python_wrapper.utils.cache_map import CacheMap
from aws_advanced_python_wrapper.utils.messages import Messages
from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils
from aws_advanced_python_wrapper.utils.storage.storage_service import (
StorageService, Topology)
from aws_advanced_python_wrapper.utils.thread_safe_connection_holder import \
ThreadSafeConnectionHolder
from aws_advanced_python_wrapper.utils.utils import LogUtils
Expand All @@ -46,11 +47,11 @@

class ClusterTopologyMonitor(ABC):
@abstractmethod
def force_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Tuple[HostInfo, ...]:
def force_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Topology:
pass

@abstractmethod
def force_refresh_with_connection(self, connection: Connection, timeout_sec: int) -> Tuple[HostInfo, ...]:
def force_refresh_with_connection(self, connection: Connection, timeout_sec: int) -> Topology:
pass

@abstractmethod
Expand All @@ -75,8 +76,6 @@ class ClusterTopologyMonitorImpl(ClusterTopologyMonitor):
INITIAL_BACKOFF_MS = 100
MAX_BACKOFF_MS = 10000

_topology_map: CacheMap[str, Tuple[HostInfo, ...]] = CacheMap()

def __init__(self, plugin_service: PluginService, topology_utils: TopologyUtils, cluster_id: str,
initial_host_info: HostInfo, properties: Properties, instance_template: HostInfo,
refresh_rate_nano: int, high_refresh_rate_nano: int):
Expand All @@ -103,7 +102,7 @@ def __init__(self, plugin_service: PluginService, topology_utils: TopologyUtils,
self._host_threads_writer_connection: AtomicReference[Optional[Connection]] = AtomicReference(None)
self._host_threads_writer_host_info: AtomicReference[Optional[HostInfo]] = AtomicReference(None)
self._host_threads_reader_connection: AtomicReference[Optional[Connection]] = AtomicReference(None)
self._host_threads_latest_topology: AtomicReference[Optional[Tuple[HostInfo, ...]]] = AtomicReference(None)
self._host_threads_latest_topology: AtomicReference[Optional[Topology]] = AtomicReference(None)

self._is_verified_writer_connection = False
self._high_refresh_rate_end_time_nano = 0
Expand All @@ -118,7 +117,7 @@ def __init__(self, plugin_service: PluginService, topology_utils: TopologyUtils,

self._start_monitoring()

def force_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Tuple[HostInfo, ...]:
def force_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Topology:
current_time_nano = time.time_ns()
if (self._ignore_new_topology_requests_end_time_nano > 0 and
current_time_nano < self._ignore_new_topology_requests_end_time_nano):
Expand All @@ -134,12 +133,12 @@ def force_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Tuple[H
result = self._wait_till_topology_gets_updated(timeout_sec)
return result

def force_refresh_with_connection(self, connection: Connection, timeout_sec: int) -> Tuple[HostInfo, ...]:
def force_refresh_with_connection(self, connection: Connection, timeout_sec: int) -> Topology:
if self._is_verified_writer_connection:
return self._wait_till_topology_gets_updated(timeout_sec)
return self._fetch_topology_and_update_cache(connection)

def _wait_till_topology_gets_updated(self, timeout_sec: int) -> Tuple[HostInfo, ...]:
def _wait_till_topology_gets_updated(self, timeout_sec: int) -> Topology:
current_hosts = self._get_stored_hosts()

self._request_to_update_topology.set()
Expand All @@ -162,8 +161,8 @@ def _wait_till_topology_gets_updated(self, timeout_sec: int) -> Tuple[HostInfo,
"ClusterTopologyMonitorImpl.TopologyNotUpdated",
self._cluster_id, timeout_sec * 1000))

def _get_stored_hosts(self) -> Tuple[HostInfo, ...]:
hosts = ClusterTopologyMonitorImpl._topology_map.get(self._cluster_id)
def _get_stored_hosts(self) -> Topology:
hosts = StorageService.get(Topology, self._cluster_id)
if hosts is None:
return ()
return hosts
Expand Down Expand Up @@ -296,7 +295,7 @@ def _is_in_panic_mode(self) -> bool:
def _get_host_monitor(self, host_info: HostInfo, writer_host_info: Optional[HostInfo]):
return HostMonitor(self, host_info, writer_host_info)

def _open_any_connection_and_update_topology(self) -> Tuple[HostInfo, ...]:
def _open_any_connection_and_update_topology(self) -> Topology:
writer_verified_by_this_thread = False
if self._monitoring_connection.get() is None:
# Try to connect to the initial host first
Expand Down Expand Up @@ -409,7 +408,7 @@ def _delay(self, use_high_refresh_rate: bool) -> None:
while not self._request_to_update_topology.is_set() and time.time() < end_time and not self._stop.is_set():
time.sleep(0.05)

def _fetch_topology_and_update_cache(self, connection: Optional[Connection]) -> Tuple[HostInfo, ...]:
def _fetch_topology_and_update_cache(self, connection: Optional[Connection]) -> Topology:
if connection is None:
return ()

Expand All @@ -423,7 +422,7 @@ def _fetch_topology_and_update_cache(self, connection: Optional[Connection]) ->
logger.debug("ClusterTopologyMonitorImpl.ErrorFetchingTopology", self._cluster_id, ex)
return ()

def _fetch_topology_and_update_cache_safe(self) -> Tuple[HostInfo, ...]:
def _fetch_topology_and_update_cache_safe(self) -> Topology:
"""
Safely fetch topology using ThreadSafeConnectionHolder to prevent race conditions.
The lock is held during the entire query operation.
Expand All @@ -433,16 +432,14 @@ def _fetch_topology_and_update_cache_safe(self) -> Tuple[HostInfo, ...]:
)
return result if result is not None else ()

def _query_for_topology(self, connection: Connection) -> Tuple[HostInfo, ...]:
def _query_for_topology(self, connection: Connection) -> Topology:
hosts = self._topology_utils.query_for_topology(connection, self._plugin_service.driver_dialect)
if hosts is not None:
return hosts
return ()

def _update_topology_cache(self, hosts: Tuple[HostInfo, ...]) -> None:
ClusterTopologyMonitorImpl._topology_map.put(
self._cluster_id, hosts, ClusterTopologyMonitorImpl.TOPOLOGY_CACHE_EXPIRATION_NANO)

def _update_topology_cache(self, hosts: Topology) -> None:
StorageService.set(self._cluster_id, hosts, Topology)
# Notify waiting threads
self._request_to_update_topology.clear()
self._topology_updated.set()
Expand Down
69 changes: 39 additions & 30 deletions aws_advanced_python_wrapper/host_list_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
preserve_transaction_status_with_timeout
from aws_advanced_python_wrapper.utils.sliding_expiration_cache_container import \
SlidingExpirationCacheContainer
from aws_advanced_python_wrapper.utils.storage.storage_service import (
StorageService, Topology)

if TYPE_CHECKING:
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
Expand Down Expand Up @@ -59,13 +61,13 @@


class HostListProvider(Protocol):
def refresh(self, connection: Optional[Connection] = None) -> Tuple[HostInfo, ...]:
def refresh(self, connection: Optional[Connection] = None) -> Topology:
...

def force_refresh(self, connection: Optional[Connection] = None) -> Tuple[HostInfo, ...]:
def force_refresh(self, connection: Optional[Connection] = None) -> Topology:
...

def force_monitoring_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Tuple[HostInfo, ...]:
def force_monitoring_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Topology:
...

def get_host_role(self, connection: Connection) -> HostRole:
Expand Down Expand Up @@ -149,7 +151,6 @@ def is_static_host_list_provider(self) -> bool:


class RdsHostListProvider(DynamicHostListProvider, HostListProvider):
_topology_cache: CacheMap[str, Tuple[HostInfo, ...]] = CacheMap()
# Maps cluster IDs to a boolean representing whether they are a primary cluster ID or not. A primary cluster ID is a
# cluster ID that is equivalent to a cluster URL. Topology info is shared between RdsHostListProviders that have
# the same cluster ID.
Expand All @@ -164,9 +165,9 @@ def __init__(self, host_list_provider_service: HostListProviderService, props: P
self._topology_utils = topology_utils

self._rds_utils: RdsUtils = RdsUtils()
self._hosts: Tuple[HostInfo, ...] = ()
self._hosts: Topology = ()
self._cluster_id: str = str(uuid.uuid4())
self._initial_hosts: Tuple[HostInfo, ...] = ()
self._initial_hosts: Topology = ()
self._rds_url_type: Optional[RdsUrlType] = None

self._is_primary_cluster_id: bool = False
Expand All @@ -182,7 +183,7 @@ def _initialize(self):
if self._is_initialized:
return

self._initial_hosts: Tuple[HostInfo, ...] = (self._topology_utils.initial_host_info,)
self._initial_hosts: Topology = (self._topology_utils.initial_host_info,)
self._host_list_provider_service.initial_connection_host_info = self._topology_utils.initial_host_info

self._rds_url_type: RdsUrlType = self._rds_utils.identify_rds_type(self._topology_utils.initial_host_info.host)
Expand Down Expand Up @@ -210,7 +211,10 @@ def _initialize(self):
self._is_initialized = True

def _get_suggested_cluster_id(self, url: str) -> Optional[ClusterIdSuggestion]:
for key, hosts in RdsHostListProvider._topology_cache.get_dict().items():
topology_cache = StorageService.get_all(Topology)
if topology_cache is None:
return None
for key, hosts in topology_cache.get_dict().items():
is_primary_cluster_id = \
RdsHostListProvider._is_primary_cluster_id_cache.get_with_default(
key, False, self._suggested_cluster_id_refresh_ns)
Expand Down Expand Up @@ -244,7 +248,7 @@ def _get_topology(self, conn: Optional[Connection], force_update: bool = False)
self._cluster_id = suggested_primary_cluster_id
self._is_primary_cluster_id = True

cached_hosts = RdsHostListProvider._topology_cache.get(self._cluster_id)
cached_hosts = StorageService.get(Topology, self._cluster_id)
if not cached_hosts or force_update:
if not conn:
# Cannot fetch topology without a connection
Expand All @@ -255,7 +259,7 @@ def _get_topology(self, conn: Optional[Connection], force_update: bool = False)
driver_dialect = self._host_list_provider_service.driver_dialect
hosts = self.query_for_topology(conn, driver_dialect)
if hosts is not None and len(hosts) > 0:
RdsHostListProvider._topology_cache.put(self._cluster_id, hosts, self._refresh_rate_ns)
StorageService.set(self._cluster_id, hosts, Topology)
if self._is_primary_cluster_id and cached_hosts is None:
# This cluster_id is primary and a new entry was just created in the cache. When this happens,
# we check for non-primary cluster IDs associated with the same cluster so that the topology
Expand All @@ -270,14 +274,18 @@ def _get_topology(self, conn: Optional[Connection], force_update: bool = False)
else:
return RdsHostListProvider.FetchTopologyResult(self._initial_hosts, False)

def query_for_topology(self, conn, driver_dialect) -> Optional[Tuple[HostInfo, ...]]:
def query_for_topology(self, conn, driver_dialect) -> Optional[Topology]:
return self._topology_utils.query_for_topology(conn, driver_dialect)

def _suggest_cluster_id(self, primary_cluster_id_hosts: Tuple[HostInfo, ...]):
def _suggest_cluster_id(self, primary_cluster_id_hosts: Topology):
if not primary_cluster_id_hosts:
return
return None

for cluster_id, hosts in RdsHostListProvider._topology_cache.get_dict().items():
topology_cache = StorageService.get_all(Topology)
if topology_cache is None:
return None

for cluster_id, hosts in topology_cache.get_dict().items():
is_primary_cluster = RdsHostListProvider._is_primary_cluster_id_cache.get_with_default(
cluster_id, False, self._suggested_cluster_id_refresh_ns)
suggested_primary_cluster_id = RdsHostListProvider._cluster_ids_to_update.get(cluster_id)
Expand All @@ -293,8 +301,9 @@ def _suggest_cluster_id(self, primary_cluster_id_hosts: Tuple[HostInfo, ...]):
RdsHostListProvider._cluster_ids_to_update.put(
cluster_id, self._cluster_id, self._suggested_cluster_id_refresh_ns)
break
return None

def refresh(self, connection: Optional[Connection] = None) -> Tuple[HostInfo, ...]:
def refresh(self, connection: Optional[Connection] = None) -> Topology:
"""
Get topology information for the database cluster.
This method executes a database query if there is no information for the cluster in the cache, or if the cached topology is outdated.
Expand All @@ -311,7 +320,7 @@ def refresh(self, connection: Optional[Connection] = None) -> Tuple[HostInfo, ..
self._hosts = topology.hosts
return tuple(self._hosts)

def force_refresh(self, connection: Optional[Connection] = None) -> Tuple[HostInfo, ...]:
def force_refresh(self, connection: Optional[Connection] = None) -> Topology:
"""
Execute a database query to retrieve information for the current cluster topology. Any cached topology information will be ignored.

Expand All @@ -327,7 +336,7 @@ def force_refresh(self, connection: Optional[Connection] = None) -> Tuple[HostIn
self._hosts = topology.hosts
return tuple(self._hosts)

def force_monitoring_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Tuple[HostInfo, ...]:
def force_monitoring_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Topology:
raise AwsWrapperError(
Messages.get_formatted("HostListProvider.ForceMonitoringRefreshUnsupported", "RdsHostListProvider"))

Expand Down Expand Up @@ -385,7 +394,7 @@ class ClusterIdSuggestion:

@dataclass()
class FetchTopologyResult:
hosts: Tuple[HostInfo, ...]
hosts: Topology
is_cached_data: bool


Expand All @@ -394,7 +403,7 @@ class ConnectionStringHostListProvider(StaticHostListProvider):
def __init__(self, host_list_provider_service: HostListProviderService, props: Properties):
self._host_list_provider_service: HostListProviderService = host_list_provider_service
self._props: Properties = props
self._hosts: Tuple[HostInfo, ...] = ()
self._hosts: Topology = ()
self._is_initialized: bool = False
self._initial_host_info: Optional[HostInfo] = None

Expand All @@ -412,15 +421,15 @@ def _initialize(self):
self._host_list_provider_service.initial_connection_host_info = self._initial_host_info
self._is_initialized = True

def refresh(self, connection: Optional[Connection] = None) -> Tuple[HostInfo, ...]:
def refresh(self, connection: Optional[Connection] = None) -> Topology:
self._initialize()
return tuple(self._hosts)

def force_refresh(self, connection: Optional[Connection] = None) -> Tuple[HostInfo, ...]:
def force_refresh(self, connection: Optional[Connection] = None) -> Topology:
self._initialize()
return tuple(self._hosts)

def force_monitoring_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Tuple[HostInfo, ...]:
def force_monitoring_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Topology:
raise AwsWrapperError(
Messages.get_formatted("HostListProvider.ForceMonitoringRefreshUnsupported", "ConnectionStringHostListProvider"))

Expand Down Expand Up @@ -499,7 +508,7 @@ def query_for_topology(
self,
conn: Connection,
driver_dialect: DriverDialect,
) -> Optional[Tuple[HostInfo, ...]]:
) -> Optional[Topology]:
"""
Query the database for topology information.

Expand All @@ -513,7 +522,7 @@ def query_for_topology(
return x

@abstractmethod
def _query_for_topology(self, conn: Connection) -> Optional[Tuple[HostInfo, ...]]:
def _query_for_topology(self, conn: Connection) -> Optional[Topology]:
pass

def _create_host(self, record: Tuple) -> HostInfo:
Expand Down Expand Up @@ -628,7 +637,7 @@ class AuroraTopologyUtils(TopologyUtils):

_executor_name: ClassVar[str] = "AuroraTopologyUtils"

def _query_for_topology(self, conn: Connection) -> Optional[Tuple[HostInfo, ...]]:
def _query_for_topology(self, conn: Connection) -> Optional[Topology]:
"""
Query the database for topology information.

Expand All @@ -643,7 +652,7 @@ def _query_for_topology(self, conn: Connection) -> Optional[Tuple[HostInfo, ...]
except ProgrammingError as e:
raise AwsWrapperError(Messages.get("RdsHostListProvider.InvalidQuery"), e) from e

def _process_query_results(self, cursor: Cursor) -> Tuple[HostInfo, ...]:
def _process_query_results(self, cursor: Cursor) -> Topology:
"""
Form a list of hosts from the results of the topology query.
:param cursor: The Cursor object containing a reference to the results of the topology query.
Expand Down Expand Up @@ -692,7 +701,7 @@ def __init__(
self._writer_host_query = writer_host_query
self._writer_host_column_index = writer_host_column_index

def _query_for_topology(self, conn: Connection) -> Optional[Tuple[HostInfo, ...]]:
def _query_for_topology(self, conn: Connection) -> Optional[Topology]:
try:
with closing(conn.cursor()) as cursor:
cursor.execute(self._writer_host_query)
Expand All @@ -709,7 +718,7 @@ def _query_for_topology(self, conn: Connection) -> Optional[Tuple[HostInfo, ...]
except ProgrammingError as e:
raise AwsWrapperError(Messages.get("RdsHostListProvider.InvalidQuery"), e) from e

def _process_multi_az_query_results(self, cursor: Cursor, writer_id: str) -> Tuple[HostInfo, ...]:
def _process_multi_az_query_results(self, cursor: Cursor, writer_id: str) -> Topology:
hosts_dict = {}
for record in cursor:
host: HostInfo = self._create_multi_az_host(record, writer_id)
Expand Down Expand Up @@ -789,7 +798,7 @@ def _get_monitor(self) -> Optional[ClusterTopologyMonitor]:
self._high_refresh_rate_ns
), MonitoringRdsHostListProvider._MONITOR_CLEANUP_NANO)

def query_for_topology(self, connection: Connection, driver_dialect) -> Optional[Tuple[HostInfo, ...]]:
def query_for_topology(self, connection: Connection, driver_dialect) -> Optional[Topology]:
monitor = self._get_monitor()

if monitor is None:
Expand All @@ -800,7 +809,7 @@ def query_for_topology(self, connection: Connection, driver_dialect) -> Optional
except TimeoutError:
return None

def force_monitoring_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Tuple[HostInfo, ...]:
def force_monitoring_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Topology:
monitor = self._get_monitor()

if monitor is None:
Expand Down
Loading