aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorhozan23 <hozan23@karyontech.net>2024-06-13 06:02:24 +0200
committerhozan23 <hozan23@karyontech.net>2024-06-13 06:02:24 +0200
commit8c2d37e093ca64d591fc0aec15a7e2ed424b2e47 (patch)
treefd9bc62e29087a18e7eb4bdd0a1f587ba63e5dd5
parenta338905a7f8a2206161cc15f07bda872b9bfc09c (diff)
use message dispatcher to process responses and notifications & spread out comments
-rw-r--r--client/channels.go78
-rw-r--r--client/channels_test.go110
-rw-r--r--jsonrpc/client/client.go (renamed from client/client.go)179
-rw-r--r--jsonrpc/client/message_dispatcher.go75
-rw-r--r--jsonrpc/client/message_dispatcher_test.go101
-rw-r--r--jsonrpc/message/message.go (renamed from message/message.go)2
6 files changed, 266 insertions, 279 deletions
diff --git a/client/channels.go b/client/channels.go
deleted file mode 100644
index 673e366..0000000
--- a/client/channels.go
+++ /dev/null
@@ -1,78 +0,0 @@
-package client
-
-import (
- "fmt"
- "sync"
-)
-
-// channels is a generic structure that holds a map of keys and channels.
-// It is protected by mutex
-type channels[K comparable, V any] struct {
- sync.Mutex
- chans map[K]chan<- V
- bufferSize int
-}
-
-// newChannels creates a new channels
-func newChannels[K comparable, V any](bufferSize int) channels[K, V] {
- chans := make(map[K]chan<- V)
- return channels[K, V]{
- chans: chans,
- bufferSize: bufferSize,
- }
-}
-
-// add adds a new channel and returns the receiving channel
-func (c *channels[K, V]) add(key K) <-chan V {
- c.Lock()
- defer c.Unlock()
-
- ch := make(chan V, c.bufferSize)
- c.chans[key] = ch
- return ch
-}
-
-// length returns the number of channels
-func (c *channels[K, V]) length() int {
- c.Lock()
- defer c.Unlock()
-
- return len(c.chans)
-}
-
-// notify notifies the channel with the given key
-func (c *channels[K, V]) notify(key K, msg V) error {
- c.Lock()
- defer c.Unlock()
-
- if ch, ok := c.chans[key]; ok {
- ch <- msg
- return nil
- }
-
- return fmt.Errorf("Channel not found")
-}
-
-// remove removes and returns the channel.
-func (c *channels[K, V]) remove(key K) chan<- V {
- c.Lock()
- defer c.Unlock()
-
- if ch, ok := c.chans[key]; ok {
- delete(c.chans, key)
- return ch
- }
-
- return nil
-}
-
-// clear close all the channels and remove them from the map
-func (c *channels[K, V]) clear() {
- c.Lock()
- defer c.Unlock()
-
- for k, ch := range c.chans {
- close(ch)
- delete(c.chans, k)
- }
-}
diff --git a/client/channels_test.go b/client/channels_test.go
deleted file mode 100644
index 4465fed..0000000
--- a/client/channels_test.go
+++ /dev/null
@@ -1,110 +0,0 @@
-package client
-
-import (
- "sync"
- "sync/atomic"
- "testing"
-
- "github.com/stretchr/testify/assert"
-)
-
-func TestNotifyChannel(t *testing.T) {
-
- chans := newChannels[int, int](10)
-
- chanKey := 1
- rx := chans.add(chanKey)
-
- chanKey2 := 2
- rx2 := chans.add(chanKey2)
-
- var wg sync.WaitGroup
-
- wg.Add(1)
- go func() {
- for i := 0; i < 50; i++ {
- err := chans.notify(chanKey, i)
- assert.Nil(t, err)
- }
-
- // drop the channel
- tx := chans.remove(chanKey)
- close(tx)
- wg.Done()
- }()
-
- wg.Add(1)
- go func() {
- for i := 0; i < 50; i++ {
- err := chans.notify(chanKey2, i)
- assert.Nil(t, err)
- }
-
- // drop the channel
- tx := chans.remove(chanKey2)
- close(tx)
- wg.Done()
- }()
-
- var receivedItem atomic.Int32
-
- wg.Add(1)
- go func() {
- for range rx {
- receivedItem.Add(1)
- }
- wg.Done()
- }()
-
- wg.Add(1)
- go func() {
- for range rx2 {
- receivedItem.Add(1)
- }
- wg.Done()
- }()
-
- wg.Wait()
- assert.Equal(t, receivedItem.Load(), int32(100))
-}
-
-func TestRemoveChannel(t *testing.T) {
-
- chans := newChannels[int, int](1)
-
- chanKey := 1
- rx := chans.add(chanKey)
-
- tx := chans.remove(chanKey)
- assert.Equal(t, chans.length(), 0, "channels should be empty")
-
- tx <- 3
- val := <-rx
- assert.Equal(t, val, 3)
-
- tx = chans.remove(chanKey)
- assert.Nil(t, tx)
-
- err := chans.notify(chanKey, 1)
- assert.NotNil(t, err)
-}
-
-func TestClearChannels(t *testing.T) {
-
- chans := newChannels[int, int](1)
-
- chanKey := 1
- rx := chans.add(chanKey)
-
- chans.clear()
- assert.Equal(t, chans.length(), 0, "channels should be empty")
-
- _, ok := <-rx
- assert.False(t, ok, "chan closed")
-
- tx := chans.remove(chanKey)
- assert.Nil(t, tx)
-
- err := chans.notify(chanKey, 1)
- assert.NotNil(t, err)
-}
diff --git a/client/client.go b/jsonrpc/client/client.go
index 7379e43..a34828a 100644
--- a/client/client.go
+++ b/jsonrpc/client/client.go
@@ -11,33 +11,33 @@ import (
"github.com/gorilla/websocket"
log "github.com/sirupsen/logrus"
- "github.com/karyontech/karyon-go/message"
+ "github.com/karyontech/karyon-go/jsonrpc/message"
)
const (
- // JsonRPCVersion defines the version of the JSON-RPC protocol being used.
+ // JsonRPCVersion Defines the version of the JSON-RPC protocol being used.
JsonRPCVersion = "2.0"
// Default timeout for receiving requests from the server, in milliseconds.
DefaultTimeout = 3000
)
-// RPCClientConfig holds the configuration settings for the RPC client.
+// RPCClientConfig Holds the configuration settings for the RPC client.
type RPCClientConfig struct {
Timeout int // Timeout for receiving requests from the server, in milliseconds.
Addr string // Address of the RPC server.
}
-// RPCClient
+// RPCClient RPC Client
type RPCClient struct {
- config RPCClientConfig
- conn *websocket.Conn
- request_chans channels[message.RequestID, message.Response]
- subscriptions channels[message.SubscriptionID, json.RawMessage]
- stop_signal chan struct{}
+ config RPCClientConfig
+ conn *websocket.Conn
+ requests messageDispatcher[message.RequestID, message.Response]
+ subscriber messageDispatcher[message.SubscriptionID, json.RawMessage]
+ stop_signal chan struct{}
}
-// NewRPCClient creates a new instance of RPCClient with the provided configuration.
+// NewRPCClient Creates a new instance of RPCClient with the provided configuration.
// It establishes a WebSocket connection to the RPC server.
func NewRPCClient(config RPCClientConfig) (*RPCClient, error) {
conn, _, err := websocket.DefaultDialer.Dial(config.Addr, nil)
@@ -53,11 +53,11 @@ func NewRPCClient(config RPCClientConfig) (*RPCClient, error) {
stop_signal := make(chan struct{}, 2)
client := &RPCClient{
- conn: conn,
- config: config,
- request_chans: newChannels[message.RequestID, message.Response](1),
- subscriptions: newChannels[message.SubscriptionID, json.RawMessage](10),
- stop_signal: stop_signal,
+ conn: conn,
+ config: config,
+ requests: newMessageDispatcher[message.RequestID, message.Response](1),
+ subscriber: newMessageDispatcher[message.SubscriptionID, json.RawMessage](10),
+ stop_signal: stop_signal,
}
go func() {
@@ -69,30 +69,27 @@ func NewRPCClient(config RPCClientConfig) (*RPCClient, error) {
return client, nil
}
-// Close closes the underlying websocket connection and stop the receiving loop.
+// Close Closes the underlying websocket connection and stop the receiving loop.
func (client *RPCClient) Close() {
log.Warn("Close the rpc client...")
+ // Send stop signal to the background receiving loop
client.stop_signal <- struct{}{}
+ // Close the underlying websocket connection
err := client.conn.Close()
if err != nil {
log.WithError(err).Error("Close websocket connection")
}
- client.request_chans.clear()
- client.subscriptions.clear()
+ client.requests.clear()
+ client.subscriber.clear()
}
-// Call sends an RPC call to the server with the specified method and parameters.
-// It returns the response from the server.
+// Call Sends an RPC call to the server with the specified method and
+// parameters, and returns the response.
func (client *RPCClient) Call(method string, params any) (*json.RawMessage, error) {
log.Tracef("Call -> method: %s, params: %v", method, params)
- param_raw, err := json.Marshal(params)
- if err != nil {
- return nil, err
- }
-
- response, err := client.sendRequest(method, param_raw)
+ response, err := client.sendRequest(method, params)
if err != nil {
return nil, err
}
@@ -100,16 +97,12 @@ func (client *RPCClient) Call(method string, params any) (*json.RawMessage, erro
return response.Result, nil
}
-// Subscribe sends a subscription request to the server with the specified method and parameters.
-// It returns the subscription ID and a channel to receive notifications.
+// Subscribe Sends a subscription request to the server with the specified
+// method and parameters, and it returns the subscription ID and the channel to
+// receive notifications.
func (client *RPCClient) Subscribe(method string, params any) (message.SubscriptionID, <-chan json.RawMessage, error) {
log.Tracef("Sbuscribe -> method: %s, params: %v", method, params)
- param_raw, err := json.Marshal(params)
- if err != nil {
- return 0, nil, err
- }
-
- response, err := client.sendRequest(method, param_raw)
+ response, err := client.sendRequest(method, params)
if err != nil {
return 0, nil, err
}
@@ -124,57 +117,68 @@ func (client *RPCClient) Subscribe(method string, params any) (message.Subscript
return 0, nil, err
}
- ch := client.subscriptions.add(subID)
+ // Register a new subscription
+ sub := client.subscriber.register(subID)
- return subID, ch, nil
+ return subID, sub, nil
}
-// Unsubscribe sends an unsubscription request to the server to cancel the given subscription.
+// Unsubscribe Sends an unsubscription request to the server to cancel the
+// given subscription.
func (client *RPCClient) Unsubscribe(method string, subID message.SubscriptionID) error {
log.Tracef("Unsubscribe -> method: %s, subID: %d", method, subID)
- subIDJSON, err := json.Marshal(subID)
- if err != nil {
- return err
- }
-
- _, err = client.sendRequest(method, subIDJSON)
+ _, err := client.sendRequest(method, subID)
if err != nil {
return err
}
- // on success remove the subscription from the map
- client.subscriptions.remove(subID)
+ // On success unregister the subscription channel
+ client.subscriber.unregister(subID)
return nil
}
-// backgroundReceivingLoop starts reading new messages from the underlying connection.
+// backgroundReceivingLoop Starts reading new messages from the underlying connection.
func (client *RPCClient) backgroundReceivingLoop(stop_signal <-chan struct{}) error {
+
log.Debug("Background loop started")
+
+ new_msg_ch := make(chan []byte)
+ receive_err_ch := make(chan error)
+
+ // Start listing for new messages
+ go func() {
+ for {
+ _, msg, err := client.conn.ReadMessage()
+ if err != nil {
+ receive_err_ch <- err
+ return
+ }
+ new_msg_ch <- msg
+ }
+ }()
+
for {
select {
case <-stop_signal:
log.Warn("Stopping background receiving loop: received stop signal")
return nil
- default:
- _, msg, err := client.conn.ReadMessage()
- if err != nil {
- log.WithError(err).Error("Receive a new msg")
- return err
- }
-
- err = client.handleNewMsg(msg)
+ case msg := <-new_msg_ch:
+ err := client.handleNewMsg(msg)
if err != nil {
log.WithError(err).Error("Handle a new received msg")
}
+ case err := <-receive_err_ch:
+ log.WithError(err).Error("Receive a new msg")
+ return err
}
}
}
-// handleNewMsg attempts to decode the received message into either a Response
-// or Notification struct.
+// handleNewMsg Attempts to decode the received message into either a Response
+// or Notification.
func (client *RPCClient) handleNewMsg(msg []byte) error {
- // try to decode the msg into message.Response
+ // Check if the received message is of type Response
response := message.Response{}
decoder := json.NewDecoder(bytes.NewReader(msg))
decoder.DisallowUnknownFields()
@@ -183,28 +187,27 @@ func (client *RPCClient) handleNewMsg(msg []byte) error {
return fmt.Errorf("Response doesn't have an id")
}
- if v := client.request_chans.remove(*response.ID); v != nil {
- v <- response
+ err := client.requests.disptach(*response.ID, response)
+ if err != nil {
+ return fmt.Errorf("Dispatch a response: %w", err)
}
return nil
}
- // try to decode the msg into message.Notification
+ // Check if the received message is of type Notification
notification := message.Notification{}
if err := json.Unmarshal(msg, &notification); err == nil {
- notificationResult := message.NotificationResult{}
- if err := json.Unmarshal(*notification.Params, &notificationResult); err != nil {
+ ntRes := message.NotificationResult{}
+ if err := json.Unmarshal(*notification.Params, &ntRes); err != nil {
return fmt.Errorf("Failed to unmarshal notification params: %w", err)
}
- err := client.subscriptions.notify(
- notificationResult.Subscription,
- *notificationResult.Result,
- )
+ // Send the notification to the subscription
+ err := client.subscriber.disptach(ntRes.Subscription, *ntRes.Result)
if err != nil {
- return fmt.Errorf("Notify a subscriber: %w", err)
+ return fmt.Errorf("Dispatch a notification: %w", err)
}
log.Debugf("<-- %s", notification.String())
@@ -215,11 +218,19 @@ func (client *RPCClient) handleNewMsg(msg []byte) error {
return fmt.Errorf("Receive unexpected msg: %s", msg)
}
-// sendRequest sends a request and wait the response
-func (client *RPCClient) sendRequest(method string, params []byte) (message.Response, error) {
- id := strconv.Itoa(rand.Int())
- params_raw := json.RawMessage(params)
+// sendRequest Sends a request and wait for the response
+func (client *RPCClient) sendRequest(method string, params any) (message.Response, error) {
+ response := message.Response{}
+
+ params_bytes, err := json.Marshal(params)
+ if err != nil {
+ return response, err
+ }
+
+ params_raw := json.RawMessage(params_bytes)
+ // Generate a new id
+ id := strconv.Itoa(rand.Int())
req := message.Request{
JSONRPC: JsonRPCVersion,
ID: id,
@@ -227,8 +238,6 @@ func (client *RPCClient) sendRequest(method string, params []byte) (message.Resp
Params: &params_raw,
}
- response := message.Response{}
-
reqJSON, err := json.Marshal(req)
if err != nil {
return response, err
@@ -241,13 +250,14 @@ func (client *RPCClient) sendRequest(method string, params []byte) (message.Resp
log.Debugf("--> %s", req.String())
- req_chan := client.request_chans.add(id)
+ rx_ch := client.requests.register(id)
+ defer client.requests.unregister(id)
- response, err = client.waitResponse(req_chan)
- if err != nil {
- log.WithError(err).Errorf("Receive a response from the server")
- client.request_chans.remove(id)
- return response, err
+ // Waits the response, it fails and return error if it exceed the timeout
+ select {
+ case response = <-rx_ch:
+ case <-time.After(time.Duration(client.config.Timeout) * time.Millisecond):
+ return response, fmt.Errorf("Timeout error")
}
err = validateResponse(&response, id)
@@ -260,17 +270,6 @@ func (client *RPCClient) sendRequest(method string, params []byte) (message.Resp
return response, nil
}
-// waitResponse waits the response, it fails and return error if it exceed the timeout
-func (client *RPCClient) waitResponse(ch <-chan message.Response) (message.Response, error) {
- response := message.Response{}
- select {
- case response = <-ch:
- return response, nil
- case <-time.After(time.Duration(client.config.Timeout) * time.Millisecond):
- return response, fmt.Errorf("Timeout error")
- }
-}
-
// validateResponse Checks the error field and whether the request id is the
// same as the response id
func validateResponse(res *message.Response, reqID message.RequestID) error {
diff --git a/jsonrpc/client/message_dispatcher.go b/jsonrpc/client/message_dispatcher.go
new file mode 100644
index 0000000..6484484
--- /dev/null
+++ b/jsonrpc/client/message_dispatcher.go
@@ -0,0 +1,75 @@
+package client
+
+import (
+ "fmt"
+ "sync"
+)
+
+// messageDispatcher Is a generic structure that holds a map of keys and
+// channels, and it is protected by mutex
+type messageDispatcher[K comparable, V any] struct {
+ sync.Mutex
+ chans map[K]chan<- V
+ bufferSize int
+}
+
+// newMessageDispatcher Creates a new messageDispatcher
+func newMessageDispatcher[K comparable, V any](bufferSize int) messageDispatcher[K, V] {
+ chans := make(map[K]chan<- V)
+ return messageDispatcher[K, V]{
+ chans: chans,
+ bufferSize: bufferSize,
+ }
+}
+
+// register Registers a new channel with a given key. It returns the receiving channel.
+func (c *messageDispatcher[K, V]) register(key K) <-chan V {
+ c.Lock()
+ defer c.Unlock()
+
+ ch := make(chan V, c.bufferSize)
+ c.chans[key] = ch
+ return ch
+}
+
+// length Returns the number of channels
+func (c *messageDispatcher[K, V]) length() int {
+ c.Lock()
+ defer c.Unlock()
+
+ return len(c.chans)
+}
+
+// disptach Disptaches the msg to the channel with the given key
+func (c *messageDispatcher[K, V]) disptach(key K, msg V) error {
+ c.Lock()
+ defer c.Unlock()
+
+ if ch, ok := c.chans[key]; ok {
+ ch <- msg
+ return nil
+ }
+
+ return fmt.Errorf("Channel not found")
+}
+
+// unregister Unregisters the channel with the provided key
+func (c *messageDispatcher[K, V]) unregister(key K) {
+ c.Lock()
+ defer c.Unlock()
+ if ch, ok := c.chans[key]; ok {
+ close(ch)
+ delete(c.chans, key)
+ }
+}
+
+// clear Closes all the channels and remove them from the map
+func (c *messageDispatcher[K, V]) clear() {
+ c.Lock()
+ defer c.Unlock()
+
+ for k, ch := range c.chans {
+ close(ch)
+ delete(c.chans, k)
+ }
+}
diff --git a/jsonrpc/client/message_dispatcher_test.go b/jsonrpc/client/message_dispatcher_test.go
new file mode 100644
index 0000000..7cc1366
--- /dev/null
+++ b/jsonrpc/client/message_dispatcher_test.go
@@ -0,0 +1,101 @@
+package client
+
+import (
+ // "sync"
+ // "sync/atomic"
+
+ "sync"
+ "sync/atomic"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestDispatchToChannel(t *testing.T) {
+
+ messageDispatcher := newMessageDispatcher[int, int](10)
+
+ chanKey := 1
+ rx := messageDispatcher.register(chanKey)
+
+ chanKey2 := 2
+ rx2 := messageDispatcher.register(chanKey2)
+
+ var wg sync.WaitGroup
+
+ wg.Add(1)
+ go func() {
+ for i := 0; i < 50; i++ {
+ err := messageDispatcher.disptach(chanKey, i)
+ assert.Nil(t, err)
+ }
+
+ messageDispatcher.unregister(chanKey)
+ wg.Done()
+ }()
+
+ wg.Add(1)
+ go func() {
+ for i := 0; i < 50; i++ {
+ err := messageDispatcher.disptach(chanKey2, i)
+ assert.Nil(t, err)
+ }
+
+ messageDispatcher.unregister(chanKey2)
+ wg.Done()
+ }()
+
+ var receivedItem atomic.Int32
+
+ wg.Add(1)
+ go func() {
+ for range rx {
+ receivedItem.Add(1)
+ }
+ wg.Done()
+ }()
+
+ wg.Add(1)
+ go func() {
+ for range rx2 {
+ receivedItem.Add(1)
+ }
+ wg.Done()
+ }()
+
+ wg.Wait()
+ assert.Equal(t, receivedItem.Load(), int32(100))
+}
+
+func TestUnregisterChannel(t *testing.T) {
+ messageDispatcher := newMessageDispatcher[int, int](1)
+
+ chanKey := 1
+ rx := messageDispatcher.register(chanKey)
+
+ messageDispatcher.unregister(chanKey)
+ assert.Equal(t, messageDispatcher.length(), 0, "channels should be empty")
+
+ _, ok := <-rx
+ assert.False(t, ok, "chan closed")
+
+ err := messageDispatcher.disptach(chanKey, 1)
+ assert.NotNil(t, err)
+}
+
+func TestClearChannels(t *testing.T) {
+
+ messageDispatcher := newMessageDispatcher[int, int](1)
+
+ chanKey := 1
+ rx := messageDispatcher.register(chanKey)
+
+ messageDispatcher.clear()
+ assert.Equal(t, messageDispatcher.length(), 0, "channels should be empty")
+
+ _, ok := <-rx
+ assert.False(t, ok, "chan closed")
+
+ err := messageDispatcher.disptach(chanKey, 1)
+ assert.NotNil(t, err)
+}
diff --git a/message/message.go b/jsonrpc/message/message.go
index 33a18c5..ff68da2 100644
--- a/message/message.go
+++ b/jsonrpc/message/message.go
@@ -65,5 +65,5 @@ func (nt *Notification) String() string {
}
func (err *Error) String() string {
- return fmt.Sprintf("{CODE: %d, MESSAGE: %s, DATA: %s}", err.Code, err.Message, *err.Data)
+ return fmt.Sprintf("{CODE: %d, MESSAGE: %s, DATA: %b}", err.Code, err.Message, err.Data)
}