diff options
| author | hozan23 <hozan23@karyontech.net> | 2024-06-23 15:57:43 +0200 | 
|---|---|---|
| committer | hozan23 <hozan23@karyontech.net> | 2024-07-09 11:46:03 +0200 | 
| commit | 6355144b8c3514cccc5c2ab4f7c4fd8e76a1a9fc (patch) | |
| tree | 3c31e350c8da79198f6127398905461addccef1e | |
| parent | 223d80fa52d3efd2909b7061e3c42a0ed930b4ff (diff) | |
Fix the issue with message dispatcher and channels
Resolved a previous error where each subscription would create a new
channel with the fixed buffer size. This caused blocking when the
channel buffer was full, preventing the client from handling additional messages.
Now, there is a `subscriptions` struct that holds a queue for receiving
notifications, ensuring the notify function does not block.
| -rw-r--r-- | Makefile | 17 | ||||
| -rw-r--r-- | README.md | 6 | ||||
| -rw-r--r-- | go.mod | 9 | ||||
| -rw-r--r-- | go.sum | 16 | ||||
| -rw-r--r-- | jsonrpc/client/client.go | 124 | ||||
| -rw-r--r-- | jsonrpc/client/concurrent_queue.go | 88 | ||||
| -rw-r--r-- | jsonrpc/client/message_dispatcher.go | 64 | ||||
| -rw-r--r-- | jsonrpc/client/message_dispatcher_test.go | 49 | ||||
| -rw-r--r-- | jsonrpc/client/subscription.go | 83 | ||||
| -rw-r--r-- | jsonrpc/client/subscription_test.go | 86 | ||||
| -rw-r--r-- | jsonrpc/client/subscriptions.go | 81 | ||||
| -rw-r--r-- | jsonrpc/client/subscriptions_test.go | 87 | ||||
| -rw-r--r-- | jsonrpc/message/message.go | 12 | 
13 files changed, 595 insertions, 127 deletions
diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..12f1a55 --- /dev/null +++ b/Makefile @@ -0,0 +1,17 @@ +GO := go +LINTER := golangci-lint +PKGS := ./... + +.PHONY: all test lint clean + +all: lint test + +test: +	$(GO) test $(PKGS) + +lint: +	$(LINTER) run $(PKGS) + +clean: +	$(GO) clean + @@ -21,14 +21,14 @@ if err != nil {  }  defer client.Close() -subID, ch, err := client.Subscribe("RPCService.log_subscribe", nil) +sub, err := client.Subscribe("RPCService.log_subscribe", nil)  if err != nil {  	log.Fatal(err)  } -log.Infof("Subscribed successfully: %d\n", subID) +log.Infof("Subscribed successfully: %d\n", sub.ID)  go func() { -	for notification := range ch { +	for notification := range sub.Recv() {  		log.Infof("Receive new notification: %s\n", notification)  	}  }() @@ -3,15 +3,14 @@ module github.com/karyontech/karyon-go  go 1.22  require ( -	github.com/gorilla/websocket v1.5.1 +	github.com/gorilla/websocket v1.5.3  	github.com/sirupsen/logrus v1.9.3 -	github.com/stretchr/testify v1.7.0 +	github.com/stretchr/testify v1.9.0  )  require (  	github.com/davecgh/go-spew v1.1.1 // indirect  	github.com/pmezard/go-difflib v1.0.0 // indirect -	golang.org/x/net v0.17.0 // indirect -	golang.org/x/sys v0.13.0 // indirect -	gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect +	golang.org/x/sys v0.21.0 // indirect +	gopkg.in/yaml.v3 v3.0.1 // indirect  ) @@ -1,21 +1,21 @@  github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=  github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=  github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= -github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=  github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=  github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=  github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=  github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=  github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=  github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= -golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=  golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= -golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= +golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=  gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=  gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=  gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/jsonrpc/client/client.go b/jsonrpc/client/client.go index afcc5c7..1f12ced 100644 --- a/jsonrpc/client/client.go +++ b/jsonrpc/client/client.go @@ -3,9 +3,11 @@ package client  import (  	"bytes"  	"encoding/json" +	"errors"  	"fmt"  	"math/rand"  	"strconv" +	"sync/atomic"  	"time"  	"github.com/gorilla/websocket" @@ -20,21 +22,34 @@ const (  	// Default timeout for receiving requests from the server, in milliseconds.  	DefaultTimeout = 3000 + +	// The default buffer size for a subscription. +	DefaultSubscriptionBufferSize = 10000 +) + +var ( +	ClientIsDisconnectedErr  = errors.New("Client is disconnected and closed") +	TimeoutError             = errors.New("Timeout Error") +	InvalidResponseIDErr     = errors.New("Invalid response ID") +	InvalidResponseResultErr = errors.New("Invalid response result") +	receivedStopSignalErr    = errors.New("Received stop signal")  )  // RPCClientConfig Holds the configuration settings for the RPC client.  type RPCClientConfig struct { -	Timeout int    // Timeout for receiving requests from the server, in milliseconds. -	Addr    string // Address of the RPC server. +	Timeout                int    // Timeout for receiving requests from the server, in milliseconds. +	Addr                   string // Address of the RPC server. +	SubscriptionBufferSize int    // The buffer size for a subscription.  }  // RPCClient RPC Client  type RPCClient struct {  	config        RPCClientConfig  	conn          *websocket.Conn -	requests      *messageDispatcher[message.RequestID, message.Response] -	subscriptions *messageDispatcher[message.SubscriptionID, json.RawMessage] -	stop_signal   chan struct{} +	requests      *messageDispatcher +	subscriptions *subscriptions +	stopSignal    chan struct{} +	isClosed      atomic.Bool  }  // NewRPCClient Creates a new instance of RPCClient with the provided configuration. @@ -46,25 +61,29 @@ func NewRPCClient(config RPCClientConfig) (*RPCClient, error) {  	}  	log.Infof("Successfully connected to the server: %s", config.Addr) -	if config.Timeout == 0 { +	if config.Timeout <= 0 {  		config.Timeout = DefaultTimeout  	} -	stop_signal := make(chan struct{}) +	if config.SubscriptionBufferSize <= 0 { +		config.SubscriptionBufferSize = DefaultSubscriptionBufferSize +	} + +	stopSignal := make(chan struct{}) -	requests := newMessageDispatcher[message.RequestID, message.Response](0) -	subscriptions := newMessageDispatcher[message.SubscriptionID, json.RawMessage](100) +	requests := newMessageDispatcher() +	subscriptions := newSubscriptions(config.SubscriptionBufferSize)  	client := &RPCClient{  		conn:          conn,  		config:        config,  		requests:      requests,  		subscriptions: subscriptions, -		stop_signal:   stop_signal, +		stopSignal:    stopSignal,  	}  	go func() { -		if err := client.backgroundReceivingLoop(stop_signal); err != nil { +		if err := client.backgroundReceivingLoop(stopSignal); err != nil {  			client.Close()  		}  	}() @@ -74,9 +93,14 @@ func NewRPCClient(config RPCClientConfig) (*RPCClient, error) {  // Close Closes the underlying websocket connection and stop the receiving loop.  func (client *RPCClient) Close() { +	// Check if it's already closed +	if !client.isClosed.CompareAndSwap(false, true) { +		return +	} +  	log.Warn("Close the rpc client...")  	// Send stop signal to the background receiving loop -	close(client.stop_signal) +	close(client.stopSignal)  	// Close the underlying websocket connection  	err := client.conn.Close() @@ -90,7 +114,7 @@ func (client *RPCClient) Close() {  // Call Sends an RPC call to the server with the specified method and  // parameters, and returns the response. -func (client *RPCClient) Call(method string, params any) (*json.RawMessage, error) { +func (client *RPCClient) Call(method string, params any) (json.RawMessage, error) {  	log.Tracef("Call -> method: %s, params: %v", method, params)  	response, err := client.sendRequest(method, params)  	if err != nil { @@ -101,29 +125,27 @@ func (client *RPCClient) Call(method string, params any) (*json.RawMessage, erro  }  // Subscribe Sends a subscription request to the server with the specified -// method and parameters, and it returns the subscription ID and the channel to -// receive notifications. -func (client *RPCClient) Subscribe(method string, params any) (message.SubscriptionID, <-chan json.RawMessage, error) { +// method and parameters, and it returns the subscription. +func (client *RPCClient) Subscribe(method string, params any) (*Subscription, error) {  	log.Tracef("Sbuscribe ->  method: %s, params: %v", method, params)  	response, err := client.sendRequest(method, params)  	if err != nil { -		return 0, nil, err +		return nil, err  	}  	if response.Result == nil { -		return 0, nil, fmt.Errorf("Invalid response result") +		return nil, InvalidResponseResultErr  	}  	var subID message.SubscriptionID -	err = json.Unmarshal(*response.Result, &subID) +	err = json.Unmarshal(response.Result, &subID)  	if err != nil { -		return 0, nil, err +		return nil, err  	} -	// Register a new subscription -	sub := client.subscriptions.register(subID) +	sub := client.subscriptions.subscribe(subID) -	return subID, sub, nil +	return sub, nil  }  // Unsubscribe Sends an unsubscription request to the server to cancel the @@ -135,44 +157,49 @@ func (client *RPCClient) Unsubscribe(method string, subID message.SubscriptionID  		return err  	} -	// On success unregister the subscription channel -	client.subscriptions.unregister(subID) +	// On success unsubscribe +	client.subscriptions.unsubscribe(subID)  	return nil  }  // backgroundReceivingLoop Starts reading new messages from the underlying connection. -func (client *RPCClient) backgroundReceivingLoop(stop_signal <-chan struct{}) error { +func (client *RPCClient) backgroundReceivingLoop(stopSignal <-chan struct{}) error {  	log.Debug("Background loop started") -	new_msg := make(chan []byte) -	receive_err := make(chan error) +	newMsgCh := make(chan []byte) +	receiveErrCh := make(chan error)  	// Start listing for new messages  	go func() {  		for {  			_, msg, err := client.conn.ReadMessage()  			if err != nil { -				receive_err <- err +				receiveErrCh <- err +				return +			} +			select { +			case <-client.stopSignal:  				return +			case newMsgCh <- msg:  			} -			new_msg <- msg  		}  	}()  	for {  		select { -		case <-stop_signal: -			log.Warn("Stopping background receiving loop: received stop signal") +		case err := <-receiveErrCh: +			log.WithError(err).Error("Read a new msg") +			return err +		case <-stopSignal: +			log.Debug("Background receiving loop stopped %w", receivedStopSignalErr)  			return nil -		case msg := <-new_msg: +		case msg := <-newMsgCh:  			err := client.handleNewMsg(msg)  			if err != nil { -				log.WithError(err).Error("Handle a new received msg") +				log.WithError(err).Error("Handle a msg") +				return err  			} -		case err := <-receive_err: -			log.WithError(err).Error("Receive a new msg") -			return err  		}  	}  } @@ -186,7 +213,7 @@ func (client *RPCClient) handleNewMsg(msg []byte) error {  	decoder.DisallowUnknownFields()  	if err := decoder.Decode(&response); err == nil {  		if response.ID == nil { -			return fmt.Errorf("Response doesn't have an id") +			return InvalidResponseIDErr  		}  		err := client.requests.dispatch(*response.ID, response) @@ -202,18 +229,13 @@ func (client *RPCClient) handleNewMsg(msg []byte) error {  	if err := json.Unmarshal(msg, ¬ification); err == nil {  		ntRes := message.NotificationResult{} -		if err := json.Unmarshal(*notification.Params, &ntRes); err != nil { +		if err := json.Unmarshal(notification.Params, &ntRes); err != nil {  			return fmt.Errorf("Failed to unmarshal notification params: %w", err)  		} -		// TODO: Consider using a more efficient design here because, -		// on each registration, messageDispatcher will create a new subscription -		// channel with the provided buffer size. If the buffer for the subscription -		// channel is full, calling the dispatch method will block and in this -		// case the client will not be able to handle any more messages. -		err := client.subscriptions.dispatch(ntRes.Subscription, *ntRes.Result) +		err := client.subscriptions.notify(ntRes.Subscription, ntRes.Result)  		if err != nil { -			return fmt.Errorf("Dispatch a notification: %w", err) +			return fmt.Errorf("Notify a subscriber: %w", err)  		}  		log.Debugf("<-- %s", notification.String()) @@ -228,6 +250,10 @@ func (client *RPCClient) handleNewMsg(msg []byte) error {  func (client *RPCClient) sendRequest(method string, params any) (message.Response, error) {  	response := message.Response{} +	if client.isClosed.Load() { +		return response, ClientIsDisconnectedErr +	} +  	params_bytes, err := json.Marshal(params)  	if err != nil {  		return response, err @@ -262,8 +288,10 @@ func (client *RPCClient) sendRequest(method string, params any) (message.Respons  	// Waits the response, it fails and return error if it exceed the timeout  	select {  	case response = <-rx_ch: +	case <-client.stopSignal: +		return response, ClientIsDisconnectedErr  	case <-time.After(time.Duration(client.config.Timeout) * time.Millisecond): -		return response, fmt.Errorf("Timeout error") +		return response, TimeoutError  	}  	err = validateResponse(&response, id) @@ -285,7 +313,7 @@ func validateResponse(res *message.Response, reqID message.RequestID) error {  	if res.ID != nil {  		if *res.ID != reqID { -			return fmt.Errorf("Invalid response id") +			return InvalidResponseIDErr  		}  	} diff --git a/jsonrpc/client/concurrent_queue.go b/jsonrpc/client/concurrent_queue.go new file mode 100644 index 0000000..daf92cd --- /dev/null +++ b/jsonrpc/client/concurrent_queue.go @@ -0,0 +1,88 @@ +package client + +import ( +	"errors" +	"sync" +	"sync/atomic" +) + +// queue A concurrent queue. +type queue[T any] struct { +	lock       sync.Mutex +	cond       *sync.Cond +	items      []T +	bufferSize int +	stopSignal chan struct{} +	isClosed   atomic.Bool +} + +var ( +	queueIsFullErr   = errors.New("Queue is full") +	queueIsClosedErr = errors.New("Queue is closed") +) + +// newQueue creates a new queue with the specified buffer size. +func newQueue[T any](bufferSize int) *queue[T] { +	q := &queue[T]{ +		bufferSize: bufferSize, +		items:      make([]T, 0), +		stopSignal: make(chan struct{}), +	} +	q.cond = sync.NewCond(&q.lock) +	return q +} + +// push Adds a new item to the queue and returns an error if the queue +// is full. +func (q *queue[T]) push(item T) error { +	if q.isClosed.Load() { +		return queueIsClosedErr +	} + +	q.lock.Lock() +	defer q.lock.Unlock() +	if len(q.items) >= q.bufferSize { +		return queueIsFullErr +	} + +	q.items = append(q.items, item) +	q.cond.Signal() +	return nil +} + +// pop waits for and removes the first element from the queue, then returns it. +func (q *queue[T]) pop() (T, bool) { +	var t T +	if q.isClosed.Load() { +		return t, false +	} +	q.lock.Lock() +	defer q.lock.Unlock() + +	// Wait for an item to be available or for a stop signal. +	// This ensures that the waiting stops once the queue is cleared. +	for len(q.items) == 0 { +		select { +		case <-q.stopSignal: +			return t, false +		default: +			q.cond.Wait() +		} +	} + +	item := q.items[0] +	q.items = q.items[1:] +	return item, true +} + +// clear Clears all elements from the queue. +func (q *queue[T]) clear() { +	if !q.isClosed.CompareAndSwap(false, true) { +		return +	} +	q.lock.Lock() +	defer q.lock.Unlock() +	close(q.stopSignal) +	q.items = nil +	q.cond.Broadcast() +} diff --git a/jsonrpc/client/message_dispatcher.go b/jsonrpc/client/message_dispatcher.go index 9177f6e..adedd1c 100644 --- a/jsonrpc/client/message_dispatcher.go +++ b/jsonrpc/client/message_dispatcher.go @@ -1,76 +1,74 @@  package client  import ( -	"fmt" +	"errors"  	"sync" + +	"github.com/karyontech/karyon-go/jsonrpc/message" +) + +var ( +	requestChannelNotFoundErr = errors.New("Request channel not found")  ) -// messageDispatcher Is a generic structure that holds a map of keys and +// messageDispatcher Is a structure that holds a map of request IDs and  // channels, and it is protected by mutex -type messageDispatcher[K comparable, V any] struct { +type messageDispatcher struct {  	sync.Mutex -	chans      map[K]chan<- V -	bufferSize int +	chans map[message.RequestID]chan<- message.Response  }  // newMessageDispatcher Creates a new messageDispatcher -func newMessageDispatcher[K comparable, V any](bufferSize int) *messageDispatcher[K, V] { -	chans := make(map[K]chan<- V) -	return &messageDispatcher[K, V]{ -		chans:      chans, -		bufferSize: bufferSize, +func newMessageDispatcher() *messageDispatcher { +	chans := make(map[message.RequestID]chan<- message.Response) +	return &messageDispatcher{ +		chans: chans,  	}  } -// register Registers a new channel with a given key. It returns the receiving channel. -func (c *messageDispatcher[K, V]) register(key K) <-chan V { +// register Registers a new request channel with the given id. It returns a +// channel for receiving response. +func (c *messageDispatcher) register(key message.RequestID) <-chan message.Response {  	c.Lock()  	defer c.Unlock() -	ch := make(chan V, c.bufferSize) +	ch := make(chan message.Response)  	c.chans[key] = ch  	return ch  } -// length Returns the number of channels -func (c *messageDispatcher[K, V]) length() int { +// dispatch Disptaches the response to the channel with the given request id +func (c *messageDispatcher) dispatch(key message.RequestID, res message.Response) error {  	c.Lock()  	defer c.Unlock() -	return len(c.chans) -} - -// dispatch Disptaches the msg to the channel with the given key -func (c *messageDispatcher[K, V]) dispatch(key K, msg V) error { -	c.Lock() -	ch, ok := c.chans[key] -	c.Unlock() - -	if !ok { -		return fmt.Errorf("Channel not found") +	if ch, ok := c.chans[key]; ok { +		ch <- res +	} else { +		return requestChannelNotFoundErr  	} -	ch <- msg  	return nil  } -// unregister Unregisters the channel with the provided key -func (c *messageDispatcher[K, V]) unregister(key K) { +// unregister Unregisters the request with the provided id +func (c *messageDispatcher) unregister(key message.RequestID) {  	c.Lock()  	defer c.Unlock() +  	if ch, ok := c.chans[key]; ok {  		close(ch)  		delete(c.chans, key)  	}  } -// clear Closes all the channels and remove them from the map -func (c *messageDispatcher[K, V]) clear() { +// clear Closes all the request channels and remove them from the map +func (c *messageDispatcher) clear() {  	c.Lock()  	defer c.Unlock() -	for k, ch := range c.chans { +	for _, ch := range c.chans {  		close(ch) -		delete(c.chans, k)  	} +	c.chans = nil  } diff --git a/jsonrpc/client/message_dispatcher_test.go b/jsonrpc/client/message_dispatcher_test.go index a1fc1c6..4d7cc60 100644 --- a/jsonrpc/client/message_dispatcher_test.go +++ b/jsonrpc/client/message_dispatcher_test.go @@ -5,58 +5,61 @@ import (  	"sync/atomic"  	"testing" +	"github.com/karyontech/karyon-go/jsonrpc/message"  	"github.com/stretchr/testify/assert"  )  func TestDispatchToChannel(t *testing.T) { -	messageDispatcher := newMessageDispatcher[int, int](10) +	messageDispatcher := newMessageDispatcher() -	chanKey := 1 -	rx := messageDispatcher.register(chanKey) +	req1 := "1" +	rx := messageDispatcher.register(req1) -	chanKey2 := 2 -	rx2 := messageDispatcher.register(chanKey2) +	req2 := "2" +	rx2 := messageDispatcher.register(req2)  	var wg sync.WaitGroup  	wg.Add(1)  	go func() { +		defer wg.Done()  		for i := 0; i < 50; i++ { -			err := messageDispatcher.dispatch(chanKey, i) +			res := message.Response{ID: &req1} +			err := messageDispatcher.dispatch(req1, res)  			assert.Nil(t, err)  		} -		messageDispatcher.unregister(chanKey) -		wg.Done() +		messageDispatcher.unregister(req1)  	}()  	wg.Add(1)  	go func() { +		defer wg.Done()  		for i := 0; i < 50; i++ { -			err := messageDispatcher.dispatch(chanKey2, i) +			res := message.Response{ID: &req2} +			err := messageDispatcher.dispatch(req2, res)  			assert.Nil(t, err)  		} -		messageDispatcher.unregister(chanKey2) -		wg.Done() +		messageDispatcher.unregister(req2)  	}()  	var receivedItem atomic.Int32  	wg.Add(1)  	go func() { +		defer wg.Done()  		for range rx {  			receivedItem.Add(1)  		} -		wg.Done()  	}()  	wg.Add(1)  	go func() { +		defer wg.Done()  		for range rx2 {  			receivedItem.Add(1)  		} -		wg.Done()  	}()  	wg.Wait() @@ -64,33 +67,31 @@ func TestDispatchToChannel(t *testing.T) {  }  func TestUnregisterChannel(t *testing.T) { -	messageDispatcher := newMessageDispatcher[int, int](1) +	messageDispatcher := newMessageDispatcher() -	chanKey := 1 -	rx := messageDispatcher.register(chanKey) +	req := "1" +	rx := messageDispatcher.register(req) -	messageDispatcher.unregister(chanKey) -	assert.Equal(t, messageDispatcher.length(), 0, "channels should be empty") +	messageDispatcher.unregister(req)  	_, ok := <-rx  	assert.False(t, ok, "chan closed") -	err := messageDispatcher.dispatch(chanKey, 1) +	err := messageDispatcher.dispatch(req, message.Response{ID: &req})  	assert.NotNil(t, err)  }  func TestClearChannels(t *testing.T) { -	messageDispatcher := newMessageDispatcher[int, int](1) +	messageDispatcher := newMessageDispatcher() -	chanKey := 1 -	rx := messageDispatcher.register(chanKey) +	req := "1" +	rx := messageDispatcher.register(req)  	messageDispatcher.clear() -	assert.Equal(t, messageDispatcher.length(), 0, "channels should be empty")  	_, ok := <-rx  	assert.False(t, ok, "chan closed") -	err := messageDispatcher.dispatch(chanKey, 1) +	err := messageDispatcher.dispatch(req, message.Response{ID: &req})  	assert.NotNil(t, err)  } diff --git a/jsonrpc/client/subscription.go b/jsonrpc/client/subscription.go new file mode 100644 index 0000000..a9a94e7 --- /dev/null +++ b/jsonrpc/client/subscription.go @@ -0,0 +1,83 @@ +package client + +import ( +	"encoding/json" +	"errors" +	"fmt" +	"sync/atomic" + +	log "github.com/sirupsen/logrus" +) + +var ( +	subscriptionIsClosedErr = errors.New("Subscription is closed") +) + +// Subscription A subscription established when the client's subscribe to a method +type Subscription struct { +	ch         chan json.RawMessage +	ID         int +	queue      *queue[json.RawMessage] +	stopSignal chan struct{} +	isClosed   atomic.Bool +} + +// newSubscription Creates a new Subscription +func newSubscription(subID int, bufferSize int) *Subscription { +	sub := &Subscription{ +		ch:         make(chan json.RawMessage), +		ID:         subID, +		queue:      newQueue[json.RawMessage](bufferSize), +		stopSignal: make(chan struct{}), +	} +	sub.startBackgroundJob() + +	return sub +} + +// Recv Receives a new notification. +func (s *Subscription) Recv() <-chan json.RawMessage { +	return s.ch +} + +// startBackgroundJob starts waiting for the queue to receive new items. +// It stops when it receives a stop signal. +func (s *Subscription) startBackgroundJob() { +	go func() { +		logger := log.WithField("Subscription", s.ID) +		for { +			msg, ok := s.queue.pop() +			if !ok { +				logger.Debug("Background job stopped") +				return +			} +			select { +			case <-s.stopSignal: +				logger.Debug("Background job stopped: %w", receivedStopSignalErr) +				return +			case s.ch <- msg: +			} +		} +	}() +} + +// notify adds a new notification to the queue. +func (s *Subscription) notify(nt json.RawMessage) error { +	if s.isClosed.Load() { +		return subscriptionIsClosedErr +	} +	if err := s.queue.push(nt); err != nil { +		return fmt.Errorf("Unable to push new notification: %w", err) +	} +	return nil +} + +// stop Terminates the subscription, clears the queue, and closes channels. +func (s *Subscription) stop() { +	if !s.isClosed.CompareAndSwap(false, true) { +		return +	} +	close(s.stopSignal) +	close(s.ch) +	s.queue.clear() +} diff --git a/jsonrpc/client/subscription_test.go b/jsonrpc/client/subscription_test.go new file mode 100644 index 0000000..5928665 --- /dev/null +++ b/jsonrpc/client/subscription_test.go @@ -0,0 +1,86 @@ +package client + +import ( +	"encoding/json" +	"sync" +	"testing" + +	"github.com/stretchr/testify/assert" +) + +func TestSubscriptionFullQueue(t *testing.T) { +	bufSize := 100 +	sub := newSubscription(1, bufSize) + +	var wg sync.WaitGroup + +	wg.Add(1) +	go func() { +		defer wg.Done() +		defer sub.stop() +		for i := 0; i < bufSize+10; i++ { +			b, err := json.Marshal(i) +			assert.Nil(t, err) +			err = sub.notify(b) +			if i > bufSize { +				if assert.Error(t, err) { +					assert.ErrorIs(t, err, queueIsFullErr) +				} +			} +		} +	}() + +	wg.Wait() +} + +func TestSubscriptionRecv(t *testing.T) { +	bufSize := 100 +	sub := newSubscription(1, bufSize) + +	var wg sync.WaitGroup + +	wg.Add(1) +	go func() { +		defer wg.Done() +		for i := 0; i < bufSize; i++ { +			b, err := json.Marshal(i) +			assert.Nil(t, err) +			err = sub.notify(b) +			assert.Nil(t, err) +		} +	}() + +	wg.Add(1) +	go func() { +		defer wg.Done() +		i := 0 +		for nt := range sub.Recv() { +			var v int +			err := json.Unmarshal(nt, &v) +			assert.Nil(t, err) +			assert.Equal(t, v, i) +			i += 1 +			if i == bufSize { +				break +			} +		} +	}() + +	wg.Wait() +} + +func TestSubscriptionStop(t *testing.T) { +	sub := newSubscription(1, 10) + +	sub.stop() + +	_, ok := <-sub.Recv() +	assert.False(t, ok) + +	b, err := json.Marshal(1) +	assert.Nil(t, err) +	err = sub.notify(b) +	if assert.Error(t, err) { +		assert.ErrorIs(t, err, subscriptionIsClosedErr) +	} +} diff --git a/jsonrpc/client/subscriptions.go b/jsonrpc/client/subscriptions.go new file mode 100644 index 0000000..0524837 --- /dev/null +++ b/jsonrpc/client/subscriptions.go @@ -0,0 +1,81 @@ +package client + +import ( +	"encoding/json" +	"errors" +	"sync" + +	"github.com/karyontech/karyon-go/jsonrpc/message" +) + +var ( +	subscriptionNotFoundErr = errors.New("Subscription not found") +) + +// subscriptions Is a structure that holds a map of subscription IDs and +// subscriptions +type subscriptions struct { +	sync.Mutex +	subs       map[message.SubscriptionID]*Subscription +	bufferSize int +} + +// newSubscriptions Creates a new subscriptions +func newSubscriptions(bufferSize int) *subscriptions { +	subs := make(map[message.SubscriptionID]*Subscription) +	return &subscriptions{ +		subs:       subs, +		bufferSize: bufferSize, +	} +} + +// subscribe Subscribes and returns a Subscription. +func (c *subscriptions) subscribe(key message.SubscriptionID) *Subscription { +	c.Lock() +	defer c.Unlock() + +	sub := newSubscription(key, c.bufferSize) +	c.subs[key] = sub +	return sub +} + +// notify Notifies the msg the subscription with the given id +func (c *subscriptions) notify(key message.SubscriptionID, msg json.RawMessage) error { +	c.Lock() +	defer c.Unlock() + +	sub, ok := c.subs[key] + +	if !ok { +		return subscriptionNotFoundErr +	} + +	err := sub.notify(msg) + +	if err != nil { +		return err +	} + +	return nil +} + +// unsubscribe Unsubscribe from the subscription with the provided id +func (c *subscriptions) unsubscribe(key message.SubscriptionID) { +	c.Lock() +	defer c.Unlock() +	if sub, ok := c.subs[key]; ok { +		sub.stop() +		delete(c.subs, key) +	} +} + +// clear Stops all the subscriptions and remove them from the map +func (c *subscriptions) clear() { +	c.Lock() +	defer c.Unlock() + +	for _, sub := range c.subs { +		sub.stop() +	} +	c.subs = nil +} diff --git a/jsonrpc/client/subscriptions_test.go b/jsonrpc/client/subscriptions_test.go new file mode 100644 index 0000000..cb5705a --- /dev/null +++ b/jsonrpc/client/subscriptions_test.go @@ -0,0 +1,87 @@ +package client + +import ( +	"encoding/json" +	"sync" +	"sync/atomic" +	"testing" + +	"github.com/stretchr/testify/assert" +) + +func TestSubscriptionsSubscribe(t *testing.T) { +	bufSize := 100 +	subs := newSubscriptions(bufSize) + +	var receivedNotifications atomic.Int32 + +	var wg sync.WaitGroup + +	runSubNotify := func(sub *Subscription) { +		wg.Add(1) +		go func() { +			defer wg.Done() +			for i := 0; i < bufSize; i++ { +				b, err := json.Marshal(i) +				assert.Nil(t, err) +				err = sub.notify(b) +				assert.Nil(t, err) +			} +		}() +	} + +	runSubRecv := func(sub *Subscription) { +		wg.Add(1) +		go func() { +			defer wg.Done() +			i := 0 +			for nt := range sub.Recv() { +				var v int +				err := json.Unmarshal(nt, &v) +				assert.Nil(t, err) +				assert.Equal(t, v, i) +				receivedNotifications.Add(1) +				i += 1 +				if i == bufSize { +					break +				} +			} +			sub.stop() +		}() +	} + +	wg.Add(1) +	go func() { +		defer wg.Done() +		for i := 0; i < 3; i++ { +			sub := subs.subscribe(i) +			runSubNotify(sub) +			runSubRecv(sub) +		} +	}() + +	wg.Wait() +	assert.Equal(t, receivedNotifications.Load(), int32(bufSize*3)) +} + +func TestSubscriptionsUnsubscribe(t *testing.T) { +	bufSize := 100 +	subs := newSubscriptions(bufSize) + +	var wg sync.WaitGroup + +	sub := subs.subscribe(1) +	subs.unsubscribe(1) + +	_, ok := <-sub.Recv() +	assert.False(t, ok) + +	b, err := json.Marshal(1) +	assert.Nil(t, err) +	err = sub.notify(b) +	if assert.Error(t, err) { +		assert.ErrorIs(t, err, subscriptionIsClosedErr) +	} + +	wg.Wait() +} diff --git a/jsonrpc/message/message.go b/jsonrpc/message/message.go index ff68da2..b01c05d 100644 --- a/jsonrpc/message/message.go +++ b/jsonrpc/message/message.go @@ -26,7 +26,7 @@ type Request struct {  type Response struct {  	JSONRPC string           `json:"jsonrpc"`          // JSON-RPC version, typically "2.0".  	ID      *RequestID       `json:"id,omitempty"`     // Unique identifier matching the request ID, can be null for notifications. -	Result  *json.RawMessage `json:"result,omitempty"` // Result of the request if it was successful. +	Result  json.RawMessage `json:"result,omitempty"` // Result of the request if it was successful.  	Error   *Error           `json:"error,omitempty"`  // Error object if the request failed.  } @@ -34,13 +34,13 @@ type Response struct {  type Notification struct {  	JSONRPC string           `json:"jsonrpc"`          // JSON-RPC version, typically "2.0".  	Method  string           `json:"method"`           // The name of the method to be invoked. -	Params  *json.RawMessage `json:"params,omitempty"` // Optional parameters for the method. +	Params  json.RawMessage `json:"params,omitempty"` // Optional parameters for the method.  }  // NotificationResult represents the result of a subscription notification.  // It includes the result and the subscription ID that triggered the notification.  type NotificationResult struct { -	Result       *json.RawMessage `json:"result,omitempty"` // Result data of the notification. +	Result       json.RawMessage `json:"result,omitempty"` // Result data of the notification.  	Subscription SubscriptionID   `json:"subscription"`     // ID of the subscription that triggered the notification.  } @@ -49,7 +49,7 @@ type NotificationResult struct {  type Error struct {  	Code    int              `json:"code"`           // Error code indicating the type of error.  	Message string           `json:"message"`        // Human-readable error message. -	Data    *json.RawMessage `json:"data,omitempty"` // Optional additional data about the error. +	Data    json.RawMessage `json:"data,omitempty"` // Optional additional data about the error.  }  func (req *Request) String() string { @@ -57,11 +57,11 @@ func (req *Request) String() string {  }  func (res *Response) String() string { -	return fmt.Sprintf("{JSONRPC: %s, ID: %s, RESULT: %s, ERROR: %v}", res.JSONRPC, *res.ID, *res.Result, res.Error) +	return fmt.Sprintf("{JSONRPC: %s, ID: %v, RESULT: %v, ERROR: %v}", res.JSONRPC, res.ID, res.Result, res.Error)  }  func (nt *Notification) String() string { -	return fmt.Sprintf("{JSONRPC: %s, METHOD: %s, PARAMS: %s}", nt.JSONRPC, nt.Method, *nt.Params) +	return fmt.Sprintf("{JSONRPC: %s, METHOD: %s, PARAMS: %s}", nt.JSONRPC, nt.Method, nt.Params)  }  func (err *Error) String() string {  | 
