From 6355144b8c3514cccc5c2ab4f7c4fd8e76a1a9fc Mon Sep 17 00:00:00 2001 From: hozan23 Date: Sun, 23 Jun 2024 15:57:43 +0200 Subject: Fix the issue with message dispatcher and channels Resolved a previous error where each subscription would create a new channel with the fixed buffer size. This caused blocking when the channel buffer was full, preventing the client from handling additional messages. Now, there is a `subscriptions` struct that holds a queue for receiving notifications, ensuring the notify function does not block. --- jsonrpc/client/client.go | 124 ++++++++++++++++++------------ jsonrpc/client/concurrent_queue.go | 88 +++++++++++++++++++++ jsonrpc/client/message_dispatcher.go | 64 ++++++++------- jsonrpc/client/message_dispatcher_test.go | 49 ++++++------ jsonrpc/client/subscription.go | 83 ++++++++++++++++++++ jsonrpc/client/subscription_test.go | 86 +++++++++++++++++++++ jsonrpc/client/subscriptions.go | 81 +++++++++++++++++++ jsonrpc/client/subscriptions_test.go | 87 +++++++++++++++++++++ 8 files changed, 557 insertions(+), 105 deletions(-) create mode 100644 jsonrpc/client/concurrent_queue.go create mode 100644 jsonrpc/client/subscription.go create mode 100644 jsonrpc/client/subscription_test.go create mode 100644 jsonrpc/client/subscriptions.go create mode 100644 jsonrpc/client/subscriptions_test.go (limited to 'jsonrpc/client') diff --git a/jsonrpc/client/client.go b/jsonrpc/client/client.go index afcc5c7..1f12ced 100644 --- a/jsonrpc/client/client.go +++ b/jsonrpc/client/client.go @@ -3,9 +3,11 @@ package client import ( "bytes" "encoding/json" + "errors" "fmt" "math/rand" "strconv" + "sync/atomic" "time" "github.com/gorilla/websocket" @@ -20,21 +22,34 @@ const ( // Default timeout for receiving requests from the server, in milliseconds. DefaultTimeout = 3000 + + // The default buffer size for a subscription. + DefaultSubscriptionBufferSize = 10000 +) + +var ( + ClientIsDisconnectedErr = errors.New("Client is disconnected and closed") + TimeoutError = errors.New("Timeout Error") + InvalidResponseIDErr = errors.New("Invalid response ID") + InvalidResponseResultErr = errors.New("Invalid response result") + receivedStopSignalErr = errors.New("Received stop signal") ) // RPCClientConfig Holds the configuration settings for the RPC client. type RPCClientConfig struct { - Timeout int // Timeout for receiving requests from the server, in milliseconds. - Addr string // Address of the RPC server. + Timeout int // Timeout for receiving requests from the server, in milliseconds. + Addr string // Address of the RPC server. + SubscriptionBufferSize int // The buffer size for a subscription. } // RPCClient RPC Client type RPCClient struct { config RPCClientConfig conn *websocket.Conn - requests *messageDispatcher[message.RequestID, message.Response] - subscriptions *messageDispatcher[message.SubscriptionID, json.RawMessage] - stop_signal chan struct{} + requests *messageDispatcher + subscriptions *subscriptions + stopSignal chan struct{} + isClosed atomic.Bool } // NewRPCClient Creates a new instance of RPCClient with the provided configuration. @@ -46,25 +61,29 @@ func NewRPCClient(config RPCClientConfig) (*RPCClient, error) { } log.Infof("Successfully connected to the server: %s", config.Addr) - if config.Timeout == 0 { + if config.Timeout <= 0 { config.Timeout = DefaultTimeout } - stop_signal := make(chan struct{}) + if config.SubscriptionBufferSize <= 0 { + config.SubscriptionBufferSize = DefaultSubscriptionBufferSize + } + + stopSignal := make(chan struct{}) - requests := newMessageDispatcher[message.RequestID, message.Response](0) - subscriptions := newMessageDispatcher[message.SubscriptionID, json.RawMessage](100) + requests := newMessageDispatcher() + subscriptions := newSubscriptions(config.SubscriptionBufferSize) client := &RPCClient{ conn: conn, config: config, requests: requests, subscriptions: subscriptions, - stop_signal: stop_signal, + stopSignal: stopSignal, } go func() { - if err := client.backgroundReceivingLoop(stop_signal); err != nil { + if err := client.backgroundReceivingLoop(stopSignal); err != nil { client.Close() } }() @@ -74,9 +93,14 @@ func NewRPCClient(config RPCClientConfig) (*RPCClient, error) { // Close Closes the underlying websocket connection and stop the receiving loop. func (client *RPCClient) Close() { + // Check if it's already closed + if !client.isClosed.CompareAndSwap(false, true) { + return + } + log.Warn("Close the rpc client...") // Send stop signal to the background receiving loop - close(client.stop_signal) + close(client.stopSignal) // Close the underlying websocket connection err := client.conn.Close() @@ -90,7 +114,7 @@ func (client *RPCClient) Close() { // Call Sends an RPC call to the server with the specified method and // parameters, and returns the response. -func (client *RPCClient) Call(method string, params any) (*json.RawMessage, error) { +func (client *RPCClient) Call(method string, params any) (json.RawMessage, error) { log.Tracef("Call -> method: %s, params: %v", method, params) response, err := client.sendRequest(method, params) if err != nil { @@ -101,29 +125,27 @@ func (client *RPCClient) Call(method string, params any) (*json.RawMessage, erro } // Subscribe Sends a subscription request to the server with the specified -// method and parameters, and it returns the subscription ID and the channel to -// receive notifications. -func (client *RPCClient) Subscribe(method string, params any) (message.SubscriptionID, <-chan json.RawMessage, error) { +// method and parameters, and it returns the subscription. +func (client *RPCClient) Subscribe(method string, params any) (*Subscription, error) { log.Tracef("Sbuscribe -> method: %s, params: %v", method, params) response, err := client.sendRequest(method, params) if err != nil { - return 0, nil, err + return nil, err } if response.Result == nil { - return 0, nil, fmt.Errorf("Invalid response result") + return nil, InvalidResponseResultErr } var subID message.SubscriptionID - err = json.Unmarshal(*response.Result, &subID) + err = json.Unmarshal(response.Result, &subID) if err != nil { - return 0, nil, err + return nil, err } - // Register a new subscription - sub := client.subscriptions.register(subID) + sub := client.subscriptions.subscribe(subID) - return subID, sub, nil + return sub, nil } // Unsubscribe Sends an unsubscription request to the server to cancel the @@ -135,44 +157,49 @@ func (client *RPCClient) Unsubscribe(method string, subID message.SubscriptionID return err } - // On success unregister the subscription channel - client.subscriptions.unregister(subID) + // On success unsubscribe + client.subscriptions.unsubscribe(subID) return nil } // backgroundReceivingLoop Starts reading new messages from the underlying connection. -func (client *RPCClient) backgroundReceivingLoop(stop_signal <-chan struct{}) error { +func (client *RPCClient) backgroundReceivingLoop(stopSignal <-chan struct{}) error { log.Debug("Background loop started") - new_msg := make(chan []byte) - receive_err := make(chan error) + newMsgCh := make(chan []byte) + receiveErrCh := make(chan error) // Start listing for new messages go func() { for { _, msg, err := client.conn.ReadMessage() if err != nil { - receive_err <- err + receiveErrCh <- err + return + } + select { + case <-client.stopSignal: return + case newMsgCh <- msg: } - new_msg <- msg } }() for { select { - case <-stop_signal: - log.Warn("Stopping background receiving loop: received stop signal") + case err := <-receiveErrCh: + log.WithError(err).Error("Read a new msg") + return err + case <-stopSignal: + log.Debug("Background receiving loop stopped %w", receivedStopSignalErr) return nil - case msg := <-new_msg: + case msg := <-newMsgCh: err := client.handleNewMsg(msg) if err != nil { - log.WithError(err).Error("Handle a new received msg") + log.WithError(err).Error("Handle a msg") + return err } - case err := <-receive_err: - log.WithError(err).Error("Receive a new msg") - return err } } } @@ -186,7 +213,7 @@ func (client *RPCClient) handleNewMsg(msg []byte) error { decoder.DisallowUnknownFields() if err := decoder.Decode(&response); err == nil { if response.ID == nil { - return fmt.Errorf("Response doesn't have an id") + return InvalidResponseIDErr } err := client.requests.dispatch(*response.ID, response) @@ -202,18 +229,13 @@ func (client *RPCClient) handleNewMsg(msg []byte) error { if err := json.Unmarshal(msg, ¬ification); err == nil { ntRes := message.NotificationResult{} - if err := json.Unmarshal(*notification.Params, &ntRes); err != nil { + if err := json.Unmarshal(notification.Params, &ntRes); err != nil { return fmt.Errorf("Failed to unmarshal notification params: %w", err) } - // TODO: Consider using a more efficient design here because, - // on each registration, messageDispatcher will create a new subscription - // channel with the provided buffer size. If the buffer for the subscription - // channel is full, calling the dispatch method will block and in this - // case the client will not be able to handle any more messages. - err := client.subscriptions.dispatch(ntRes.Subscription, *ntRes.Result) + err := client.subscriptions.notify(ntRes.Subscription, ntRes.Result) if err != nil { - return fmt.Errorf("Dispatch a notification: %w", err) + return fmt.Errorf("Notify a subscriber: %w", err) } log.Debugf("<-- %s", notification.String()) @@ -228,6 +250,10 @@ func (client *RPCClient) handleNewMsg(msg []byte) error { func (client *RPCClient) sendRequest(method string, params any) (message.Response, error) { response := message.Response{} + if client.isClosed.Load() { + return response, ClientIsDisconnectedErr + } + params_bytes, err := json.Marshal(params) if err != nil { return response, err @@ -262,8 +288,10 @@ func (client *RPCClient) sendRequest(method string, params any) (message.Respons // Waits the response, it fails and return error if it exceed the timeout select { case response = <-rx_ch: + case <-client.stopSignal: + return response, ClientIsDisconnectedErr case <-time.After(time.Duration(client.config.Timeout) * time.Millisecond): - return response, fmt.Errorf("Timeout error") + return response, TimeoutError } err = validateResponse(&response, id) @@ -285,7 +313,7 @@ func validateResponse(res *message.Response, reqID message.RequestID) error { if res.ID != nil { if *res.ID != reqID { - return fmt.Errorf("Invalid response id") + return InvalidResponseIDErr } } diff --git a/jsonrpc/client/concurrent_queue.go b/jsonrpc/client/concurrent_queue.go new file mode 100644 index 0000000..daf92cd --- /dev/null +++ b/jsonrpc/client/concurrent_queue.go @@ -0,0 +1,88 @@ +package client + +import ( + "errors" + "sync" + "sync/atomic" +) + +// queue A concurrent queue. +type queue[T any] struct { + lock sync.Mutex + cond *sync.Cond + items []T + bufferSize int + stopSignal chan struct{} + isClosed atomic.Bool +} + +var ( + queueIsFullErr = errors.New("Queue is full") + queueIsClosedErr = errors.New("Queue is closed") +) + +// newQueue creates a new queue with the specified buffer size. +func newQueue[T any](bufferSize int) *queue[T] { + q := &queue[T]{ + bufferSize: bufferSize, + items: make([]T, 0), + stopSignal: make(chan struct{}), + } + q.cond = sync.NewCond(&q.lock) + return q +} + +// push Adds a new item to the queue and returns an error if the queue +// is full. +func (q *queue[T]) push(item T) error { + if q.isClosed.Load() { + return queueIsClosedErr + } + + q.lock.Lock() + defer q.lock.Unlock() + if len(q.items) >= q.bufferSize { + return queueIsFullErr + } + + q.items = append(q.items, item) + q.cond.Signal() + return nil +} + +// pop waits for and removes the first element from the queue, then returns it. +func (q *queue[T]) pop() (T, bool) { + var t T + if q.isClosed.Load() { + return t, false + } + q.lock.Lock() + defer q.lock.Unlock() + + // Wait for an item to be available or for a stop signal. + // This ensures that the waiting stops once the queue is cleared. + for len(q.items) == 0 { + select { + case <-q.stopSignal: + return t, false + default: + q.cond.Wait() + } + } + + item := q.items[0] + q.items = q.items[1:] + return item, true +} + +// clear Clears all elements from the queue. +func (q *queue[T]) clear() { + if !q.isClosed.CompareAndSwap(false, true) { + return + } + q.lock.Lock() + defer q.lock.Unlock() + close(q.stopSignal) + q.items = nil + q.cond.Broadcast() +} diff --git a/jsonrpc/client/message_dispatcher.go b/jsonrpc/client/message_dispatcher.go index 9177f6e..adedd1c 100644 --- a/jsonrpc/client/message_dispatcher.go +++ b/jsonrpc/client/message_dispatcher.go @@ -1,76 +1,74 @@ package client import ( - "fmt" + "errors" "sync" + + "github.com/karyontech/karyon-go/jsonrpc/message" +) + +var ( + requestChannelNotFoundErr = errors.New("Request channel not found") ) -// messageDispatcher Is a generic structure that holds a map of keys and +// messageDispatcher Is a structure that holds a map of request IDs and // channels, and it is protected by mutex -type messageDispatcher[K comparable, V any] struct { +type messageDispatcher struct { sync.Mutex - chans map[K]chan<- V - bufferSize int + chans map[message.RequestID]chan<- message.Response } // newMessageDispatcher Creates a new messageDispatcher -func newMessageDispatcher[K comparable, V any](bufferSize int) *messageDispatcher[K, V] { - chans := make(map[K]chan<- V) - return &messageDispatcher[K, V]{ - chans: chans, - bufferSize: bufferSize, +func newMessageDispatcher() *messageDispatcher { + chans := make(map[message.RequestID]chan<- message.Response) + return &messageDispatcher{ + chans: chans, } } -// register Registers a new channel with a given key. It returns the receiving channel. -func (c *messageDispatcher[K, V]) register(key K) <-chan V { +// register Registers a new request channel with the given id. It returns a +// channel for receiving response. +func (c *messageDispatcher) register(key message.RequestID) <-chan message.Response { c.Lock() defer c.Unlock() - ch := make(chan V, c.bufferSize) + ch := make(chan message.Response) c.chans[key] = ch return ch } -// length Returns the number of channels -func (c *messageDispatcher[K, V]) length() int { +// dispatch Disptaches the response to the channel with the given request id +func (c *messageDispatcher) dispatch(key message.RequestID, res message.Response) error { c.Lock() defer c.Unlock() - return len(c.chans) -} - -// dispatch Disptaches the msg to the channel with the given key -func (c *messageDispatcher[K, V]) dispatch(key K, msg V) error { - c.Lock() - ch, ok := c.chans[key] - c.Unlock() - - if !ok { - return fmt.Errorf("Channel not found") + if ch, ok := c.chans[key]; ok { + ch <- res + } else { + return requestChannelNotFoundErr } - ch <- msg return nil } -// unregister Unregisters the channel with the provided key -func (c *messageDispatcher[K, V]) unregister(key K) { +// unregister Unregisters the request with the provided id +func (c *messageDispatcher) unregister(key message.RequestID) { c.Lock() defer c.Unlock() + if ch, ok := c.chans[key]; ok { close(ch) delete(c.chans, key) } } -// clear Closes all the channels and remove them from the map -func (c *messageDispatcher[K, V]) clear() { +// clear Closes all the request channels and remove them from the map +func (c *messageDispatcher) clear() { c.Lock() defer c.Unlock() - for k, ch := range c.chans { + for _, ch := range c.chans { close(ch) - delete(c.chans, k) } + c.chans = nil } diff --git a/jsonrpc/client/message_dispatcher_test.go b/jsonrpc/client/message_dispatcher_test.go index a1fc1c6..4d7cc60 100644 --- a/jsonrpc/client/message_dispatcher_test.go +++ b/jsonrpc/client/message_dispatcher_test.go @@ -5,58 +5,61 @@ import ( "sync/atomic" "testing" + "github.com/karyontech/karyon-go/jsonrpc/message" "github.com/stretchr/testify/assert" ) func TestDispatchToChannel(t *testing.T) { - messageDispatcher := newMessageDispatcher[int, int](10) + messageDispatcher := newMessageDispatcher() - chanKey := 1 - rx := messageDispatcher.register(chanKey) + req1 := "1" + rx := messageDispatcher.register(req1) - chanKey2 := 2 - rx2 := messageDispatcher.register(chanKey2) + req2 := "2" + rx2 := messageDispatcher.register(req2) var wg sync.WaitGroup wg.Add(1) go func() { + defer wg.Done() for i := 0; i < 50; i++ { - err := messageDispatcher.dispatch(chanKey, i) + res := message.Response{ID: &req1} + err := messageDispatcher.dispatch(req1, res) assert.Nil(t, err) } - messageDispatcher.unregister(chanKey) - wg.Done() + messageDispatcher.unregister(req1) }() wg.Add(1) go func() { + defer wg.Done() for i := 0; i < 50; i++ { - err := messageDispatcher.dispatch(chanKey2, i) + res := message.Response{ID: &req2} + err := messageDispatcher.dispatch(req2, res) assert.Nil(t, err) } - messageDispatcher.unregister(chanKey2) - wg.Done() + messageDispatcher.unregister(req2) }() var receivedItem atomic.Int32 wg.Add(1) go func() { + defer wg.Done() for range rx { receivedItem.Add(1) } - wg.Done() }() wg.Add(1) go func() { + defer wg.Done() for range rx2 { receivedItem.Add(1) } - wg.Done() }() wg.Wait() @@ -64,33 +67,31 @@ func TestDispatchToChannel(t *testing.T) { } func TestUnregisterChannel(t *testing.T) { - messageDispatcher := newMessageDispatcher[int, int](1) + messageDispatcher := newMessageDispatcher() - chanKey := 1 - rx := messageDispatcher.register(chanKey) + req := "1" + rx := messageDispatcher.register(req) - messageDispatcher.unregister(chanKey) - assert.Equal(t, messageDispatcher.length(), 0, "channels should be empty") + messageDispatcher.unregister(req) _, ok := <-rx assert.False(t, ok, "chan closed") - err := messageDispatcher.dispatch(chanKey, 1) + err := messageDispatcher.dispatch(req, message.Response{ID: &req}) assert.NotNil(t, err) } func TestClearChannels(t *testing.T) { - messageDispatcher := newMessageDispatcher[int, int](1) + messageDispatcher := newMessageDispatcher() - chanKey := 1 - rx := messageDispatcher.register(chanKey) + req := "1" + rx := messageDispatcher.register(req) messageDispatcher.clear() - assert.Equal(t, messageDispatcher.length(), 0, "channels should be empty") _, ok := <-rx assert.False(t, ok, "chan closed") - err := messageDispatcher.dispatch(chanKey, 1) + err := messageDispatcher.dispatch(req, message.Response{ID: &req}) assert.NotNil(t, err) } diff --git a/jsonrpc/client/subscription.go b/jsonrpc/client/subscription.go new file mode 100644 index 0000000..a9a94e7 --- /dev/null +++ b/jsonrpc/client/subscription.go @@ -0,0 +1,83 @@ +package client + +import ( + "encoding/json" + "errors" + "fmt" + "sync/atomic" + + log "github.com/sirupsen/logrus" +) + +var ( + subscriptionIsClosedErr = errors.New("Subscription is closed") +) + +// Subscription A subscription established when the client's subscribe to a method +type Subscription struct { + ch chan json.RawMessage + ID int + queue *queue[json.RawMessage] + stopSignal chan struct{} + isClosed atomic.Bool +} + +// newSubscription Creates a new Subscription +func newSubscription(subID int, bufferSize int) *Subscription { + sub := &Subscription{ + ch: make(chan json.RawMessage), + ID: subID, + queue: newQueue[json.RawMessage](bufferSize), + stopSignal: make(chan struct{}), + } + sub.startBackgroundJob() + + return sub +} + +// Recv Receives a new notification. +func (s *Subscription) Recv() <-chan json.RawMessage { + return s.ch +} + +// startBackgroundJob starts waiting for the queue to receive new items. +// It stops when it receives a stop signal. +func (s *Subscription) startBackgroundJob() { + go func() { + logger := log.WithField("Subscription", s.ID) + for { + msg, ok := s.queue.pop() + if !ok { + logger.Debug("Background job stopped") + return + } + select { + case <-s.stopSignal: + logger.Debug("Background job stopped: %w", receivedStopSignalErr) + return + case s.ch <- msg: + } + } + }() +} + +// notify adds a new notification to the queue. +func (s *Subscription) notify(nt json.RawMessage) error { + if s.isClosed.Load() { + return subscriptionIsClosedErr + } + if err := s.queue.push(nt); err != nil { + return fmt.Errorf("Unable to push new notification: %w", err) + } + return nil +} + +// stop Terminates the subscription, clears the queue, and closes channels. +func (s *Subscription) stop() { + if !s.isClosed.CompareAndSwap(false, true) { + return + } + close(s.stopSignal) + close(s.ch) + s.queue.clear() +} diff --git a/jsonrpc/client/subscription_test.go b/jsonrpc/client/subscription_test.go new file mode 100644 index 0000000..5928665 --- /dev/null +++ b/jsonrpc/client/subscription_test.go @@ -0,0 +1,86 @@ +package client + +import ( + "encoding/json" + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSubscriptionFullQueue(t *testing.T) { + bufSize := 100 + sub := newSubscription(1, bufSize) + + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + defer sub.stop() + for i := 0; i < bufSize+10; i++ { + b, err := json.Marshal(i) + assert.Nil(t, err) + err = sub.notify(b) + if i > bufSize { + if assert.Error(t, err) { + assert.ErrorIs(t, err, queueIsFullErr) + } + } + } + }() + + wg.Wait() +} + +func TestSubscriptionRecv(t *testing.T) { + bufSize := 100 + sub := newSubscription(1, bufSize) + + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < bufSize; i++ { + b, err := json.Marshal(i) + assert.Nil(t, err) + err = sub.notify(b) + assert.Nil(t, err) + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + i := 0 + for nt := range sub.Recv() { + var v int + err := json.Unmarshal(nt, &v) + assert.Nil(t, err) + assert.Equal(t, v, i) + i += 1 + if i == bufSize { + break + } + } + }() + + wg.Wait() +} + +func TestSubscriptionStop(t *testing.T) { + sub := newSubscription(1, 10) + + sub.stop() + + _, ok := <-sub.Recv() + assert.False(t, ok) + + b, err := json.Marshal(1) + assert.Nil(t, err) + err = sub.notify(b) + if assert.Error(t, err) { + assert.ErrorIs(t, err, subscriptionIsClosedErr) + } +} diff --git a/jsonrpc/client/subscriptions.go b/jsonrpc/client/subscriptions.go new file mode 100644 index 0000000..0524837 --- /dev/null +++ b/jsonrpc/client/subscriptions.go @@ -0,0 +1,81 @@ +package client + +import ( + "encoding/json" + "errors" + "sync" + + "github.com/karyontech/karyon-go/jsonrpc/message" +) + +var ( + subscriptionNotFoundErr = errors.New("Subscription not found") +) + +// subscriptions Is a structure that holds a map of subscription IDs and +// subscriptions +type subscriptions struct { + sync.Mutex + subs map[message.SubscriptionID]*Subscription + bufferSize int +} + +// newSubscriptions Creates a new subscriptions +func newSubscriptions(bufferSize int) *subscriptions { + subs := make(map[message.SubscriptionID]*Subscription) + return &subscriptions{ + subs: subs, + bufferSize: bufferSize, + } +} + +// subscribe Subscribes and returns a Subscription. +func (c *subscriptions) subscribe(key message.SubscriptionID) *Subscription { + c.Lock() + defer c.Unlock() + + sub := newSubscription(key, c.bufferSize) + c.subs[key] = sub + return sub +} + +// notify Notifies the msg the subscription with the given id +func (c *subscriptions) notify(key message.SubscriptionID, msg json.RawMessage) error { + c.Lock() + defer c.Unlock() + + sub, ok := c.subs[key] + + if !ok { + return subscriptionNotFoundErr + } + + err := sub.notify(msg) + + if err != nil { + return err + } + + return nil +} + +// unsubscribe Unsubscribe from the subscription with the provided id +func (c *subscriptions) unsubscribe(key message.SubscriptionID) { + c.Lock() + defer c.Unlock() + if sub, ok := c.subs[key]; ok { + sub.stop() + delete(c.subs, key) + } +} + +// clear Stops all the subscriptions and remove them from the map +func (c *subscriptions) clear() { + c.Lock() + defer c.Unlock() + + for _, sub := range c.subs { + sub.stop() + } + c.subs = nil +} diff --git a/jsonrpc/client/subscriptions_test.go b/jsonrpc/client/subscriptions_test.go new file mode 100644 index 0000000..cb5705a --- /dev/null +++ b/jsonrpc/client/subscriptions_test.go @@ -0,0 +1,87 @@ +package client + +import ( + "encoding/json" + "sync" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSubscriptionsSubscribe(t *testing.T) { + bufSize := 100 + subs := newSubscriptions(bufSize) + + var receivedNotifications atomic.Int32 + + var wg sync.WaitGroup + + runSubNotify := func(sub *Subscription) { + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < bufSize; i++ { + b, err := json.Marshal(i) + assert.Nil(t, err) + err = sub.notify(b) + assert.Nil(t, err) + } + }() + } + + runSubRecv := func(sub *Subscription) { + wg.Add(1) + go func() { + defer wg.Done() + i := 0 + for nt := range sub.Recv() { + var v int + err := json.Unmarshal(nt, &v) + assert.Nil(t, err) + assert.Equal(t, v, i) + receivedNotifications.Add(1) + i += 1 + if i == bufSize { + break + } + } + sub.stop() + }() + } + + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 3; i++ { + sub := subs.subscribe(i) + runSubNotify(sub) + runSubRecv(sub) + } + }() + + wg.Wait() + assert.Equal(t, receivedNotifications.Load(), int32(bufSize*3)) +} + +func TestSubscriptionsUnsubscribe(t *testing.T) { + bufSize := 100 + subs := newSubscriptions(bufSize) + + var wg sync.WaitGroup + + sub := subs.subscribe(1) + subs.unsubscribe(1) + + _, ok := <-sub.Recv() + assert.False(t, ok) + + b, err := json.Marshal(1) + assert.Nil(t, err) + err = sub.notify(b) + if assert.Error(t, err) { + assert.ErrorIs(t, err, subscriptionIsClosedErr) + } + + wg.Wait() +} -- cgit v1.2.3