From a7ee1965c76c6f582e75e60cacc17567cb496882 Mon Sep 17 00:00:00 2001 From: Yilun Date: Tue, 17 Mar 2020 17:22:13 -0700 Subject: [PATCH] Will put message into buffer if client timeout This should help prevent client message lost when client conn accidentally drop off. Signed-off-by: Yilun --- api/websocket/messagebuffer/messagebuffer.go | 3 -- api/websocket/server/delayedchan.go | 47 ++++++++++++++++++++ api/websocket/server/relay.go | 38 +++++++++++++++- api/websocket/server/server.go | 45 +++++++++++-------- api/websocket/session/session.go | 29 +++++++++--- 5 files changed, 131 insertions(+), 31 deletions(-) create mode 100644 api/websocket/server/delayedchan.go diff --git a/api/websocket/messagebuffer/messagebuffer.go b/api/websocket/messagebuffer/messagebuffer.go index 40a781265..ccea2476d 100644 --- a/api/websocket/messagebuffer/messagebuffer.go +++ b/api/websocket/messagebuffer/messagebuffer.go @@ -22,9 +22,6 @@ func NewMessageBuffer() *MessageBuffer { // AddMessage adds a message to message buffer func (messageBuffer *MessageBuffer) AddMessage(clientID []byte, msg *pb.Relay) { - if msg.MaxHoldingSeconds == 0 { - return - } clientIDStr := hex.EncodeToString(clientID) messageBuffer.Lock() defer messageBuffer.Unlock() diff --git a/api/websocket/server/delayedchan.go b/api/websocket/server/delayedchan.go new file mode 100644 index 000000000..e5597bbe5 --- /dev/null +++ b/api/websocket/server/delayedchan.go @@ -0,0 +1,47 @@ +package server + +import ( + "time" +) + +type DelayedChan struct { + buffer chan *delayedValue + delay time.Duration +} + +type delayedValue struct { + value interface{} + releaseTime time.Time +} + +func NewDelayedChan(size int, delay time.Duration) *DelayedChan { + buffer := make(chan *delayedValue, size) + return &DelayedChan{ + buffer: buffer, + delay: delay, + } +} + +func (dc *DelayedChan) Push(v interface{}) bool { + dv := &delayedValue{ + value: v, + releaseTime: time.Now().Add(dc.delay), + } + select { + case dc.buffer <- dv: + return true + default: + return false + } +} + +func (dc *DelayedChan) Pop() (interface{}, bool) { + dv, ok := <-dc.buffer + if !ok { + return nil, false + } + if dv.releaseTime.After(time.Now()) { + time.Sleep(time.Until(dv.releaseTime)) + } + return dv.value, true +} diff --git a/api/websocket/server/relay.go b/api/websocket/server/relay.go index dabb63888..4ef1ec476 100644 --- a/api/websocket/server/relay.go +++ b/api/websocket/server/relay.go @@ -4,6 +4,7 @@ import ( "encoding/hex" "errors" "fmt" + "time" "github.com/gogo/protobuf/proto" "github.com/nknorg/nkn/pb" @@ -65,7 +66,7 @@ func (ws *WsServer) sendOutboundRelayMessage(srcAddrStrPtr *string, msg *pb.Outb func (ws *WsServer) sendInboundMessage(clientID string, inboundMsg *pb.InboundMessage) bool { clients := ws.SessionList.GetSessionsById(clientID) if clients == nil { - log.Infof("Client Not Online: %s", clientID) + log.Debugf("Client Not Online: %s", clientID) return false } @@ -124,11 +125,44 @@ func (ws *WsServer) sendInboundRelayMessage(relayMessage *pb.Relay) { sigChainLen: int(relayMessage.SigChainLen), }) } - } else { + if time.Duration(relayMessage.MaxHoldingSeconds) > pongTimeout/time.Second { + ok := ws.messageDeliveredCache.Push(relayMessage) + if !ok { + log.Warningf("MessageDeliveredCache full, discarding messages.") + } + } + } else if relayMessage.MaxHoldingSeconds > 0 { ws.messageBuffer.AddMessage(clientID, relayMessage) } } +func (ws *WsServer) startCheckingLostMessages() { + for { + v, ok := ws.messageDeliveredCache.Pop() + if !ok { + break + } + if relayMessage, ok := v.(*pb.Relay); ok { + clientID := relayMessage.DestId + clients := ws.SessionList.GetSessionsById(hex.EncodeToString(clientID)) + if len(clients) > 0 { + threshold := time.Now().Add(-pongTimeout) + success := false + for _, client := range clients { + if client.GetLastReadTime().After(threshold) { + success = true + break + } + } + if success { + continue + } + } + ws.messageBuffer.AddMessage(clientID, relayMessage) + } + } +} + func (ws *WsServer) handleReceipt(receipt *pb.Receipt) error { v, ok := ws.sigChainCache.Get(receipt.PrevSignature) if !ok { diff --git a/api/websocket/server/server.go b/api/websocket/server/server.go index 8b5c8f2eb..17369df28 100644 --- a/api/websocket/server/server.go +++ b/api/websocket/server/server.go @@ -41,6 +41,7 @@ const ( pingInterval = 8 * time.Second pongTimeout = 10 * time.Second // should be greater than pingInterval maxMessageSize = config.MaxClientMessageSize + messageDeliveredCacheSize = 65536 ) type Handler struct { @@ -50,29 +51,31 @@ type Handler struct { type WsServer struct { sync.RWMutex - Upgrader websocket.Upgrader - listener net.Listener - tlsListener net.Listener - server *http.Server - tlsServer *http.Server - SessionList *session.SessionList - ActionMap map[string]Handler - TxHashMap map[string]string //key: txHash value:sessionid - localNode *node.LocalNode - wallet vault.Wallet - messageBuffer *messagebuffer.MessageBuffer - sigChainCache Cache + Upgrader websocket.Upgrader + listener net.Listener + tlsListener net.Listener + server *http.Server + tlsServer *http.Server + SessionList *session.SessionList + ActionMap map[string]Handler + TxHashMap map[string]string //key: txHash value:sessionid + localNode *node.LocalNode + wallet vault.Wallet + messageBuffer *messagebuffer.MessageBuffer + messageDeliveredCache *DelayedChan + sigChainCache Cache } func InitWsServer(localNode *node.LocalNode, wallet vault.Wallet) *WsServer { ws := &WsServer{ - Upgrader: websocket.Upgrader{}, - SessionList: session.NewSessionList(), - TxHashMap: make(map[string]string), - localNode: localNode, - wallet: wallet, - messageBuffer: messagebuffer.NewMessageBuffer(), - sigChainCache: NewGoCache(sigChainCacheExpiration, sigChainCacheCleanupInterval), + Upgrader: websocket.Upgrader{}, + SessionList: session.NewSessionList(), + TxHashMap: make(map[string]string), + localNode: localNode, + wallet: wallet, + messageBuffer: messagebuffer.NewMessageBuffer(), + messageDeliveredCache: NewDelayedChan(messageDeliveredCacheSize, pongTimeout), + sigChainCache: NewGoCache(sigChainCacheExpiration, sigChainCacheCleanupInterval), } return ws } @@ -108,6 +111,8 @@ func (ws *WsServer) Start() error { ws.tlsServer = &http.Server{Handler: http.HandlerFunc(ws.websocketHandler)} go ws.tlsServer.Serve(ws.tlsListener) + go ws.startCheckingLostMessages() + return nil } @@ -259,6 +264,7 @@ func (ws *WsServer) websocketHandler(w http.ResponseWriter, r *http.Request) { wsConn.SetReadDeadline(time.Now().Add(pongTimeout)) wsConn.SetPongHandler(func(string) error { wsConn.SetReadDeadline(time.Now().Add(pongTimeout)) + sess.UpdateLastReadTime() return nil }) @@ -289,6 +295,7 @@ func (ws *WsServer) websocketHandler(w http.ResponseWriter, r *http.Request) { } wsConn.SetReadDeadline(time.Now().Add(pongTimeout)) + sess.UpdateLastReadTime() err = ws.OnDataHandle(sess, messageType, bysMsg, r) if err != nil { diff --git a/api/websocket/session/session.go b/api/websocket/session/session.go index 37fd6e88d..68087e68a 100644 --- a/api/websocket/session/session.go +++ b/api/websocket/session/session.go @@ -14,14 +14,16 @@ const ( ) type Session struct { - ws *websocket.Conn - sessionID string - sync.RWMutex + sessionID string clientChordID []byte clientPubKey []byte clientAddrStr *string isTlsClient bool + lastReadTime time.Time + + wsLock sync.Mutex + ws *websocket.Conn } func (s *Session) GetSessionId() string { @@ -31,8 +33,9 @@ func (s *Session) GetSessionId() string { func newSession(wsConn *websocket.Conn) (session *Session, err error) { sessionID := uuid.NewUUID().String() session = &Session{ - ws: wsConn, - sessionID: sessionID, + ws: wsConn, + sessionID: sessionID, + lastReadTime: time.Now(), } return session, nil } @@ -48,8 +51,8 @@ func (s *Session) close() { } func (s *Session) Send(msgType int, data []byte) error { - s.RLock() - defer s.RUnlock() + s.wsLock.Lock() + defer s.wsLock.Unlock() if s.ws == nil { return errors.New("Websocket is null") } @@ -126,3 +129,15 @@ func (s *Session) IsTlsClient() bool { defer s.RUnlock() return s.isTlsClient } + +func (s *Session) GetLastReadTime() time.Time { + s.RLock() + defer s.RUnlock() + return s.lastReadTime +} + +func (s *Session) UpdateLastReadTime() { + s.Lock() + s.lastReadTime = time.Now() + s.Unlock() +}