diff options
| author | hozan23 <hozan23@karyontech.net> | 2024-06-13 06:02:24 +0200 | 
|---|---|---|
| committer | hozan23 <hozan23@karyontech.net> | 2024-06-13 06:02:24 +0200 | 
| commit | 8c2d37e093ca64d591fc0aec15a7e2ed424b2e47 (patch) | |
| tree | fd9bc62e29087a18e7eb4bdd0a1f587ba63e5dd5 /jsonrpc | |
| parent | a338905a7f8a2206161cc15f07bda872b9bfc09c (diff) | |
use message dispatcher to process responses and notifications & spread out comments
Diffstat (limited to 'jsonrpc')
| -rw-r--r-- | jsonrpc/client/client.go | 287 | ||||
| -rw-r--r-- | jsonrpc/client/message_dispatcher.go | 75 | ||||
| -rw-r--r-- | jsonrpc/client/message_dispatcher_test.go | 101 | ||||
| -rw-r--r-- | jsonrpc/message/message.go | 69 | 
4 files changed, 532 insertions, 0 deletions
diff --git a/jsonrpc/client/client.go b/jsonrpc/client/client.go new file mode 100644 index 0000000..a34828a --- /dev/null +++ b/jsonrpc/client/client.go @@ -0,0 +1,287 @@ +package client + +import ( +	"bytes" +	"encoding/json" +	"fmt" +	"math/rand" +	"strconv" +	"time" + +	"github.com/gorilla/websocket" +	log "github.com/sirupsen/logrus" + +	"github.com/karyontech/karyon-go/jsonrpc/message" +) + +const ( +	// JsonRPCVersion Defines the version of the JSON-RPC protocol being used. +	JsonRPCVersion = "2.0" + +	// Default timeout for receiving requests from the server, in milliseconds. +	DefaultTimeout = 3000 +) + +// RPCClientConfig Holds the configuration settings for the RPC client. +type RPCClientConfig struct { +	Timeout int    // Timeout for receiving requests from the server, in milliseconds. +	Addr    string // Address of the RPC server. +} + +// RPCClient RPC Client +type RPCClient struct { +	config      RPCClientConfig +	conn        *websocket.Conn +	requests    messageDispatcher[message.RequestID, message.Response] +	subscriber  messageDispatcher[message.SubscriptionID, json.RawMessage] +	stop_signal chan struct{} +} + +// NewRPCClient Creates a new instance of RPCClient with the provided configuration. +// It establishes a WebSocket connection to the RPC server. +func NewRPCClient(config RPCClientConfig) (*RPCClient, error) { +	conn, _, err := websocket.DefaultDialer.Dial(config.Addr, nil) +	if err != nil { +		return nil, err +	} +	log.Infof("Successfully connected to the server: %s", config.Addr) + +	if config.Timeout == 0 { +		config.Timeout = DefaultTimeout +	} + +	stop_signal := make(chan struct{}, 2) + +	client := &RPCClient{ +		conn:        conn, +		config:      config, +		requests:    newMessageDispatcher[message.RequestID, message.Response](1), +		subscriber:  newMessageDispatcher[message.SubscriptionID, json.RawMessage](10), +		stop_signal: stop_signal, +	} + +	go func() { +		if err := client.backgroundReceivingLoop(stop_signal); err != nil { +			client.Close() +		} +	}() + +	return client, nil +} + +// Close Closes the underlying websocket connection and stop the receiving loop. +func (client *RPCClient) Close() { +	log.Warn("Close the rpc client...") +	// Send stop signal to the background receiving loop +	client.stop_signal <- struct{}{} + +	// Close the underlying websocket connection +	err := client.conn.Close() +	if err != nil { +		log.WithError(err).Error("Close websocket connection") +	} + +	client.requests.clear() +	client.subscriber.clear() +} + +// Call Sends an RPC call to the server with the specified method and +// parameters, and returns the response. +func (client *RPCClient) Call(method string, params any) (*json.RawMessage, error) { +	log.Tracef("Call -> method: %s, params: %v", method, params) +	response, err := client.sendRequest(method, params) +	if err != nil { +		return nil, err +	} + +	return response.Result, nil +} + +// Subscribe Sends a subscription request to the server with the specified +// method and parameters, and it returns the subscription ID and the channel to +// receive notifications. +func (client *RPCClient) Subscribe(method string, params any) (message.SubscriptionID, <-chan json.RawMessage, error) { +	log.Tracef("Sbuscribe ->  method: %s, params: %v", method, params) +	response, err := client.sendRequest(method, params) +	if err != nil { +		return 0, nil, err +	} + +	if response.Result == nil { +		return 0, nil, fmt.Errorf("Invalid response result") +	} + +	var subID message.SubscriptionID +	err = json.Unmarshal(*response.Result, &subID) +	if err != nil { +		return 0, nil, err +	} + +	// Register a new subscription +	sub := client.subscriber.register(subID) + +	return subID, sub, nil +} + +// Unsubscribe Sends an unsubscription request to the server to cancel the +// given subscription. +func (client *RPCClient) Unsubscribe(method string, subID message.SubscriptionID) error { +	log.Tracef("Unsubscribe -> method: %s, subID: %d", method, subID) +	_, err := client.sendRequest(method, subID) +	if err != nil { +		return err +	} + +	// On success unregister the subscription channel +	client.subscriber.unregister(subID) + +	return nil +} + +// backgroundReceivingLoop Starts reading new messages from the underlying connection. +func (client *RPCClient) backgroundReceivingLoop(stop_signal <-chan struct{}) error { + +	log.Debug("Background loop started") + +	new_msg_ch := make(chan []byte) +	receive_err_ch := make(chan error) + +	// Start listing for new messages +	go func() { +		for { +			_, msg, err := client.conn.ReadMessage() +			if err != nil { +				receive_err_ch <- err +				return +			} +			new_msg_ch <- msg +		} +	}() + +	for { +		select { +		case <-stop_signal: +			log.Warn("Stopping background receiving loop: received stop signal") +			return nil +		case msg := <-new_msg_ch: +			err := client.handleNewMsg(msg) +			if err != nil { +				log.WithError(err).Error("Handle a new received msg") +			} +		case err := <-receive_err_ch: +			log.WithError(err).Error("Receive a new msg") +			return err +		} +	} +} + +// handleNewMsg Attempts to decode the received message into either a Response +// or Notification. +func (client *RPCClient) handleNewMsg(msg []byte) error { +	// Check if the received message is of type Response +	response := message.Response{} +	decoder := json.NewDecoder(bytes.NewReader(msg)) +	decoder.DisallowUnknownFields() +	if err := decoder.Decode(&response); err == nil { +		if response.ID == nil { +			return fmt.Errorf("Response doesn't have an id") +		} + +		err := client.requests.disptach(*response.ID, response) +		if err != nil { +			return fmt.Errorf("Dispatch a response: %w", err) +		} + +		return nil +	} + +	// Check if the received message is of type Notification +	notification := message.Notification{} +	if err := json.Unmarshal(msg, ¬ification); err == nil { + +		ntRes := message.NotificationResult{} +		if err := json.Unmarshal(*notification.Params, &ntRes); err != nil { +			return fmt.Errorf("Failed to unmarshal notification params: %w", err) +		} + +		// Send the notification to the subscription +		err := client.subscriber.disptach(ntRes.Subscription, *ntRes.Result) +		if err != nil { +			return fmt.Errorf("Dispatch a notification: %w", err) +		} + +		log.Debugf("<-- %s", notification.String()) + +		return nil +	} + +	return fmt.Errorf("Receive unexpected msg: %s", msg) +} + +// sendRequest Sends a request and wait for the response +func (client *RPCClient) sendRequest(method string, params any) (message.Response, error) { +	response := message.Response{} + +	params_bytes, err := json.Marshal(params) +	if err != nil { +		return response, err +	} + +	params_raw := json.RawMessage(params_bytes) + +	// Generate a new id +	id := strconv.Itoa(rand.Int()) +	req := message.Request{ +		JSONRPC: JsonRPCVersion, +		ID:      id, +		Method:  method, +		Params:  ¶ms_raw, +	} + +	reqJSON, err := json.Marshal(req) +	if err != nil { +		return response, err +	} + +	err = client.conn.WriteMessage(websocket.TextMessage, []byte(string(reqJSON))) +	if err != nil { +		return response, err +	} + +	log.Debugf("--> %s", req.String()) + +	rx_ch := client.requests.register(id) +	defer client.requests.unregister(id) + +	// Waits the response, it fails and return error if it exceed the timeout +	select { +	case response = <-rx_ch: +	case <-time.After(time.Duration(client.config.Timeout) * time.Millisecond): +		return response, fmt.Errorf("Timeout error") +	} + +	err = validateResponse(&response, id) +	if err != nil { +		return response, err +	} + +	log.Debugf("<-- %s", response.String()) + +	return response, nil +} + +// validateResponse Checks the error field and whether the request id is the +// same as the response id +func validateResponse(res *message.Response, reqID message.RequestID) error { +	if res.Error != nil { +		return fmt.Errorf("Receive An Error: %s", res.Error.String()) +	} + +	if res.ID != nil { +		if *res.ID != reqID { +			return fmt.Errorf("Invalid response id") +		} +	} + +	return nil +} diff --git a/jsonrpc/client/message_dispatcher.go b/jsonrpc/client/message_dispatcher.go new file mode 100644 index 0000000..6484484 --- /dev/null +++ b/jsonrpc/client/message_dispatcher.go @@ -0,0 +1,75 @@ +package client + +import ( +	"fmt" +	"sync" +) + +// messageDispatcher Is a generic structure that holds a map of keys and +// channels, and it is protected by mutex +type messageDispatcher[K comparable, V any] struct { +	sync.Mutex +	chans      map[K]chan<- V +	bufferSize int +} + +// newMessageDispatcher Creates a new messageDispatcher +func newMessageDispatcher[K comparable, V any](bufferSize int) messageDispatcher[K, V] { +	chans := make(map[K]chan<- V) +	return messageDispatcher[K, V]{ +		chans:      chans, +		bufferSize: bufferSize, +	} +} + +// register Registers a new channel with a given key. It returns the receiving channel. +func (c *messageDispatcher[K, V]) register(key K) <-chan V { +	c.Lock() +	defer c.Unlock() + +	ch := make(chan V, c.bufferSize) +	c.chans[key] = ch +	return ch +} + +// length Returns the number of channels +func (c *messageDispatcher[K, V]) length() int { +	c.Lock() +	defer c.Unlock() + +	return len(c.chans) +} + +// disptach Disptaches the msg to the channel with the given key +func (c *messageDispatcher[K, V]) disptach(key K, msg V) error { +	c.Lock() +	defer c.Unlock() + +	if ch, ok := c.chans[key]; ok { +		ch <- msg +		return nil +	} + +	return fmt.Errorf("Channel not found") +} + +// unregister Unregisters the channel with the provided key +func (c *messageDispatcher[K, V]) unregister(key K) { +	c.Lock() +	defer c.Unlock() +	if ch, ok := c.chans[key]; ok { +		close(ch) +		delete(c.chans, key) +	} +} + +// clear Closes all the channels and remove them from the map +func (c *messageDispatcher[K, V]) clear() { +	c.Lock() +	defer c.Unlock() + +	for k, ch := range c.chans { +		close(ch) +		delete(c.chans, k) +	} +} diff --git a/jsonrpc/client/message_dispatcher_test.go b/jsonrpc/client/message_dispatcher_test.go new file mode 100644 index 0000000..7cc1366 --- /dev/null +++ b/jsonrpc/client/message_dispatcher_test.go @@ -0,0 +1,101 @@ +package client + +import ( +	// "sync" +	// "sync/atomic" + +	"sync" +	"sync/atomic" +	"testing" + +	"github.com/stretchr/testify/assert" +) + +func TestDispatchToChannel(t *testing.T) { + +	messageDispatcher := newMessageDispatcher[int, int](10) + +	chanKey := 1 +	rx := messageDispatcher.register(chanKey) + +	chanKey2 := 2 +	rx2 := messageDispatcher.register(chanKey2) + +	var wg sync.WaitGroup + +	wg.Add(1) +	go func() { +		for i := 0; i < 50; i++ { +			err := messageDispatcher.disptach(chanKey, i) +			assert.Nil(t, err) +		} + +		messageDispatcher.unregister(chanKey) +		wg.Done() +	}() + +	wg.Add(1) +	go func() { +		for i := 0; i < 50; i++ { +			err := messageDispatcher.disptach(chanKey2, i) +			assert.Nil(t, err) +		} + +		messageDispatcher.unregister(chanKey2) +		wg.Done() +	}() + +	var receivedItem atomic.Int32 + +	wg.Add(1) +	go func() { +		for range rx { +			receivedItem.Add(1) +		} +		wg.Done() +	}() + +	wg.Add(1) +	go func() { +		for range rx2 { +			receivedItem.Add(1) +		} +		wg.Done() +	}() + +	wg.Wait() +	assert.Equal(t, receivedItem.Load(), int32(100)) +} + +func TestUnregisterChannel(t *testing.T) { +	messageDispatcher := newMessageDispatcher[int, int](1) + +	chanKey := 1 +	rx := messageDispatcher.register(chanKey) + +	messageDispatcher.unregister(chanKey) +	assert.Equal(t, messageDispatcher.length(), 0, "channels should be empty") + +	_, ok := <-rx +	assert.False(t, ok, "chan closed") + +	err := messageDispatcher.disptach(chanKey, 1) +	assert.NotNil(t, err) +} + +func TestClearChannels(t *testing.T) { + +	messageDispatcher := newMessageDispatcher[int, int](1) + +	chanKey := 1 +	rx := messageDispatcher.register(chanKey) + +	messageDispatcher.clear() +	assert.Equal(t, messageDispatcher.length(), 0, "channels should be empty") + +	_, ok := <-rx +	assert.False(t, ok, "chan closed") + +	err := messageDispatcher.disptach(chanKey, 1) +	assert.NotNil(t, err) +} diff --git a/jsonrpc/message/message.go b/jsonrpc/message/message.go new file mode 100644 index 0000000..ff68da2 --- /dev/null +++ b/jsonrpc/message/message.go @@ -0,0 +1,69 @@ +package message + +import ( +	"encoding/json" +	"fmt" +) + +// RequestID is used to identify a request. +type RequestID = string + +// SubscriptionID is used to identify a subscription. +type SubscriptionID = int + +// Request represents a JSON-RPC request message. +// It includes the JSON-RPC version, an identifier for the request, the method +// to be invoked, and optional parameters. +type Request struct { +	JSONRPC string           `json:"jsonrpc"`          // JSON-RPC version, typically "2.0". +	ID      RequestID        `json:"id"`               // Unique identifier for the request, can be a number or a string. +	Method  string           `json:"method"`           // The name of the method to be invoked. +	Params  *json.RawMessage `json:"params,omitempty"` // Optional parameters for the method. +} + +// Response represents a JSON-RPC response message. +// It includes the JSON-RPC version, an identifier matching the request, the result of the request, and an optional error. +type Response struct { +	JSONRPC string           `json:"jsonrpc"`          // JSON-RPC version, typically "2.0". +	ID      *RequestID       `json:"id,omitempty"`     // Unique identifier matching the request ID, can be null for notifications. +	Result  *json.RawMessage `json:"result,omitempty"` // Result of the request if it was successful. +	Error   *Error           `json:"error,omitempty"`  // Error object if the request failed. +} + +// Notification represents a JSON-RPC notification message. +type Notification struct { +	JSONRPC string           `json:"jsonrpc"`          // JSON-RPC version, typically "2.0". +	Method  string           `json:"method"`           // The name of the method to be invoked. +	Params  *json.RawMessage `json:"params,omitempty"` // Optional parameters for the method. +} + +// NotificationResult represents the result of a subscription notification. +// It includes the result and the subscription ID that triggered the notification. +type NotificationResult struct { +	Result       *json.RawMessage `json:"result,omitempty"` // Result data of the notification. +	Subscription SubscriptionID   `json:"subscription"`     // ID of the subscription that triggered the notification. +} + +// Error represents an error in a JSON-RPC response. +// It includes an error code, a message, and optional additional data. +type Error struct { +	Code    int              `json:"code"`           // Error code indicating the type of error. +	Message string           `json:"message"`        // Human-readable error message. +	Data    *json.RawMessage `json:"data,omitempty"` // Optional additional data about the error. +} + +func (req *Request) String() string { +	return fmt.Sprintf("{JSONRPC: %s, ID: %s, METHOD: %s, PARAMS: %s}", req.JSONRPC, req.ID, req.Method, *req.Params) +} + +func (res *Response) String() string { +	return fmt.Sprintf("{JSONRPC: %s, ID: %s, RESULT: %s, ERROR: %v}", res.JSONRPC, *res.ID, *res.Result, res.Error) +} + +func (nt *Notification) String() string { +	return fmt.Sprintf("{JSONRPC: %s, METHOD: %s, PARAMS: %s}", nt.JSONRPC, nt.Method, *nt.Params) +} + +func (err *Error) String() string { +	return fmt.Sprintf("{CODE: %d, MESSAGE: %s, DATA: %b}", err.Code, err.Message, err.Data) +}  | 
