Skip to content

Commit

Permalink
close session fix
Browse files Browse the repository at this point in the history
  • Loading branch information
sdfsdhgjkbmnmxc committed Mar 18, 2021
1 parent 9a34e84 commit e5ce91e
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 30 deletions.
6 changes: 4 additions & 2 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ type Session struct {
features *tdproto.Features
}

const defaultTimeout = 10 * time.Second

func NewSession(server string) (Session, error) {
s := Session{
Timeout: 10 * time.Second,
Timeout: defaultTimeout,
logger: log.New(os.Stdout, "tdclient: ", log.LstdFlags|log.Lmicroseconds|log.Lmsgprefix),
}

s.logger = log.New(os.Stdout, "tdclient: ", log.LstdFlags|log.Lmicroseconds|log.Lmsgprefix)
s.SetVerbose(false)

u, err := url.Parse(server)
Expand Down
48 changes: 27 additions & 21 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,27 @@ func TestSession(t *testing.T) {
testAccountPhone := mustEnv("TEST_ACCOUNT_PHONE")
testAccountCode := mustEnv("TEST_ACCOUNT_CODE")

c, err := NewSession(testServer)
s, err := NewSession(testServer)
if err != nil {
t.Fatal(err)
}

t.Run("http ping", func(t *testing.T) {
if err := c.Ping(); err != nil {
if err := s.Ping(); err != nil {
t.Fatal(err)
}
})

t.Run("features smoke test", func(t *testing.T) {
if _, err := c.Features(); err != nil {
if _, err := s.Features(); err != nil {
t.Fatal(err)
}
})

var team tdproto.Team

t.Run("sms login", func(t *testing.T) {
codeResp, err := c.AuthBySmsSendCode(testAccountPhone)
codeResp, err := s.AuthBySmsSendCode(testAccountPhone)
if err != nil {
t.Fatal(err)
}
Expand All @@ -44,7 +44,7 @@ func TestSession(t *testing.T) {
t.Fatalf("invalid code length: %+v", codeResp)
}

tokenResp, err := c.AuthBySmsGetToken(testAccountPhone, testAccountCode)
tokenResp, err := s.AuthBySmsGetToken(testAccountPhone, testAccountCode)
if err != nil {
t.Fatal(err)
}
Expand All @@ -60,7 +60,7 @@ func TestSession(t *testing.T) {
}
}

c.SetToken(tokenResp.Token)
s.SetToken(tokenResp.Token)
})

if team.Uid == "" {
Expand All @@ -70,12 +70,12 @@ func TestSession(t *testing.T) {
var newContact tdproto.Contact
t.Run("contacts list", func(t *testing.T) {
anyPhone := "+79870000000"
newContact, err = c.AddContact(team.Uid, anyPhone)
newContact, err = s.AddContact(team.Uid, anyPhone)
if err != nil {
t.Fatal(err)
}

contacts, err := c.Contacts(team.Uid)
contacts, err := s.Contacts(team.Uid)
if err != nil {
t.Fatal(err)
}
Expand All @@ -92,7 +92,7 @@ func TestSession(t *testing.T) {
})

t.Run("messages", func(t *testing.T) {
message, err := c.SendPlaintextMessage(team.Uid, newContact.Jid, kozma.Say())
message, err := s.SendPlaintextMessage(team.Uid, newContact.Jid, kozma.Say())
if err != nil {
t.Fatal(err)
}
Expand All @@ -105,7 +105,7 @@ func TestSession(t *testing.T) {
filter := new(tdapi.MessageFilter)
filter.Lang = "ru"
filter.Limit = 200
messages, err := c.GetMessages(team.Uid, newContact.Jid, filter)
messages, err := s.GetMessages(team.Uid, newContact.Jid, filter)
if err != nil {
t.Fatal(err)
}
Expand All @@ -115,7 +115,7 @@ func TestSession(t *testing.T) {
})

t.Run("delete messages", func(t *testing.T) {
_, err := c.DeleteMessage(team.Uid, newContact.Jid, message.MessageId)
_, err := s.DeleteMessage(team.Uid, newContact.Jid, message.MessageId)
if err != nil {
t.Fatal(err)
}
Expand All @@ -124,7 +124,7 @@ func TestSession(t *testing.T) {
})

t.Run("me smoke test", func(t *testing.T) {
me, err := c.Me(team.Uid)
me, err := s.Me(team.Uid)
if err != nil {
t.Fatal(err)
}
Expand All @@ -134,7 +134,7 @@ func TestSession(t *testing.T) {
})

t.Run("ws", func(t *testing.T) {
ws, err := c.Ws(team.Uid, func(err error) {
ws, err := s.Ws(team.Uid, func(err error) {
t.Fatal(err)
})
if err != nil {
Expand Down Expand Up @@ -172,12 +172,18 @@ func TestSession(t *testing.T) {
t.Fatal("invalid message uid")
}
})

t.Run("close", func(t *testing.T) {
if err := ws.Close(); err != nil {
t.Fatal("close ws session fail")
}
})
})
})

t.Run("create task", func(t *testing.T) {
text := kozma.Say()
task, err := c.CreateTask(team.Uid, tdapi.Task{
task, err := s.CreateTask(team.Uid, tdapi.Task{
Description: text,
Tags: []string{"autotest"},
Assignee: newContact.Jid,
Expand All @@ -194,7 +200,7 @@ func TestSession(t *testing.T) {
})

t.Run("chats", func(t *testing.T) {
chats, err := c.GetChats(team.Uid, &tdapi.ChatFilter{
chats, err := s.GetChats(team.Uid, &tdapi.ChatFilter{
ChatType: "direct",
Paginator: tdapi.Paginator{
Limit: 1,
Expand All @@ -216,7 +222,7 @@ func TestSession(t *testing.T) {
})

t.Run("groups", func(t *testing.T) {
group, err := c.CreateGroup(team.Uid, tdapi.Group{
group, err := s.CreateGroup(team.Uid, tdapi.Group{
DisplayName: "test group",
Public: false,
})
Expand All @@ -225,7 +231,7 @@ func TestSession(t *testing.T) {
}

t.Run("add member", func(t *testing.T) {
member, err := c.AddGroupMember(team.Uid, group.Jid, newContact.Jid)
member, err := s.AddGroupMember(team.Uid, group.Jid, newContact.Jid)
if err != nil {
t.Fatal(err)
}
Expand All @@ -235,7 +241,7 @@ func TestSession(t *testing.T) {
})

t.Run("get members", func(t *testing.T) {
members, err := c.GroupMembers(team.Uid, group.Jid)
members, err := s.GroupMembers(team.Uid, group.Jid)
if err != nil {
t.Fatal(err)
}
Expand All @@ -245,19 +251,19 @@ func TestSession(t *testing.T) {
})

t.Run("remove member", func(t *testing.T) {
if err := c.DropGroupMember(team.Uid, group.Jid, newContact.Jid); err != nil {
if err := s.DropGroupMember(team.Uid, group.Jid, newContact.Jid); err != nil {
t.Fatal(err)
}
})

t.Run("remove group", func(t *testing.T) {
if err := c.DropGroup(team.Uid, group.Jid); err != nil {
if err := s.DropGroup(team.Uid, group.Jid); err != nil {
t.Fatal(err)
}
})

t.Run("group list smoke test", func(t *testing.T) {
_, err := c.GetGroups(team.Uid)
_, err := s.GetGroups(team.Uid)
if err != nil {
t.Fatal(err)
}
Expand Down
23 changes: 16 additions & 7 deletions ws.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package tdclient

import (
"context"
"fmt"
"log"
"net/http"
Expand Down Expand Up @@ -45,6 +46,8 @@ func (s *Session) Ws(team string, onfail func(error)) (*WsSession, error) {
fail: make(chan error),
}

w.ctx, w.cancel = context.WithCancel(context.Background())

go func() {
err := <-w.fail
if err != nil {
Expand All @@ -70,11 +73,12 @@ type WsSession struct {
session *Session
team string
conn *websocket.Conn
closed bool
inbox chan serverEvent
outBytes chan []byte
fail chan error
listeners map[string]chan []byte
ctx context.Context
cancel context.CancelFunc
}

func (w *WsSession) Ping() string {
Expand Down Expand Up @@ -176,13 +180,15 @@ func (w *WsSession) SendRaw(b []byte) {
}

func (w *WsSession) Close() error {
w.closed = true
w.cancel()
return w.conn.Close()
}

func (w *WsSession) outboxLoop() {
for !w.closed {
for {
select {
case <-w.ctx.Done():
return
case b := <-w.outBytes:
w.session.logger.Println("send:", string(b))
if err := w.conn.WriteMessage(websocket.BinaryMessage, b); err != nil {
Expand All @@ -193,12 +199,14 @@ func (w *WsSession) outboxLoop() {
}
}

func (w WsSession) inboxLoop() {
func (w *WsSession) inboxLoop() {
var parser fastjson.Parser
for !w.closed {
for {
_, data, err := w.conn.ReadMessage()
if err != nil {
w.fail <- errors.Wrap(err, "conn read fail")
if w.ctx.Err() == nil {
w.fail <- errors.Wrap(err, "conn read fail")
}
return
}

Expand Down Expand Up @@ -231,6 +239,8 @@ func (w WsSession) inboxLoop() {

select {
case w.inbox <- ev:
case <-w.ctx.Done():
return
default:
w.fail <- errors.Wrapf(err, "full inbox")
}
Expand All @@ -253,4 +263,3 @@ func (w *WsSession) SendCallLeave(jid tdproto.JID) {
callLeave.Params.Reason = ""
w.Send(callLeave)
}

0 comments on commit e5ce91e

Please sign in to comment.