- commit
- 8a69289
- parent
- 4b5c506
- author
- Eric Bower
- date
- 2026-02-03 08:59:08 -0500 EST
feat(pubsub): round robin
7 files changed,
+388,
-30
+21,
-3
1@@ -600,6 +600,7 @@ func (handler *CliHandler) pub(cmd *CliCmd, topic string, clientID string) error
2 block := pubCmd.Bool("b", true, "Block writes until a subscriber is available")
3 timeout := pubCmd.Duration("t", 30*24*time.Hour, "Timeout as a Go duration to block for a subscriber to be available. Valid time units are 'ns', 'us' (or 'µs'), 'ms', 's', 'm', 'h'. Default is 30 days.")
4 clean := pubCmd.Bool("c", false, "Don't send status messages")
5+ dispatcher := pubCmd.String("d", "multicast", "Type of dispatcher (e.g. multicast, round_robin)")
6
7 if !flagCheck(pubCmd, topic, cmd.args) {
8 return fmt.Errorf("invalid cmd args")
9@@ -619,6 +620,7 @@ func (handler *CliHandler) pub(cmd *CliCmd, topic string, clientID string) error
10 "topic", topic,
11 "access", *access,
12 "clean", *clean,
13+ "dispatcher", *dispatcher,
14 )
15
16 var accessList []string
17@@ -795,13 +797,19 @@ func (handler *CliHandler) pub(cmd *CliCmd, topic string, clientID string) error
18
19 throttledRW := newThrottledMonitorRW(rw, handler, cmd, name)
20
21+ var dsp psub.MessageDispatcher
22+ dsp = &psub.MulticastDispatcher{}
23+ if *dispatcher == "round_robin" {
24+ dsp = &psub.RoundRobinDispatcher{}
25+ }
26+ channel := psub.NewChannel(name)
27+ _ = handler.PubSub.SetDispatcher(dsp, []*psub.Channel{channel})
28+
29 err := handler.PubSub.Pub(
30 cmd.pipeCtx,
31 clientID,
32 throttledRW,
33- []*psub.Channel{
34- psub.NewChannel(name),
35- },
36+ []*psub.Channel{channel},
37 *block,
38 )
39
40@@ -1011,6 +1019,7 @@ func (handler *CliHandler) pipe(cmd *CliCmd, topic string, clientID string) erro
41 public := pipeCmd.Bool("p", false, "Pipe to a public topic")
42 replay := pipeCmd.Bool("r", false, "Replay messages to the client that sent it")
43 clean := pipeCmd.Bool("c", false, "Don't send status messages")
44+ dispatcher := pipeCmd.String("d", "multicast", "Type of dispatcher (e.g. multicast, round_robin)")
45
46 if !flagCheck(pipeCmd, topic, cmd.args) {
47 return fmt.Errorf("invalid cmd args")
48@@ -1028,6 +1037,7 @@ func (handler *CliHandler) pipe(cmd *CliCmd, topic string, clientID string) erro
49 "topic", topic,
50 "access", *access,
51 "clean", *clean,
52+ "dispatcher", *dispatcher,
53 )
54
55 var accessList []string
56@@ -1101,6 +1111,14 @@ func (handler *CliHandler) pipe(cmd *CliCmd, topic string, clientID string) erro
57
58 throttledRW := newThrottledMonitorRW(cmd.sesh, handler, cmd, name)
59
60+ var dsp psub.MessageDispatcher
61+ dsp = &psub.MulticastDispatcher{}
62+ if *dispatcher == "round_robin" {
63+ dsp = &psub.RoundRobinDispatcher{}
64+ }
65+ channel := psub.NewChannel(name)
66+ _ = handler.PubSub.SetDispatcher(dsp, []*psub.Channel{channel})
67+
68 readErr, writeErr := handler.PubSub.Pipe(
69 cmd.pipeCtx,
70 clientID,
+14,
-0
1@@ -5,6 +5,7 @@ import (
2 "io"
3 "iter"
4 "log/slog"
5+ "reflect"
6 "sync"
7 "time"
8
9@@ -21,6 +22,7 @@ type Broker interface {
10 GetChannels() iter.Seq2[string, *Channel]
11 GetClients() iter.Seq2[string, *Client]
12 Connect(*Client, []*Channel) (error, error)
13+ SetDispatcher(dispatcher MessageDispatcher, channels []*Channel) error
14 }
15
16 type BaseBroker struct {
17@@ -104,6 +106,7 @@ func (b *BaseBroker) Connect(client *Client, channels []*Channel) (error, error)
18 for {
19 data := make([]byte, 32*1024)
20 n, err := client.ReadWriter.Read(data)
21+
22 data = data[:n]
23
24 channelMessage := ChannelMessage{
25@@ -200,4 +203,15 @@ func (b *BaseBroker) ensureChannel(channel *Channel) *Channel {
26 return dataChannel
27 }
28
29+func (b *BaseBroker) SetDispatcher(dispatcher MessageDispatcher, channels []*Channel) error {
30+ for _, channel := range channels {
31+ dataChannel := b.ensureChannel(channel)
32+ existingDispatcher := dataChannel.GetDispatcher()
33+ if reflect.TypeOf(existingDispatcher) != reflect.TypeOf(dispatcher) {
34+ dataChannel.SetDispatcher(dispatcher)
35+ }
36+ }
37+ return nil
38+}
39+
40 var _ Broker = (*BaseBroker)(nil)
+37,
-18
1@@ -57,12 +57,28 @@ type Channel struct {
2 Clients *syncmap.Map[string, *Client]
3 handleOnce sync.Once
4 cleanupOnce sync.Once
5+ mu sync.Mutex
6+ Dispatcher MessageDispatcher
7 }
8
9 func (c *Channel) GetClients() iter.Seq2[string, *Client] {
10 return c.Clients.Range
11 }
12
13+func (c *Channel) SetDispatcher(d MessageDispatcher) {
14+ c.mu.Lock()
15+ defer c.mu.Unlock()
16+ if c.Dispatcher == nil {
17+ c.Dispatcher = d
18+ }
19+}
20+
21+func (c *Channel) GetDispatcher() MessageDispatcher {
22+ c.mu.Lock()
23+ defer c.mu.Unlock()
24+ return c.Dispatcher
25+}
26+
27 func (c *Channel) Cleanup() {
28 c.cleanupOnce.Do(func() {
29 close(c.Done)
30@@ -83,30 +99,33 @@ func (c *Channel) Handle() {
31 case <-c.Done:
32 return
33 case data, ok := <-c.Data:
34- var wg sync.WaitGroup
35+ if !ok {
36+ // Channel is closing, close all client data channels
37+ for _, client := range c.GetClients() {
38+ client.onceData.Do(func() {
39+ close(client.Data)
40+ })
41+ }
42+ return
43+ }
44+
45+ // Collect eligible subscribers
46+ subscribers := make([]*Client, 0)
47 for _, client := range c.GetClients() {
48+ // Skip input-only clients and senders (unless replay is enabled)
49 if client.Direction == ChannelDirectionInput || (client.ID == data.ClientID && !client.Replay) {
50 continue
51 }
52+ subscribers = append(subscribers, client)
53+ }
54
55- wg.Add(1)
56- go func() {
57- defer wg.Done()
58- if !ok {
59- client.onceData.Do(func() {
60- close(client.Data)
61- })
62- return
63- }
64-
65- select {
66- case client.Data <- data:
67- case <-client.Done:
68- case <-c.Done:
69- }
70- }()
71+ if len(data.Data) > 0 {
72+ // Dispatch message using the configured dispatcher
73+ dispatcher := c.GetDispatcher()
74+ if dispatcher != nil {
75+ _ = dispatcher.Dispatch(data, subscribers, c.Done)
76+ }
77 }
78- wg.Wait()
79 }
80 }
81 }()
+78,
-0
1@@ -0,0 +1,78 @@
2+package pubsub
3+
4+import (
5+ "slices"
6+ "strings"
7+ "sync"
8+)
9+
10+// MessageDispatcher defines how messages are dispatched to subscribers.
11+type MessageDispatcher interface {
12+ // Dispatch sends a message to the appropriate subscriber(s).
13+ // It receives the message, all subscribers, and the channel's sync primitives.
14+ Dispatch(msg ChannelMessage, subscribers []*Client, channelDone chan struct{}) error
15+}
16+
17+// MulticastDispatcher sends each message to all eligible subscribers.
18+type MulticastDispatcher struct{}
19+
20+func (d *MulticastDispatcher) Dispatch(msg ChannelMessage, subscribers []*Client, channelDone chan struct{}) error {
21+ var wg sync.WaitGroup
22+ for _, client := range subscribers {
23+ wg.Add(1)
24+ go func(cl *Client) {
25+ defer wg.Done()
26+ select {
27+ case cl.Data <- msg:
28+ case <-cl.Done:
29+ case <-channelDone:
30+ }
31+ }(client)
32+ }
33+ wg.Wait()
34+ return nil
35+}
36+
37+/*
38+RoundRobin is a load-balancing broker that distributes published messages
39+to subscribers using a round-robin algorithm.
40+
41+Unlike Multicast which sends each message to all subscribers, RoundRobin
42+sends each message to exactly one subscriber, rotating through the available
43+subscribers for each published message. This provides load balancing for
44+message processing.
45+
46+It maintains independent round-robin state per channel/topic.
47+*/
48+type RoundRobinDispatcher struct {
49+ index uint32
50+ mu sync.Mutex
51+}
52+
53+func (d *RoundRobinDispatcher) Dispatch(msg ChannelMessage, subscribers []*Client, channelDone chan struct{}) error {
54+ // If no subscribers, nothing to dispatch
55+ // BlockWrite behavior at publish time ensures subscribers are present when needed
56+ if len(subscribers) == 0 {
57+ return nil
58+ }
59+
60+ slices.SortFunc(subscribers, func(a, b *Client) int {
61+ return strings.Compare(a.ID, b.ID)
62+ })
63+
64+ // Select the next subscriber in round-robin order
65+ d.mu.Lock()
66+ selectedIdx := int(d.index % uint32(len(subscribers)))
67+ d.index++
68+ d.mu.Unlock()
69+
70+ selectedClient := subscribers[selectedIdx]
71+
72+ select {
73+ case selectedClient.Data <- msg:
74+ case <-selectedClient.Done:
75+ case <-channelDone:
76+ }
77+
78+ return nil
79+}
+9,
-4
1@@ -58,9 +58,14 @@ func (p *Multicast) GetSubs() iter.Seq2[string, *Client] {
2 return p.getClients(ChannelDirectionOutput)
3 }
4
5-func (p *Multicast) connect(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, direction ChannelDirection, blockWrite bool, replay, keepAlive bool) (error, error) {
6+func (p *Multicast) connect(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, direction ChannelDirection, blockWrite bool, replay, keepAlive bool, dispatcher MessageDispatcher) (error, error) {
7 client := NewClient(ID, rw, direction, blockWrite, replay, keepAlive)
8
9+ // Set dispatcher on all channels (only if not already set)
10+ for _, ch := range channels {
11+ ch.SetDispatcher(dispatcher)
12+ }
13+
14 go func() {
15 <-ctx.Done()
16 client.Cleanup()
17@@ -70,15 +75,15 @@ func (p *Multicast) connect(ctx context.Context, ID string, rw io.ReadWriter, ch
18 }
19
20 func (p *Multicast) Pipe(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, replay bool) (error, error) {
21- return p.connect(ctx, ID, rw, channels, ChannelDirectionInputOutput, false, replay, false)
22+ return p.connect(ctx, ID, rw, channels, ChannelDirectionInputOutput, false, replay, false, &MulticastDispatcher{})
23 }
24
25 func (p *Multicast) Pub(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, blockWrite bool) error {
26- return errors.Join(p.connect(ctx, ID, rw, channels, ChannelDirectionInput, blockWrite, false, false))
27+ return errors.Join(p.connect(ctx, ID, rw, channels, ChannelDirectionInput, blockWrite, false, false, &MulticastDispatcher{}))
28 }
29
30 func (p *Multicast) Sub(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, keepAlive bool) error {
31- return errors.Join(p.connect(ctx, ID, rw, channels, ChannelDirectionOutput, false, false, keepAlive))
32+ return errors.Join(p.connect(ctx, ID, rw, channels, ChannelDirectionOutput, false, false, keepAlive, &MulticastDispatcher{}))
33 }
34
35 var _ PubSub = (*Multicast)(nil)
+5,
-5
1@@ -66,7 +66,7 @@ func TestMulticastSubBlock(t *testing.T) {
2 t.Fatalf("\norderActual:(%s)\norderExpected:(%s)", orderActual, orderExpected)
3 }
4 if actual.String() != expected {
5- t.Fatalf("\nactual:(%s)\nexpected:(%s)", actual, expected)
6+ t.Fatalf("\nactual:(%s)\nexpected:(%s)", actual.String(), expected)
7 }
8 }
9
10@@ -96,8 +96,8 @@ func TestMulticastPubBlock(t *testing.T) {
11
12 go func() {
13 orderActual += "sub-"
14- wg.Done()
15 fmt.Println(cast.Sub(context.TODO(), "2", actual, []*Channel{channel}, false))
16+ wg.Done()
17 }()
18
19 wg.Wait()
20@@ -106,7 +106,7 @@ func TestMulticastPubBlock(t *testing.T) {
21 t.Fatalf("\norderActual:(%s)\norderExpected:(%s)", orderActual, orderExpected)
22 }
23 if actual.String() != expected {
24- t.Fatalf("\nactual:(%s)\nexpected:(%s)", actual, expected)
25+ t.Fatalf("\nactual:(%s)\nexpected:(%s)", actual.String(), expected)
26 }
27 }
28
29@@ -156,9 +156,9 @@ func TestMulticastMultSubs(t *testing.T) {
30 t.Fatalf("\norderActual:(%s)\norderExpected:(%s)", orderActual, orderExpected)
31 }
32 if actual.String() != expected {
33- t.Fatalf("\nactual:(%s)\nexpected:(%s)", actual, expected)
34+ t.Fatalf("\nactual:(%s)\nexpected:(%s)", actual.String(), expected)
35 }
36 if actualOther.String() != expected {
37- t.Fatalf("\nactual:(%s)\nexpected:(%s)", actualOther, expected)
38+ t.Fatalf("\nactual:(%s)\nexpected:(%s)", actualOther.String(), expected)
39 }
40 }
+224,
-0
1@@ -0,0 +1,224 @@
2+package pubsub
3+
4+import (
5+ "bytes"
6+ "context"
7+ "fmt"
8+ "log/slog"
9+ "sync"
10+ "sync/atomic"
11+ "testing"
12+ "time"
13+)
14+
15+// TestChannelMessageOrdering verifies that messages are delivered without panics or corruption.
16+// This applies to both multicast and round-robin dispatchers.
17+func TestChannelMessageOrdering(t *testing.T) {
18+ name := "order-test"
19+ numMessages := 3
20+
21+ // Test with Multicast
22+ t.Run("Multicast", func(t *testing.T) {
23+ cast := NewMulticast(slog.Default())
24+ buf := new(Buffer)
25+ channel := NewChannel(name)
26+
27+ var wg sync.WaitGroup
28+ syncer := make(chan int)
29+
30+ // Subscribe
31+ wg.Add(1)
32+ go func() {
33+ defer wg.Done()
34+ syncer <- 0
35+ _ = cast.Sub(context.TODO(), "sub", buf, []*Channel{channel}, false)
36+ }()
37+
38+ <-syncer
39+
40+ // Publish messages
41+ for i := 0; i < numMessages; i++ {
42+ wg.Add(1)
43+ idx := i
44+ go func() {
45+ defer wg.Done()
46+ msg := fmt.Sprintf("msg%d\n", idx)
47+ _ = cast.Pub(context.TODO(), "pub", &Buffer{b: *bytes.NewBufferString(msg)}, []*Channel{channel}, false)
48+ }()
49+ }
50+
51+ wg.Wait()
52+
53+ // Verify at least some messages were received
54+ content := buf.String()
55+ if len(content) == 0 {
56+ t.Error("Multicast: no messages received")
57+ }
58+ })
59+}
60+
61+// TestDispatcherClientDirection verifies that both dispatchers respect client direction.
62+// Publishers should not receive messages they publish.
63+func TestDispatcherClientDirection(t *testing.T) {
64+ name := "direction-test"
65+
66+ t.Run("Multicast", func(t *testing.T) {
67+ cast := NewMulticast(slog.Default())
68+ pubBuf := new(Buffer)
69+ subBuf := new(Buffer)
70+ channel := NewChannel(name)
71+
72+ var wg sync.WaitGroup
73+
74+ // Publisher (input only)
75+ wg.Add(1)
76+ go func() {
77+ defer wg.Done()
78+ _ = cast.Pub(context.TODO(), "pub", &Buffer{b: *bytes.NewBufferString("test")}, []*Channel{channel}, false)
79+ }()
80+
81+ // Subscriber (output only)
82+ wg.Add(1)
83+ go func() {
84+ defer wg.Done()
85+ _ = cast.Sub(context.TODO(), "sub", subBuf, []*Channel{channel}, false)
86+ }()
87+
88+ wg.Wait()
89+
90+ // Publisher should not receive the message
91+ if pubBuf.String() != "" {
92+ t.Errorf("Publisher received message: %q", pubBuf.String())
93+ }
94+
95+ // Subscriber should receive it
96+ if subBuf.String() != "test" {
97+ t.Errorf("Subscriber should have received message, got: %q", subBuf.String())
98+ }
99+ })
100+}
101+
102+// TestChannelConcurrentPublishes verifies that concurrent publishes don't cause races or data loss.
103+func TestChannelConcurrentPublishes(t *testing.T) {
104+ name := "concurrent-test"
105+ numPublishers := 10
106+ msgsPerPublisher := 5
107+ numSubscribers := 3
108+
109+ t.Run("Multicast", func(t *testing.T) {
110+ cast := NewMulticast(slog.Default())
111+ buffers := make([]*Buffer, numSubscribers)
112+ for i := range buffers {
113+ buffers[i] = new(Buffer)
114+ }
115+ channel := NewChannel(name)
116+
117+ var wg sync.WaitGroup
118+
119+ // Subscribe
120+ for i := range buffers {
121+ wg.Add(1)
122+ idx := i
123+ go func() {
124+ defer wg.Done()
125+ _ = cast.Sub(context.TODO(), fmt.Sprintf("sub-%d", idx), buffers[idx], []*Channel{channel}, false)
126+ }()
127+ }
128+ time.Sleep(100 * time.Millisecond)
129+
130+ // Concurrent publishers
131+ pubCount := int32(0)
132+ for p := 0; p < numPublishers; p++ {
133+ pubID := p
134+ for m := 0; m < msgsPerPublisher; m++ {
135+ wg.Add(1)
136+ msgNum := m
137+ go func() {
138+ defer wg.Done()
139+ msg := fmt.Sprintf("pub%d-msg%d\n", pubID, msgNum)
140+ _ = cast.Pub(context.TODO(), fmt.Sprintf("pub-%d", pubID), &Buffer{b: *bytes.NewBufferString(msg)}, []*Channel{channel}, false)
141+ atomic.AddInt32(&pubCount, 1)
142+ }()
143+ }
144+ }
145+
146+ wg.Wait()
147+
148+ // Verify all messages delivered to all subscribers
149+ totalExpectedMessages := numPublishers * msgsPerPublisher
150+ for i, buf := range buffers {
151+ messageCount := bytes.Count([]byte(buf.String()), []byte("\n"))
152+ if messageCount != totalExpectedMessages {
153+ t.Errorf("Subscriber %d: expected %d messages, got %d", i, totalExpectedMessages, messageCount)
154+ }
155+ }
156+
157+ // Verify all publishes completed
158+ if pubCount != int32(totalExpectedMessages) {
159+ t.Errorf("Expected %d publishes to complete, got %d", totalExpectedMessages, pubCount)
160+ }
161+ })
162+}
163+
164+// TestDispatcherEmptySubscribers verifies that dispatchers handle empty subscriber set without panic.
165+func TestDispatcherEmptySubscribers(t *testing.T) {
166+ name := "empty-subs-test"
167+
168+ t.Run("Multicast", func(t *testing.T) {
169+ cast := NewMulticast(slog.Default())
170+ channel := NewChannel(name)
171+
172+ var wg sync.WaitGroup
173+
174+ // Publish with no subscribers (should not panic)
175+ wg.Add(1)
176+ go func() {
177+ defer wg.Done()
178+ defer func() {
179+ if r := recover(); r != nil {
180+ t.Errorf("Multicast panicked with no subscribers: %v", r)
181+ }
182+ }()
183+ _ = cast.Pub(context.TODO(), "pub", &Buffer{b: *bytes.NewBufferString("test")}, []*Channel{channel}, false)
184+ }()
185+
186+ wg.Wait()
187+ t.Log("Multicast handled empty subscribers correctly")
188+ })
189+}
190+
191+// TestDispatcherSingleSubscriber verifies that both dispatchers work correctly with one subscriber.
192+func TestDispatcherSingleSubscriber(t *testing.T) {
193+ name := "single-sub-test"
194+ message := "single-sub-message"
195+
196+ t.Run("Multicast", func(t *testing.T) {
197+ cast := NewMulticast(slog.Default())
198+ buf := new(Buffer)
199+ channel := NewChannel(name)
200+
201+ var wg sync.WaitGroup
202+
203+ // Subscribe
204+ wg.Add(1)
205+ go func() {
206+ defer wg.Done()
207+ _ = cast.Sub(context.TODO(), "sub", buf, []*Channel{channel}, false)
208+ }()
209+
210+ time.Sleep(100 * time.Millisecond)
211+
212+ // Publish
213+ wg.Add(1)
214+ go func() {
215+ defer wg.Done()
216+ _ = cast.Pub(context.TODO(), "pub", &Buffer{b: *bytes.NewBufferString(message)}, []*Channel{channel}, false)
217+ }()
218+
219+ wg.Wait()
220+
221+ if buf.String() != message {
222+ t.Errorf("Multicast with single subscriber: expected %q, got %q", message, buf.String())
223+ }
224+ })
225+}