Eric Bower
·
2026-02-03
broker.go
1package pubsub
2
3import (
4 "errors"
5 "io"
6 "iter"
7 "log/slog"
8 "reflect"
9 "sync"
10 "time"
11
12 "github.com/antoniomika/syncmap"
13)
14
15/*
16Broker receives published messages and dispatches the message to the
17subscribing clients. An message contains a message topic that clients
18subscribe to and brokers use these subscription lists for determining the
19clients to receive the message.
20*/
21type Broker interface {
22 GetChannels() iter.Seq2[string, *Channel]
23 GetClients() iter.Seq2[string, *Client]
24 Connect(*Client, []*Channel) (error, error)
25 SetDispatcher(dispatcher MessageDispatcher, channels []*Channel) error
26}
27
28type BaseBroker struct {
29 Channels *syncmap.Map[string, *Channel]
30 Logger *slog.Logger
31}
32
33func (b *BaseBroker) Cleanup() {
34 toRemove := []string{}
35 for _, channel := range b.GetChannels() {
36 count := 0
37
38 for range channel.GetClients() {
39 count++
40 }
41
42 if count == 0 {
43 channel.Cleanup()
44 toRemove = append(toRemove, channel.Topic)
45 }
46 }
47
48 for _, channel := range toRemove {
49 b.Channels.Delete(channel)
50 }
51}
52
53func (b *BaseBroker) GetChannels() iter.Seq2[string, *Channel] {
54 return b.Channels.Range
55}
56
57func (b *BaseBroker) GetClients() iter.Seq2[string, *Client] {
58 return func(yield func(string, *Client) bool) {
59 for _, channel := range b.GetChannels() {
60 channel.Clients.Range(yield)
61 }
62 }
63}
64
65func (b *BaseBroker) Connect(client *Client, channels []*Channel) (error, error) {
66 for _, channel := range channels {
67 dataChannel := b.ensureChannel(channel)
68 dataChannel.Clients.Store(client.ID, client)
69 client.Channels.Store(dataChannel.Topic, dataChannel)
70 defer func() {
71 client.Channels.Delete(channel.Topic)
72 dataChannel.Clients.Delete(client.ID)
73
74 client.Cleanup()
75
76 count := 0
77 for _, cl := range dataChannel.GetClients() {
78 if cl.Direction == ChannelDirectionInput || cl.Direction == ChannelDirectionInputOutput {
79 count++
80 }
81 }
82
83 if count == 0 {
84 for _, cl := range dataChannel.GetClients() {
85 if !cl.KeepAlive {
86 cl.Cleanup()
87 }
88 }
89 }
90
91 b.Cleanup()
92 }()
93 }
94
95 var (
96 inputErr error
97 outputErr error
98 wg sync.WaitGroup
99 )
100
101 // Pub
102 if client.Direction == ChannelDirectionInput || client.Direction == ChannelDirectionInputOutput {
103 wg.Add(1)
104 go func() {
105 defer wg.Done()
106 for {
107 data := make([]byte, 32*1024)
108 n, err := client.ReadWriter.Read(data)
109
110 data = data[:n]
111
112 channelMessage := ChannelMessage{
113 Data: data,
114 ClientID: client.ID,
115 Direction: ChannelDirectionInput,
116 }
117
118 if client.BlockWrite {
119 mainLoop:
120 for {
121 count := 0
122 for _, channel := range client.GetChannels() {
123 for _, chanClient := range channel.GetClients() {
124 if chanClient.Direction == ChannelDirectionOutput || chanClient.Direction == ChannelDirectionInputOutput {
125 count++
126 }
127 }
128 }
129
130 if count > 0 {
131 break mainLoop
132 }
133
134 select {
135 case <-client.Done:
136 break mainLoop
137 case <-time.After(1 * time.Millisecond):
138 continue
139 }
140 }
141 }
142
143 var sendwg sync.WaitGroup
144
145 for _, channel := range client.GetChannels() {
146 sendwg.Add(1)
147 go func() {
148 defer sendwg.Done()
149 select {
150 case channel.Data <- channelMessage:
151 case <-client.Done:
152 case <-channel.Done:
153 }
154 }()
155 }
156
157 sendwg.Wait()
158
159 if err != nil {
160 if errors.Is(err, io.EOF) {
161 return
162 }
163 inputErr = err
164 return
165 }
166 }
167 }()
168 }
169
170 // Sub
171 if client.Direction == ChannelDirectionOutput || client.Direction == ChannelDirectionInputOutput {
172 wg.Add(1)
173 go func() {
174 defer wg.Done()
175 mainLoop:
176 for {
177 select {
178 case data, ok := <-client.Data:
179 _, err := client.ReadWriter.Write(data.Data)
180 if err != nil {
181 outputErr = err
182 break mainLoop
183 }
184
185 if !ok {
186 break mainLoop
187 }
188 case <-client.Done:
189 break mainLoop
190 }
191 }
192 }()
193 }
194
195 wg.Wait()
196
197 return inputErr, outputErr
198}
199
200func (b *BaseBroker) ensureChannel(channel *Channel) *Channel {
201 dataChannel, _ := b.Channels.LoadOrStore(channel.Topic, channel)
202 dataChannel.Handle()
203 return dataChannel
204}
205
206func (b *BaseBroker) SetDispatcher(dispatcher MessageDispatcher, channels []*Channel) error {
207 for _, channel := range channels {
208 dataChannel := b.ensureChannel(channel)
209 existingDispatcher := dataChannel.GetDispatcher()
210 if reflect.TypeOf(existingDispatcher) != reflect.TypeOf(dispatcher) {
211 dataChannel.SetDispatcher(dispatcher)
212 }
213 }
214 return nil
215}
216
217var _ Broker = (*BaseBroker)(nil)