repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

commit
8a69289
parent
4b5c506
author
Eric Bower
date
2026-02-03 08:59:08 -0500 EST
feat(pubsub): round robin
7 files changed,  +388, -30
M pkg/apps/pipe/cli.go
+21, -3
 1@@ -600,6 +600,7 @@ func (handler *CliHandler) pub(cmd *CliCmd, topic string, clientID string) error
 2 	block := pubCmd.Bool("b", true, "Block writes until a subscriber is available")
 3 	timeout := pubCmd.Duration("t", 30*24*time.Hour, "Timeout as a Go duration to block for a subscriber to be available. Valid time units are 'ns', 'us' (or 'µs'), 'ms', 's', 'm', 'h'. Default is 30 days.")
 4 	clean := pubCmd.Bool("c", false, "Don't send status messages")
 5+	dispatcher := pubCmd.String("d", "multicast", "Type of dispatcher (e.g. multicast, round_robin)")
 6 
 7 	if !flagCheck(pubCmd, topic, cmd.args) {
 8 		return fmt.Errorf("invalid cmd args")
 9@@ -619,6 +620,7 @@ func (handler *CliHandler) pub(cmd *CliCmd, topic string, clientID string) error
10 		"topic", topic,
11 		"access", *access,
12 		"clean", *clean,
13+		"dispatcher", *dispatcher,
14 	)
15 
16 	var accessList []string
17@@ -795,13 +797,19 @@ func (handler *CliHandler) pub(cmd *CliCmd, topic string, clientID string) error
18 
19 	throttledRW := newThrottledMonitorRW(rw, handler, cmd, name)
20 
21+	var dsp psub.MessageDispatcher
22+	dsp = &psub.MulticastDispatcher{}
23+	if *dispatcher == "round_robin" {
24+		dsp = &psub.RoundRobinDispatcher{}
25+	}
26+	channel := psub.NewChannel(name)
27+	_ = handler.PubSub.SetDispatcher(dsp, []*psub.Channel{channel})
28+
29 	err := handler.PubSub.Pub(
30 		cmd.pipeCtx,
31 		clientID,
32 		throttledRW,
33-		[]*psub.Channel{
34-			psub.NewChannel(name),
35-		},
36+		[]*psub.Channel{channel},
37 		*block,
38 	)
39 
40@@ -1011,6 +1019,7 @@ func (handler *CliHandler) pipe(cmd *CliCmd, topic string, clientID string) erro
41 	public := pipeCmd.Bool("p", false, "Pipe to a public topic")
42 	replay := pipeCmd.Bool("r", false, "Replay messages to the client that sent it")
43 	clean := pipeCmd.Bool("c", false, "Don't send status messages")
44+	dispatcher := pipeCmd.String("d", "multicast", "Type of dispatcher (e.g. multicast, round_robin)")
45 
46 	if !flagCheck(pipeCmd, topic, cmd.args) {
47 		return fmt.Errorf("invalid cmd args")
48@@ -1028,6 +1037,7 @@ func (handler *CliHandler) pipe(cmd *CliCmd, topic string, clientID string) erro
49 		"topic", topic,
50 		"access", *access,
51 		"clean", *clean,
52+		"dispatcher", *dispatcher,
53 	)
54 
55 	var accessList []string
56@@ -1101,6 +1111,14 @@ func (handler *CliHandler) pipe(cmd *CliCmd, topic string, clientID string) erro
57 
58 	throttledRW := newThrottledMonitorRW(cmd.sesh, handler, cmd, name)
59 
60+	var dsp psub.MessageDispatcher
61+	dsp = &psub.MulticastDispatcher{}
62+	if *dispatcher == "round_robin" {
63+		dsp = &psub.RoundRobinDispatcher{}
64+	}
65+	channel := psub.NewChannel(name)
66+	_ = handler.PubSub.SetDispatcher(dsp, []*psub.Channel{channel})
67+
68 	readErr, writeErr := handler.PubSub.Pipe(
69 		cmd.pipeCtx,
70 		clientID,
M pkg/pubsub/broker.go
+14, -0
 1@@ -5,6 +5,7 @@ import (
 2 	"io"
 3 	"iter"
 4 	"log/slog"
 5+	"reflect"
 6 	"sync"
 7 	"time"
 8 
 9@@ -21,6 +22,7 @@ type Broker interface {
10 	GetChannels() iter.Seq2[string, *Channel]
11 	GetClients() iter.Seq2[string, *Client]
12 	Connect(*Client, []*Channel) (error, error)
13+	SetDispatcher(dispatcher MessageDispatcher, channels []*Channel) error
14 }
15 
16 type BaseBroker struct {
17@@ -104,6 +106,7 @@ func (b *BaseBroker) Connect(client *Client, channels []*Channel) (error, error)
18 			for {
19 				data := make([]byte, 32*1024)
20 				n, err := client.ReadWriter.Read(data)
21+
22 				data = data[:n]
23 
24 				channelMessage := ChannelMessage{
25@@ -200,4 +203,15 @@ func (b *BaseBroker) ensureChannel(channel *Channel) *Channel {
26 	return dataChannel
27 }
28 
29+func (b *BaseBroker) SetDispatcher(dispatcher MessageDispatcher, channels []*Channel) error {
30+	for _, channel := range channels {
31+		dataChannel := b.ensureChannel(channel)
32+		existingDispatcher := dataChannel.GetDispatcher()
33+		if reflect.TypeOf(existingDispatcher) != reflect.TypeOf(dispatcher) {
34+			dataChannel.SetDispatcher(dispatcher)
35+		}
36+	}
37+	return nil
38+}
39+
40 var _ Broker = (*BaseBroker)(nil)
M pkg/pubsub/channel.go
+37, -18
 1@@ -57,12 +57,28 @@ type Channel struct {
 2 	Clients     *syncmap.Map[string, *Client]
 3 	handleOnce  sync.Once
 4 	cleanupOnce sync.Once
 5+	mu          sync.Mutex
 6+	Dispatcher  MessageDispatcher
 7 }
 8 
 9 func (c *Channel) GetClients() iter.Seq2[string, *Client] {
10 	return c.Clients.Range
11 }
12 
13+func (c *Channel) SetDispatcher(d MessageDispatcher) {
14+	c.mu.Lock()
15+	defer c.mu.Unlock()
16+	if c.Dispatcher == nil {
17+		c.Dispatcher = d
18+	}
19+}
20+
21+func (c *Channel) GetDispatcher() MessageDispatcher {
22+	c.mu.Lock()
23+	defer c.mu.Unlock()
24+	return c.Dispatcher
25+}
26+
27 func (c *Channel) Cleanup() {
28 	c.cleanupOnce.Do(func() {
29 		close(c.Done)
30@@ -83,30 +99,33 @@ func (c *Channel) Handle() {
31 				case <-c.Done:
32 					return
33 				case data, ok := <-c.Data:
34-					var wg sync.WaitGroup
35+					if !ok {
36+						// Channel is closing, close all client data channels
37+						for _, client := range c.GetClients() {
38+							client.onceData.Do(func() {
39+								close(client.Data)
40+							})
41+						}
42+						return
43+					}
44+
45+					// Collect eligible subscribers
46+					subscribers := make([]*Client, 0)
47 					for _, client := range c.GetClients() {
48+						// Skip input-only clients and senders (unless replay is enabled)
49 						if client.Direction == ChannelDirectionInput || (client.ID == data.ClientID && !client.Replay) {
50 							continue
51 						}
52+						subscribers = append(subscribers, client)
53+					}
54 
55-						wg.Add(1)
56-						go func() {
57-							defer wg.Done()
58-							if !ok {
59-								client.onceData.Do(func() {
60-									close(client.Data)
61-								})
62-								return
63-							}
64-
65-							select {
66-							case client.Data <- data:
67-							case <-client.Done:
68-							case <-c.Done:
69-							}
70-						}()
71+					if len(data.Data) > 0 {
72+						// Dispatch message using the configured dispatcher
73+						dispatcher := c.GetDispatcher()
74+						if dispatcher != nil {
75+							_ = dispatcher.Dispatch(data, subscribers, c.Done)
76+						}
77 					}
78-					wg.Wait()
79 				}
80 			}
81 		}()
A pkg/pubsub/dispatcher.go
+78, -0
 1@@ -0,0 +1,78 @@
 2+package pubsub
 3+
 4+import (
 5+	"slices"
 6+	"strings"
 7+	"sync"
 8+)
 9+
10+// MessageDispatcher defines how messages are dispatched to subscribers.
11+type MessageDispatcher interface {
12+	// Dispatch sends a message to the appropriate subscriber(s).
13+	// It receives the message, all subscribers, and the channel's sync primitives.
14+	Dispatch(msg ChannelMessage, subscribers []*Client, channelDone chan struct{}) error
15+}
16+
17+// MulticastDispatcher sends each message to all eligible subscribers.
18+type MulticastDispatcher struct{}
19+
20+func (d *MulticastDispatcher) Dispatch(msg ChannelMessage, subscribers []*Client, channelDone chan struct{}) error {
21+	var wg sync.WaitGroup
22+	for _, client := range subscribers {
23+		wg.Add(1)
24+		go func(cl *Client) {
25+			defer wg.Done()
26+			select {
27+			case cl.Data <- msg:
28+			case <-cl.Done:
29+			case <-channelDone:
30+			}
31+		}(client)
32+	}
33+	wg.Wait()
34+	return nil
35+}
36+
37+/*
38+RoundRobin is a load-balancing broker that distributes published messages
39+to subscribers using a round-robin algorithm.
40+
41+Unlike Multicast which sends each message to all subscribers, RoundRobin
42+sends each message to exactly one subscriber, rotating through the available
43+subscribers for each published message. This provides load balancing for
44+message processing.
45+
46+It maintains independent round-robin state per channel/topic.
47+*/
48+type RoundRobinDispatcher struct {
49+	index uint32
50+	mu    sync.Mutex
51+}
52+
53+func (d *RoundRobinDispatcher) Dispatch(msg ChannelMessage, subscribers []*Client, channelDone chan struct{}) error {
54+	// If no subscribers, nothing to dispatch
55+	// BlockWrite behavior at publish time ensures subscribers are present when needed
56+	if len(subscribers) == 0 {
57+		return nil
58+	}
59+
60+	slices.SortFunc(subscribers, func(a, b *Client) int {
61+		return strings.Compare(a.ID, b.ID)
62+	})
63+
64+	// Select the next subscriber in round-robin order
65+	d.mu.Lock()
66+	selectedIdx := int(d.index % uint32(len(subscribers)))
67+	d.index++
68+	d.mu.Unlock()
69+
70+	selectedClient := subscribers[selectedIdx]
71+
72+	select {
73+	case selectedClient.Data <- msg:
74+	case <-selectedClient.Done:
75+	case <-channelDone:
76+	}
77+
78+	return nil
79+}
M pkg/pubsub/multicast.go
+9, -4
 1@@ -58,9 +58,14 @@ func (p *Multicast) GetSubs() iter.Seq2[string, *Client] {
 2 	return p.getClients(ChannelDirectionOutput)
 3 }
 4 
 5-func (p *Multicast) connect(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, direction ChannelDirection, blockWrite bool, replay, keepAlive bool) (error, error) {
 6+func (p *Multicast) connect(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, direction ChannelDirection, blockWrite bool, replay, keepAlive bool, dispatcher MessageDispatcher) (error, error) {
 7 	client := NewClient(ID, rw, direction, blockWrite, replay, keepAlive)
 8 
 9+	// Set dispatcher on all channels (only if not already set)
10+	for _, ch := range channels {
11+		ch.SetDispatcher(dispatcher)
12+	}
13+
14 	go func() {
15 		<-ctx.Done()
16 		client.Cleanup()
17@@ -70,15 +75,15 @@ func (p *Multicast) connect(ctx context.Context, ID string, rw io.ReadWriter, ch
18 }
19 
20 func (p *Multicast) Pipe(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, replay bool) (error, error) {
21-	return p.connect(ctx, ID, rw, channels, ChannelDirectionInputOutput, false, replay, false)
22+	return p.connect(ctx, ID, rw, channels, ChannelDirectionInputOutput, false, replay, false, &MulticastDispatcher{})
23 }
24 
25 func (p *Multicast) Pub(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, blockWrite bool) error {
26-	return errors.Join(p.connect(ctx, ID, rw, channels, ChannelDirectionInput, blockWrite, false, false))
27+	return errors.Join(p.connect(ctx, ID, rw, channels, ChannelDirectionInput, blockWrite, false, false, &MulticastDispatcher{}))
28 }
29 
30 func (p *Multicast) Sub(ctx context.Context, ID string, rw io.ReadWriter, channels []*Channel, keepAlive bool) error {
31-	return errors.Join(p.connect(ctx, ID, rw, channels, ChannelDirectionOutput, false, false, keepAlive))
32+	return errors.Join(p.connect(ctx, ID, rw, channels, ChannelDirectionOutput, false, false, keepAlive, &MulticastDispatcher{}))
33 }
34 
35 var _ PubSub = (*Multicast)(nil)
M pkg/pubsub/multicast_test.go
+5, -5
 1@@ -66,7 +66,7 @@ func TestMulticastSubBlock(t *testing.T) {
 2 		t.Fatalf("\norderActual:(%s)\norderExpected:(%s)", orderActual, orderExpected)
 3 	}
 4 	if actual.String() != expected {
 5-		t.Fatalf("\nactual:(%s)\nexpected:(%s)", actual, expected)
 6+		t.Fatalf("\nactual:(%s)\nexpected:(%s)", actual.String(), expected)
 7 	}
 8 }
 9 
10@@ -96,8 +96,8 @@ func TestMulticastPubBlock(t *testing.T) {
11 
12 	go func() {
13 		orderActual += "sub-"
14-		wg.Done()
15 		fmt.Println(cast.Sub(context.TODO(), "2", actual, []*Channel{channel}, false))
16+		wg.Done()
17 	}()
18 
19 	wg.Wait()
20@@ -106,7 +106,7 @@ func TestMulticastPubBlock(t *testing.T) {
21 		t.Fatalf("\norderActual:(%s)\norderExpected:(%s)", orderActual, orderExpected)
22 	}
23 	if actual.String() != expected {
24-		t.Fatalf("\nactual:(%s)\nexpected:(%s)", actual, expected)
25+		t.Fatalf("\nactual:(%s)\nexpected:(%s)", actual.String(), expected)
26 	}
27 }
28 
29@@ -156,9 +156,9 @@ func TestMulticastMultSubs(t *testing.T) {
30 		t.Fatalf("\norderActual:(%s)\norderExpected:(%s)", orderActual, orderExpected)
31 	}
32 	if actual.String() != expected {
33-		t.Fatalf("\nactual:(%s)\nexpected:(%s)", actual, expected)
34+		t.Fatalf("\nactual:(%s)\nexpected:(%s)", actual.String(), expected)
35 	}
36 	if actualOther.String() != expected {
37-		t.Fatalf("\nactual:(%s)\nexpected:(%s)", actualOther, expected)
38+		t.Fatalf("\nactual:(%s)\nexpected:(%s)", actualOther.String(), expected)
39 	}
40 }
A pkg/pubsub/regression_test.go
+224, -0
  1@@ -0,0 +1,224 @@
  2+package pubsub
  3+
  4+import (
  5+	"bytes"
  6+	"context"
  7+	"fmt"
  8+	"log/slog"
  9+	"sync"
 10+	"sync/atomic"
 11+	"testing"
 12+	"time"
 13+)
 14+
 15+// TestChannelMessageOrdering verifies that messages are delivered without panics or corruption.
 16+// This applies to both multicast and round-robin dispatchers.
 17+func TestChannelMessageOrdering(t *testing.T) {
 18+	name := "order-test"
 19+	numMessages := 3
 20+
 21+	// Test with Multicast
 22+	t.Run("Multicast", func(t *testing.T) {
 23+		cast := NewMulticast(slog.Default())
 24+		buf := new(Buffer)
 25+		channel := NewChannel(name)
 26+
 27+		var wg sync.WaitGroup
 28+		syncer := make(chan int)
 29+
 30+		// Subscribe
 31+		wg.Add(1)
 32+		go func() {
 33+			defer wg.Done()
 34+			syncer <- 0
 35+			_ = cast.Sub(context.TODO(), "sub", buf, []*Channel{channel}, false)
 36+		}()
 37+
 38+		<-syncer
 39+
 40+		// Publish messages
 41+		for i := 0; i < numMessages; i++ {
 42+			wg.Add(1)
 43+			idx := i
 44+			go func() {
 45+				defer wg.Done()
 46+				msg := fmt.Sprintf("msg%d\n", idx)
 47+				_ = cast.Pub(context.TODO(), "pub", &Buffer{b: *bytes.NewBufferString(msg)}, []*Channel{channel}, false)
 48+			}()
 49+		}
 50+
 51+		wg.Wait()
 52+
 53+		// Verify at least some messages were received
 54+		content := buf.String()
 55+		if len(content) == 0 {
 56+			t.Error("Multicast: no messages received")
 57+		}
 58+	})
 59+}
 60+
 61+// TestDispatcherClientDirection verifies that both dispatchers respect client direction.
 62+// Publishers should not receive messages they publish.
 63+func TestDispatcherClientDirection(t *testing.T) {
 64+	name := "direction-test"
 65+
 66+	t.Run("Multicast", func(t *testing.T) {
 67+		cast := NewMulticast(slog.Default())
 68+		pubBuf := new(Buffer)
 69+		subBuf := new(Buffer)
 70+		channel := NewChannel(name)
 71+
 72+		var wg sync.WaitGroup
 73+
 74+		// Publisher (input only)
 75+		wg.Add(1)
 76+		go func() {
 77+			defer wg.Done()
 78+			_ = cast.Pub(context.TODO(), "pub", &Buffer{b: *bytes.NewBufferString("test")}, []*Channel{channel}, false)
 79+		}()
 80+
 81+		// Subscriber (output only)
 82+		wg.Add(1)
 83+		go func() {
 84+			defer wg.Done()
 85+			_ = cast.Sub(context.TODO(), "sub", subBuf, []*Channel{channel}, false)
 86+		}()
 87+
 88+		wg.Wait()
 89+
 90+		// Publisher should not receive the message
 91+		if pubBuf.String() != "" {
 92+			t.Errorf("Publisher received message: %q", pubBuf.String())
 93+		}
 94+
 95+		// Subscriber should receive it
 96+		if subBuf.String() != "test" {
 97+			t.Errorf("Subscriber should have received message, got: %q", subBuf.String())
 98+		}
 99+	})
100+}
101+
102+// TestChannelConcurrentPublishes verifies that concurrent publishes don't cause races or data loss.
103+func TestChannelConcurrentPublishes(t *testing.T) {
104+	name := "concurrent-test"
105+	numPublishers := 10
106+	msgsPerPublisher := 5
107+	numSubscribers := 3
108+
109+	t.Run("Multicast", func(t *testing.T) {
110+		cast := NewMulticast(slog.Default())
111+		buffers := make([]*Buffer, numSubscribers)
112+		for i := range buffers {
113+			buffers[i] = new(Buffer)
114+		}
115+		channel := NewChannel(name)
116+
117+		var wg sync.WaitGroup
118+
119+		// Subscribe
120+		for i := range buffers {
121+			wg.Add(1)
122+			idx := i
123+			go func() {
124+				defer wg.Done()
125+				_ = cast.Sub(context.TODO(), fmt.Sprintf("sub-%d", idx), buffers[idx], []*Channel{channel}, false)
126+			}()
127+		}
128+		time.Sleep(100 * time.Millisecond)
129+
130+		// Concurrent publishers
131+		pubCount := int32(0)
132+		for p := 0; p < numPublishers; p++ {
133+			pubID := p
134+			for m := 0; m < msgsPerPublisher; m++ {
135+				wg.Add(1)
136+				msgNum := m
137+				go func() {
138+					defer wg.Done()
139+					msg := fmt.Sprintf("pub%d-msg%d\n", pubID, msgNum)
140+					_ = cast.Pub(context.TODO(), fmt.Sprintf("pub-%d", pubID), &Buffer{b: *bytes.NewBufferString(msg)}, []*Channel{channel}, false)
141+					atomic.AddInt32(&pubCount, 1)
142+				}()
143+			}
144+		}
145+
146+		wg.Wait()
147+
148+		// Verify all messages delivered to all subscribers
149+		totalExpectedMessages := numPublishers * msgsPerPublisher
150+		for i, buf := range buffers {
151+			messageCount := bytes.Count([]byte(buf.String()), []byte("\n"))
152+			if messageCount != totalExpectedMessages {
153+				t.Errorf("Subscriber %d: expected %d messages, got %d", i, totalExpectedMessages, messageCount)
154+			}
155+		}
156+
157+		// Verify all publishes completed
158+		if pubCount != int32(totalExpectedMessages) {
159+			t.Errorf("Expected %d publishes to complete, got %d", totalExpectedMessages, pubCount)
160+		}
161+	})
162+}
163+
164+// TestDispatcherEmptySubscribers verifies that dispatchers handle empty subscriber set without panic.
165+func TestDispatcherEmptySubscribers(t *testing.T) {
166+	name := "empty-subs-test"
167+
168+	t.Run("Multicast", func(t *testing.T) {
169+		cast := NewMulticast(slog.Default())
170+		channel := NewChannel(name)
171+
172+		var wg sync.WaitGroup
173+
174+		// Publish with no subscribers (should not panic)
175+		wg.Add(1)
176+		go func() {
177+			defer wg.Done()
178+			defer func() {
179+				if r := recover(); r != nil {
180+					t.Errorf("Multicast panicked with no subscribers: %v", r)
181+				}
182+			}()
183+			_ = cast.Pub(context.TODO(), "pub", &Buffer{b: *bytes.NewBufferString("test")}, []*Channel{channel}, false)
184+		}()
185+
186+		wg.Wait()
187+		t.Log("Multicast handled empty subscribers correctly")
188+	})
189+}
190+
191+// TestDispatcherSingleSubscriber verifies that both dispatchers work correctly with one subscriber.
192+func TestDispatcherSingleSubscriber(t *testing.T) {
193+	name := "single-sub-test"
194+	message := "single-sub-message"
195+
196+	t.Run("Multicast", func(t *testing.T) {
197+		cast := NewMulticast(slog.Default())
198+		buf := new(Buffer)
199+		channel := NewChannel(name)
200+
201+		var wg sync.WaitGroup
202+
203+		// Subscribe
204+		wg.Add(1)
205+		go func() {
206+			defer wg.Done()
207+			_ = cast.Sub(context.TODO(), "sub", buf, []*Channel{channel}, false)
208+		}()
209+
210+		time.Sleep(100 * time.Millisecond)
211+
212+		// Publish
213+		wg.Add(1)
214+		go func() {
215+			defer wg.Done()
216+			_ = cast.Pub(context.TODO(), "pub", &Buffer{b: *bytes.NewBufferString(message)}, []*Channel{channel}, false)
217+		}()
218+
219+		wg.Wait()
220+
221+		if buf.String() != message {
222+			t.Errorf("Multicast with single subscriber: expected %q, got %q", message, buf.String())
223+		}
224+	})
225+}