diff --git a/session.go b/session.go index d80dd19..675b326 100644 --- a/session.go +++ b/session.go @@ -26,12 +26,14 @@ type Session struct { features *tdproto.Features } +const defaultTimeout = 10 * time.Second + func NewSession(server string) (Session, error) { s := Session{ - Timeout: 10 * time.Second, + Timeout: defaultTimeout, + logger: log.New(os.Stdout, "tdclient: ", log.LstdFlags|log.Lmicroseconds|log.Lmsgprefix), } - s.logger = log.New(os.Stdout, "tdclient: ", log.LstdFlags|log.Lmicroseconds|log.Lmsgprefix) s.SetVerbose(false) u, err := url.Parse(server) diff --git a/session_test.go b/session_test.go index c4b1b64..67979af 100644 --- a/session_test.go +++ b/session_test.go @@ -15,19 +15,19 @@ func TestSession(t *testing.T) { testAccountPhone := mustEnv("TEST_ACCOUNT_PHONE") testAccountCode := mustEnv("TEST_ACCOUNT_CODE") - c, err := NewSession(testServer) + s, err := NewSession(testServer) if err != nil { t.Fatal(err) } t.Run("http ping", func(t *testing.T) { - if err := c.Ping(); err != nil { + if err := s.Ping(); err != nil { t.Fatal(err) } }) t.Run("features smoke test", func(t *testing.T) { - if _, err := c.Features(); err != nil { + if _, err := s.Features(); err != nil { t.Fatal(err) } }) @@ -35,7 +35,7 @@ func TestSession(t *testing.T) { var team tdproto.Team t.Run("sms login", func(t *testing.T) { - codeResp, err := c.AuthBySmsSendCode(testAccountPhone) + codeResp, err := s.AuthBySmsSendCode(testAccountPhone) if err != nil { t.Fatal(err) } @@ -44,7 +44,7 @@ func TestSession(t *testing.T) { t.Fatalf("invalid code length: %+v", codeResp) } - tokenResp, err := c.AuthBySmsGetToken(testAccountPhone, testAccountCode) + tokenResp, err := s.AuthBySmsGetToken(testAccountPhone, testAccountCode) if err != nil { t.Fatal(err) } @@ -60,7 +60,7 @@ func TestSession(t *testing.T) { } } - c.SetToken(tokenResp.Token) + s.SetToken(tokenResp.Token) }) if team.Uid == "" { @@ -70,12 +70,12 @@ func TestSession(t *testing.T) { var newContact tdproto.Contact t.Run("contacts list", func(t *testing.T) { anyPhone := "+79870000000" - newContact, err = c.AddContact(team.Uid, anyPhone) + newContact, err = s.AddContact(team.Uid, anyPhone) if err != nil { t.Fatal(err) } - contacts, err := c.Contacts(team.Uid) + contacts, err := s.Contacts(team.Uid) if err != nil { t.Fatal(err) } @@ -92,7 +92,7 @@ func TestSession(t *testing.T) { }) t.Run("messages", func(t *testing.T) { - message, err := c.SendPlaintextMessage(team.Uid, newContact.Jid, kozma.Say()) + message, err := s.SendPlaintextMessage(team.Uid, newContact.Jid, kozma.Say()) if err != nil { t.Fatal(err) } @@ -105,7 +105,7 @@ func TestSession(t *testing.T) { filter := new(tdapi.MessageFilter) filter.Lang = "ru" filter.Limit = 200 - messages, err := c.GetMessages(team.Uid, newContact.Jid, filter) + messages, err := s.GetMessages(team.Uid, newContact.Jid, filter) if err != nil { t.Fatal(err) } @@ -115,7 +115,7 @@ func TestSession(t *testing.T) { }) t.Run("delete messages", func(t *testing.T) { - _, err := c.DeleteMessage(team.Uid, newContact.Jid, message.MessageId) + _, err := s.DeleteMessage(team.Uid, newContact.Jid, message.MessageId) if err != nil { t.Fatal(err) } @@ -124,7 +124,7 @@ func TestSession(t *testing.T) { }) t.Run("me smoke test", func(t *testing.T) { - me, err := c.Me(team.Uid) + me, err := s.Me(team.Uid) if err != nil { t.Fatal(err) } @@ -134,7 +134,7 @@ func TestSession(t *testing.T) { }) t.Run("ws", func(t *testing.T) { - ws, err := c.Ws(team.Uid, func(err error) { + ws, err := s.Ws(team.Uid, func(err error) { t.Fatal(err) }) if err != nil { @@ -172,12 +172,18 @@ func TestSession(t *testing.T) { t.Fatal("invalid message uid") } }) + + t.Run("close", func(t *testing.T) { + if err := ws.Close(); err != nil { + t.Fatal("close ws session fail") + } + }) }) }) t.Run("create task", func(t *testing.T) { text := kozma.Say() - task, err := c.CreateTask(team.Uid, tdapi.Task{ + task, err := s.CreateTask(team.Uid, tdapi.Task{ Description: text, Tags: []string{"autotest"}, Assignee: newContact.Jid, @@ -194,7 +200,7 @@ func TestSession(t *testing.T) { }) t.Run("chats", func(t *testing.T) { - chats, err := c.GetChats(team.Uid, &tdapi.ChatFilter{ + chats, err := s.GetChats(team.Uid, &tdapi.ChatFilter{ ChatType: "direct", Paginator: tdapi.Paginator{ Limit: 1, @@ -216,7 +222,7 @@ func TestSession(t *testing.T) { }) t.Run("groups", func(t *testing.T) { - group, err := c.CreateGroup(team.Uid, tdapi.Group{ + group, err := s.CreateGroup(team.Uid, tdapi.Group{ DisplayName: "test group", Public: false, }) @@ -225,7 +231,7 @@ func TestSession(t *testing.T) { } t.Run("add member", func(t *testing.T) { - member, err := c.AddGroupMember(team.Uid, group.Jid, newContact.Jid) + member, err := s.AddGroupMember(team.Uid, group.Jid, newContact.Jid) if err != nil { t.Fatal(err) } @@ -235,7 +241,7 @@ func TestSession(t *testing.T) { }) t.Run("get members", func(t *testing.T) { - members, err := c.GroupMembers(team.Uid, group.Jid) + members, err := s.GroupMembers(team.Uid, group.Jid) if err != nil { t.Fatal(err) } @@ -245,19 +251,19 @@ func TestSession(t *testing.T) { }) t.Run("remove member", func(t *testing.T) { - if err := c.DropGroupMember(team.Uid, group.Jid, newContact.Jid); err != nil { + if err := s.DropGroupMember(team.Uid, group.Jid, newContact.Jid); err != nil { t.Fatal(err) } }) t.Run("remove group", func(t *testing.T) { - if err := c.DropGroup(team.Uid, group.Jid); err != nil { + if err := s.DropGroup(team.Uid, group.Jid); err != nil { t.Fatal(err) } }) t.Run("group list smoke test", func(t *testing.T) { - _, err := c.GetGroups(team.Uid) + _, err := s.GetGroups(team.Uid) if err != nil { t.Fatal(err) } diff --git a/ws.go b/ws.go index a856140..3cc5fb5 100644 --- a/ws.go +++ b/ws.go @@ -1,6 +1,7 @@ package tdclient import ( + "context" "fmt" "log" "net/http" @@ -45,6 +46,8 @@ func (s *Session) Ws(team string, onfail func(error)) (*WsSession, error) { fail: make(chan error), } + w.ctx, w.cancel = context.WithCancel(context.Background()) + go func() { err := <-w.fail if err != nil { @@ -70,11 +73,12 @@ type WsSession struct { session *Session team string conn *websocket.Conn - closed bool inbox chan serverEvent outBytes chan []byte fail chan error listeners map[string]chan []byte + ctx context.Context + cancel context.CancelFunc } func (w *WsSession) Ping() string { @@ -176,13 +180,15 @@ func (w *WsSession) SendRaw(b []byte) { } func (w *WsSession) Close() error { - w.closed = true + w.cancel() return w.conn.Close() } func (w *WsSession) outboxLoop() { - for !w.closed { + for { select { + case <-w.ctx.Done(): + return case b := <-w.outBytes: w.session.logger.Println("send:", string(b)) if err := w.conn.WriteMessage(websocket.BinaryMessage, b); err != nil { @@ -193,12 +199,14 @@ func (w *WsSession) outboxLoop() { } } -func (w WsSession) inboxLoop() { +func (w *WsSession) inboxLoop() { var parser fastjson.Parser - for !w.closed { + for { _, data, err := w.conn.ReadMessage() if err != nil { - w.fail <- errors.Wrap(err, "conn read fail") + if w.ctx.Err() == nil { + w.fail <- errors.Wrap(err, "conn read fail") + } return } @@ -231,6 +239,8 @@ func (w WsSession) inboxLoop() { select { case w.inbox <- ev: + case <-w.ctx.Done(): + return default: w.fail <- errors.Wrapf(err, "full inbox") } @@ -253,4 +263,3 @@ func (w *WsSession) SendCallLeave(jid tdproto.JID) { callLeave.Params.Reason = "" w.Send(callLeave) } -