From 9249694a05335b6a4ff70a1168bd5ebb263f4b71 Mon Sep 17 00:00:00 2001 From: Dmytro Steblyna Date: Thu, 22 May 2025 12:01:42 +0100 Subject: [PATCH] feat: filter channels by status, return channels event on auth, use usified method to send updates --- channel.go | 15 ++++-- docs/API.md | 51 ++++++++++++++++--- handlers.go | 4 +- handlers_test.go | 102 +++++++++++++++++++++++++++++++++++++ ws.go | 129 +++++++++++++++++++++-------------------------- 5 files changed, 218 insertions(+), 83 deletions(-) diff --git a/channel.go b/channel.go index e585acc..6fb9196 100644 --- a/channel.go +++ b/channel.go @@ -76,17 +76,22 @@ func GetChannelByID(tx *gorm.DB, channelID string) (*Channel, error) { return &channel, nil } -// getChannelsForParticipant finds all channels for a participant -func getChannelsForParticipant(tx *gorm.DB, participant string) ([]Channel, error) { +// getChannelsByParticipant finds all channels for a participant +func getChannelsByParticipant(tx *gorm.DB, participant string, status string) ([]Channel, error) { var channels []Channel - if err := tx.Where("participant = ?", - participant).Order("created_at DESC").Find(&channels).Error; err != nil { + q := tx.Where("participant = ?", participant) + if status != "" { + q = q.Where("status = ?", status) + } + + if err := q.Order("created_at DESC").Find(&channels).Error; err != nil { return nil, fmt.Errorf("error finding channels for participant %s: %w", participant, err) } + return channels, nil } -// CheckExistingChannels checks if there is an existing open channel on the same network between participant A and B +// CheckExistingChannels checks if there is an existing open channel on the same network between participant and broker func CheckExistingChannels(tx *gorm.DB, participantA, token string, chainID uint32) (*Channel, error) { var channel Channel err := tx.Where("participant = ? AND token = ? AND chain_id = ? AND status = ?", participantA, token, chainID, ChannelStatusOpen). diff --git a/docs/API.md b/docs/API.md index f085665..57c8089 100644 --- a/docs/API.md +++ b/docs/API.md @@ -591,15 +591,54 @@ Balance updates are sent as unsolicited server messages with the "bu" method: The balance update provides the latest balances for all assets in the participant's unified ledger, allowing clients to maintain an up-to-date view of available funds without explicitly requesting them. +### Open Channels + +The server automatically sends all open channels as a batch update to clients after successful authentication. + +```json +{ + "res": [1234567890123, "channels", [[ + { + "channel_id": "0xfedcba9876543210...", + "participant": "0x1234567890abcdef...", + "status": "open", + "token": "0xeeee567890abcdef...", + "amount": "100000", + "chain_id": 137, + "adjudicator": "0xAdjudicatorContractAddress...", + "challenge": 86400, + "nonce": 1, + "version": 2, + "created_at": "2023-05-01T12:00:00Z", + "updated_at": "2023-05-01T12:30:00Z" + }, + { + "channel_id": "0xabcdef1234567890...", + "participant": "0x1234567890abcdef...", + "status": "open", + "token": "0xeeee567890abcdef...", + "amount": "50000", + "chain_id": 42220, + "adjudicator": "0xAdjudicatorContractAddress...", + "challenge": 86400, + "nonce": 1, + "version": 3, + "created_at": "2023-04-15T10:00:00Z", + "updated_at": "2023-04-20T14:30:00Z" + } + ]], 1619123456789], + "sig": ["0xabcd1234..."] +} +``` + ### Channel Updates -The server automatically sends channel updates to clients in these scenarios: -1. After successful authentication (for all existing channels) -2. When a channel is created -3. When a channel's status changes (open, joined, closed) -4. When a channel is resized +For channel updates, the server sends them in these scenarios: +1. When a channel is created +2. When a channel's status changes (open, joined, closed) +3. When a channel is resized -Channel updates are sent as unsolicited server messages with the "cu" method: +Individual channel updates are sent as unsolicited server messages with the "cu" method: ```json { diff --git a/handlers.go b/handlers.go index 8743c2b..98c2379 100644 --- a/handlers.go +++ b/handlers.go @@ -872,6 +872,7 @@ func HandleCloseChannel(rpc *RPCMessage, db *gorm.DB, signer *Signer) (*RPCMessa // TODO: add filters, pagination, etc. func HandleGetChannels(rpc *RPCMessage, db *gorm.DB) (*RPCMessage, error) { var participant string + var status string if len(rpc.Req.Params) > 0 { paramsJSON, err := json.Marshal(rpc.Req.Params[0]) @@ -879,6 +880,7 @@ func HandleGetChannels(rpc *RPCMessage, db *gorm.DB) (*RPCMessage, error) { var params map[string]string if err := json.Unmarshal(paramsJSON, ¶ms); err == nil { participant = params["participant"] + status = params["status"] } } } @@ -887,7 +889,7 @@ func HandleGetChannels(rpc *RPCMessage, db *gorm.DB) (*RPCMessage, error) { return nil, errors.New("missing participant parameter") } - channels, err := getChannelsForParticipant(db, participant) + channels, err := getChannelsByParticipant(db, participant, status) if err != nil { return nil, fmt.Errorf("failed to get channels: %w", err) } diff --git a/handlers_test.go b/handlers_test.go index 58d5d9f..0839396 100644 --- a/handlers_test.go +++ b/handlers_test.go @@ -605,6 +605,108 @@ func TestHandleGetChannels(t *testing.T) { assert.NotEmpty(t, ch.UpdatedAt, "UpdatedAt should not be empty") } + // Test with status filter for "open" channels + openStatusParams := map[string]string{ + "participant": participantAddr, + "status": string(ChannelStatusOpen), + } + openStatusParamsJSON, err := json.Marshal(openStatusParams) + require.NoError(t, err) + + openStatusRequest := &RPCMessage{ + Req: &RPCData{ + RequestID: 456, + Method: "get_channels", + Params: []any{json.RawMessage(openStatusParamsJSON)}, + Timestamp: uint64(time.Now().Unix()), + }, + } + + reqBytes, err = json.Marshal(openStatusRequest.Req) + require.NoError(t, err) + signed, err = signer.Sign(reqBytes) + require.NoError(t, err) + openStatusRequest.Sig = []string{hexutil.Encode(signed)} + + openStatusResponse, err := HandleGetChannels(openStatusRequest, db) + require.NoError(t, err) + require.NotNil(t, openStatusResponse) + + // Extract and verify filtered channels + openChannels, ok := openStatusResponse.Res.Params[0].([]ChannelResponse) + require.True(t, ok, "Response parameter should be a slice of ChannelResponse") + assert.Len(t, openChannels, 1, "Should return only 1 open channel") + assert.Equal(t, "0xChannel1", openChannels[0].ChannelID, "Should return the open channel") + assert.Equal(t, ChannelStatusOpen, openChannels[0].Status, "Status should be open") + + // Test with status filter for "closed" channels + closedStatusParams := map[string]string{ + "participant": participantAddr, + "status": string(ChannelStatusClosed), + } + closedStatusParamsJSON, err := json.Marshal(closedStatusParams) + require.NoError(t, err) + + closedStatusRequest := &RPCMessage{ + Req: &RPCData{ + RequestID: 457, + Method: "get_channels", + Params: []any{json.RawMessage(closedStatusParamsJSON)}, + Timestamp: uint64(time.Now().Unix()), + }, + } + + reqBytes, err = json.Marshal(closedStatusRequest.Req) + require.NoError(t, err) + signed, err = signer.Sign(reqBytes) + require.NoError(t, err) + closedStatusRequest.Sig = []string{hexutil.Encode(signed)} + + closedStatusResponse, err := HandleGetChannels(closedStatusRequest, db) + require.NoError(t, err) + require.NotNil(t, closedStatusResponse) + + // Extract and verify filtered channels + closedChannels, ok := closedStatusResponse.Res.Params[0].([]ChannelResponse) + require.True(t, ok, "Response parameter should be a slice of ChannelResponse") + assert.Len(t, closedChannels, 1, "Should return only 1 closed channel") + assert.Equal(t, "0xChannel2", closedChannels[0].ChannelID, "Should return the closed channel") + assert.Equal(t, ChannelStatusClosed, closedChannels[0].Status, "Status should be closed") + + // Test with status filter for "joining" channels + joiningStatusParams := map[string]string{ + "participant": participantAddr, + "status": string(ChannelStatusJoining), + } + joiningStatusParamsJSON, err := json.Marshal(joiningStatusParams) + require.NoError(t, err) + + joiningStatusRequest := &RPCMessage{ + Req: &RPCData{ + RequestID: 458, + Method: "get_channels", + Params: []any{json.RawMessage(joiningStatusParamsJSON)}, + Timestamp: uint64(time.Now().Unix()), + }, + } + + reqBytes, err = json.Marshal(joiningStatusRequest.Req) + require.NoError(t, err) + signed, err = signer.Sign(reqBytes) + require.NoError(t, err) + joiningStatusRequest.Sig = []string{hexutil.Encode(signed)} + + joiningStatusResponse, err := HandleGetChannels(joiningStatusRequest, db) + require.NoError(t, err) + require.NotNil(t, joiningStatusResponse) + + // Extract and verify filtered channels + joiningChannels, ok := joiningStatusResponse.Res.Params[0].([]ChannelResponse) + require.True(t, ok, "Response parameter should be a slice of ChannelResponse") + assert.Len(t, joiningChannels, 1, "Should return only 1 joining channel") + assert.Equal(t, "0xChannel3", joiningChannels[0].ChannelID, "Should return the joining channel") + assert.Equal(t, ChannelStatusJoining, joiningChannels[0].Status, "Status should be joining") + // Test with missing participant parameter missingParamReq := &RPCMessage{ Req: &RPCData{ diff --git a/ws.go b/ws.go index f2e3735..aef9dd1 100644 --- a/ws.go +++ b/ws.go @@ -153,15 +153,13 @@ func (h *UnifiedWSHandler) HandleConnection(w http.ResponseWriter, r *http.Reque log.Printf("Participant authenticated: %s", address) // Send initial balance and channels information in form of balance and channel updates - h.sendBalanceUpdate(address) - channels, err := getChannelsForParticipant(h.db, address) + channels, err := getChannelsByParticipant(h.db, address, string(ChannelStatusOpen)) if err != nil { log.Printf("Error retrieving channels for participant %s: %v", address, err) } - for _, channel := range channels { - h.sendChannelUpdate(channel) - } + h.sendChannelsUpdate(address, channels) + h.sendBalanceUpdate(address) for { _, messageBytes, err := conn.ReadMessage() @@ -490,10 +488,9 @@ func (h *UnifiedWSHandler) sendErrorResponse(sender string, rpc *RPCMessage, con conn.SetWriteDeadline(time.Time{}) } -// sendErrorResponse creates and sends an error response to the client -func (h *UnifiedWSHandler) sendBalanceUpdate(sender string) { - balances, err := GetParticipantLedger(h.db, sender).GetBalances(sender) - response := CreateResponse(uint64(time.Now().UnixMilli()), "bu", []any{balances}, time.Now()) +// sendResponse sends a response with a given method and payload to a recipient +func (h *UnifiedWSHandler) sendResponse(recipient string, method string, payload []any, updateType string) { + response := CreateResponse(uint64(time.Now().UnixMilli()), method, payload, time.Now()) byteData, _ := json.Marshal(response.Req) signature, _ := h.signer.Sign(byteData) @@ -501,101 +498,91 @@ func (h *UnifiedWSHandler) sendBalanceUpdate(sender string) { responseData, err := json.Marshal(response) if err != nil { - log.Printf("Error marshaling error response: %v", err) + log.Printf("Error marshaling %s response: %v", updateType, err) return } h.connectionsMu.RLock() - recipientConn, exists := h.connections[sender] + recipientConn, exists := h.connections[recipient] h.connectionsMu.RUnlock() if exists { // Use NextWriter for safer message delivery w, err := recipientConn.NextWriter(websocket.TextMessage) if err != nil { - log.Printf("Error getting writer for balance update to %s: %v", sender, err) + log.Printf("Error getting writer for %s update to %s: %v", updateType, recipient, err) return } if _, err := w.Write(responseData); err != nil { - log.Printf("Error writing balance update to %s: %v", sender, err) + log.Printf("Error writing %s update to %s: %v", updateType, recipient, err) w.Close() return } if err := w.Close(); err != nil { - log.Printf("Error closing writer for balance update to %s: %v", sender, err) + log.Printf("Error closing writer for %s update to %s: %v", updateType, recipient, err) return } - // Increment sent message counter for each forwarded message + // Increment sent message counter h.metrics.MessageSent.Inc() - log.Printf("Successfully forwarded message to %s", sender) + log.Printf("Successfully sent %s update to %s", updateType, recipient) } else { - log.Printf("Recipient %s not connected", sender) + log.Printf("Recipient %s not connected", recipient) return } } -// sendErrorResponse creates and sends an error response to the client -func (h *UnifiedWSHandler) sendChannelUpdate(channel Channel) { - response := CreateResponse(uint64(time.Now().UnixMilli()), "cu", []any{ - ChannelResponse{ - ChannelID: channel.ChannelID, - Participant: channel.Participant, - Status: channel.Status, - Token: channel.Token, - Amount: big.NewInt(int64(channel.Amount)), - ChainID: channel.ChainID, - Adjudicator: channel.Adjudicator, - Challenge: channel.Challenge, - Nonce: channel.Nonce, - Version: channel.Version, - CreatedAt: channel.CreatedAt.Format(time.RFC3339), - UpdatedAt: channel.UpdatedAt.Format(time.RFC3339), - }, - }, time.Now()) - - byteData, _ := json.Marshal(response.Req) - signature, _ := h.signer.Sign(byteData) - response.Sig = []string{hexutil.Encode(signature)} - - responseData, err := json.Marshal(response) +// sendBalanceUpdate sends balance updates to the client +func (h *UnifiedWSHandler) sendBalanceUpdate(sender string) { + balances, err := GetParticipantLedger(h.db, sender).GetBalances(sender) if err != nil { - log.Printf("Error marshaling error response: %v", err) + log.Printf("Error getting balances for %s: %v", sender, err) return } + h.sendResponse(sender, "bu", []any{balances}, "balance") +} - h.connectionsMu.RLock() - recipientConn, exists := h.connections[channel.Participant] - h.connectionsMu.RUnlock() - if exists { - // Use NextWriter for safer message delivery - w, err := recipientConn.NextWriter(websocket.TextMessage) - if err != nil { - log.Printf("Error getting writer for balance update to %s: %v", channel.Participant, err) - return - } - - if _, err := w.Write(responseData); err != nil { - log.Printf("Error writing balance update to %s: %v", channel.Participant, err) - w.Close() - return - } - - if err := w.Close(); err != nil { - log.Printf("Error closing writer for balance update to %s: %v", channel.Participant, err) - return - } - - // Increment sent message counter for each forwarded message - h.metrics.MessageSent.Inc() +// sendChannelsUpdate sends multiple channels updates to the client +func (h *UnifiedWSHandler) sendChannelsUpdate(address string, channels []Channel) { + resp := []ChannelResponse{} + for _, ch := range channels { + resp = append(resp, ChannelResponse{ + ChannelID: ch.ChannelID, + Participant: ch.Participant, + Status: ch.Status, + Token: ch.Token, + Amount: big.NewInt(int64(ch.Amount)), + ChainID: ch.ChainID, + Adjudicator: ch.Adjudicator, + Challenge: ch.Challenge, + Nonce: ch.Nonce, + Version: ch.Version, + CreatedAt: ch.CreatedAt.Format(time.RFC3339), + UpdatedAt: ch.UpdatedAt.Format(time.RFC3339), + }) + } + h.sendResponse(address, "channels", []any{resp}, "channels") +} - log.Printf("Successfully forwarded message to %s", channel.Participant) - } else { - log.Printf("Recipient %s not connected", channel.Participant) - return - } +// sendChannelUpdate sends a single channel update to the client +func (h *UnifiedWSHandler) sendChannelUpdate(channel Channel) { + channelResponse := ChannelResponse{ + ChannelID: channel.ChannelID, + Participant: channel.Participant, + Status: channel.Status, + Token: channel.Token, + Amount: big.NewInt(int64(channel.Amount)), + ChainID: channel.ChainID, + Adjudicator: channel.Adjudicator, + Challenge: channel.Challenge, + Nonce: channel.Nonce, + Version: channel.Version, + CreatedAt: channel.CreatedAt.Format(time.RFC3339), + UpdatedAt: channel.UpdatedAt.Format(time.RFC3339), + } + h.sendResponse(channel.Participant, "cu", []any{channelResponse}, "channel") } // CloseAllConnections closes all open WebSocket connections during shutdown