aboutsummaryrefslogtreecommitdiff
path: root/jsonrpc
diff options
context:
space:
mode:
Diffstat (limited to 'jsonrpc')
-rw-r--r--jsonrpc/client/client.go124
-rw-r--r--jsonrpc/client/concurrent_queue.go88
-rw-r--r--jsonrpc/client/message_dispatcher.go64
-rw-r--r--jsonrpc/client/message_dispatcher_test.go49
-rw-r--r--jsonrpc/client/subscription.go83
-rw-r--r--jsonrpc/client/subscription_test.go86
-rw-r--r--jsonrpc/client/subscriptions.go81
-rw-r--r--jsonrpc/client/subscriptions_test.go87
-rw-r--r--jsonrpc/message/message.go12
9 files changed, 563 insertions, 111 deletions
diff --git a/jsonrpc/client/client.go b/jsonrpc/client/client.go
index afcc5c7..1f12ced 100644
--- a/jsonrpc/client/client.go
+++ b/jsonrpc/client/client.go
@@ -3,9 +3,11 @@ package client
import (
"bytes"
"encoding/json"
+ "errors"
"fmt"
"math/rand"
"strconv"
+ "sync/atomic"
"time"
"github.com/gorilla/websocket"
@@ -20,21 +22,34 @@ const (
// Default timeout for receiving requests from the server, in milliseconds.
DefaultTimeout = 3000
+
+ // The default buffer size for a subscription.
+ DefaultSubscriptionBufferSize = 10000
+)
+
+var (
+ ClientIsDisconnectedErr = errors.New("Client is disconnected and closed")
+ TimeoutError = errors.New("Timeout Error")
+ InvalidResponseIDErr = errors.New("Invalid response ID")
+ InvalidResponseResultErr = errors.New("Invalid response result")
+ receivedStopSignalErr = errors.New("Received stop signal")
)
// RPCClientConfig Holds the configuration settings for the RPC client.
type RPCClientConfig struct {
- Timeout int // Timeout for receiving requests from the server, in milliseconds.
- Addr string // Address of the RPC server.
+ Timeout int // Timeout for receiving requests from the server, in milliseconds.
+ Addr string // Address of the RPC server.
+ SubscriptionBufferSize int // The buffer size for a subscription.
}
// RPCClient RPC Client
type RPCClient struct {
config RPCClientConfig
conn *websocket.Conn
- requests *messageDispatcher[message.RequestID, message.Response]
- subscriptions *messageDispatcher[message.SubscriptionID, json.RawMessage]
- stop_signal chan struct{}
+ requests *messageDispatcher
+ subscriptions *subscriptions
+ stopSignal chan struct{}
+ isClosed atomic.Bool
}
// NewRPCClient Creates a new instance of RPCClient with the provided configuration.
@@ -46,25 +61,29 @@ func NewRPCClient(config RPCClientConfig) (*RPCClient, error) {
}
log.Infof("Successfully connected to the server: %s", config.Addr)
- if config.Timeout == 0 {
+ if config.Timeout <= 0 {
config.Timeout = DefaultTimeout
}
- stop_signal := make(chan struct{})
+ if config.SubscriptionBufferSize <= 0 {
+ config.SubscriptionBufferSize = DefaultSubscriptionBufferSize
+ }
+
+ stopSignal := make(chan struct{})
- requests := newMessageDispatcher[message.RequestID, message.Response](0)
- subscriptions := newMessageDispatcher[message.SubscriptionID, json.RawMessage](100)
+ requests := newMessageDispatcher()
+ subscriptions := newSubscriptions(config.SubscriptionBufferSize)
client := &RPCClient{
conn: conn,
config: config,
requests: requests,
subscriptions: subscriptions,
- stop_signal: stop_signal,
+ stopSignal: stopSignal,
}
go func() {
- if err := client.backgroundReceivingLoop(stop_signal); err != nil {
+ if err := client.backgroundReceivingLoop(stopSignal); err != nil {
client.Close()
}
}()
@@ -74,9 +93,14 @@ func NewRPCClient(config RPCClientConfig) (*RPCClient, error) {
// Close Closes the underlying websocket connection and stop the receiving loop.
func (client *RPCClient) Close() {
+ // Check if it's already closed
+ if !client.isClosed.CompareAndSwap(false, true) {
+ return
+ }
+
log.Warn("Close the rpc client...")
// Send stop signal to the background receiving loop
- close(client.stop_signal)
+ close(client.stopSignal)
// Close the underlying websocket connection
err := client.conn.Close()
@@ -90,7 +114,7 @@ func (client *RPCClient) Close() {
// Call Sends an RPC call to the server with the specified method and
// parameters, and returns the response.
-func (client *RPCClient) Call(method string, params any) (*json.RawMessage, error) {
+func (client *RPCClient) Call(method string, params any) (json.RawMessage, error) {
log.Tracef("Call -> method: %s, params: %v", method, params)
response, err := client.sendRequest(method, params)
if err != nil {
@@ -101,29 +125,27 @@ func (client *RPCClient) Call(method string, params any) (*json.RawMessage, erro
}
// Subscribe Sends a subscription request to the server with the specified
-// method and parameters, and it returns the subscription ID and the channel to
-// receive notifications.
-func (client *RPCClient) Subscribe(method string, params any) (message.SubscriptionID, <-chan json.RawMessage, error) {
+// method and parameters, and it returns the subscription.
+func (client *RPCClient) Subscribe(method string, params any) (*Subscription, error) {
log.Tracef("Sbuscribe -> method: %s, params: %v", method, params)
response, err := client.sendRequest(method, params)
if err != nil {
- return 0, nil, err
+ return nil, err
}
if response.Result == nil {
- return 0, nil, fmt.Errorf("Invalid response result")
+ return nil, InvalidResponseResultErr
}
var subID message.SubscriptionID
- err = json.Unmarshal(*response.Result, &subID)
+ err = json.Unmarshal(response.Result, &subID)
if err != nil {
- return 0, nil, err
+ return nil, err
}
- // Register a new subscription
- sub := client.subscriptions.register(subID)
+ sub := client.subscriptions.subscribe(subID)
- return subID, sub, nil
+ return sub, nil
}
// Unsubscribe Sends an unsubscription request to the server to cancel the
@@ -135,44 +157,49 @@ func (client *RPCClient) Unsubscribe(method string, subID message.SubscriptionID
return err
}
- // On success unregister the subscription channel
- client.subscriptions.unregister(subID)
+ // On success unsubscribe
+ client.subscriptions.unsubscribe(subID)
return nil
}
// backgroundReceivingLoop Starts reading new messages from the underlying connection.
-func (client *RPCClient) backgroundReceivingLoop(stop_signal <-chan struct{}) error {
+func (client *RPCClient) backgroundReceivingLoop(stopSignal <-chan struct{}) error {
log.Debug("Background loop started")
- new_msg := make(chan []byte)
- receive_err := make(chan error)
+ newMsgCh := make(chan []byte)
+ receiveErrCh := make(chan error)
// Start listing for new messages
go func() {
for {
_, msg, err := client.conn.ReadMessage()
if err != nil {
- receive_err <- err
+ receiveErrCh <- err
+ return
+ }
+ select {
+ case <-client.stopSignal:
return
+ case newMsgCh <- msg:
}
- new_msg <- msg
}
}()
for {
select {
- case <-stop_signal:
- log.Warn("Stopping background receiving loop: received stop signal")
+ case err := <-receiveErrCh:
+ log.WithError(err).Error("Read a new msg")
+ return err
+ case <-stopSignal:
+ log.Debug("Background receiving loop stopped %w", receivedStopSignalErr)
return nil
- case msg := <-new_msg:
+ case msg := <-newMsgCh:
err := client.handleNewMsg(msg)
if err != nil {
- log.WithError(err).Error("Handle a new received msg")
+ log.WithError(err).Error("Handle a msg")
+ return err
}
- case err := <-receive_err:
- log.WithError(err).Error("Receive a new msg")
- return err
}
}
}
@@ -186,7 +213,7 @@ func (client *RPCClient) handleNewMsg(msg []byte) error {
decoder.DisallowUnknownFields()
if err := decoder.Decode(&response); err == nil {
if response.ID == nil {
- return fmt.Errorf("Response doesn't have an id")
+ return InvalidResponseIDErr
}
err := client.requests.dispatch(*response.ID, response)
@@ -202,18 +229,13 @@ func (client *RPCClient) handleNewMsg(msg []byte) error {
if err := json.Unmarshal(msg, &notification); err == nil {
ntRes := message.NotificationResult{}
- if err := json.Unmarshal(*notification.Params, &ntRes); err != nil {
+ if err := json.Unmarshal(notification.Params, &ntRes); err != nil {
return fmt.Errorf("Failed to unmarshal notification params: %w", err)
}
- // TODO: Consider using a more efficient design here because,
- // on each registration, messageDispatcher will create a new subscription
- // channel with the provided buffer size. If the buffer for the subscription
- // channel is full, calling the dispatch method will block and in this
- // case the client will not be able to handle any more messages.
- err := client.subscriptions.dispatch(ntRes.Subscription, *ntRes.Result)
+ err := client.subscriptions.notify(ntRes.Subscription, ntRes.Result)
if err != nil {
- return fmt.Errorf("Dispatch a notification: %w", err)
+ return fmt.Errorf("Notify a subscriber: %w", err)
}
log.Debugf("<-- %s", notification.String())
@@ -228,6 +250,10 @@ func (client *RPCClient) handleNewMsg(msg []byte) error {
func (client *RPCClient) sendRequest(method string, params any) (message.Response, error) {
response := message.Response{}
+ if client.isClosed.Load() {
+ return response, ClientIsDisconnectedErr
+ }
+
params_bytes, err := json.Marshal(params)
if err != nil {
return response, err
@@ -262,8 +288,10 @@ func (client *RPCClient) sendRequest(method string, params any) (message.Respons
// Waits the response, it fails and return error if it exceed the timeout
select {
case response = <-rx_ch:
+ case <-client.stopSignal:
+ return response, ClientIsDisconnectedErr
case <-time.After(time.Duration(client.config.Timeout) * time.Millisecond):
- return response, fmt.Errorf("Timeout error")
+ return response, TimeoutError
}
err = validateResponse(&response, id)
@@ -285,7 +313,7 @@ func validateResponse(res *message.Response, reqID message.RequestID) error {
if res.ID != nil {
if *res.ID != reqID {
- return fmt.Errorf("Invalid response id")
+ return InvalidResponseIDErr
}
}
diff --git a/jsonrpc/client/concurrent_queue.go b/jsonrpc/client/concurrent_queue.go
new file mode 100644
index 0000000..daf92cd
--- /dev/null
+++ b/jsonrpc/client/concurrent_queue.go
@@ -0,0 +1,88 @@
+package client
+
+import (
+ "errors"
+ "sync"
+ "sync/atomic"
+)
+
+// queue A concurrent queue.
+type queue[T any] struct {
+ lock sync.Mutex
+ cond *sync.Cond
+ items []T
+ bufferSize int
+ stopSignal chan struct{}
+ isClosed atomic.Bool
+}
+
+var (
+ queueIsFullErr = errors.New("Queue is full")
+ queueIsClosedErr = errors.New("Queue is closed")
+)
+
+// newQueue creates a new queue with the specified buffer size.
+func newQueue[T any](bufferSize int) *queue[T] {
+ q := &queue[T]{
+ bufferSize: bufferSize,
+ items: make([]T, 0),
+ stopSignal: make(chan struct{}),
+ }
+ q.cond = sync.NewCond(&q.lock)
+ return q
+}
+
+// push Adds a new item to the queue and returns an error if the queue
+// is full.
+func (q *queue[T]) push(item T) error {
+ if q.isClosed.Load() {
+ return queueIsClosedErr
+ }
+
+ q.lock.Lock()
+ defer q.lock.Unlock()
+ if len(q.items) >= q.bufferSize {
+ return queueIsFullErr
+ }
+
+ q.items = append(q.items, item)
+ q.cond.Signal()
+ return nil
+}
+
+// pop waits for and removes the first element from the queue, then returns it.
+func (q *queue[T]) pop() (T, bool) {
+ var t T
+ if q.isClosed.Load() {
+ return t, false
+ }
+ q.lock.Lock()
+ defer q.lock.Unlock()
+
+ // Wait for an item to be available or for a stop signal.
+ // This ensures that the waiting stops once the queue is cleared.
+ for len(q.items) == 0 {
+ select {
+ case <-q.stopSignal:
+ return t, false
+ default:
+ q.cond.Wait()
+ }
+ }
+
+ item := q.items[0]
+ q.items = q.items[1:]
+ return item, true
+}
+
+// clear Clears all elements from the queue.
+func (q *queue[T]) clear() {
+ if !q.isClosed.CompareAndSwap(false, true) {
+ return
+ }
+ q.lock.Lock()
+ defer q.lock.Unlock()
+ close(q.stopSignal)
+ q.items = nil
+ q.cond.Broadcast()
+}
diff --git a/jsonrpc/client/message_dispatcher.go b/jsonrpc/client/message_dispatcher.go
index 9177f6e..adedd1c 100644
--- a/jsonrpc/client/message_dispatcher.go
+++ b/jsonrpc/client/message_dispatcher.go
@@ -1,76 +1,74 @@
package client
import (
- "fmt"
+ "errors"
"sync"
+
+ "github.com/karyontech/karyon-go/jsonrpc/message"
+)
+
+var (
+ requestChannelNotFoundErr = errors.New("Request channel not found")
)
-// messageDispatcher Is a generic structure that holds a map of keys and
+// messageDispatcher Is a structure that holds a map of request IDs and
// channels, and it is protected by mutex
-type messageDispatcher[K comparable, V any] struct {
+type messageDispatcher struct {
sync.Mutex
- chans map[K]chan<- V
- bufferSize int
+ chans map[message.RequestID]chan<- message.Response
}
// newMessageDispatcher Creates a new messageDispatcher
-func newMessageDispatcher[K comparable, V any](bufferSize int) *messageDispatcher[K, V] {
- chans := make(map[K]chan<- V)
- return &messageDispatcher[K, V]{
- chans: chans,
- bufferSize: bufferSize,
+func newMessageDispatcher() *messageDispatcher {
+ chans := make(map[message.RequestID]chan<- message.Response)
+ return &messageDispatcher{
+ chans: chans,
}
}
-// register Registers a new channel with a given key. It returns the receiving channel.
-func (c *messageDispatcher[K, V]) register(key K) <-chan V {
+// register Registers a new request channel with the given id. It returns a
+// channel for receiving response.
+func (c *messageDispatcher) register(key message.RequestID) <-chan message.Response {
c.Lock()
defer c.Unlock()
- ch := make(chan V, c.bufferSize)
+ ch := make(chan message.Response)
c.chans[key] = ch
return ch
}
-// length Returns the number of channels
-func (c *messageDispatcher[K, V]) length() int {
+// dispatch Disptaches the response to the channel with the given request id
+func (c *messageDispatcher) dispatch(key message.RequestID, res message.Response) error {
c.Lock()
defer c.Unlock()
- return len(c.chans)
-}
-
-// dispatch Disptaches the msg to the channel with the given key
-func (c *messageDispatcher[K, V]) dispatch(key K, msg V) error {
- c.Lock()
- ch, ok := c.chans[key]
- c.Unlock()
-
- if !ok {
- return fmt.Errorf("Channel not found")
+ if ch, ok := c.chans[key]; ok {
+ ch <- res
+ } else {
+ return requestChannelNotFoundErr
}
- ch <- msg
return nil
}
-// unregister Unregisters the channel with the provided key
-func (c *messageDispatcher[K, V]) unregister(key K) {
+// unregister Unregisters the request with the provided id
+func (c *messageDispatcher) unregister(key message.RequestID) {
c.Lock()
defer c.Unlock()
+
if ch, ok := c.chans[key]; ok {
close(ch)
delete(c.chans, key)
}
}
-// clear Closes all the channels and remove them from the map
-func (c *messageDispatcher[K, V]) clear() {
+// clear Closes all the request channels and remove them from the map
+func (c *messageDispatcher) clear() {
c.Lock()
defer c.Unlock()
- for k, ch := range c.chans {
+ for _, ch := range c.chans {
close(ch)
- delete(c.chans, k)
}
+ c.chans = nil
}
diff --git a/jsonrpc/client/message_dispatcher_test.go b/jsonrpc/client/message_dispatcher_test.go
index a1fc1c6..4d7cc60 100644
--- a/jsonrpc/client/message_dispatcher_test.go
+++ b/jsonrpc/client/message_dispatcher_test.go
@@ -5,58 +5,61 @@ import (
"sync/atomic"
"testing"
+ "github.com/karyontech/karyon-go/jsonrpc/message"
"github.com/stretchr/testify/assert"
)
func TestDispatchToChannel(t *testing.T) {
- messageDispatcher := newMessageDispatcher[int, int](10)
+ messageDispatcher := newMessageDispatcher()
- chanKey := 1
- rx := messageDispatcher.register(chanKey)
+ req1 := "1"
+ rx := messageDispatcher.register(req1)
- chanKey2 := 2
- rx2 := messageDispatcher.register(chanKey2)
+ req2 := "2"
+ rx2 := messageDispatcher.register(req2)
var wg sync.WaitGroup
wg.Add(1)
go func() {
+ defer wg.Done()
for i := 0; i < 50; i++ {
- err := messageDispatcher.dispatch(chanKey, i)
+ res := message.Response{ID: &req1}
+ err := messageDispatcher.dispatch(req1, res)
assert.Nil(t, err)
}
- messageDispatcher.unregister(chanKey)
- wg.Done()
+ messageDispatcher.unregister(req1)
}()
wg.Add(1)
go func() {
+ defer wg.Done()
for i := 0; i < 50; i++ {
- err := messageDispatcher.dispatch(chanKey2, i)
+ res := message.Response{ID: &req2}
+ err := messageDispatcher.dispatch(req2, res)
assert.Nil(t, err)
}
- messageDispatcher.unregister(chanKey2)
- wg.Done()
+ messageDispatcher.unregister(req2)
}()
var receivedItem atomic.Int32
wg.Add(1)
go func() {
+ defer wg.Done()
for range rx {
receivedItem.Add(1)
}
- wg.Done()
}()
wg.Add(1)
go func() {
+ defer wg.Done()
for range rx2 {
receivedItem.Add(1)
}
- wg.Done()
}()
wg.Wait()
@@ -64,33 +67,31 @@ func TestDispatchToChannel(t *testing.T) {
}
func TestUnregisterChannel(t *testing.T) {
- messageDispatcher := newMessageDispatcher[int, int](1)
+ messageDispatcher := newMessageDispatcher()
- chanKey := 1
- rx := messageDispatcher.register(chanKey)
+ req := "1"
+ rx := messageDispatcher.register(req)
- messageDispatcher.unregister(chanKey)
- assert.Equal(t, messageDispatcher.length(), 0, "channels should be empty")
+ messageDispatcher.unregister(req)
_, ok := <-rx
assert.False(t, ok, "chan closed")
- err := messageDispatcher.dispatch(chanKey, 1)
+ err := messageDispatcher.dispatch(req, message.Response{ID: &req})
assert.NotNil(t, err)
}
func TestClearChannels(t *testing.T) {
- messageDispatcher := newMessageDispatcher[int, int](1)
+ messageDispatcher := newMessageDispatcher()
- chanKey := 1
- rx := messageDispatcher.register(chanKey)
+ req := "1"
+ rx := messageDispatcher.register(req)
messageDispatcher.clear()
- assert.Equal(t, messageDispatcher.length(), 0, "channels should be empty")
_, ok := <-rx
assert.False(t, ok, "chan closed")
- err := messageDispatcher.dispatch(chanKey, 1)
+ err := messageDispatcher.dispatch(req, message.Response{ID: &req})
assert.NotNil(t, err)
}
diff --git a/jsonrpc/client/subscription.go b/jsonrpc/client/subscription.go
new file mode 100644
index 0000000..a9a94e7
--- /dev/null
+++ b/jsonrpc/client/subscription.go
@@ -0,0 +1,83 @@
+package client
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "sync/atomic"
+
+ log "github.com/sirupsen/logrus"
+)
+
+var (
+ subscriptionIsClosedErr = errors.New("Subscription is closed")
+)
+
+// Subscription A subscription established when the client's subscribe to a method
+type Subscription struct {
+ ch chan json.RawMessage
+ ID int
+ queue *queue[json.RawMessage]
+ stopSignal chan struct{}
+ isClosed atomic.Bool
+}
+
+// newSubscription Creates a new Subscription
+func newSubscription(subID int, bufferSize int) *Subscription {
+ sub := &Subscription{
+ ch: make(chan json.RawMessage),
+ ID: subID,
+ queue: newQueue[json.RawMessage](bufferSize),
+ stopSignal: make(chan struct{}),
+ }
+ sub.startBackgroundJob()
+
+ return sub
+}
+
+// Recv Receives a new notification.
+func (s *Subscription) Recv() <-chan json.RawMessage {
+ return s.ch
+}
+
+// startBackgroundJob starts waiting for the queue to receive new items.
+// It stops when it receives a stop signal.
+func (s *Subscription) startBackgroundJob() {
+ go func() {
+ logger := log.WithField("Subscription", s.ID)
+ for {
+ msg, ok := s.queue.pop()
+ if !ok {
+ logger.Debug("Background job stopped")
+ return
+ }
+ select {
+ case <-s.stopSignal:
+ logger.Debug("Background job stopped: %w", receivedStopSignalErr)
+ return
+ case s.ch <- msg:
+ }
+ }
+ }()
+}
+
+// notify adds a new notification to the queue.
+func (s *Subscription) notify(nt json.RawMessage) error {
+ if s.isClosed.Load() {
+ return subscriptionIsClosedErr
+ }
+ if err := s.queue.push(nt); err != nil {
+ return fmt.Errorf("Unable to push new notification: %w", err)
+ }
+ return nil
+}
+
+// stop Terminates the subscription, clears the queue, and closes channels.
+func (s *Subscription) stop() {
+ if !s.isClosed.CompareAndSwap(false, true) {
+ return
+ }
+ close(s.stopSignal)
+ close(s.ch)
+ s.queue.clear()
+}
diff --git a/jsonrpc/client/subscription_test.go b/jsonrpc/client/subscription_test.go
new file mode 100644
index 0000000..5928665
--- /dev/null
+++ b/jsonrpc/client/subscription_test.go
@@ -0,0 +1,86 @@
+package client
+
+import (
+ "encoding/json"
+ "sync"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestSubscriptionFullQueue(t *testing.T) {
+ bufSize := 100
+ sub := newSubscription(1, bufSize)
+
+ var wg sync.WaitGroup
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ defer sub.stop()
+ for i := 0; i < bufSize+10; i++ {
+ b, err := json.Marshal(i)
+ assert.Nil(t, err)
+ err = sub.notify(b)
+ if i > bufSize {
+ if assert.Error(t, err) {
+ assert.ErrorIs(t, err, queueIsFullErr)
+ }
+ }
+ }
+ }()
+
+ wg.Wait()
+}
+
+func TestSubscriptionRecv(t *testing.T) {
+ bufSize := 100
+ sub := newSubscription(1, bufSize)
+
+ var wg sync.WaitGroup
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for i := 0; i < bufSize; i++ {
+ b, err := json.Marshal(i)
+ assert.Nil(t, err)
+ err = sub.notify(b)
+ assert.Nil(t, err)
+ }
+ }()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ i := 0
+ for nt := range sub.Recv() {
+ var v int
+ err := json.Unmarshal(nt, &v)
+ assert.Nil(t, err)
+ assert.Equal(t, v, i)
+ i += 1
+ if i == bufSize {
+ break
+ }
+ }
+ }()
+
+ wg.Wait()
+}
+
+func TestSubscriptionStop(t *testing.T) {
+ sub := newSubscription(1, 10)
+
+ sub.stop()
+
+ _, ok := <-sub.Recv()
+ assert.False(t, ok)
+
+ b, err := json.Marshal(1)
+ assert.Nil(t, err)
+ err = sub.notify(b)
+ if assert.Error(t, err) {
+ assert.ErrorIs(t, err, subscriptionIsClosedErr)
+ }
+}
diff --git a/jsonrpc/client/subscriptions.go b/jsonrpc/client/subscriptions.go
new file mode 100644
index 0000000..0524837
--- /dev/null
+++ b/jsonrpc/client/subscriptions.go
@@ -0,0 +1,81 @@
+package client
+
+import (
+ "encoding/json"
+ "errors"
+ "sync"
+
+ "github.com/karyontech/karyon-go/jsonrpc/message"
+)
+
+var (
+ subscriptionNotFoundErr = errors.New("Subscription not found")
+)
+
+// subscriptions Is a structure that holds a map of subscription IDs and
+// subscriptions
+type subscriptions struct {
+ sync.Mutex
+ subs map[message.SubscriptionID]*Subscription
+ bufferSize int
+}
+
+// newSubscriptions Creates a new subscriptions
+func newSubscriptions(bufferSize int) *subscriptions {
+ subs := make(map[message.SubscriptionID]*Subscription)
+ return &subscriptions{
+ subs: subs,
+ bufferSize: bufferSize,
+ }
+}
+
+// subscribe Subscribes and returns a Subscription.
+func (c *subscriptions) subscribe(key message.SubscriptionID) *Subscription {
+ c.Lock()
+ defer c.Unlock()
+
+ sub := newSubscription(key, c.bufferSize)
+ c.subs[key] = sub
+ return sub
+}
+
+// notify Notifies the msg the subscription with the given id
+func (c *subscriptions) notify(key message.SubscriptionID, msg json.RawMessage) error {
+ c.Lock()
+ defer c.Unlock()
+
+ sub, ok := c.subs[key]
+
+ if !ok {
+ return subscriptionNotFoundErr
+ }
+
+ err := sub.notify(msg)
+
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// unsubscribe Unsubscribe from the subscription with the provided id
+func (c *subscriptions) unsubscribe(key message.SubscriptionID) {
+ c.Lock()
+ defer c.Unlock()
+ if sub, ok := c.subs[key]; ok {
+ sub.stop()
+ delete(c.subs, key)
+ }
+}
+
+// clear Stops all the subscriptions and remove them from the map
+func (c *subscriptions) clear() {
+ c.Lock()
+ defer c.Unlock()
+
+ for _, sub := range c.subs {
+ sub.stop()
+ }
+ c.subs = nil
+}
diff --git a/jsonrpc/client/subscriptions_test.go b/jsonrpc/client/subscriptions_test.go
new file mode 100644
index 0000000..cb5705a
--- /dev/null
+++ b/jsonrpc/client/subscriptions_test.go
@@ -0,0 +1,87 @@
+package client
+
+import (
+ "encoding/json"
+ "sync"
+ "sync/atomic"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestSubscriptionsSubscribe(t *testing.T) {
+ bufSize := 100
+ subs := newSubscriptions(bufSize)
+
+ var receivedNotifications atomic.Int32
+
+ var wg sync.WaitGroup
+
+ runSubNotify := func(sub *Subscription) {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for i := 0; i < bufSize; i++ {
+ b, err := json.Marshal(i)
+ assert.Nil(t, err)
+ err = sub.notify(b)
+ assert.Nil(t, err)
+ }
+ }()
+ }
+
+ runSubRecv := func(sub *Subscription) {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ i := 0
+ for nt := range sub.Recv() {
+ var v int
+ err := json.Unmarshal(nt, &v)
+ assert.Nil(t, err)
+ assert.Equal(t, v, i)
+ receivedNotifications.Add(1)
+ i += 1
+ if i == bufSize {
+ break
+ }
+ }
+ sub.stop()
+ }()
+ }
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for i := 0; i < 3; i++ {
+ sub := subs.subscribe(i)
+ runSubNotify(sub)
+ runSubRecv(sub)
+ }
+ }()
+
+ wg.Wait()
+ assert.Equal(t, receivedNotifications.Load(), int32(bufSize*3))
+}
+
+func TestSubscriptionsUnsubscribe(t *testing.T) {
+ bufSize := 100
+ subs := newSubscriptions(bufSize)
+
+ var wg sync.WaitGroup
+
+ sub := subs.subscribe(1)
+ subs.unsubscribe(1)
+
+ _, ok := <-sub.Recv()
+ assert.False(t, ok)
+
+ b, err := json.Marshal(1)
+ assert.Nil(t, err)
+ err = sub.notify(b)
+ if assert.Error(t, err) {
+ assert.ErrorIs(t, err, subscriptionIsClosedErr)
+ }
+
+ wg.Wait()
+}
diff --git a/jsonrpc/message/message.go b/jsonrpc/message/message.go
index ff68da2..b01c05d 100644
--- a/jsonrpc/message/message.go
+++ b/jsonrpc/message/message.go
@@ -26,7 +26,7 @@ type Request struct {
type Response struct {
JSONRPC string `json:"jsonrpc"` // JSON-RPC version, typically "2.0".
ID *RequestID `json:"id,omitempty"` // Unique identifier matching the request ID, can be null for notifications.
- Result *json.RawMessage `json:"result,omitempty"` // Result of the request if it was successful.
+ Result json.RawMessage `json:"result,omitempty"` // Result of the request if it was successful.
Error *Error `json:"error,omitempty"` // Error object if the request failed.
}
@@ -34,13 +34,13 @@ type Response struct {
type Notification struct {
JSONRPC string `json:"jsonrpc"` // JSON-RPC version, typically "2.0".
Method string `json:"method"` // The name of the method to be invoked.
- Params *json.RawMessage `json:"params,omitempty"` // Optional parameters for the method.
+ Params json.RawMessage `json:"params,omitempty"` // Optional parameters for the method.
}
// NotificationResult represents the result of a subscription notification.
// It includes the result and the subscription ID that triggered the notification.
type NotificationResult struct {
- Result *json.RawMessage `json:"result,omitempty"` // Result data of the notification.
+ Result json.RawMessage `json:"result,omitempty"` // Result data of the notification.
Subscription SubscriptionID `json:"subscription"` // ID of the subscription that triggered the notification.
}
@@ -49,7 +49,7 @@ type NotificationResult struct {
type Error struct {
Code int `json:"code"` // Error code indicating the type of error.
Message string `json:"message"` // Human-readable error message.
- Data *json.RawMessage `json:"data,omitempty"` // Optional additional data about the error.
+ Data json.RawMessage `json:"data,omitempty"` // Optional additional data about the error.
}
func (req *Request) String() string {
@@ -57,11 +57,11 @@ func (req *Request) String() string {
}
func (res *Response) String() string {
- return fmt.Sprintf("{JSONRPC: %s, ID: %s, RESULT: %s, ERROR: %v}", res.JSONRPC, *res.ID, *res.Result, res.Error)
+ return fmt.Sprintf("{JSONRPC: %s, ID: %v, RESULT: %v, ERROR: %v}", res.JSONRPC, res.ID, res.Result, res.Error)
}
func (nt *Notification) String() string {
- return fmt.Sprintf("{JSONRPC: %s, METHOD: %s, PARAMS: %s}", nt.JSONRPC, nt.Method, *nt.Params)
+ return fmt.Sprintf("{JSONRPC: %s, METHOD: %s, PARAMS: %s}", nt.JSONRPC, nt.Method, nt.Params)
}
func (err *Error) String() string {