mirror of
https://github.com/c9s/bbgo.git
synced 2024-11-15 11:33:50 +00:00
types: refactor, redesign connectivity
This commit is contained in:
parent
1eb0c1b092
commit
ebab50f30c
|
@ -1,96 +1,9 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ConnectivityGroup struct {
|
||||
connections []*Connectivity
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewConnectivityGroup(cons ...*Connectivity) *ConnectivityGroup {
|
||||
return &ConnectivityGroup{
|
||||
connections: cons,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *ConnectivityGroup) Add(con *Connectivity) {
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
|
||||
g.connections = append(g.connections, con)
|
||||
}
|
||||
|
||||
func (g *ConnectivityGroup) AnyDisconnected(ctx context.Context) bool {
|
||||
g.mu.Lock()
|
||||
conns := g.connections
|
||||
g.mu.Unlock()
|
||||
|
||||
for _, conn := range conns {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
|
||||
case <-conn.connectedC:
|
||||
continue
|
||||
|
||||
case <-conn.disconnectedC:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (g *ConnectivityGroup) waitAllAuthed(ctx context.Context, c chan struct{}, allTimeoutDuration time.Duration) {
|
||||
g.mu.Lock()
|
||||
conns := g.connections
|
||||
g.mu.Unlock()
|
||||
|
||||
authedConns := make([]bool, len(conns))
|
||||
allTimeout := time.After(allTimeoutDuration)
|
||||
for {
|
||||
for idx, con := range conns {
|
||||
// if the connection is not authed, mark it as false
|
||||
if !con.authed {
|
||||
// authedConns[idx] = false
|
||||
}
|
||||
|
||||
timeout := time.After(3 * time.Second)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
|
||||
case <-allTimeout:
|
||||
return
|
||||
|
||||
case <-timeout:
|
||||
continue
|
||||
|
||||
case <-con.AuthedC():
|
||||
authedConns[idx] = true
|
||||
}
|
||||
}
|
||||
|
||||
if allTrue(authedConns) {
|
||||
close(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AllAuthedC returns a channel that will be closed when all connections are authenticated
|
||||
// the returned channel will be closed when all connections are authenticated
|
||||
// and the channel can only be used once (because we can't close a channel twice)
|
||||
func (g *ConnectivityGroup) AllAuthedC(ctx context.Context, timeout time.Duration) <-chan struct{} {
|
||||
c := make(chan struct{})
|
||||
go g.waitAllAuthed(ctx, c, timeout)
|
||||
return c
|
||||
}
|
||||
|
||||
func allTrue(bools []bool) bool {
|
||||
for _, b := range bools {
|
||||
if !b {
|
||||
|
@ -101,6 +14,7 @@ func allTrue(bools []bool) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
//go:generate callbackgen -type Connectivity
|
||||
type Connectivity struct {
|
||||
authed bool
|
||||
authedC chan struct{}
|
||||
|
@ -109,7 +23,12 @@ type Connectivity struct {
|
|||
connectedC chan struct{}
|
||||
disconnectedC chan struct{}
|
||||
|
||||
mu sync.Mutex
|
||||
connectCallbacks []func()
|
||||
disconnectCallbacks []func()
|
||||
authCallbacks []func()
|
||||
|
||||
stream Stream
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewConnectivity() *Connectivity {
|
||||
|
@ -141,31 +60,39 @@ func (c *Connectivity) IsAuthed() (authed bool) {
|
|||
return authed
|
||||
}
|
||||
|
||||
func (c *Connectivity) handleConnect() {
|
||||
func (c *Connectivity) setConnect() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.connected = true
|
||||
close(c.connectedC)
|
||||
c.disconnectedC = make(chan struct{})
|
||||
if !c.connected {
|
||||
c.connected = true
|
||||
close(c.connectedC)
|
||||
c.disconnectedC = make(chan struct{})
|
||||
}
|
||||
c.mu.Unlock()
|
||||
c.EmitConnect()
|
||||
}
|
||||
|
||||
func (c *Connectivity) handleDisconnect() {
|
||||
func (c *Connectivity) setDisconnect() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.connected = false
|
||||
c.authedC = make(chan struct{})
|
||||
c.connectedC = make(chan struct{})
|
||||
close(c.disconnectedC)
|
||||
if c.connected {
|
||||
c.connected = false
|
||||
c.authed = false
|
||||
c.authedC = make(chan struct{})
|
||||
c.connectedC = make(chan struct{})
|
||||
close(c.disconnectedC)
|
||||
}
|
||||
c.mu.Unlock()
|
||||
c.EmitDisconnect()
|
||||
}
|
||||
|
||||
func (c *Connectivity) handleAuth() {
|
||||
func (c *Connectivity) setAuthed() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if !c.authed {
|
||||
c.authed = true
|
||||
close(c.authedC)
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
c.authed = true
|
||||
close(c.authedC)
|
||||
c.EmitAuth()
|
||||
}
|
||||
|
||||
func (c *Connectivity) AuthedC() chan struct{} {
|
||||
|
@ -187,7 +114,8 @@ func (c *Connectivity) DisconnectedC() chan struct{} {
|
|||
}
|
||||
|
||||
func (c *Connectivity) Bind(stream Stream) {
|
||||
stream.OnConnect(c.handleConnect)
|
||||
stream.OnDisconnect(c.handleDisconnect)
|
||||
stream.OnAuth(c.handleAuth)
|
||||
stream.OnConnect(c.setConnect)
|
||||
stream.OnDisconnect(c.setDisconnect)
|
||||
stream.OnAuth(c.setAuthed)
|
||||
c.stream = stream
|
||||
}
|
||||
|
|
35
pkg/types/connectivity_callbacks.go
Normal file
35
pkg/types/connectivity_callbacks.go
Normal file
|
@ -0,0 +1,35 @@
|
|||
// Code generated by "callbackgen -type Connectivity"; DO NOT EDIT.
|
||||
|
||||
package types
|
||||
|
||||
import ()
|
||||
|
||||
func (c *Connectivity) OnConnect(cb func()) {
|
||||
c.connectCallbacks = append(c.connectCallbacks, cb)
|
||||
}
|
||||
|
||||
func (c *Connectivity) EmitConnect() {
|
||||
for _, cb := range c.connectCallbacks {
|
||||
cb()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connectivity) OnDisconnect(cb func()) {
|
||||
c.disconnectCallbacks = append(c.disconnectCallbacks, cb)
|
||||
}
|
||||
|
||||
func (c *Connectivity) EmitDisconnect() {
|
||||
for _, cb := range c.disconnectCallbacks {
|
||||
cb()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connectivity) OnAuth(cb func()) {
|
||||
c.authCallbacks = append(c.authCallbacks, cb)
|
||||
}
|
||||
|
||||
func (c *Connectivity) EmitAuth() {
|
||||
for _, cb := range c.authCallbacks {
|
||||
cb()
|
||||
}
|
||||
}
|
|
@ -1,105 +1,39 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConnectivity(t *testing.T) {
|
||||
t.Run("general", func(t *testing.T) {
|
||||
conn1 := NewConnectivity()
|
||||
conn1.handleConnect()
|
||||
conn1.handleAuth()
|
||||
conn1.handleDisconnect()
|
||||
conn1.setConnect()
|
||||
conn1.setAuthed()
|
||||
conn1.setDisconnect()
|
||||
})
|
||||
|
||||
t.Run("reconnect", func(t *testing.T) {
|
||||
conn1 := NewConnectivity()
|
||||
conn1.handleConnect()
|
||||
conn1.handleAuth()
|
||||
conn1.handleDisconnect()
|
||||
conn1.setConnect()
|
||||
conn1.setAuthed()
|
||||
conn1.setDisconnect()
|
||||
|
||||
conn1.handleConnect()
|
||||
conn1.handleAuth()
|
||||
conn1.handleDisconnect()
|
||||
conn1.setConnect()
|
||||
conn1.setAuthed()
|
||||
conn1.setDisconnect()
|
||||
})
|
||||
|
||||
t.Run("no-auth reconnect", func(t *testing.T) {
|
||||
conn1 := NewConnectivity()
|
||||
conn1.handleConnect()
|
||||
conn1.handleDisconnect()
|
||||
conn1.setConnect()
|
||||
conn1.setDisconnect()
|
||||
|
||||
conn1.handleConnect()
|
||||
conn1.handleDisconnect()
|
||||
conn1.setConnect()
|
||||
conn1.setDisconnect()
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnectivityGroupAuthC(t *testing.T) {
|
||||
timeout := 100 * time.Millisecond
|
||||
delay := timeout * 2
|
||||
|
||||
ctx := context.Background()
|
||||
conn1 := NewConnectivity()
|
||||
conn2 := NewConnectivity()
|
||||
group := NewConnectivityGroup(conn1, conn2)
|
||||
allAuthedC := group.AllAuthedC(ctx, time.Second)
|
||||
|
||||
time.Sleep(delay)
|
||||
conn1.handleConnect()
|
||||
assert.True(t, waitSigChan(conn1.ConnectedC(), timeout))
|
||||
conn1.handleAuth()
|
||||
assert.True(t, waitSigChan(conn1.AuthedC(), timeout))
|
||||
|
||||
time.Sleep(delay)
|
||||
conn2.handleConnect()
|
||||
assert.True(t, waitSigChan(conn2.ConnectedC(), timeout))
|
||||
|
||||
conn2.handleAuth()
|
||||
assert.True(t, waitSigChan(conn2.AuthedC(), timeout))
|
||||
|
||||
assert.True(t, waitSigChan(allAuthedC, timeout))
|
||||
}
|
||||
|
||||
func TestConnectivityGroupReconnect(t *testing.T) {
|
||||
timeout := 100 * time.Millisecond
|
||||
delay := timeout * 2
|
||||
|
||||
ctx := context.Background()
|
||||
conn1 := NewConnectivity()
|
||||
conn2 := NewConnectivity()
|
||||
group := NewConnectivityGroup(conn1, conn2)
|
||||
|
||||
time.Sleep(delay)
|
||||
conn1.handleConnect()
|
||||
conn1.handleAuth()
|
||||
conn1authC := conn1.authedC
|
||||
|
||||
time.Sleep(delay)
|
||||
conn2.handleConnect()
|
||||
conn2.handleAuth()
|
||||
|
||||
assert.True(t, waitSigChan(group.AllAuthedC(ctx, time.Second), timeout), "all connections are authenticated")
|
||||
|
||||
assert.False(t, group.AnyDisconnected(ctx))
|
||||
|
||||
// this should re-allocate authedC
|
||||
conn1.handleDisconnect()
|
||||
assert.NotEqual(t, conn1authC, conn1.authedC)
|
||||
|
||||
assert.True(t, group.AnyDisconnected(ctx))
|
||||
|
||||
assert.False(t, waitSigChan(group.AllAuthedC(ctx, time.Second), timeout), "one connection should be un-authed")
|
||||
|
||||
time.Sleep(delay)
|
||||
|
||||
conn1.handleConnect()
|
||||
conn1.handleAuth()
|
||||
assert.True(t, waitSigChan(group.AllAuthedC(ctx, time.Second), timeout), "all connections are authenticated, again")
|
||||
}
|
||||
|
||||
func waitSigChan(c <-chan struct{}, timeoutDuration time.Duration) bool {
|
||||
select {
|
||||
case <-time.After(timeoutDuration):
|
||||
|
|
237
pkg/types/connectivitygroup.go
Normal file
237
pkg/types/connectivitygroup.go
Normal file
|
@ -0,0 +1,237 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
//go:generate callbackgen -type ConnectivityGroup
|
||||
type ConnectivityGroup struct {
|
||||
*Connectivity
|
||||
|
||||
connections []*Connectivity
|
||||
mu sync.Mutex
|
||||
authedC chan struct{}
|
||||
|
||||
states map[*Connectivity]ConnectivityState
|
||||
sumState ConnectivityState
|
||||
|
||||
connectCallbacks []func()
|
||||
disconnectCallbacks []func()
|
||||
authCallbacks []func()
|
||||
}
|
||||
|
||||
type ConnectivityState int
|
||||
|
||||
const (
|
||||
ConnectivityStateDisconnected ConnectivityState = -1
|
||||
ConnectivityStateUnknown ConnectivityState = 0
|
||||
ConnectivityStateConnected ConnectivityState = 1
|
||||
ConnectivityStateAuthed ConnectivityState = 2
|
||||
)
|
||||
|
||||
func getConnState(con *Connectivity) ConnectivityState {
|
||||
state := ConnectivityStateDisconnected
|
||||
|
||||
if con.IsAuthed() {
|
||||
state = ConnectivityStateAuthed
|
||||
} else if con.IsConnected() {
|
||||
state = ConnectivityStateConnected
|
||||
}
|
||||
|
||||
return state
|
||||
}
|
||||
|
||||
func initConnStates(cons []*Connectivity) map[*Connectivity]ConnectivityState {
|
||||
states := map[*Connectivity]ConnectivityState{}
|
||||
|
||||
for _, con := range cons {
|
||||
state := ConnectivityStateDisconnected
|
||||
|
||||
if con.IsAuthed() {
|
||||
state = ConnectivityStateAuthed
|
||||
} else if con.IsConnected() {
|
||||
state = ConnectivityStateConnected
|
||||
}
|
||||
|
||||
states[con] = state
|
||||
}
|
||||
|
||||
return states
|
||||
}
|
||||
|
||||
func sumStates(states map[*Connectivity]ConnectivityState) ConnectivityState {
|
||||
disconnected := 0
|
||||
connected := 0
|
||||
authed := 0
|
||||
|
||||
for _, state := range states {
|
||||
switch state {
|
||||
case ConnectivityStateDisconnected:
|
||||
disconnected++
|
||||
case ConnectivityStateConnected:
|
||||
connected++
|
||||
case ConnectivityStateAuthed:
|
||||
authed++
|
||||
}
|
||||
}
|
||||
|
||||
numConn := len(states)
|
||||
|
||||
// if one of the connections is disconnected, the group is disconnected
|
||||
if disconnected > 0 {
|
||||
return ConnectivityStateDisconnected
|
||||
} else if authed == numConn {
|
||||
// if all connections are authed, the group is authed
|
||||
return ConnectivityStateAuthed
|
||||
} else if connected == numConn || (connected+authed) == numConn {
|
||||
// if all connections are connected, the group is connected
|
||||
return ConnectivityStateConnected
|
||||
}
|
||||
|
||||
return ConnectivityStateUnknown
|
||||
}
|
||||
|
||||
func NewConnectivityGroup(cons ...*Connectivity) *ConnectivityGroup {
|
||||
states := initConnStates(cons)
|
||||
sumState := sumStates(states)
|
||||
g := &ConnectivityGroup{
|
||||
Connectivity: NewConnectivity(),
|
||||
connections: cons,
|
||||
authedC: make(chan struct{}),
|
||||
states: states,
|
||||
sumState: sumState,
|
||||
}
|
||||
|
||||
for _, con := range cons {
|
||||
g.Add(con)
|
||||
}
|
||||
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *ConnectivityGroup) setState(con *Connectivity, state ConnectivityState) {
|
||||
g.mu.Lock()
|
||||
g.states[con] = state
|
||||
prevState := g.sumState
|
||||
curState := sumStates(g.states)
|
||||
g.sumState = curState
|
||||
g.mu.Unlock()
|
||||
|
||||
// if the state is not changed, return
|
||||
if prevState == curState {
|
||||
return
|
||||
}
|
||||
|
||||
g.setFlags(curState)
|
||||
switch curState {
|
||||
case ConnectivityStateAuthed:
|
||||
g.EmitAuth()
|
||||
case ConnectivityStateConnected:
|
||||
g.EmitConnect()
|
||||
case ConnectivityStateDisconnected:
|
||||
g.EmitDisconnect()
|
||||
}
|
||||
}
|
||||
|
||||
func (g *ConnectivityGroup) setFlags(state ConnectivityState) {
|
||||
switch state {
|
||||
case ConnectivityStateAuthed:
|
||||
g.Connectivity.setConnect()
|
||||
g.Connectivity.setAuthed()
|
||||
case ConnectivityStateConnected:
|
||||
g.Connectivity.setConnect()
|
||||
case ConnectivityStateDisconnected:
|
||||
g.Connectivity.setDisconnect()
|
||||
}
|
||||
}
|
||||
|
||||
func (g *ConnectivityGroup) Add(con *Connectivity) {
|
||||
g.mu.Lock()
|
||||
g.connections = append(g.connections, con)
|
||||
g.states[con] = getConnState(con)
|
||||
g.sumState = sumStates(g.states)
|
||||
g.setFlags(g.sumState)
|
||||
g.mu.Unlock()
|
||||
|
||||
_con := con
|
||||
con.OnDisconnect(func() {
|
||||
g.setState(_con, ConnectivityStateDisconnected)
|
||||
})
|
||||
|
||||
con.OnConnect(func() {
|
||||
g.setState(_con, ConnectivityStateConnected)
|
||||
})
|
||||
|
||||
con.OnAuth(func() {
|
||||
g.setState(_con, ConnectivityStateAuthed)
|
||||
})
|
||||
}
|
||||
|
||||
func (g *ConnectivityGroup) AnyDisconnected(ctx context.Context) bool {
|
||||
g.mu.Lock()
|
||||
conns := g.connections
|
||||
g.mu.Unlock()
|
||||
|
||||
for _, conn := range conns {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
|
||||
case <-conn.connectedC:
|
||||
continue
|
||||
|
||||
case <-conn.disconnectedC:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (g *ConnectivityGroup) waitAllAuthed(ctx context.Context, allTimeoutDuration time.Duration) {
|
||||
g.mu.Lock()
|
||||
conns := g.connections
|
||||
c := g.authedC
|
||||
g.mu.Unlock()
|
||||
|
||||
authedConns := make([]bool, len(conns))
|
||||
allTimeout := time.After(allTimeoutDuration)
|
||||
for {
|
||||
for idx, con := range conns {
|
||||
// if the connection is not authed, mark it as false
|
||||
if !con.authed {
|
||||
// authedConns[idx] = false
|
||||
}
|
||||
|
||||
timeout := time.After(3 * time.Second)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
|
||||
case <-allTimeout:
|
||||
return
|
||||
|
||||
case <-timeout:
|
||||
continue
|
||||
|
||||
case <-con.AuthedC():
|
||||
authedConns[idx] = true
|
||||
}
|
||||
}
|
||||
|
||||
if allTrue(authedConns) {
|
||||
close(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AllAuthedC returns a channel that will be closed when all connections are authenticated
|
||||
// the returned channel will be closed when all connections are authenticated
|
||||
// and the channel can only be used once (because we can't close a channel twice)
|
||||
func (g *ConnectivityGroup) AllAuthedC(ctx context.Context, timeout time.Duration) <-chan struct{} {
|
||||
go g.waitAllAuthed(ctx, timeout)
|
||||
return g.authedC
|
||||
}
|
35
pkg/types/connectivitygroup_callbacks.go
Normal file
35
pkg/types/connectivitygroup_callbacks.go
Normal file
|
@ -0,0 +1,35 @@
|
|||
// Code generated by "callbackgen -type ConnectivityGroup"; DO NOT EDIT.
|
||||
|
||||
package types
|
||||
|
||||
import ()
|
||||
|
||||
func (g *ConnectivityGroup) OnConnect(cb func()) {
|
||||
g.connectCallbacks = append(g.connectCallbacks, cb)
|
||||
}
|
||||
|
||||
func (g *ConnectivityGroup) EmitConnect() {
|
||||
for _, cb := range g.connectCallbacks {
|
||||
cb()
|
||||
}
|
||||
}
|
||||
|
||||
func (g *ConnectivityGroup) OnDisconnect(cb func()) {
|
||||
g.disconnectCallbacks = append(g.disconnectCallbacks, cb)
|
||||
}
|
||||
|
||||
func (g *ConnectivityGroup) EmitDisconnect() {
|
||||
for _, cb := range g.disconnectCallbacks {
|
||||
cb()
|
||||
}
|
||||
}
|
||||
|
||||
func (g *ConnectivityGroup) OnAuth(cb func()) {
|
||||
g.authCallbacks = append(g.authCallbacks, cb)
|
||||
}
|
||||
|
||||
func (g *ConnectivityGroup) EmitAuth() {
|
||||
for _, cb := range g.authCallbacks {
|
||||
cb()
|
||||
}
|
||||
}
|
239
pkg/types/connectivitygroup_test.go
Normal file
239
pkg/types/connectivitygroup_test.go
Normal file
|
@ -0,0 +1,239 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConnectivityGroupAuthC(t *testing.T) {
|
||||
timeout := 100 * time.Millisecond
|
||||
delay := timeout * 2
|
||||
|
||||
ctx := context.Background()
|
||||
conn1 := NewConnectivity()
|
||||
conn2 := NewConnectivity()
|
||||
group := NewConnectivityGroup(conn1, conn2)
|
||||
allAuthedC := group.AllAuthedC(ctx, time.Second)
|
||||
|
||||
time.Sleep(delay)
|
||||
conn1.setConnect()
|
||||
assert.True(t, waitSigChan(conn1.ConnectedC(), timeout))
|
||||
conn1.setAuthed()
|
||||
assert.True(t, waitSigChan(conn1.AuthedC(), timeout))
|
||||
|
||||
time.Sleep(delay)
|
||||
conn2.setConnect()
|
||||
assert.True(t, waitSigChan(conn2.ConnectedC(), timeout))
|
||||
|
||||
conn2.setAuthed()
|
||||
assert.True(t, waitSigChan(conn2.AuthedC(), timeout))
|
||||
|
||||
assert.True(t, waitSigChan(allAuthedC, timeout))
|
||||
}
|
||||
|
||||
func TestConnectivityGroup(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
_ = ctx
|
||||
|
||||
t.Run("connected", func(t *testing.T) {
|
||||
conn1 := NewConnectivity()
|
||||
conn1.setConnect()
|
||||
|
||||
conn2 := NewConnectivity()
|
||||
conn2.setConnect()
|
||||
|
||||
group := NewConnectivityGroup(conn1, conn2)
|
||||
assert.Equal(t, ConnectivityStateConnected, group.sumState)
|
||||
assert.True(t, group.IsConnected())
|
||||
assert.False(t, group.IsAuthed())
|
||||
})
|
||||
|
||||
t.Run("only one connected", func(t *testing.T) {
|
||||
conn1 := NewConnectivity()
|
||||
conn1.setConnect()
|
||||
|
||||
conn2 := NewConnectivity()
|
||||
conn2.setDisconnect()
|
||||
|
||||
group := NewConnectivityGroup(conn1, conn2)
|
||||
assert.Equal(t, ConnectivityStateDisconnected, group.sumState)
|
||||
assert.False(t, group.IsConnected())
|
||||
assert.False(t, group.IsAuthed())
|
||||
})
|
||||
|
||||
t.Run("only one authed", func(t *testing.T) {
|
||||
conn1 := NewConnectivity()
|
||||
conn1.setConnect()
|
||||
|
||||
conn2 := NewConnectivity()
|
||||
conn2.setAuthed()
|
||||
|
||||
group := NewConnectivityGroup(conn1, conn2)
|
||||
assert.Equal(t, ConnectivityStateConnected, group.sumState)
|
||||
assert.True(t, group.IsConnected())
|
||||
assert.False(t, group.IsAuthed())
|
||||
})
|
||||
|
||||
t.Run("all authed", func(t *testing.T) {
|
||||
conn1 := NewConnectivity()
|
||||
conn1.setAuthed()
|
||||
|
||||
conn2 := NewConnectivity()
|
||||
conn2.setAuthed()
|
||||
|
||||
group := NewConnectivityGroup(conn1, conn2)
|
||||
assert.Equal(t, ConnectivityStateAuthed, group.sumState)
|
||||
assert.True(t, group.IsConnected())
|
||||
assert.True(t, group.IsAuthed())
|
||||
})
|
||||
|
||||
t.Run("reconnect", func(t *testing.T) {
|
||||
conn1 := NewConnectivity()
|
||||
conn1.setDisconnect()
|
||||
conn2 := NewConnectivity()
|
||||
conn2.setDisconnect()
|
||||
|
||||
// callbackState is used to test the callback state
|
||||
callbackState := ConnectivityStateDisconnected
|
||||
|
||||
group := NewConnectivityGroup(conn1, conn2)
|
||||
group.OnConnect(func() {
|
||||
callbackState = ConnectivityStateConnected
|
||||
})
|
||||
group.OnAuth(func() {
|
||||
callbackState = ConnectivityStateAuthed
|
||||
})
|
||||
group.OnDisconnect(func() {
|
||||
callbackState = ConnectivityStateDisconnected
|
||||
})
|
||||
|
||||
assert.Equal(t, ConnectivityStateDisconnected, group.sumState)
|
||||
assert.False(t, group.IsConnected())
|
||||
assert.False(t, group.IsAuthed())
|
||||
assert.Equal(t, ConnectivityStateDisconnected, callbackState)
|
||||
|
||||
t.Log("conn1 connected")
|
||||
conn1.setConnect()
|
||||
assert.Equal(t, ConnectivityStateDisconnected, group.sumState)
|
||||
assert.False(t, group.IsConnected())
|
||||
assert.False(t, group.IsAuthed())
|
||||
assert.Equal(t, ConnectivityStateDisconnected, callbackState)
|
||||
|
||||
t.Log("conn2 connected")
|
||||
conn2.setConnect()
|
||||
assert.Equal(t, ConnectivityStateConnected, group.sumState)
|
||||
assert.True(t, group.IsConnected())
|
||||
assert.False(t, group.IsAuthed())
|
||||
assert.Equal(t, ConnectivityStateConnected, callbackState)
|
||||
|
||||
t.Log("conn1 and conn2 authed")
|
||||
conn1.setAuthed()
|
||||
conn2.setAuthed()
|
||||
|
||||
assert.Equal(t, ConnectivityStateAuthed, group.sumState)
|
||||
assert.True(t, group.IsConnected())
|
||||
assert.True(t, group.IsAuthed())
|
||||
assert.Equal(t, ConnectivityStateAuthed, callbackState)
|
||||
|
||||
t.Log("one connection gets disconnected should fallback to disconnected state")
|
||||
conn2.setDisconnect()
|
||||
|
||||
assert.Equal(t, ConnectivityStateDisconnected, group.sumState)
|
||||
assert.False(t, group.IsConnected())
|
||||
assert.False(t, group.IsAuthed())
|
||||
assert.Equal(t, ConnectivityStateDisconnected, callbackState)
|
||||
|
||||
t.Log("all connections get disconnected should fallback to disconnected state")
|
||||
conn1.setDisconnect()
|
||||
conn2.setDisconnect()
|
||||
|
||||
assert.Equal(t, ConnectivityStateDisconnected, group.sumState)
|
||||
assert.False(t, group.IsConnected())
|
||||
assert.False(t, group.IsAuthed())
|
||||
assert.Equal(t, ConnectivityStateDisconnected, callbackState)
|
||||
|
||||
t.Log("all connections are connected again")
|
||||
conn1.setConnect()
|
||||
conn2.setConnect()
|
||||
|
||||
assert.Equal(t, ConnectivityStateConnected, group.sumState)
|
||||
assert.True(t, group.IsConnected())
|
||||
assert.False(t, group.IsAuthed())
|
||||
assert.Equal(t, ConnectivityStateConnected, callbackState)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_sumStates(t *testing.T) {
|
||||
type args struct {
|
||||
states map[*Connectivity]ConnectivityState
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want ConnectivityState
|
||||
}{
|
||||
{
|
||||
name: "all connected",
|
||||
args: args{
|
||||
states: map[*Connectivity]ConnectivityState{
|
||||
NewConnectivity(): ConnectivityStateConnected,
|
||||
NewConnectivity(): ConnectivityStateConnected,
|
||||
NewConnectivity(): ConnectivityStateConnected,
|
||||
},
|
||||
},
|
||||
want: ConnectivityStateConnected,
|
||||
},
|
||||
{
|
||||
name: "only one connected should return disconnected",
|
||||
args: args{
|
||||
states: map[*Connectivity]ConnectivityState{
|
||||
NewConnectivity(): ConnectivityStateConnected,
|
||||
NewConnectivity(): ConnectivityStateDisconnected,
|
||||
NewConnectivity(): ConnectivityStateDisconnected,
|
||||
},
|
||||
},
|
||||
want: ConnectivityStateDisconnected,
|
||||
},
|
||||
{
|
||||
name: "only one connected and others authed should return connected",
|
||||
args: args{
|
||||
states: map[*Connectivity]ConnectivityState{
|
||||
NewConnectivity(): ConnectivityStateConnected,
|
||||
NewConnectivity(): ConnectivityStateAuthed,
|
||||
NewConnectivity(): ConnectivityStateAuthed,
|
||||
},
|
||||
},
|
||||
want: ConnectivityStateConnected,
|
||||
},
|
||||
{
|
||||
name: "all disconnected should return disconnected",
|
||||
args: args{
|
||||
states: map[*Connectivity]ConnectivityState{
|
||||
NewConnectivity(): ConnectivityStateDisconnected,
|
||||
NewConnectivity(): ConnectivityStateDisconnected,
|
||||
NewConnectivity(): ConnectivityStateDisconnected,
|
||||
},
|
||||
},
|
||||
want: ConnectivityStateDisconnected,
|
||||
},
|
||||
{
|
||||
name: "one disconnected should return disconnected",
|
||||
args: args{
|
||||
states: map[*Connectivity]ConnectivityState{
|
||||
NewConnectivity(): ConnectivityStateConnected,
|
||||
NewConnectivity(): ConnectivityStateConnected,
|
||||
NewConnectivity(): ConnectivityStateDisconnected,
|
||||
},
|
||||
},
|
||||
want: ConnectivityStateDisconnected,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equalf(t, tt.want, sumStates(tt.args.states), "sumStates(%v)", tt.args.states)
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user