aboutsummaryrefslogtreecommitdiff
path: root/client
diff options
context:
space:
mode:
authorhozan23 <hozan23@karyontech.net>2024-05-31 02:17:56 +0200
committerhozan23 <hozan23@karyontech.net>2024-05-31 02:17:56 +0200
commitfa0b0efc14f84ff87789cabe0010f3240245407c (patch)
tree6b0b40e81a7b589511a0bd5cd0be7ab785ec5c96 /client
init commit
Diffstat (limited to 'client')
-rw-r--r--client/channels.go78
-rw-r--r--client/channels_test.go110
-rw-r--r--client/client.go289
3 files changed, 477 insertions, 0 deletions
diff --git a/client/channels.go b/client/channels.go
new file mode 100644
index 0000000..673e366
--- /dev/null
+++ b/client/channels.go
@@ -0,0 +1,78 @@
+package client
+
+import (
+ "fmt"
+ "sync"
+)
+
+// channels is a generic structure that holds a map of keys and channels.
+// It is protected by mutex
+type channels[K comparable, V any] struct {
+ sync.Mutex
+ chans map[K]chan<- V
+ bufferSize int
+}
+
+// newChannels creates a new channels
+func newChannels[K comparable, V any](bufferSize int) channels[K, V] {
+ chans := make(map[K]chan<- V)
+ return channels[K, V]{
+ chans: chans,
+ bufferSize: bufferSize,
+ }
+}
+
+// add adds a new channel and returns the receiving channel
+func (c *channels[K, V]) add(key K) <-chan V {
+ c.Lock()
+ defer c.Unlock()
+
+ ch := make(chan V, c.bufferSize)
+ c.chans[key] = ch
+ return ch
+}
+
+// length returns the number of channels
+func (c *channels[K, V]) length() int {
+ c.Lock()
+ defer c.Unlock()
+
+ return len(c.chans)
+}
+
+// notify notifies the channel with the given key
+func (c *channels[K, V]) notify(key K, msg V) error {
+ c.Lock()
+ defer c.Unlock()
+
+ if ch, ok := c.chans[key]; ok {
+ ch <- msg
+ return nil
+ }
+
+ return fmt.Errorf("Channel not found")
+}
+
+// remove removes and returns the channel.
+func (c *channels[K, V]) remove(key K) chan<- V {
+ c.Lock()
+ defer c.Unlock()
+
+ if ch, ok := c.chans[key]; ok {
+ delete(c.chans, key)
+ return ch
+ }
+
+ return nil
+}
+
+// clear close all the channels and remove them from the map
+func (c *channels[K, V]) clear() {
+ c.Lock()
+ defer c.Unlock()
+
+ for k, ch := range c.chans {
+ close(ch)
+ delete(c.chans, k)
+ }
+}
diff --git a/client/channels_test.go b/client/channels_test.go
new file mode 100644
index 0000000..4465fed
--- /dev/null
+++ b/client/channels_test.go
@@ -0,0 +1,110 @@
+package client
+
+import (
+ "sync"
+ "sync/atomic"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestNotifyChannel(t *testing.T) {
+
+ chans := newChannels[int, int](10)
+
+ chanKey := 1
+ rx := chans.add(chanKey)
+
+ chanKey2 := 2
+ rx2 := chans.add(chanKey2)
+
+ var wg sync.WaitGroup
+
+ wg.Add(1)
+ go func() {
+ for i := 0; i < 50; i++ {
+ err := chans.notify(chanKey, i)
+ assert.Nil(t, err)
+ }
+
+ // drop the channel
+ tx := chans.remove(chanKey)
+ close(tx)
+ wg.Done()
+ }()
+
+ wg.Add(1)
+ go func() {
+ for i := 0; i < 50; i++ {
+ err := chans.notify(chanKey2, i)
+ assert.Nil(t, err)
+ }
+
+ // drop the channel
+ tx := chans.remove(chanKey2)
+ close(tx)
+ wg.Done()
+ }()
+
+ var receivedItem atomic.Int32
+
+ wg.Add(1)
+ go func() {
+ for range rx {
+ receivedItem.Add(1)
+ }
+ wg.Done()
+ }()
+
+ wg.Add(1)
+ go func() {
+ for range rx2 {
+ receivedItem.Add(1)
+ }
+ wg.Done()
+ }()
+
+ wg.Wait()
+ assert.Equal(t, receivedItem.Load(), int32(100))
+}
+
+func TestRemoveChannel(t *testing.T) {
+
+ chans := newChannels[int, int](1)
+
+ chanKey := 1
+ rx := chans.add(chanKey)
+
+ tx := chans.remove(chanKey)
+ assert.Equal(t, chans.length(), 0, "channels should be empty")
+
+ tx <- 3
+ val := <-rx
+ assert.Equal(t, val, 3)
+
+ tx = chans.remove(chanKey)
+ assert.Nil(t, tx)
+
+ err := chans.notify(chanKey, 1)
+ assert.NotNil(t, err)
+}
+
+func TestClearChannels(t *testing.T) {
+
+ chans := newChannels[int, int](1)
+
+ chanKey := 1
+ rx := chans.add(chanKey)
+
+ chans.clear()
+ assert.Equal(t, chans.length(), 0, "channels should be empty")
+
+ _, ok := <-rx
+ assert.False(t, ok, "chan closed")
+
+ tx := chans.remove(chanKey)
+ assert.Nil(t, tx)
+
+ err := chans.notify(chanKey, 1)
+ assert.NotNil(t, err)
+}
diff --git a/client/client.go b/client/client.go
new file mode 100644
index 0000000..7ad64fd
--- /dev/null
+++ b/client/client.go
@@ -0,0 +1,289 @@
+package client
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "math/rand"
+ "strconv"
+ "time"
+
+ "github.com/gorilla/websocket"
+ log "github.com/sirupsen/logrus"
+
+ "github.com/karyontech/karyon-go/message"
+)
+
+const (
+ // JsonRPCVersion defines the version of the JSON-RPC protocol being used.
+ JsonRPCVersion = "2.0"
+
+ // Default timeout for receiving requests from the server, in milliseconds.
+ DefaultTimeout = 3000
+)
+
+// RPCClientConfig holds the configuration settings for the RPC client.
+type RPCClientConfig struct {
+ Timeout int // Timeout for receiving requests from the server, in milliseconds.
+ Addr string // Address of the RPC server.
+}
+
+// RPCClient
+type RPCClient struct {
+ config RPCClientConfig
+ conn *websocket.Conn
+ request_chans channels[message.RequestID, message.Response]
+ subscriptions channels[message.SubscriptionID, json.RawMessage]
+ stop_signal chan struct{}
+}
+
+// NewRPCClient creates a new instance of RPCClient with the provided configuration.
+// It establishes a WebSocket connection to the RPC server.
+func NewRPCClient(config RPCClientConfig) (*RPCClient, error) {
+ conn, _, err := websocket.DefaultDialer.Dial(config.Addr, nil)
+ if err != nil {
+ return nil, err
+ }
+ log.Infof("Successfully connected to the server: %s", config.Addr)
+
+ if config.Timeout == 0 {
+ config.Timeout = DefaultTimeout
+ }
+
+ stop_signal := make(chan struct{}, 2)
+
+ client := &RPCClient{
+ conn: conn,
+ config: config,
+ request_chans: newChannels[message.RequestID, message.Response](1),
+ subscriptions: newChannels[message.SubscriptionID, json.RawMessage](10),
+ stop_signal: stop_signal,
+ }
+
+ go func() {
+ if err := client.backgroundReceivingLoop(stop_signal); err != nil {
+ client.Close()
+ }
+ }()
+
+ return client, nil
+}
+
+// Close closes the underlying websocket connection and stop the receiving loop.
+func (client *RPCClient) Close() error {
+ log.Warn("Close the rpc client...")
+ client.stop_signal <- struct{}{}
+
+ err := client.conn.Close()
+ if err != nil {
+ log.WithError(err).Error("Close websocket connection")
+ }
+
+ client.request_chans.clear()
+ client.subscriptions.clear()
+ return nil
+}
+
+// Call sends an RPC call to the server with the specified method and parameters.
+// It returns the response from the server.
+func (client *RPCClient) Call(method string, params any) (*json.RawMessage, error) {
+ log.Tracef("Call -> method: %s, params: %v", method, params)
+ param_raw, err := json.Marshal(params)
+ if err != nil {
+ return nil, err
+ }
+
+ response, err := client.sendRequest(method, param_raw)
+ if err != nil {
+ return nil, err
+ }
+
+ return response.Result, nil
+}
+
+// Subscribe sends a subscription request to the server with the specified method and parameters.
+// It returns the subscription ID and a channel to receive notifications.
+func (client *RPCClient) Subscribe(method string, params any) (message.SubscriptionID, <-chan json.RawMessage, error) {
+ log.Tracef("Sbuscribe -> method: %s, params: %v", method, params)
+ param_raw, err := json.Marshal(params)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ response, err := client.sendRequest(method, param_raw)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if response.Result == nil {
+ return 0, nil, fmt.Errorf("Invalid response result")
+ }
+
+ var subID message.SubscriptionID
+ err = json.Unmarshal(*response.Result, &subID)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ ch := client.subscriptions.add(subID)
+
+ return subID, ch, nil
+}
+
+// Unsubscribe sends an unsubscription request to the server to cancel the given subscription.
+func (client *RPCClient) Unsubscribe(method string, subID message.SubscriptionID) error {
+ log.Tracef("Unsubscribe -> method: %s, subID: %d", method, subID)
+ subIDJSON, err := json.Marshal(subID)
+ if err != nil {
+ return err
+ }
+
+ _, err = client.sendRequest(method, subIDJSON)
+ if err != nil {
+ return err
+ }
+
+ // on success remove the subscription from the map
+ client.subscriptions.remove(subID)
+
+ return nil
+}
+
+// backgroundReceivingLoop starts reading new messages from the underlying connection.
+func (client *RPCClient) backgroundReceivingLoop(stop_signal <-chan struct{}) error {
+ log.Debug("Background loop started")
+ for {
+ select {
+ case <-stop_signal:
+ log.Warn("Stopping background receiving loop: received stop signal")
+ return nil
+ default:
+ _, msg, err := client.conn.ReadMessage()
+ if err != nil {
+ log.WithError(err).Error("Receive a new msg")
+ return err
+ }
+
+ err = client.handleNewMsg(msg)
+ if err != nil {
+ log.WithError(err).Error("Handle a new received msg")
+ }
+ }
+ }
+}
+
+// handleNewMsg attempts to decode the received message into either a Response
+// or Notification struct.
+func (client *RPCClient) handleNewMsg(msg []byte) error {
+ // try to decode the msg into message.Response
+ response := message.Response{}
+ decoder := json.NewDecoder(bytes.NewReader(msg))
+ decoder.DisallowUnknownFields()
+ if err := decoder.Decode(&response); err == nil {
+ if response.ID == nil {
+ return fmt.Errorf("Response doesn't have an id")
+ }
+
+ if v := client.request_chans.remove(*response.ID); v != nil {
+ v <- response
+ }
+
+ return nil
+ }
+
+ // try to decode the msg into message.Notification
+ notification := message.Notification{}
+ if err := json.Unmarshal(msg, &notification); err == nil {
+
+ notificationResult := message.NotificationResult{}
+ if err := json.Unmarshal(*notification.Params, &notificationResult); err != nil {
+ return fmt.Errorf("Failed to unmarshal notification params: %w", err)
+ }
+
+ err := client.subscriptions.notify(
+ notificationResult.Subscription,
+ *notificationResult.Result,
+ )
+ if err != nil {
+ return fmt.Errorf("Notify a subscriber: %w", err)
+ }
+
+ log.Debugf("<-- %s", notification.String())
+
+ return nil
+ }
+
+ return fmt.Errorf("Receive unexpected msg: %s", msg)
+}
+
+// sendRequest sends a request and wait the response
+func (client *RPCClient) sendRequest(method string, params []byte) (message.Response, error) {
+ id := strconv.Itoa(rand.Int())
+ params_raw := json.RawMessage(params)
+
+ req := message.Request{
+ JSONRPC: JsonRPCVersion,
+ ID: id,
+ Method: method,
+ Params: &params_raw,
+ }
+
+ response := message.Response{}
+
+ reqJSON, err := json.Marshal(req)
+ if err != nil {
+ return response, err
+ }
+
+ err = client.conn.WriteMessage(websocket.TextMessage, []byte(string(reqJSON)))
+ if err != nil {
+ return response, err
+ }
+
+ log.Debugf("--> %s", req.String())
+
+ req_chan := client.request_chans.add(id)
+
+ response, err = client.waitResponse(req_chan)
+ if err != nil {
+ log.WithError(err).Errorf("Receive a response from the server")
+ client.request_chans.remove(id)
+ return response, err
+ }
+
+ err = validateResponse(&response, id)
+ if err != nil {
+ return response, err
+ }
+
+ log.Debugf("<-- %s", response.String())
+
+ return response, nil
+}
+
+// waitResponse waits the response, it fails and return error if it exceed the timeout
+func (client *RPCClient) waitResponse(ch <-chan message.Response) (message.Response, error) {
+ response := message.Response{}
+ select {
+ case response = <-ch:
+ return response, nil
+ case <-time.After(time.Duration(client.config.Timeout) * time.Millisecond):
+ return response, fmt.Errorf("Timeout error")
+ }
+}
+
+// validateResponse Checks the error field and whether the request id is the
+// same as the response id
+func validateResponse(res *message.Response, reqID message.RequestID) error {
+ if res.Error != nil {
+ return fmt.Errorf("Receive An Error: %s", res.Error.String())
+ }
+
+ if res.ID != nil {
+ if *res.ID != reqID {
+ return fmt.Errorf("Invalid response id")
+ }
+ }
+
+ return nil
+}