diff --git a/mtproto.go b/mtproto.go index 07d71bd..4cdd888 100644 --- a/mtproto.go +++ b/mtproto.go @@ -14,14 +14,12 @@ import ( ) type MTProto struct { - addr string - conn *net.TCPConn - f *os.File - queueSend chan packetToSend - stopSend chan struct{} - stopRead chan struct{} - stopPing chan struct{} - allDone chan struct{} + addr string + conn *net.TCPConn + f *os.File + queueSend chan packetToSend + stopRoutines chan struct{} + allDone sync.WaitGroup authKey []byte authKeyHash []byte @@ -188,10 +186,8 @@ func (m *MTProto) Connect() error { // start goroutines m.queueSend = make(chan packetToSend, 64) - m.stopSend = make(chan struct{}, 1) - m.stopRead = make(chan struct{}, 1) - m.stopPing = make(chan struct{}, 1) - m.allDone = make(chan struct{}, 3) + m.stopRoutines = make(chan struct{}) + m.allDone = sync.WaitGroup{} m.msgsIdToAck = make(map[int64]packetToSend) m.msgsIdToResp = make(map[int64]chan response) m.mutex = &sync.Mutex{} @@ -237,21 +233,11 @@ func (m *MTProto) Connect() error { } func (m *MTProto) Disconnect() error { - // stop ping routine - m.stopPing <- struct{}{} - close(m.stopPing) + // stop ping, send and read routine by closing channel stopRoutines + close(m.stopRoutines) - // stop send routine - m.stopSend <- struct{}{} - close(m.stopSend) - - // stop read routine - m.stopRead <- struct{}{} - close(m.stopRead) - - <-m.allDone - <-m.allDone - <-m.allDone + // Wait until all goroutines stopped + m.allDone.Wait() // close send queue close(m.queueSend) @@ -283,10 +269,11 @@ func (m *MTProto) reconnect(newaddr string) error { } func (m *MTProto) pingRoutine() { - defer func() { m.allDone <- struct{}{} }() + m.allDone.Add(1) + defer func() { m.allDone.Done() }() for { select { - case <-m.stopPing: + case <-m.stopRoutines: return case <-time.After(60 * time.Second): m.queueSend <- packetToSend{TL_ping{0xCADACADA}, nil} @@ -295,10 +282,11 @@ func (m *MTProto) pingRoutine() { } func (m *MTProto) sendRoutine() { - defer func() { m.allDone <- struct{}{} }() + m.allDone.Add(1) + defer func() { m.allDone.Done() }() for { select { - case <-m.stopSend: + case <-m.stopRoutines: return case x := <-m.queueSend: err := m.sendPacket(x.msg, x.resp) @@ -310,7 +298,8 @@ func (m *MTProto) sendRoutine() { } func (m *MTProto) readRoutine() { - defer func() { m.allDone <- struct{}{} }() + m.allDone.Add(1) + defer func() { m.allDone.Done() }() for { // Run async wait for data from server ch := make(chan interface{}, 1) @@ -330,7 +319,7 @@ func (m *MTProto) readRoutine() { }(ch) select { - case <-m.stopRead: + case <-m.stopRoutines: return case data := <-ch: if data == nil {