diff --git a/redis/cluster.py b/redis/cluster.py index 005206a725..dabac841db 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -2208,7 +2208,8 @@ def _sharded_message_generator(self): def _pubsubs_generator(self): while True: - yield from self.node_pubsub_mapping.values() + current_nodes = list(self.node_pubsub_mapping.values()) + yield from current_nodes def get_sharded_message( self, ignore_subscribe_messages=False, timeout=0.0, target_node=None diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index 4a49ab4dce..fa4a7fa4da 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -871,6 +871,120 @@ def test_pubsub_shardnumsub(self, r): channels = [(b"foo", 1), (b"bar", 2), (b"baz", 3)] assert r.pubsub_shardnumsub("foo", "bar", "baz", target_nodes="all") == channels + @pytest.mark.onlycluster + @skip_if_server_version_lt("7.0.0") + def test_ssubscribe_multiple_channels_different_nodes(self, r): + """ + Test subscribing to multiple sharded channels on different nodes. + Validates that the generator properly handles multiple node_pubsub_mapping entries. + """ + pubsub = r.pubsub() + channel1 = "test-channel:{0}" + channel2 = "test-channel:{6}" + + # Subscribe to first channel + pubsub.ssubscribe(channel1) + msg = wait_for_message(pubsub, timeout=1.0, func=pubsub.get_sharded_message) + assert msg is not None + assert msg["type"] == "ssubscribe" + + # Subscribe to second channel (likely different node) + pubsub.ssubscribe(channel2) + msg = wait_for_message(pubsub, timeout=1.0, func=pubsub.get_sharded_message) + assert msg is not None + assert msg["type"] == "ssubscribe" + + # Verify both channels are in shard_channels + assert channel1.encode() in pubsub.shard_channels + assert channel2.encode() in pubsub.shard_channels + + pubsub.close() + + @pytest.mark.onlycluster + @skip_if_server_version_lt("7.0.0") + def test_ssubscribe_multiple_channels_publish_and_read(self, r): + """ + Test publishing to multiple sharded channels and reading messages. + Validates that _sharded_message_generator properly cycles through + multiple node_pubsub_mapping entries. + """ + pubsub = r.pubsub() + channel1 = "test-channel:{0}" + channel2 = "test-channel:{6}" + msg1_data = "message-1" + msg2_data = "message-2" + + # Subscribe to both channels + pubsub.ssubscribe(channel1, channel2) + + # Read subscription confirmations + for _ in range(2): + msg = wait_for_message(pubsub, timeout=1.0, func=pubsub.get_sharded_message) + assert msg is not None + assert msg["type"] == "ssubscribe" + + # Publish messages to both channels + r.spublish(channel1, msg1_data) + r.spublish(channel2, msg2_data) + + # Read messages - should get both messages + messages = [] + for _ in range(2): + msg = wait_for_message(pubsub, timeout=1.0, func=pubsub.get_sharded_message) + assert msg is not None + assert msg["type"] == "smessage" + messages.append(msg) + + # Verify we got messages from both channels + channels_received = {msg["channel"] for msg in messages} + assert channel1.encode() in channels_received + assert channel2.encode() in channels_received + + pubsub.close() + + @pytest.mark.onlycluster + @skip_if_server_version_lt("7.0.0") + def test_generator_handles_concurrent_mapping_changes(self, r): + """ + Test that the generator properly handles mapping changes during iteration. + This validates the fix for the RuntimeError: dictionary changed size during iteration. + """ + pubsub = r.pubsub() + channel1 = "test-channel:{0}" + channel2 = "test-channel:{6}" + + # Subscribe to first channel + pubsub.ssubscribe(channel1) + msg = wait_for_message(pubsub, timeout=1.0, func=pubsub.get_sharded_message) + assert msg is not None + assert msg["type"] == "ssubscribe" + + # Get initial mapping size (cluster pubsub only) + assert hasattr(pubsub, "node_pubsub_mapping"), "Test requires ClusterPubSub" + initial_size = len(pubsub.node_pubsub_mapping) + + # Subscribe to second channel (modifies mapping during potential iteration) + pubsub.ssubscribe(channel2) + msg = wait_for_message(pubsub, timeout=1.0, func=pubsub.get_sharded_message) + assert msg is not None + assert msg["type"] == "ssubscribe" + + # Verify mapping was updated + assert len(pubsub.node_pubsub_mapping) >= initial_size + + # Publish and read messages - should not raise RuntimeError + r.spublish(channel1, "msg1") + r.spublish(channel2, "msg2") + + messages_received = 0 + for _ in range(2): + msg = wait_for_message(pubsub, timeout=1.0, func=pubsub.get_sharded_message) + if msg and msg["type"] == "smessage": + messages_received += 1 + + assert messages_received == 2 + pubsub.close() + class TestPubSubPings: @skip_if_server_version_lt("3.0.0")