mirror of
https://github.com/c9s/bbgo.git
synced 2024-11-15 03:23:52 +00:00
types: refactor, redesign connectivity
This commit is contained in:
parent
1eb0c1b092
commit
ebab50f30c
|
@ -1,96 +1,9 @@
|
||||||
package types
|
package types
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type ConnectivityGroup struct {
|
|
||||||
connections []*Connectivity
|
|
||||||
mu sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewConnectivityGroup(cons ...*Connectivity) *ConnectivityGroup {
|
|
||||||
return &ConnectivityGroup{
|
|
||||||
connections: cons,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *ConnectivityGroup) Add(con *Connectivity) {
|
|
||||||
g.mu.Lock()
|
|
||||||
defer g.mu.Unlock()
|
|
||||||
|
|
||||||
g.connections = append(g.connections, con)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *ConnectivityGroup) AnyDisconnected(ctx context.Context) bool {
|
|
||||||
g.mu.Lock()
|
|
||||||
conns := g.connections
|
|
||||||
g.mu.Unlock()
|
|
||||||
|
|
||||||
for _, conn := range conns {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return false
|
|
||||||
|
|
||||||
case <-conn.connectedC:
|
|
||||||
continue
|
|
||||||
|
|
||||||
case <-conn.disconnectedC:
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *ConnectivityGroup) waitAllAuthed(ctx context.Context, c chan struct{}, allTimeoutDuration time.Duration) {
|
|
||||||
g.mu.Lock()
|
|
||||||
conns := g.connections
|
|
||||||
g.mu.Unlock()
|
|
||||||
|
|
||||||
authedConns := make([]bool, len(conns))
|
|
||||||
allTimeout := time.After(allTimeoutDuration)
|
|
||||||
for {
|
|
||||||
for idx, con := range conns {
|
|
||||||
// if the connection is not authed, mark it as false
|
|
||||||
if !con.authed {
|
|
||||||
// authedConns[idx] = false
|
|
||||||
}
|
|
||||||
|
|
||||||
timeout := time.After(3 * time.Second)
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
|
|
||||||
case <-allTimeout:
|
|
||||||
return
|
|
||||||
|
|
||||||
case <-timeout:
|
|
||||||
continue
|
|
||||||
|
|
||||||
case <-con.AuthedC():
|
|
||||||
authedConns[idx] = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if allTrue(authedConns) {
|
|
||||||
close(c)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// AllAuthedC returns a channel that will be closed when all connections are authenticated
|
|
||||||
// the returned channel will be closed when all connections are authenticated
|
|
||||||
// and the channel can only be used once (because we can't close a channel twice)
|
|
||||||
func (g *ConnectivityGroup) AllAuthedC(ctx context.Context, timeout time.Duration) <-chan struct{} {
|
|
||||||
c := make(chan struct{})
|
|
||||||
go g.waitAllAuthed(ctx, c, timeout)
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
func allTrue(bools []bool) bool {
|
func allTrue(bools []bool) bool {
|
||||||
for _, b := range bools {
|
for _, b := range bools {
|
||||||
if !b {
|
if !b {
|
||||||
|
@ -101,6 +14,7 @@ func allTrue(bools []bool) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//go:generate callbackgen -type Connectivity
|
||||||
type Connectivity struct {
|
type Connectivity struct {
|
||||||
authed bool
|
authed bool
|
||||||
authedC chan struct{}
|
authedC chan struct{}
|
||||||
|
@ -109,7 +23,12 @@ type Connectivity struct {
|
||||||
connectedC chan struct{}
|
connectedC chan struct{}
|
||||||
disconnectedC chan struct{}
|
disconnectedC chan struct{}
|
||||||
|
|
||||||
mu sync.Mutex
|
connectCallbacks []func()
|
||||||
|
disconnectCallbacks []func()
|
||||||
|
authCallbacks []func()
|
||||||
|
|
||||||
|
stream Stream
|
||||||
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConnectivity() *Connectivity {
|
func NewConnectivity() *Connectivity {
|
||||||
|
@ -141,31 +60,39 @@ func (c *Connectivity) IsAuthed() (authed bool) {
|
||||||
return authed
|
return authed
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Connectivity) handleConnect() {
|
func (c *Connectivity) setConnect() {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
if !c.connected {
|
||||||
|
c.connected = true
|
||||||
c.connected = true
|
close(c.connectedC)
|
||||||
close(c.connectedC)
|
c.disconnectedC = make(chan struct{})
|
||||||
c.disconnectedC = make(chan struct{})
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
|
c.EmitConnect()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Connectivity) handleDisconnect() {
|
func (c *Connectivity) setDisconnect() {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
if c.connected {
|
||||||
|
c.connected = false
|
||||||
c.connected = false
|
c.authed = false
|
||||||
c.authedC = make(chan struct{})
|
c.authedC = make(chan struct{})
|
||||||
c.connectedC = make(chan struct{})
|
c.connectedC = make(chan struct{})
|
||||||
close(c.disconnectedC)
|
close(c.disconnectedC)
|
||||||
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
|
c.EmitDisconnect()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Connectivity) handleAuth() {
|
func (c *Connectivity) setAuthed() {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
if !c.authed {
|
||||||
|
c.authed = true
|
||||||
|
close(c.authedC)
|
||||||
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
c.authed = true
|
c.EmitAuth()
|
||||||
close(c.authedC)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Connectivity) AuthedC() chan struct{} {
|
func (c *Connectivity) AuthedC() chan struct{} {
|
||||||
|
@ -187,7 +114,8 @@ func (c *Connectivity) DisconnectedC() chan struct{} {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Connectivity) Bind(stream Stream) {
|
func (c *Connectivity) Bind(stream Stream) {
|
||||||
stream.OnConnect(c.handleConnect)
|
stream.OnConnect(c.setConnect)
|
||||||
stream.OnDisconnect(c.handleDisconnect)
|
stream.OnDisconnect(c.setDisconnect)
|
||||||
stream.OnAuth(c.handleAuth)
|
stream.OnAuth(c.setAuthed)
|
||||||
|
c.stream = stream
|
||||||
}
|
}
|
||||||
|
|
35
pkg/types/connectivity_callbacks.go
Normal file
35
pkg/types/connectivity_callbacks.go
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
// Code generated by "callbackgen -type Connectivity"; DO NOT EDIT.
|
||||||
|
|
||||||
|
package types
|
||||||
|
|
||||||
|
import ()
|
||||||
|
|
||||||
|
func (c *Connectivity) OnConnect(cb func()) {
|
||||||
|
c.connectCallbacks = append(c.connectCallbacks, cb)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connectivity) EmitConnect() {
|
||||||
|
for _, cb := range c.connectCallbacks {
|
||||||
|
cb()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connectivity) OnDisconnect(cb func()) {
|
||||||
|
c.disconnectCallbacks = append(c.disconnectCallbacks, cb)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connectivity) EmitDisconnect() {
|
||||||
|
for _, cb := range c.disconnectCallbacks {
|
||||||
|
cb()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connectivity) OnAuth(cb func()) {
|
||||||
|
c.authCallbacks = append(c.authCallbacks, cb)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connectivity) EmitAuth() {
|
||||||
|
for _, cb := range c.authCallbacks {
|
||||||
|
cb()
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,105 +1,39 @@
|
||||||
package types
|
package types
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestConnectivity(t *testing.T) {
|
func TestConnectivity(t *testing.T) {
|
||||||
t.Run("general", func(t *testing.T) {
|
t.Run("general", func(t *testing.T) {
|
||||||
conn1 := NewConnectivity()
|
conn1 := NewConnectivity()
|
||||||
conn1.handleConnect()
|
conn1.setConnect()
|
||||||
conn1.handleAuth()
|
conn1.setAuthed()
|
||||||
conn1.handleDisconnect()
|
conn1.setDisconnect()
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("reconnect", func(t *testing.T) {
|
t.Run("reconnect", func(t *testing.T) {
|
||||||
conn1 := NewConnectivity()
|
conn1 := NewConnectivity()
|
||||||
conn1.handleConnect()
|
conn1.setConnect()
|
||||||
conn1.handleAuth()
|
conn1.setAuthed()
|
||||||
conn1.handleDisconnect()
|
conn1.setDisconnect()
|
||||||
|
|
||||||
conn1.handleConnect()
|
conn1.setConnect()
|
||||||
conn1.handleAuth()
|
conn1.setAuthed()
|
||||||
conn1.handleDisconnect()
|
conn1.setDisconnect()
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("no-auth reconnect", func(t *testing.T) {
|
t.Run("no-auth reconnect", func(t *testing.T) {
|
||||||
conn1 := NewConnectivity()
|
conn1 := NewConnectivity()
|
||||||
conn1.handleConnect()
|
conn1.setConnect()
|
||||||
conn1.handleDisconnect()
|
conn1.setDisconnect()
|
||||||
|
|
||||||
conn1.handleConnect()
|
conn1.setConnect()
|
||||||
conn1.handleDisconnect()
|
conn1.setDisconnect()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnectivityGroupAuthC(t *testing.T) {
|
|
||||||
timeout := 100 * time.Millisecond
|
|
||||||
delay := timeout * 2
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
conn1 := NewConnectivity()
|
|
||||||
conn2 := NewConnectivity()
|
|
||||||
group := NewConnectivityGroup(conn1, conn2)
|
|
||||||
allAuthedC := group.AllAuthedC(ctx, time.Second)
|
|
||||||
|
|
||||||
time.Sleep(delay)
|
|
||||||
conn1.handleConnect()
|
|
||||||
assert.True(t, waitSigChan(conn1.ConnectedC(), timeout))
|
|
||||||
conn1.handleAuth()
|
|
||||||
assert.True(t, waitSigChan(conn1.AuthedC(), timeout))
|
|
||||||
|
|
||||||
time.Sleep(delay)
|
|
||||||
conn2.handleConnect()
|
|
||||||
assert.True(t, waitSigChan(conn2.ConnectedC(), timeout))
|
|
||||||
|
|
||||||
conn2.handleAuth()
|
|
||||||
assert.True(t, waitSigChan(conn2.AuthedC(), timeout))
|
|
||||||
|
|
||||||
assert.True(t, waitSigChan(allAuthedC, timeout))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConnectivityGroupReconnect(t *testing.T) {
|
|
||||||
timeout := 100 * time.Millisecond
|
|
||||||
delay := timeout * 2
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
conn1 := NewConnectivity()
|
|
||||||
conn2 := NewConnectivity()
|
|
||||||
group := NewConnectivityGroup(conn1, conn2)
|
|
||||||
|
|
||||||
time.Sleep(delay)
|
|
||||||
conn1.handleConnect()
|
|
||||||
conn1.handleAuth()
|
|
||||||
conn1authC := conn1.authedC
|
|
||||||
|
|
||||||
time.Sleep(delay)
|
|
||||||
conn2.handleConnect()
|
|
||||||
conn2.handleAuth()
|
|
||||||
|
|
||||||
assert.True(t, waitSigChan(group.AllAuthedC(ctx, time.Second), timeout), "all connections are authenticated")
|
|
||||||
|
|
||||||
assert.False(t, group.AnyDisconnected(ctx))
|
|
||||||
|
|
||||||
// this should re-allocate authedC
|
|
||||||
conn1.handleDisconnect()
|
|
||||||
assert.NotEqual(t, conn1authC, conn1.authedC)
|
|
||||||
|
|
||||||
assert.True(t, group.AnyDisconnected(ctx))
|
|
||||||
|
|
||||||
assert.False(t, waitSigChan(group.AllAuthedC(ctx, time.Second), timeout), "one connection should be un-authed")
|
|
||||||
|
|
||||||
time.Sleep(delay)
|
|
||||||
|
|
||||||
conn1.handleConnect()
|
|
||||||
conn1.handleAuth()
|
|
||||||
assert.True(t, waitSigChan(group.AllAuthedC(ctx, time.Second), timeout), "all connections are authenticated, again")
|
|
||||||
}
|
|
||||||
|
|
||||||
func waitSigChan(c <-chan struct{}, timeoutDuration time.Duration) bool {
|
func waitSigChan(c <-chan struct{}, timeoutDuration time.Duration) bool {
|
||||||
select {
|
select {
|
||||||
case <-time.After(timeoutDuration):
|
case <-time.After(timeoutDuration):
|
||||||
|
|
237
pkg/types/connectivitygroup.go
Normal file
237
pkg/types/connectivitygroup.go
Normal file
|
@ -0,0 +1,237 @@
|
||||||
|
package types
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:generate callbackgen -type ConnectivityGroup
|
||||||
|
type ConnectivityGroup struct {
|
||||||
|
*Connectivity
|
||||||
|
|
||||||
|
connections []*Connectivity
|
||||||
|
mu sync.Mutex
|
||||||
|
authedC chan struct{}
|
||||||
|
|
||||||
|
states map[*Connectivity]ConnectivityState
|
||||||
|
sumState ConnectivityState
|
||||||
|
|
||||||
|
connectCallbacks []func()
|
||||||
|
disconnectCallbacks []func()
|
||||||
|
authCallbacks []func()
|
||||||
|
}
|
||||||
|
|
||||||
|
type ConnectivityState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
ConnectivityStateDisconnected ConnectivityState = -1
|
||||||
|
ConnectivityStateUnknown ConnectivityState = 0
|
||||||
|
ConnectivityStateConnected ConnectivityState = 1
|
||||||
|
ConnectivityStateAuthed ConnectivityState = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
func getConnState(con *Connectivity) ConnectivityState {
|
||||||
|
state := ConnectivityStateDisconnected
|
||||||
|
|
||||||
|
if con.IsAuthed() {
|
||||||
|
state = ConnectivityStateAuthed
|
||||||
|
} else if con.IsConnected() {
|
||||||
|
state = ConnectivityStateConnected
|
||||||
|
}
|
||||||
|
|
||||||
|
return state
|
||||||
|
}
|
||||||
|
|
||||||
|
func initConnStates(cons []*Connectivity) map[*Connectivity]ConnectivityState {
|
||||||
|
states := map[*Connectivity]ConnectivityState{}
|
||||||
|
|
||||||
|
for _, con := range cons {
|
||||||
|
state := ConnectivityStateDisconnected
|
||||||
|
|
||||||
|
if con.IsAuthed() {
|
||||||
|
state = ConnectivityStateAuthed
|
||||||
|
} else if con.IsConnected() {
|
||||||
|
state = ConnectivityStateConnected
|
||||||
|
}
|
||||||
|
|
||||||
|
states[con] = state
|
||||||
|
}
|
||||||
|
|
||||||
|
return states
|
||||||
|
}
|
||||||
|
|
||||||
|
func sumStates(states map[*Connectivity]ConnectivityState) ConnectivityState {
|
||||||
|
disconnected := 0
|
||||||
|
connected := 0
|
||||||
|
authed := 0
|
||||||
|
|
||||||
|
for _, state := range states {
|
||||||
|
switch state {
|
||||||
|
case ConnectivityStateDisconnected:
|
||||||
|
disconnected++
|
||||||
|
case ConnectivityStateConnected:
|
||||||
|
connected++
|
||||||
|
case ConnectivityStateAuthed:
|
||||||
|
authed++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
numConn := len(states)
|
||||||
|
|
||||||
|
// if one of the connections is disconnected, the group is disconnected
|
||||||
|
if disconnected > 0 {
|
||||||
|
return ConnectivityStateDisconnected
|
||||||
|
} else if authed == numConn {
|
||||||
|
// if all connections are authed, the group is authed
|
||||||
|
return ConnectivityStateAuthed
|
||||||
|
} else if connected == numConn || (connected+authed) == numConn {
|
||||||
|
// if all connections are connected, the group is connected
|
||||||
|
return ConnectivityStateConnected
|
||||||
|
}
|
||||||
|
|
||||||
|
return ConnectivityStateUnknown
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConnectivityGroup(cons ...*Connectivity) *ConnectivityGroup {
|
||||||
|
states := initConnStates(cons)
|
||||||
|
sumState := sumStates(states)
|
||||||
|
g := &ConnectivityGroup{
|
||||||
|
Connectivity: NewConnectivity(),
|
||||||
|
connections: cons,
|
||||||
|
authedC: make(chan struct{}),
|
||||||
|
states: states,
|
||||||
|
sumState: sumState,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, con := range cons {
|
||||||
|
g.Add(con)
|
||||||
|
}
|
||||||
|
|
||||||
|
return g
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *ConnectivityGroup) setState(con *Connectivity, state ConnectivityState) {
|
||||||
|
g.mu.Lock()
|
||||||
|
g.states[con] = state
|
||||||
|
prevState := g.sumState
|
||||||
|
curState := sumStates(g.states)
|
||||||
|
g.sumState = curState
|
||||||
|
g.mu.Unlock()
|
||||||
|
|
||||||
|
// if the state is not changed, return
|
||||||
|
if prevState == curState {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
g.setFlags(curState)
|
||||||
|
switch curState {
|
||||||
|
case ConnectivityStateAuthed:
|
||||||
|
g.EmitAuth()
|
||||||
|
case ConnectivityStateConnected:
|
||||||
|
g.EmitConnect()
|
||||||
|
case ConnectivityStateDisconnected:
|
||||||
|
g.EmitDisconnect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *ConnectivityGroup) setFlags(state ConnectivityState) {
|
||||||
|
switch state {
|
||||||
|
case ConnectivityStateAuthed:
|
||||||
|
g.Connectivity.setConnect()
|
||||||
|
g.Connectivity.setAuthed()
|
||||||
|
case ConnectivityStateConnected:
|
||||||
|
g.Connectivity.setConnect()
|
||||||
|
case ConnectivityStateDisconnected:
|
||||||
|
g.Connectivity.setDisconnect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *ConnectivityGroup) Add(con *Connectivity) {
|
||||||
|
g.mu.Lock()
|
||||||
|
g.connections = append(g.connections, con)
|
||||||
|
g.states[con] = getConnState(con)
|
||||||
|
g.sumState = sumStates(g.states)
|
||||||
|
g.setFlags(g.sumState)
|
||||||
|
g.mu.Unlock()
|
||||||
|
|
||||||
|
_con := con
|
||||||
|
con.OnDisconnect(func() {
|
||||||
|
g.setState(_con, ConnectivityStateDisconnected)
|
||||||
|
})
|
||||||
|
|
||||||
|
con.OnConnect(func() {
|
||||||
|
g.setState(_con, ConnectivityStateConnected)
|
||||||
|
})
|
||||||
|
|
||||||
|
con.OnAuth(func() {
|
||||||
|
g.setState(_con, ConnectivityStateAuthed)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *ConnectivityGroup) AnyDisconnected(ctx context.Context) bool {
|
||||||
|
g.mu.Lock()
|
||||||
|
conns := g.connections
|
||||||
|
g.mu.Unlock()
|
||||||
|
|
||||||
|
for _, conn := range conns {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return false
|
||||||
|
|
||||||
|
case <-conn.connectedC:
|
||||||
|
continue
|
||||||
|
|
||||||
|
case <-conn.disconnectedC:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *ConnectivityGroup) waitAllAuthed(ctx context.Context, allTimeoutDuration time.Duration) {
|
||||||
|
g.mu.Lock()
|
||||||
|
conns := g.connections
|
||||||
|
c := g.authedC
|
||||||
|
g.mu.Unlock()
|
||||||
|
|
||||||
|
authedConns := make([]bool, len(conns))
|
||||||
|
allTimeout := time.After(allTimeoutDuration)
|
||||||
|
for {
|
||||||
|
for idx, con := range conns {
|
||||||
|
// if the connection is not authed, mark it as false
|
||||||
|
if !con.authed {
|
||||||
|
// authedConns[idx] = false
|
||||||
|
}
|
||||||
|
|
||||||
|
timeout := time.After(3 * time.Second)
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
|
||||||
|
case <-allTimeout:
|
||||||
|
return
|
||||||
|
|
||||||
|
case <-timeout:
|
||||||
|
continue
|
||||||
|
|
||||||
|
case <-con.AuthedC():
|
||||||
|
authedConns[idx] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if allTrue(authedConns) {
|
||||||
|
close(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllAuthedC returns a channel that will be closed when all connections are authenticated
|
||||||
|
// the returned channel will be closed when all connections are authenticated
|
||||||
|
// and the channel can only be used once (because we can't close a channel twice)
|
||||||
|
func (g *ConnectivityGroup) AllAuthedC(ctx context.Context, timeout time.Duration) <-chan struct{} {
|
||||||
|
go g.waitAllAuthed(ctx, timeout)
|
||||||
|
return g.authedC
|
||||||
|
}
|
35
pkg/types/connectivitygroup_callbacks.go
Normal file
35
pkg/types/connectivitygroup_callbacks.go
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
// Code generated by "callbackgen -type ConnectivityGroup"; DO NOT EDIT.
|
||||||
|
|
||||||
|
package types
|
||||||
|
|
||||||
|
import ()
|
||||||
|
|
||||||
|
func (g *ConnectivityGroup) OnConnect(cb func()) {
|
||||||
|
g.connectCallbacks = append(g.connectCallbacks, cb)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *ConnectivityGroup) EmitConnect() {
|
||||||
|
for _, cb := range g.connectCallbacks {
|
||||||
|
cb()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *ConnectivityGroup) OnDisconnect(cb func()) {
|
||||||
|
g.disconnectCallbacks = append(g.disconnectCallbacks, cb)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *ConnectivityGroup) EmitDisconnect() {
|
||||||
|
for _, cb := range g.disconnectCallbacks {
|
||||||
|
cb()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *ConnectivityGroup) OnAuth(cb func()) {
|
||||||
|
g.authCallbacks = append(g.authCallbacks, cb)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *ConnectivityGroup) EmitAuth() {
|
||||||
|
for _, cb := range g.authCallbacks {
|
||||||
|
cb()
|
||||||
|
}
|
||||||
|
}
|
239
pkg/types/connectivitygroup_test.go
Normal file
239
pkg/types/connectivitygroup_test.go
Normal file
|
@ -0,0 +1,239 @@
|
||||||
|
package types
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConnectivityGroupAuthC(t *testing.T) {
|
||||||
|
timeout := 100 * time.Millisecond
|
||||||
|
delay := timeout * 2
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
conn1 := NewConnectivity()
|
||||||
|
conn2 := NewConnectivity()
|
||||||
|
group := NewConnectivityGroup(conn1, conn2)
|
||||||
|
allAuthedC := group.AllAuthedC(ctx, time.Second)
|
||||||
|
|
||||||
|
time.Sleep(delay)
|
||||||
|
conn1.setConnect()
|
||||||
|
assert.True(t, waitSigChan(conn1.ConnectedC(), timeout))
|
||||||
|
conn1.setAuthed()
|
||||||
|
assert.True(t, waitSigChan(conn1.AuthedC(), timeout))
|
||||||
|
|
||||||
|
time.Sleep(delay)
|
||||||
|
conn2.setConnect()
|
||||||
|
assert.True(t, waitSigChan(conn2.ConnectedC(), timeout))
|
||||||
|
|
||||||
|
conn2.setAuthed()
|
||||||
|
assert.True(t, waitSigChan(conn2.AuthedC(), timeout))
|
||||||
|
|
||||||
|
assert.True(t, waitSigChan(allAuthedC, timeout))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnectivityGroup(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
_ = ctx
|
||||||
|
|
||||||
|
t.Run("connected", func(t *testing.T) {
|
||||||
|
conn1 := NewConnectivity()
|
||||||
|
conn1.setConnect()
|
||||||
|
|
||||||
|
conn2 := NewConnectivity()
|
||||||
|
conn2.setConnect()
|
||||||
|
|
||||||
|
group := NewConnectivityGroup(conn1, conn2)
|
||||||
|
assert.Equal(t, ConnectivityStateConnected, group.sumState)
|
||||||
|
assert.True(t, group.IsConnected())
|
||||||
|
assert.False(t, group.IsAuthed())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("only one connected", func(t *testing.T) {
|
||||||
|
conn1 := NewConnectivity()
|
||||||
|
conn1.setConnect()
|
||||||
|
|
||||||
|
conn2 := NewConnectivity()
|
||||||
|
conn2.setDisconnect()
|
||||||
|
|
||||||
|
group := NewConnectivityGroup(conn1, conn2)
|
||||||
|
assert.Equal(t, ConnectivityStateDisconnected, group.sumState)
|
||||||
|
assert.False(t, group.IsConnected())
|
||||||
|
assert.False(t, group.IsAuthed())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("only one authed", func(t *testing.T) {
|
||||||
|
conn1 := NewConnectivity()
|
||||||
|
conn1.setConnect()
|
||||||
|
|
||||||
|
conn2 := NewConnectivity()
|
||||||
|
conn2.setAuthed()
|
||||||
|
|
||||||
|
group := NewConnectivityGroup(conn1, conn2)
|
||||||
|
assert.Equal(t, ConnectivityStateConnected, group.sumState)
|
||||||
|
assert.True(t, group.IsConnected())
|
||||||
|
assert.False(t, group.IsAuthed())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("all authed", func(t *testing.T) {
|
||||||
|
conn1 := NewConnectivity()
|
||||||
|
conn1.setAuthed()
|
||||||
|
|
||||||
|
conn2 := NewConnectivity()
|
||||||
|
conn2.setAuthed()
|
||||||
|
|
||||||
|
group := NewConnectivityGroup(conn1, conn2)
|
||||||
|
assert.Equal(t, ConnectivityStateAuthed, group.sumState)
|
||||||
|
assert.True(t, group.IsConnected())
|
||||||
|
assert.True(t, group.IsAuthed())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("reconnect", func(t *testing.T) {
|
||||||
|
conn1 := NewConnectivity()
|
||||||
|
conn1.setDisconnect()
|
||||||
|
conn2 := NewConnectivity()
|
||||||
|
conn2.setDisconnect()
|
||||||
|
|
||||||
|
// callbackState is used to test the callback state
|
||||||
|
callbackState := ConnectivityStateDisconnected
|
||||||
|
|
||||||
|
group := NewConnectivityGroup(conn1, conn2)
|
||||||
|
group.OnConnect(func() {
|
||||||
|
callbackState = ConnectivityStateConnected
|
||||||
|
})
|
||||||
|
group.OnAuth(func() {
|
||||||
|
callbackState = ConnectivityStateAuthed
|
||||||
|
})
|
||||||
|
group.OnDisconnect(func() {
|
||||||
|
callbackState = ConnectivityStateDisconnected
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Equal(t, ConnectivityStateDisconnected, group.sumState)
|
||||||
|
assert.False(t, group.IsConnected())
|
||||||
|
assert.False(t, group.IsAuthed())
|
||||||
|
assert.Equal(t, ConnectivityStateDisconnected, callbackState)
|
||||||
|
|
||||||
|
t.Log("conn1 connected")
|
||||||
|
conn1.setConnect()
|
||||||
|
assert.Equal(t, ConnectivityStateDisconnected, group.sumState)
|
||||||
|
assert.False(t, group.IsConnected())
|
||||||
|
assert.False(t, group.IsAuthed())
|
||||||
|
assert.Equal(t, ConnectivityStateDisconnected, callbackState)
|
||||||
|
|
||||||
|
t.Log("conn2 connected")
|
||||||
|
conn2.setConnect()
|
||||||
|
assert.Equal(t, ConnectivityStateConnected, group.sumState)
|
||||||
|
assert.True(t, group.IsConnected())
|
||||||
|
assert.False(t, group.IsAuthed())
|
||||||
|
assert.Equal(t, ConnectivityStateConnected, callbackState)
|
||||||
|
|
||||||
|
t.Log("conn1 and conn2 authed")
|
||||||
|
conn1.setAuthed()
|
||||||
|
conn2.setAuthed()
|
||||||
|
|
||||||
|
assert.Equal(t, ConnectivityStateAuthed, group.sumState)
|
||||||
|
assert.True(t, group.IsConnected())
|
||||||
|
assert.True(t, group.IsAuthed())
|
||||||
|
assert.Equal(t, ConnectivityStateAuthed, callbackState)
|
||||||
|
|
||||||
|
t.Log("one connection gets disconnected should fallback to disconnected state")
|
||||||
|
conn2.setDisconnect()
|
||||||
|
|
||||||
|
assert.Equal(t, ConnectivityStateDisconnected, group.sumState)
|
||||||
|
assert.False(t, group.IsConnected())
|
||||||
|
assert.False(t, group.IsAuthed())
|
||||||
|
assert.Equal(t, ConnectivityStateDisconnected, callbackState)
|
||||||
|
|
||||||
|
t.Log("all connections get disconnected should fallback to disconnected state")
|
||||||
|
conn1.setDisconnect()
|
||||||
|
conn2.setDisconnect()
|
||||||
|
|
||||||
|
assert.Equal(t, ConnectivityStateDisconnected, group.sumState)
|
||||||
|
assert.False(t, group.IsConnected())
|
||||||
|
assert.False(t, group.IsAuthed())
|
||||||
|
assert.Equal(t, ConnectivityStateDisconnected, callbackState)
|
||||||
|
|
||||||
|
t.Log("all connections are connected again")
|
||||||
|
conn1.setConnect()
|
||||||
|
conn2.setConnect()
|
||||||
|
|
||||||
|
assert.Equal(t, ConnectivityStateConnected, group.sumState)
|
||||||
|
assert.True(t, group.IsConnected())
|
||||||
|
assert.False(t, group.IsAuthed())
|
||||||
|
assert.Equal(t, ConnectivityStateConnected, callbackState)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_sumStates(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
states map[*Connectivity]ConnectivityState
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want ConnectivityState
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "all connected",
|
||||||
|
args: args{
|
||||||
|
states: map[*Connectivity]ConnectivityState{
|
||||||
|
NewConnectivity(): ConnectivityStateConnected,
|
||||||
|
NewConnectivity(): ConnectivityStateConnected,
|
||||||
|
NewConnectivity(): ConnectivityStateConnected,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: ConnectivityStateConnected,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only one connected should return disconnected",
|
||||||
|
args: args{
|
||||||
|
states: map[*Connectivity]ConnectivityState{
|
||||||
|
NewConnectivity(): ConnectivityStateConnected,
|
||||||
|
NewConnectivity(): ConnectivityStateDisconnected,
|
||||||
|
NewConnectivity(): ConnectivityStateDisconnected,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: ConnectivityStateDisconnected,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only one connected and others authed should return connected",
|
||||||
|
args: args{
|
||||||
|
states: map[*Connectivity]ConnectivityState{
|
||||||
|
NewConnectivity(): ConnectivityStateConnected,
|
||||||
|
NewConnectivity(): ConnectivityStateAuthed,
|
||||||
|
NewConnectivity(): ConnectivityStateAuthed,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: ConnectivityStateConnected,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "all disconnected should return disconnected",
|
||||||
|
args: args{
|
||||||
|
states: map[*Connectivity]ConnectivityState{
|
||||||
|
NewConnectivity(): ConnectivityStateDisconnected,
|
||||||
|
NewConnectivity(): ConnectivityStateDisconnected,
|
||||||
|
NewConnectivity(): ConnectivityStateDisconnected,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: ConnectivityStateDisconnected,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "one disconnected should return disconnected",
|
||||||
|
args: args{
|
||||||
|
states: map[*Connectivity]ConnectivityState{
|
||||||
|
NewConnectivity(): ConnectivityStateConnected,
|
||||||
|
NewConnectivity(): ConnectivityStateConnected,
|
||||||
|
NewConnectivity(): ConnectivityStateDisconnected,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: ConnectivityStateDisconnected,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
assert.Equalf(t, tt.want, sumStates(tt.args.states), "sumStates(%v)", tt.args.states)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user