From 515abc9d849dc435b99c9ce8aefe94373bba57e3 Mon Sep 17 00:00:00 2001 From: Taishi Kasuga Date: Wed, 9 Aug 2023 10:53:01 +0900 Subject: [PATCH] fix: client should be able to subscribe multiple channels for Pub/Sub --- lib/redis_client/cluster/pub_sub.rb | 52 ++++++++++++++++++++++------- test/redis_client/test_cluster.rb | 31 +++++++++++++++++ 2 files changed, 71 insertions(+), 12 deletions(-) diff --git a/lib/redis_client/cluster/pub_sub.rb b/lib/redis_client/cluster/pub_sub.rb index 7d543fdf..444cc113 100644 --- a/lib/redis_client/cluster/pub_sub.rb +++ b/lib/redis_client/cluster/pub_sub.rb @@ -3,33 +3,61 @@ class RedisClient class Cluster class PubSub + MAX_THREADS = Integer(ENV.fetch('REDIS_CLIENT_MAX_THREADS', 5)) + def initialize(router, command_builder) @router = router @command_builder = command_builder - @pubsub = nil + @pubsub_states = {} end def call(*args, **kwargs) - close - command = @command_builder.generate(args, kwargs) - @pubsub = @router.assign_node(command).pubsub - @pubsub.call_v(command) + _call(@command_builder.generate(args, kwargs)) end def call_v(command) - close - command = @command_builder.generate(command) - @pubsub = @router.assign_node(command).pubsub - @pubsub.call_v(command) + _call(@command_builder.generate(command)) end def close - @pubsub&.close - @pubsub = nil + @pubsub_states.each_value(&:close) + @pubsub_states.clear end def next_event(timeout = nil) - @pubsub&.next_event(timeout) + msgs = collect_messages(timeout).compact + return msgs.first if msgs.size == 1 + + msgs + end + + private + + def _call(command) + node_key = @router.find_node_key(command) + pubsub = if @pubsub_states.key?(node_key) + @pubsub_states[node_key] + else + @pubsub_states[node_key] = @router.find_node(node_key).pubsub + end + pubsub.call_v(command) + end + + def collect_messages(timeout) + @pubsub_states.each_slice(MAX_THREADS).each_with_object([]) do |chuncked_pubsub_states, acc| + threads = chuncked_pubsub_states.map do |_, v| + Thread.new(v) do |pubsub| + Thread.current[:reply] = pubsub.next_event(timeout) + rescue StandardError => e + Thread.current[:reply] = e + end + end + + threads.each do |t| + t.join + acc << t[:reply] + end + end end end end diff --git a/test/redis_client/test_cluster.rb b/test/redis_client/test_cluster.rb index 8303a15c..4e70d2b2 100644 --- a/test/redis_client/test_cluster.rb +++ b/test/redis_client/test_cluster.rb @@ -186,6 +186,7 @@ def test_global_pubsub pubsub = @client.pubsub pubsub.call('SUBSCRIBE', "channel#{i}") assert_equal(['subscribe', "channel#{i}", 1], pubsub.next_event(0.1)) + pubsub.close end sub = Fiber.new do |client| @@ -195,6 +196,7 @@ def test_global_pubsub assert_equal(['subscribe', channel, 1], pubsub.next_event(TEST_TIMEOUT_SEC)) Fiber.yield(channel) Fiber.yield(pubsub.next_event(TEST_TIMEOUT_SEC)) + pubsub.close end channel = sub.resume(@client) @@ -216,6 +218,7 @@ def test_sharded_pubsub assert_equal(['ssubscribe', channel, 1], pubsub.next_event(TEST_TIMEOUT_SEC)) Fiber.yield(channel) Fiber.yield(pubsub.next_event(TEST_TIMEOUT_SEC)) + pubsub.close end channel = sub.resume(@client) @@ -224,6 +227,34 @@ def test_sharded_pubsub end end + def test_sharded_pubsub_with_multiple_channels + if TEST_REDIS_MAJOR_VERSION < 7 + skip('Sharded Pub/Sub is supported by Redis 7+.') + return + end + + sub = Fiber.new do |pubsub| + assert_empty(pubsub.next_event(TEST_TIMEOUT_SEC)) + pubsub.call('SSUBSCRIBE', 'chan1') + pubsub.call('SSUBSCRIBE', 'chan2') + assert_equal( + [['ssubscribe', 'chan1', 1], ['ssubscribe', 'chan2', 1]], + pubsub.next_event(TEST_TIMEOUT_SEC).sort_by { |e| e[1] } + ) + Fiber.yield + Fiber.yield(pubsub.next_event(TEST_TIMEOUT_SEC)) + pubsub.close + end + + sub.resume(@client.pubsub) + @client.call('SPUBLISH', 'chan1', 'hello') + @client.call('SPUBLISH', 'chan2', 'world') + assert_equal( + [%w[smessage chan1 hello], %w[smessage chan2 world]], + sub.resume.sort_by { |e| e[1] } + ) + end + def test_close assert_nil(@client.close) end