repos / pico

pico services mono repo
git clone https://github.com/picosh/pico.git

pico / pkg / pubsub
Eric Bower  ·  2026-02-03

broker.go

  1package pubsub
  2
  3import (
  4	"errors"
  5	"io"
  6	"iter"
  7	"log/slog"
  8	"reflect"
  9	"sync"
 10	"time"
 11
 12	"github.com/antoniomika/syncmap"
 13)
 14
 15/*
 16Broker receives published messages and dispatches the message to the
 17subscribing clients. An message contains a message topic that clients
 18subscribe to and brokers use these subscription lists for determining the
 19clients to receive the message.
 20*/
 21type Broker interface {
 22	GetChannels() iter.Seq2[string, *Channel]
 23	GetClients() iter.Seq2[string, *Client]
 24	Connect(*Client, []*Channel) (error, error)
 25	SetDispatcher(dispatcher MessageDispatcher, channels []*Channel) error
 26}
 27
 28type BaseBroker struct {
 29	Channels *syncmap.Map[string, *Channel]
 30	Logger   *slog.Logger
 31}
 32
 33func (b *BaseBroker) Cleanup() {
 34	toRemove := []string{}
 35	for _, channel := range b.GetChannels() {
 36		count := 0
 37
 38		for range channel.GetClients() {
 39			count++
 40		}
 41
 42		if count == 0 {
 43			channel.Cleanup()
 44			toRemove = append(toRemove, channel.Topic)
 45		}
 46	}
 47
 48	for _, channel := range toRemove {
 49		b.Channels.Delete(channel)
 50	}
 51}
 52
 53func (b *BaseBroker) GetChannels() iter.Seq2[string, *Channel] {
 54	return b.Channels.Range
 55}
 56
 57func (b *BaseBroker) GetClients() iter.Seq2[string, *Client] {
 58	return func(yield func(string, *Client) bool) {
 59		for _, channel := range b.GetChannels() {
 60			channel.Clients.Range(yield)
 61		}
 62	}
 63}
 64
 65func (b *BaseBroker) Connect(client *Client, channels []*Channel) (error, error) {
 66	for _, channel := range channels {
 67		dataChannel := b.ensureChannel(channel)
 68		dataChannel.Clients.Store(client.ID, client)
 69		client.Channels.Store(dataChannel.Topic, dataChannel)
 70		defer func() {
 71			client.Channels.Delete(channel.Topic)
 72			dataChannel.Clients.Delete(client.ID)
 73
 74			client.Cleanup()
 75
 76			count := 0
 77			for _, cl := range dataChannel.GetClients() {
 78				if cl.Direction == ChannelDirectionInput || cl.Direction == ChannelDirectionInputOutput {
 79					count++
 80				}
 81			}
 82
 83			if count == 0 {
 84				for _, cl := range dataChannel.GetClients() {
 85					if !cl.KeepAlive {
 86						cl.Cleanup()
 87					}
 88				}
 89			}
 90
 91			b.Cleanup()
 92		}()
 93	}
 94
 95	var (
 96		inputErr  error
 97		outputErr error
 98		wg        sync.WaitGroup
 99	)
100
101	// Pub
102	if client.Direction == ChannelDirectionInput || client.Direction == ChannelDirectionInputOutput {
103		wg.Add(1)
104		go func() {
105			defer wg.Done()
106			for {
107				data := make([]byte, 32*1024)
108				n, err := client.ReadWriter.Read(data)
109
110				data = data[:n]
111
112				channelMessage := ChannelMessage{
113					Data:      data,
114					ClientID:  client.ID,
115					Direction: ChannelDirectionInput,
116				}
117
118				if client.BlockWrite {
119				mainLoop:
120					for {
121						count := 0
122						for _, channel := range client.GetChannels() {
123							for _, chanClient := range channel.GetClients() {
124								if chanClient.Direction == ChannelDirectionOutput || chanClient.Direction == ChannelDirectionInputOutput {
125									count++
126								}
127							}
128						}
129
130						if count > 0 {
131							break mainLoop
132						}
133
134						select {
135						case <-client.Done:
136							break mainLoop
137						case <-time.After(1 * time.Millisecond):
138							continue
139						}
140					}
141				}
142
143				var sendwg sync.WaitGroup
144
145				for _, channel := range client.GetChannels() {
146					sendwg.Add(1)
147					go func() {
148						defer sendwg.Done()
149						select {
150						case channel.Data <- channelMessage:
151						case <-client.Done:
152						case <-channel.Done:
153						}
154					}()
155				}
156
157				sendwg.Wait()
158
159				if err != nil {
160					if errors.Is(err, io.EOF) {
161						return
162					}
163					inputErr = err
164					return
165				}
166			}
167		}()
168	}
169
170	// Sub
171	if client.Direction == ChannelDirectionOutput || client.Direction == ChannelDirectionInputOutput {
172		wg.Add(1)
173		go func() {
174			defer wg.Done()
175		mainLoop:
176			for {
177				select {
178				case data, ok := <-client.Data:
179					_, err := client.ReadWriter.Write(data.Data)
180					if err != nil {
181						outputErr = err
182						break mainLoop
183					}
184
185					if !ok {
186						break mainLoop
187					}
188				case <-client.Done:
189					break mainLoop
190				}
191			}
192		}()
193	}
194
195	wg.Wait()
196
197	return inputErr, outputErr
198}
199
200func (b *BaseBroker) ensureChannel(channel *Channel) *Channel {
201	dataChannel, _ := b.Channels.LoadOrStore(channel.Topic, channel)
202	dataChannel.Handle()
203	return dataChannel
204}
205
206func (b *BaseBroker) SetDispatcher(dispatcher MessageDispatcher, channels []*Channel) error {
207	for _, channel := range channels {
208		dataChannel := b.ensureChannel(channel)
209		existingDispatcher := dataChannel.GetDispatcher()
210		if reflect.TypeOf(existingDispatcher) != reflect.TypeOf(dispatcher) {
211			dataChannel.SetDispatcher(dispatcher)
212		}
213	}
214	return nil
215}
216
217var _ Broker = (*BaseBroker)(nil)