Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions redis/_parsers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,14 @@ def parse_oss_maintenance_start_msg(response):
@staticmethod
def parse_oss_maintenance_completed_msg(response):
# Expected message format is:
# SMIGRATED <seq_number> <host:port> <slot, range1-range2,...>
# SMIGRATED <seq_number> [<host:port> <slot, range1-range2,...>, ...]
id = response[1]
node_address = safe_str(response[2])
slots = response[3]
nodes_to_slots_mapping_data = response[2]
nodes_to_slots_mapping = {}
for node, slots in nodes_to_slots_mapping_data:
nodes_to_slots_mapping[safe_str(node)] = safe_str(slots)

return OSSNodeMigratedNotification(id, node_address, slots)
return OSSNodeMigratedNotification(id, nodes_to_slots_mapping)

@staticmethod
def parse_maintenance_start_msg(response, notification_type):
Expand Down
18 changes: 18 additions & 0 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,8 @@ def __init__(
oss_cluster_maint_notifications_handler,
parser,
)
self._processed_start_maint_notifications = set()
self._skipped_end_maint_notifications = set()

@abstractmethod
def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser]:
Expand Down Expand Up @@ -667,6 +669,22 @@ def maintenance_state(self) -> MaintenanceState:
def maintenance_state(self, state: "MaintenanceState"):
self._maintenance_state = state

def add_maint_start_notification(self, id: int):
self._processed_start_maint_notifications.add(id)

def get_processed_start_notifications(self) -> set:
return self._processed_start_maint_notifications

def add_skipped_end_notification(self, id: int):
self._skipped_end_maint_notifications.add(id)

def get_skipped_end_notifications(self) -> set:
return self._skipped_end_maint_notifications

def reset_received_notifications(self):
self._processed_start_maint_notifications.clear()
self._skipped_end_maint_notifications.clear()

def getpeername(self):
"""
Returns the peer name of the connection.
Expand Down
104 changes: 62 additions & 42 deletions redis/maint_notifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import threading
import time
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Literal, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union

from redis.typing import Number

Expand Down Expand Up @@ -463,31 +463,26 @@ class OSSNodeMigratedNotification(MaintenanceNotification):

Args:
id (int): Unique identifier for this notification
node_address (Optional[str]): Address of the node that has completed migration
in the format "host:port"
slots (Optional[List[int]]): List of slots that have been migrated
nodes_to_slots_mapping (Dict[str, str]): Mapping of node addresses to slots
"""

DEFAULT_TTL = 30

def __init__(
self,
id: int,
node_address: str,
slots: Optional[List[int]] = None,
nodes_to_slots_mapping: Dict[str, str],
):
super().__init__(id, OSSNodeMigratedNotification.DEFAULT_TTL)
self.node_address = node_address
self.slots = slots
self.nodes_to_slots_mapping = nodes_to_slots_mapping

def __repr__(self) -> str:
expiry_time = self.creation_time + self.ttl
remaining = max(0, expiry_time - time.monotonic())
return (
f"{self.__class__.__name__}("
f"id={self.id}, "
f"node_address={self.node_address}, "
f"slots={self.slots}, "
f"nodes_to_slots_mapping={self.nodes_to_slots_mapping}, "
f"ttl={self.ttl}, "
f"creation_time={self.creation_time}, "
f"expires_at={expiry_time}, "
Expand Down Expand Up @@ -899,12 +894,14 @@ def handle_notification(self, notification: MaintenanceNotification):
return

if notification_type:
self.handle_maintenance_start_notification(MaintenanceState.MAINTENANCE)
self.handle_maintenance_start_notification(
MaintenanceState.MAINTENANCE, notification
)
else:
self.handle_maintenance_completed_notification()

def handle_maintenance_start_notification(
self, maintenance_state: MaintenanceState
self, maintenance_state: MaintenanceState, notification: MaintenanceNotification
):
if (
self.connection.maintenance_state == MaintenanceState.MOVING
Expand All @@ -918,6 +915,11 @@ def handle_maintenance_start_notification(
)
# extend the timeout for all created connections
self.connection.update_current_socket_timeout(self.config.relaxed_timeout)
if isinstance(notification, OSSNodeMigratingNotification):
# add the notification id to the set of processed start maint notifications
# this is used to skip the unrelaxing of the timeouts if we have received more than
# one start notification before the the final end notification
self.connection.add_maint_start_notification(notification.id)

def handle_maintenance_completed_notification(self):
# Only reset timeouts if state is not MOVING and relaxed timeouts are enabled
Expand All @@ -931,6 +933,9 @@ def handle_maintenance_completed_notification(self):
# timeouts by providing -1 as the relaxed timeout
self.connection.update_current_socket_timeout(-1)
self.connection.maintenance_state = MaintenanceState.NONE
# reset the sets that keep track of received start maint
# notifications and skipped end maint notifications
self.connection.reset_received_notifications()


class OSSMaintNotificationsHandler:
Expand Down Expand Up @@ -999,40 +1004,55 @@ def handle_oss_maintenance_completed_notification(

# Updates the cluster slots cache with the new slots mapping
# This will also update the nodes cache with the new nodes mapping
new_node_host, new_node_port = notification.node_address.split(":")
additional_startup_nodes_info = []
for node_address, _ in notification.nodes_to_slots_mapping.items():
new_node_host, new_node_port = node_address.split(":")
additional_startup_nodes_info.append(
(new_node_host, int(new_node_port))
)
self.cluster_client.nodes_manager.initialize(
disconnect_startup_nodes_pools=False,
additional_startup_nodes_info=[(new_node_host, int(new_node_port))],
additional_startup_nodes_info=additional_startup_nodes_info,
)
# mark for reconnect all in use connections to the node - this will force them to
# disconnect after they complete their current commands
# Some of them might be used by sub sub and we don't know which ones - so we disconnect
# all in flight connections after they are done with current command execution
for conn in (
current_node.redis_connection.connection_pool._get_in_use_connections()
):
conn.mark_for_reconnect()
with current_node.redis_connection.connection_pool._lock:
# mark for reconnect all in use connections to the node - this will force them to
# disconnect after they complete their current commands
# Some of them might be used by sub sub and we don't know which ones - so we disconnect
# all in flight connections after they are done with current command execution
for conn in current_node.redis_connection.connection_pool._get_in_use_connections():
conn.mark_for_reconnect()

if (
current_node
not in self.cluster_client.nodes_manager.nodes_cache.values()
):
# disconnect all free connections to the node - this node will be dropped
# from the cluster, so we don't need to revert the timeouts
for conn in current_node.redis_connection.connection_pool._get_free_connections():
conn.disconnect()
else:
if self.config.is_relaxed_timeouts_enabled():
# reset the timeouts for the node to which the connection is connected
# TODO: add check if other maintenance ops are in progress for the same node - CAE-1038
# and if so, don't reset the timeouts
for conn in (
*current_node.redis_connection.connection_pool._get_in_use_connections(),
*current_node.redis_connection.connection_pool._get_free_connections(),
):
conn.reset_tmp_settings(reset_relaxed_timeout=True)
conn.update_current_socket_timeout(relaxed_timeout=-1)
conn.maintenance_state = MaintenanceState.NONE
if (
current_node
not in self.cluster_client.nodes_manager.nodes_cache.values()
):
# disconnect all free connections to the node - this node will be dropped
# from the cluster, so we don't need to revert the timeouts
for conn in current_node.redis_connection.connection_pool._get_free_connections():
conn.disconnect()
else:
if self.config.is_relaxed_timeouts_enabled():
# reset the timeouts for the node to which the connection is connected
# Perform check if other maintenance ops are in progress for the same node
# and if so, don't reset the timeouts and wait for the last maintenance
# to complete
for conn in (
*current_node.redis_connection.connection_pool._get_in_use_connections(),
*current_node.redis_connection.connection_pool._get_free_connections(),
):
if (
len(conn.get_processed_start_notifications())
> len(conn.get_skipped_end_notifications()) + 1
):
# we have received more start notifications than end notifications
# for this connection - we should not reset the timeouts
# and add the notification id to the set of skipped end notifications
conn.add_skipped_end_notification(notification.id)
else:
conn.reset_tmp_settings(reset_relaxed_timeout=True)
conn.update_current_socket_timeout(relaxed_timeout=-1)
conn.maintenance_state = MaintenanceState.NONE
conn.reset_received_notifications()

# mark the notification as processed
self._processed_notifications.add(notification)
Expand Down
70 changes: 42 additions & 28 deletions tests/maint_notifications/proxy_server_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,37 +11,51 @@ class RespTranslator:
"""Helper class to translate between RESP and other encodings."""

@staticmethod
def str_or_list_to_resp(txt: str) -> str:
"""
Convert specific string or list to RESP format.
"""
if re.match(r"^<.*>$", txt):
items = txt[1:-1].split(",")
return f"*{len(items)}\r\n" + "\r\n".join(
f"${len(x)}\r\n{x}" for x in items
def oss_maint_notification_to_resp(txt: str) -> str:
"""Convert query to RESP format."""
if txt.startswith("SMIGRATED"):
# Format: SMIGRATED SeqID host:port slot1,range1-range2 host1:port1 slot2,range3-range4
# SMIGRATED 93923 abc.com:6789 123,789-1000 abc.com:4545 1000-2000 abc.com:4323 900,910,920
# SMIGRATED - simple string
# SeqID - integer
# host and slots info are provided as array of arrays
# host:port - simple string
# slots - simple string

parts = txt.split()
notification = parts[0]
seq_id = parts[1]
hosts_and_slots = parts[2:]
resp = (
">3\r\n" # Push message with 3 elements
f"+{notification}\r\n" # Element 1: Command
f":{seq_id}\r\n" # Element 2: SeqID
f"*{len(hosts_and_slots) // 2}\r\n" # Element 3: Array of host:port, slots pairs
)
for i in range(0, len(hosts_and_slots), 2):
resp += "*2\r\n"
resp += f"+{hosts_and_slots[i]}\r\n"
resp += f"+{hosts_and_slots[i + 1]}\r\n"
else:
return f"${len(txt)}\r\n{txt}"

@staticmethod
def cluster_slots_to_resp(resp: str) -> str:
"""Convert query to RESP format."""
return (
f"*{len(resp.split())}\r\n"
+ "\r\n".join(f"${len(x)}\r\n{x}" for x in resp.split())
+ "\r\n"
)

@staticmethod
def oss_maint_notification_to_resp(resp: str) -> str:
"""Convert query to RESP format."""
return (
f">{len(resp.split())}\r\n"
+ "\r\n".join(
f"{RespTranslator.str_or_list_to_resp(x)}" for x in resp.split()
# SMIGRATING
# Format: SMIGRATING SeqID slot,range1-range2
# SMIGRATING 93923 123,789-1000
# SMIGRATING - simple string
# SeqID - integer
# slots - simple string

parts = txt.split()
notification = parts[0]
seq_id = parts[1]
slots = parts[2]

resp = (
">3\r\n" # Push message with 3 elements
f"+{notification}\r\n" # Element 1: Command
f":{seq_id}\r\n" # Element 2: SeqID
f"+{slots}\r\n" # Element 3: Array of [host:port, slots] pairs
)
+ "\r\n"
)
return resp


@dataclass
Expand Down
Loading