diff options
author | hozan23 <hozan23@karyontech.net> | 2024-05-31 02:17:56 +0200 |
---|---|---|
committer | hozan23 <hozan23@karyontech.net> | 2024-05-31 02:17:56 +0200 |
commit | fa0b0efc14f84ff87789cabe0010f3240245407c (patch) | |
tree | 6b0b40e81a7b589511a0bd5cd0be7ab785ec5c96 /client |
init commit
Diffstat (limited to 'client')
-rw-r--r-- | client/channels.go | 78 | ||||
-rw-r--r-- | client/channels_test.go | 110 | ||||
-rw-r--r-- | client/client.go | 289 |
3 files changed, 477 insertions, 0 deletions
diff --git a/client/channels.go b/client/channels.go new file mode 100644 index 0000000..673e366 --- /dev/null +++ b/client/channels.go @@ -0,0 +1,78 @@ +package client + +import ( + "fmt" + "sync" +) + +// channels is a generic structure that holds a map of keys and channels. +// It is protected by mutex +type channels[K comparable, V any] struct { + sync.Mutex + chans map[K]chan<- V + bufferSize int +} + +// newChannels creates a new channels +func newChannels[K comparable, V any](bufferSize int) channels[K, V] { + chans := make(map[K]chan<- V) + return channels[K, V]{ + chans: chans, + bufferSize: bufferSize, + } +} + +// add adds a new channel and returns the receiving channel +func (c *channels[K, V]) add(key K) <-chan V { + c.Lock() + defer c.Unlock() + + ch := make(chan V, c.bufferSize) + c.chans[key] = ch + return ch +} + +// length returns the number of channels +func (c *channels[K, V]) length() int { + c.Lock() + defer c.Unlock() + + return len(c.chans) +} + +// notify notifies the channel with the given key +func (c *channels[K, V]) notify(key K, msg V) error { + c.Lock() + defer c.Unlock() + + if ch, ok := c.chans[key]; ok { + ch <- msg + return nil + } + + return fmt.Errorf("Channel not found") +} + +// remove removes and returns the channel. +func (c *channels[K, V]) remove(key K) chan<- V { + c.Lock() + defer c.Unlock() + + if ch, ok := c.chans[key]; ok { + delete(c.chans, key) + return ch + } + + return nil +} + +// clear close all the channels and remove them from the map +func (c *channels[K, V]) clear() { + c.Lock() + defer c.Unlock() + + for k, ch := range c.chans { + close(ch) + delete(c.chans, k) + } +} diff --git a/client/channels_test.go b/client/channels_test.go new file mode 100644 index 0000000..4465fed --- /dev/null +++ b/client/channels_test.go @@ -0,0 +1,110 @@ +package client + +import ( + "sync" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNotifyChannel(t *testing.T) { + + chans := newChannels[int, int](10) + + chanKey := 1 + rx := chans.add(chanKey) + + chanKey2 := 2 + rx2 := chans.add(chanKey2) + + var wg sync.WaitGroup + + wg.Add(1) + go func() { + for i := 0; i < 50; i++ { + err := chans.notify(chanKey, i) + assert.Nil(t, err) + } + + // drop the channel + tx := chans.remove(chanKey) + close(tx) + wg.Done() + }() + + wg.Add(1) + go func() { + for i := 0; i < 50; i++ { + err := chans.notify(chanKey2, i) + assert.Nil(t, err) + } + + // drop the channel + tx := chans.remove(chanKey2) + close(tx) + wg.Done() + }() + + var receivedItem atomic.Int32 + + wg.Add(1) + go func() { + for range rx { + receivedItem.Add(1) + } + wg.Done() + }() + + wg.Add(1) + go func() { + for range rx2 { + receivedItem.Add(1) + } + wg.Done() + }() + + wg.Wait() + assert.Equal(t, receivedItem.Load(), int32(100)) +} + +func TestRemoveChannel(t *testing.T) { + + chans := newChannels[int, int](1) + + chanKey := 1 + rx := chans.add(chanKey) + + tx := chans.remove(chanKey) + assert.Equal(t, chans.length(), 0, "channels should be empty") + + tx <- 3 + val := <-rx + assert.Equal(t, val, 3) + + tx = chans.remove(chanKey) + assert.Nil(t, tx) + + err := chans.notify(chanKey, 1) + assert.NotNil(t, err) +} + +func TestClearChannels(t *testing.T) { + + chans := newChannels[int, int](1) + + chanKey := 1 + rx := chans.add(chanKey) + + chans.clear() + assert.Equal(t, chans.length(), 0, "channels should be empty") + + _, ok := <-rx + assert.False(t, ok, "chan closed") + + tx := chans.remove(chanKey) + assert.Nil(t, tx) + + err := chans.notify(chanKey, 1) + assert.NotNil(t, err) +} diff --git a/client/client.go b/client/client.go new file mode 100644 index 0000000..7ad64fd --- /dev/null +++ b/client/client.go @@ -0,0 +1,289 @@ +package client + +import ( + "bytes" + "encoding/json" + "fmt" + "math/rand" + "strconv" + "time" + + "github.com/gorilla/websocket" + log "github.com/sirupsen/logrus" + + "github.com/karyontech/karyon-go/message" +) + +const ( + // JsonRPCVersion defines the version of the JSON-RPC protocol being used. + JsonRPCVersion = "2.0" + + // Default timeout for receiving requests from the server, in milliseconds. + DefaultTimeout = 3000 +) + +// RPCClientConfig holds the configuration settings for the RPC client. +type RPCClientConfig struct { + Timeout int // Timeout for receiving requests from the server, in milliseconds. + Addr string // Address of the RPC server. +} + +// RPCClient +type RPCClient struct { + config RPCClientConfig + conn *websocket.Conn + request_chans channels[message.RequestID, message.Response] + subscriptions channels[message.SubscriptionID, json.RawMessage] + stop_signal chan struct{} +} + +// NewRPCClient creates a new instance of RPCClient with the provided configuration. +// It establishes a WebSocket connection to the RPC server. +func NewRPCClient(config RPCClientConfig) (*RPCClient, error) { + conn, _, err := websocket.DefaultDialer.Dial(config.Addr, nil) + if err != nil { + return nil, err + } + log.Infof("Successfully connected to the server: %s", config.Addr) + + if config.Timeout == 0 { + config.Timeout = DefaultTimeout + } + + stop_signal := make(chan struct{}, 2) + + client := &RPCClient{ + conn: conn, + config: config, + request_chans: newChannels[message.RequestID, message.Response](1), + subscriptions: newChannels[message.SubscriptionID, json.RawMessage](10), + stop_signal: stop_signal, + } + + go func() { + if err := client.backgroundReceivingLoop(stop_signal); err != nil { + client.Close() + } + }() + + return client, nil +} + +// Close closes the underlying websocket connection and stop the receiving loop. +func (client *RPCClient) Close() error { + log.Warn("Close the rpc client...") + client.stop_signal <- struct{}{} + + err := client.conn.Close() + if err != nil { + log.WithError(err).Error("Close websocket connection") + } + + client.request_chans.clear() + client.subscriptions.clear() + return nil +} + +// Call sends an RPC call to the server with the specified method and parameters. +// It returns the response from the server. +func (client *RPCClient) Call(method string, params any) (*json.RawMessage, error) { + log.Tracef("Call -> method: %s, params: %v", method, params) + param_raw, err := json.Marshal(params) + if err != nil { + return nil, err + } + + response, err := client.sendRequest(method, param_raw) + if err != nil { + return nil, err + } + + return response.Result, nil +} + +// Subscribe sends a subscription request to the server with the specified method and parameters. +// It returns the subscription ID and a channel to receive notifications. +func (client *RPCClient) Subscribe(method string, params any) (message.SubscriptionID, <-chan json.RawMessage, error) { + log.Tracef("Sbuscribe -> method: %s, params: %v", method, params) + param_raw, err := json.Marshal(params) + if err != nil { + return 0, nil, err + } + + response, err := client.sendRequest(method, param_raw) + if err != nil { + return 0, nil, err + } + + if response.Result == nil { + return 0, nil, fmt.Errorf("Invalid response result") + } + + var subID message.SubscriptionID + err = json.Unmarshal(*response.Result, &subID) + if err != nil { + return 0, nil, err + } + + ch := client.subscriptions.add(subID) + + return subID, ch, nil +} + +// Unsubscribe sends an unsubscription request to the server to cancel the given subscription. +func (client *RPCClient) Unsubscribe(method string, subID message.SubscriptionID) error { + log.Tracef("Unsubscribe -> method: %s, subID: %d", method, subID) + subIDJSON, err := json.Marshal(subID) + if err != nil { + return err + } + + _, err = client.sendRequest(method, subIDJSON) + if err != nil { + return err + } + + // on success remove the subscription from the map + client.subscriptions.remove(subID) + + return nil +} + +// backgroundReceivingLoop starts reading new messages from the underlying connection. +func (client *RPCClient) backgroundReceivingLoop(stop_signal <-chan struct{}) error { + log.Debug("Background loop started") + for { + select { + case <-stop_signal: + log.Warn("Stopping background receiving loop: received stop signal") + return nil + default: + _, msg, err := client.conn.ReadMessage() + if err != nil { + log.WithError(err).Error("Receive a new msg") + return err + } + + err = client.handleNewMsg(msg) + if err != nil { + log.WithError(err).Error("Handle a new received msg") + } + } + } +} + +// handleNewMsg attempts to decode the received message into either a Response +// or Notification struct. +func (client *RPCClient) handleNewMsg(msg []byte) error { + // try to decode the msg into message.Response + response := message.Response{} + decoder := json.NewDecoder(bytes.NewReader(msg)) + decoder.DisallowUnknownFields() + if err := decoder.Decode(&response); err == nil { + if response.ID == nil { + return fmt.Errorf("Response doesn't have an id") + } + + if v := client.request_chans.remove(*response.ID); v != nil { + v <- response + } + + return nil + } + + // try to decode the msg into message.Notification + notification := message.Notification{} + if err := json.Unmarshal(msg, ¬ification); err == nil { + + notificationResult := message.NotificationResult{} + if err := json.Unmarshal(*notification.Params, ¬ificationResult); err != nil { + return fmt.Errorf("Failed to unmarshal notification params: %w", err) + } + + err := client.subscriptions.notify( + notificationResult.Subscription, + *notificationResult.Result, + ) + if err != nil { + return fmt.Errorf("Notify a subscriber: %w", err) + } + + log.Debugf("<-- %s", notification.String()) + + return nil + } + + return fmt.Errorf("Receive unexpected msg: %s", msg) +} + +// sendRequest sends a request and wait the response +func (client *RPCClient) sendRequest(method string, params []byte) (message.Response, error) { + id := strconv.Itoa(rand.Int()) + params_raw := json.RawMessage(params) + + req := message.Request{ + JSONRPC: JsonRPCVersion, + ID: id, + Method: method, + Params: ¶ms_raw, + } + + response := message.Response{} + + reqJSON, err := json.Marshal(req) + if err != nil { + return response, err + } + + err = client.conn.WriteMessage(websocket.TextMessage, []byte(string(reqJSON))) + if err != nil { + return response, err + } + + log.Debugf("--> %s", req.String()) + + req_chan := client.request_chans.add(id) + + response, err = client.waitResponse(req_chan) + if err != nil { + log.WithError(err).Errorf("Receive a response from the server") + client.request_chans.remove(id) + return response, err + } + + err = validateResponse(&response, id) + if err != nil { + return response, err + } + + log.Debugf("<-- %s", response.String()) + + return response, nil +} + +// waitResponse waits the response, it fails and return error if it exceed the timeout +func (client *RPCClient) waitResponse(ch <-chan message.Response) (message.Response, error) { + response := message.Response{} + select { + case response = <-ch: + return response, nil + case <-time.After(time.Duration(client.config.Timeout) * time.Millisecond): + return response, fmt.Errorf("Timeout error") + } +} + +// validateResponse Checks the error field and whether the request id is the +// same as the response id +func validateResponse(res *message.Response, reqID message.RequestID) error { + if res.Error != nil { + return fmt.Errorf("Receive An Error: %s", res.Error.String()) + } + + if res.ID != nil { + if *res.ID != reqID { + return fmt.Errorf("Invalid response id") + } + } + + return nil +} |