diff options
-rw-r--r-- | client/channels.go | 78 | ||||
-rw-r--r-- | client/channels_test.go | 110 | ||||
-rw-r--r-- | jsonrpc/client/client.go (renamed from client/client.go) | 179 | ||||
-rw-r--r-- | jsonrpc/client/message_dispatcher.go | 75 | ||||
-rw-r--r-- | jsonrpc/client/message_dispatcher_test.go | 101 | ||||
-rw-r--r-- | jsonrpc/message/message.go (renamed from message/message.go) | 2 |
6 files changed, 266 insertions, 279 deletions
diff --git a/client/channels.go b/client/channels.go deleted file mode 100644 index 673e366..0000000 --- a/client/channels.go +++ /dev/null @@ -1,78 +0,0 @@ -package client - -import ( - "fmt" - "sync" -) - -// channels is a generic structure that holds a map of keys and channels. -// It is protected by mutex -type channels[K comparable, V any] struct { - sync.Mutex - chans map[K]chan<- V - bufferSize int -} - -// newChannels creates a new channels -func newChannels[K comparable, V any](bufferSize int) channels[K, V] { - chans := make(map[K]chan<- V) - return channels[K, V]{ - chans: chans, - bufferSize: bufferSize, - } -} - -// add adds a new channel and returns the receiving channel -func (c *channels[K, V]) add(key K) <-chan V { - c.Lock() - defer c.Unlock() - - ch := make(chan V, c.bufferSize) - c.chans[key] = ch - return ch -} - -// length returns the number of channels -func (c *channels[K, V]) length() int { - c.Lock() - defer c.Unlock() - - return len(c.chans) -} - -// notify notifies the channel with the given key -func (c *channels[K, V]) notify(key K, msg V) error { - c.Lock() - defer c.Unlock() - - if ch, ok := c.chans[key]; ok { - ch <- msg - return nil - } - - return fmt.Errorf("Channel not found") -} - -// remove removes and returns the channel. -func (c *channels[K, V]) remove(key K) chan<- V { - c.Lock() - defer c.Unlock() - - if ch, ok := c.chans[key]; ok { - delete(c.chans, key) - return ch - } - - return nil -} - -// clear close all the channels and remove them from the map -func (c *channels[K, V]) clear() { - c.Lock() - defer c.Unlock() - - for k, ch := range c.chans { - close(ch) - delete(c.chans, k) - } -} diff --git a/client/channels_test.go b/client/channels_test.go deleted file mode 100644 index 4465fed..0000000 --- a/client/channels_test.go +++ /dev/null @@ -1,110 +0,0 @@ -package client - -import ( - "sync" - "sync/atomic" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestNotifyChannel(t *testing.T) { - - chans := newChannels[int, int](10) - - chanKey := 1 - rx := chans.add(chanKey) - - chanKey2 := 2 - rx2 := chans.add(chanKey2) - - var wg sync.WaitGroup - - wg.Add(1) - go func() { - for i := 0; i < 50; i++ { - err := chans.notify(chanKey, i) - assert.Nil(t, err) - } - - // drop the channel - tx := chans.remove(chanKey) - close(tx) - wg.Done() - }() - - wg.Add(1) - go func() { - for i := 0; i < 50; i++ { - err := chans.notify(chanKey2, i) - assert.Nil(t, err) - } - - // drop the channel - tx := chans.remove(chanKey2) - close(tx) - wg.Done() - }() - - var receivedItem atomic.Int32 - - wg.Add(1) - go func() { - for range rx { - receivedItem.Add(1) - } - wg.Done() - }() - - wg.Add(1) - go func() { - for range rx2 { - receivedItem.Add(1) - } - wg.Done() - }() - - wg.Wait() - assert.Equal(t, receivedItem.Load(), int32(100)) -} - -func TestRemoveChannel(t *testing.T) { - - chans := newChannels[int, int](1) - - chanKey := 1 - rx := chans.add(chanKey) - - tx := chans.remove(chanKey) - assert.Equal(t, chans.length(), 0, "channels should be empty") - - tx <- 3 - val := <-rx - assert.Equal(t, val, 3) - - tx = chans.remove(chanKey) - assert.Nil(t, tx) - - err := chans.notify(chanKey, 1) - assert.NotNil(t, err) -} - -func TestClearChannels(t *testing.T) { - - chans := newChannels[int, int](1) - - chanKey := 1 - rx := chans.add(chanKey) - - chans.clear() - assert.Equal(t, chans.length(), 0, "channels should be empty") - - _, ok := <-rx - assert.False(t, ok, "chan closed") - - tx := chans.remove(chanKey) - assert.Nil(t, tx) - - err := chans.notify(chanKey, 1) - assert.NotNil(t, err) -} diff --git a/client/client.go b/jsonrpc/client/client.go index 7379e43..a34828a 100644 --- a/client/client.go +++ b/jsonrpc/client/client.go @@ -11,33 +11,33 @@ import ( "github.com/gorilla/websocket" log "github.com/sirupsen/logrus" - "github.com/karyontech/karyon-go/message" + "github.com/karyontech/karyon-go/jsonrpc/message" ) const ( - // JsonRPCVersion defines the version of the JSON-RPC protocol being used. + // JsonRPCVersion Defines the version of the JSON-RPC protocol being used. JsonRPCVersion = "2.0" // Default timeout for receiving requests from the server, in milliseconds. DefaultTimeout = 3000 ) -// RPCClientConfig holds the configuration settings for the RPC client. +// RPCClientConfig Holds the configuration settings for the RPC client. type RPCClientConfig struct { Timeout int // Timeout for receiving requests from the server, in milliseconds. Addr string // Address of the RPC server. } -// RPCClient +// RPCClient RPC Client type RPCClient struct { - config RPCClientConfig - conn *websocket.Conn - request_chans channels[message.RequestID, message.Response] - subscriptions channels[message.SubscriptionID, json.RawMessage] - stop_signal chan struct{} + config RPCClientConfig + conn *websocket.Conn + requests messageDispatcher[message.RequestID, message.Response] + subscriber messageDispatcher[message.SubscriptionID, json.RawMessage] + stop_signal chan struct{} } -// NewRPCClient creates a new instance of RPCClient with the provided configuration. +// NewRPCClient Creates a new instance of RPCClient with the provided configuration. // It establishes a WebSocket connection to the RPC server. func NewRPCClient(config RPCClientConfig) (*RPCClient, error) { conn, _, err := websocket.DefaultDialer.Dial(config.Addr, nil) @@ -53,11 +53,11 @@ func NewRPCClient(config RPCClientConfig) (*RPCClient, error) { stop_signal := make(chan struct{}, 2) client := &RPCClient{ - conn: conn, - config: config, - request_chans: newChannels[message.RequestID, message.Response](1), - subscriptions: newChannels[message.SubscriptionID, json.RawMessage](10), - stop_signal: stop_signal, + conn: conn, + config: config, + requests: newMessageDispatcher[message.RequestID, message.Response](1), + subscriber: newMessageDispatcher[message.SubscriptionID, json.RawMessage](10), + stop_signal: stop_signal, } go func() { @@ -69,30 +69,27 @@ func NewRPCClient(config RPCClientConfig) (*RPCClient, error) { return client, nil } -// Close closes the underlying websocket connection and stop the receiving loop. +// Close Closes the underlying websocket connection and stop the receiving loop. func (client *RPCClient) Close() { log.Warn("Close the rpc client...") + // Send stop signal to the background receiving loop client.stop_signal <- struct{}{} + // Close the underlying websocket connection err := client.conn.Close() if err != nil { log.WithError(err).Error("Close websocket connection") } - client.request_chans.clear() - client.subscriptions.clear() + client.requests.clear() + client.subscriber.clear() } -// Call sends an RPC call to the server with the specified method and parameters. -// It returns the response from the server. +// Call Sends an RPC call to the server with the specified method and +// parameters, and returns the response. func (client *RPCClient) Call(method string, params any) (*json.RawMessage, error) { log.Tracef("Call -> method: %s, params: %v", method, params) - param_raw, err := json.Marshal(params) - if err != nil { - return nil, err - } - - response, err := client.sendRequest(method, param_raw) + response, err := client.sendRequest(method, params) if err != nil { return nil, err } @@ -100,16 +97,12 @@ func (client *RPCClient) Call(method string, params any) (*json.RawMessage, erro return response.Result, nil } -// Subscribe sends a subscription request to the server with the specified method and parameters. -// It returns the subscription ID and a channel to receive notifications. +// Subscribe Sends a subscription request to the server with the specified +// method and parameters, and it returns the subscription ID and the channel to +// receive notifications. func (client *RPCClient) Subscribe(method string, params any) (message.SubscriptionID, <-chan json.RawMessage, error) { log.Tracef("Sbuscribe -> method: %s, params: %v", method, params) - param_raw, err := json.Marshal(params) - if err != nil { - return 0, nil, err - } - - response, err := client.sendRequest(method, param_raw) + response, err := client.sendRequest(method, params) if err != nil { return 0, nil, err } @@ -124,57 +117,68 @@ func (client *RPCClient) Subscribe(method string, params any) (message.Subscript return 0, nil, err } - ch := client.subscriptions.add(subID) + // Register a new subscription + sub := client.subscriber.register(subID) - return subID, ch, nil + return subID, sub, nil } -// Unsubscribe sends an unsubscription request to the server to cancel the given subscription. +// Unsubscribe Sends an unsubscription request to the server to cancel the +// given subscription. func (client *RPCClient) Unsubscribe(method string, subID message.SubscriptionID) error { log.Tracef("Unsubscribe -> method: %s, subID: %d", method, subID) - subIDJSON, err := json.Marshal(subID) - if err != nil { - return err - } - - _, err = client.sendRequest(method, subIDJSON) + _, err := client.sendRequest(method, subID) if err != nil { return err } - // on success remove the subscription from the map - client.subscriptions.remove(subID) + // On success unregister the subscription channel + client.subscriber.unregister(subID) return nil } -// backgroundReceivingLoop starts reading new messages from the underlying connection. +// backgroundReceivingLoop Starts reading new messages from the underlying connection. func (client *RPCClient) backgroundReceivingLoop(stop_signal <-chan struct{}) error { + log.Debug("Background loop started") + + new_msg_ch := make(chan []byte) + receive_err_ch := make(chan error) + + // Start listing for new messages + go func() { + for { + _, msg, err := client.conn.ReadMessage() + if err != nil { + receive_err_ch <- err + return + } + new_msg_ch <- msg + } + }() + for { select { case <-stop_signal: log.Warn("Stopping background receiving loop: received stop signal") return nil - default: - _, msg, err := client.conn.ReadMessage() - if err != nil { - log.WithError(err).Error("Receive a new msg") - return err - } - - err = client.handleNewMsg(msg) + case msg := <-new_msg_ch: + err := client.handleNewMsg(msg) if err != nil { log.WithError(err).Error("Handle a new received msg") } + case err := <-receive_err_ch: + log.WithError(err).Error("Receive a new msg") + return err } } } -// handleNewMsg attempts to decode the received message into either a Response -// or Notification struct. +// handleNewMsg Attempts to decode the received message into either a Response +// or Notification. func (client *RPCClient) handleNewMsg(msg []byte) error { - // try to decode the msg into message.Response + // Check if the received message is of type Response response := message.Response{} decoder := json.NewDecoder(bytes.NewReader(msg)) decoder.DisallowUnknownFields() @@ -183,28 +187,27 @@ func (client *RPCClient) handleNewMsg(msg []byte) error { return fmt.Errorf("Response doesn't have an id") } - if v := client.request_chans.remove(*response.ID); v != nil { - v <- response + err := client.requests.disptach(*response.ID, response) + if err != nil { + return fmt.Errorf("Dispatch a response: %w", err) } return nil } - // try to decode the msg into message.Notification + // Check if the received message is of type Notification notification := message.Notification{} if err := json.Unmarshal(msg, ¬ification); err == nil { - notificationResult := message.NotificationResult{} - if err := json.Unmarshal(*notification.Params, ¬ificationResult); err != nil { + ntRes := message.NotificationResult{} + if err := json.Unmarshal(*notification.Params, &ntRes); err != nil { return fmt.Errorf("Failed to unmarshal notification params: %w", err) } - err := client.subscriptions.notify( - notificationResult.Subscription, - *notificationResult.Result, - ) + // Send the notification to the subscription + err := client.subscriber.disptach(ntRes.Subscription, *ntRes.Result) if err != nil { - return fmt.Errorf("Notify a subscriber: %w", err) + return fmt.Errorf("Dispatch a notification: %w", err) } log.Debugf("<-- %s", notification.String()) @@ -215,11 +218,19 @@ func (client *RPCClient) handleNewMsg(msg []byte) error { return fmt.Errorf("Receive unexpected msg: %s", msg) } -// sendRequest sends a request and wait the response -func (client *RPCClient) sendRequest(method string, params []byte) (message.Response, error) { - id := strconv.Itoa(rand.Int()) - params_raw := json.RawMessage(params) +// sendRequest Sends a request and wait for the response +func (client *RPCClient) sendRequest(method string, params any) (message.Response, error) { + response := message.Response{} + + params_bytes, err := json.Marshal(params) + if err != nil { + return response, err + } + + params_raw := json.RawMessage(params_bytes) + // Generate a new id + id := strconv.Itoa(rand.Int()) req := message.Request{ JSONRPC: JsonRPCVersion, ID: id, @@ -227,8 +238,6 @@ func (client *RPCClient) sendRequest(method string, params []byte) (message.Resp Params: ¶ms_raw, } - response := message.Response{} - reqJSON, err := json.Marshal(req) if err != nil { return response, err @@ -241,13 +250,14 @@ func (client *RPCClient) sendRequest(method string, params []byte) (message.Resp log.Debugf("--> %s", req.String()) - req_chan := client.request_chans.add(id) + rx_ch := client.requests.register(id) + defer client.requests.unregister(id) - response, err = client.waitResponse(req_chan) - if err != nil { - log.WithError(err).Errorf("Receive a response from the server") - client.request_chans.remove(id) - return response, err + // Waits the response, it fails and return error if it exceed the timeout + select { + case response = <-rx_ch: + case <-time.After(time.Duration(client.config.Timeout) * time.Millisecond): + return response, fmt.Errorf("Timeout error") } err = validateResponse(&response, id) @@ -260,17 +270,6 @@ func (client *RPCClient) sendRequest(method string, params []byte) (message.Resp return response, nil } -// waitResponse waits the response, it fails and return error if it exceed the timeout -func (client *RPCClient) waitResponse(ch <-chan message.Response) (message.Response, error) { - response := message.Response{} - select { - case response = <-ch: - return response, nil - case <-time.After(time.Duration(client.config.Timeout) * time.Millisecond): - return response, fmt.Errorf("Timeout error") - } -} - // validateResponse Checks the error field and whether the request id is the // same as the response id func validateResponse(res *message.Response, reqID message.RequestID) error { diff --git a/jsonrpc/client/message_dispatcher.go b/jsonrpc/client/message_dispatcher.go new file mode 100644 index 0000000..6484484 --- /dev/null +++ b/jsonrpc/client/message_dispatcher.go @@ -0,0 +1,75 @@ +package client + +import ( + "fmt" + "sync" +) + +// messageDispatcher Is a generic structure that holds a map of keys and +// channels, and it is protected by mutex +type messageDispatcher[K comparable, V any] struct { + sync.Mutex + chans map[K]chan<- V + bufferSize int +} + +// newMessageDispatcher Creates a new messageDispatcher +func newMessageDispatcher[K comparable, V any](bufferSize int) messageDispatcher[K, V] { + chans := make(map[K]chan<- V) + return messageDispatcher[K, V]{ + chans: chans, + bufferSize: bufferSize, + } +} + +// register Registers a new channel with a given key. It returns the receiving channel. +func (c *messageDispatcher[K, V]) register(key K) <-chan V { + c.Lock() + defer c.Unlock() + + ch := make(chan V, c.bufferSize) + c.chans[key] = ch + return ch +} + +// length Returns the number of channels +func (c *messageDispatcher[K, V]) length() int { + c.Lock() + defer c.Unlock() + + return len(c.chans) +} + +// disptach Disptaches the msg to the channel with the given key +func (c *messageDispatcher[K, V]) disptach(key K, msg V) error { + c.Lock() + defer c.Unlock() + + if ch, ok := c.chans[key]; ok { + ch <- msg + return nil + } + + return fmt.Errorf("Channel not found") +} + +// unregister Unregisters the channel with the provided key +func (c *messageDispatcher[K, V]) unregister(key K) { + c.Lock() + defer c.Unlock() + if ch, ok := c.chans[key]; ok { + close(ch) + delete(c.chans, key) + } +} + +// clear Closes all the channels and remove them from the map +func (c *messageDispatcher[K, V]) clear() { + c.Lock() + defer c.Unlock() + + for k, ch := range c.chans { + close(ch) + delete(c.chans, k) + } +} diff --git a/jsonrpc/client/message_dispatcher_test.go b/jsonrpc/client/message_dispatcher_test.go new file mode 100644 index 0000000..7cc1366 --- /dev/null +++ b/jsonrpc/client/message_dispatcher_test.go @@ -0,0 +1,101 @@ +package client + +import ( + // "sync" + // "sync/atomic" + + "sync" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDispatchToChannel(t *testing.T) { + + messageDispatcher := newMessageDispatcher[int, int](10) + + chanKey := 1 + rx := messageDispatcher.register(chanKey) + + chanKey2 := 2 + rx2 := messageDispatcher.register(chanKey2) + + var wg sync.WaitGroup + + wg.Add(1) + go func() { + for i := 0; i < 50; i++ { + err := messageDispatcher.disptach(chanKey, i) + assert.Nil(t, err) + } + + messageDispatcher.unregister(chanKey) + wg.Done() + }() + + wg.Add(1) + go func() { + for i := 0; i < 50; i++ { + err := messageDispatcher.disptach(chanKey2, i) + assert.Nil(t, err) + } + + messageDispatcher.unregister(chanKey2) + wg.Done() + }() + + var receivedItem atomic.Int32 + + wg.Add(1) + go func() { + for range rx { + receivedItem.Add(1) + } + wg.Done() + }() + + wg.Add(1) + go func() { + for range rx2 { + receivedItem.Add(1) + } + wg.Done() + }() + + wg.Wait() + assert.Equal(t, receivedItem.Load(), int32(100)) +} + +func TestUnregisterChannel(t *testing.T) { + messageDispatcher := newMessageDispatcher[int, int](1) + + chanKey := 1 + rx := messageDispatcher.register(chanKey) + + messageDispatcher.unregister(chanKey) + assert.Equal(t, messageDispatcher.length(), 0, "channels should be empty") + + _, ok := <-rx + assert.False(t, ok, "chan closed") + + err := messageDispatcher.disptach(chanKey, 1) + assert.NotNil(t, err) +} + +func TestClearChannels(t *testing.T) { + + messageDispatcher := newMessageDispatcher[int, int](1) + + chanKey := 1 + rx := messageDispatcher.register(chanKey) + + messageDispatcher.clear() + assert.Equal(t, messageDispatcher.length(), 0, "channels should be empty") + + _, ok := <-rx + assert.False(t, ok, "chan closed") + + err := messageDispatcher.disptach(chanKey, 1) + assert.NotNil(t, err) +} diff --git a/message/message.go b/jsonrpc/message/message.go index 33a18c5..ff68da2 100644 --- a/message/message.go +++ b/jsonrpc/message/message.go @@ -65,5 +65,5 @@ func (nt *Notification) String() string { } func (err *Error) String() string { - return fmt.Sprintf("{CODE: %d, MESSAGE: %s, DATA: %s}", err.Code, err.Message, *err.Data) + return fmt.Sprintf("{CODE: %d, MESSAGE: %s, DATA: %b}", err.Code, err.Message, err.Data) } |