types: refactor, redesign connectivity

This commit is contained in:
c9s 2024-11-14 22:19:07 +08:00
parent 1eb0c1b092
commit ebab50f30c
No known key found for this signature in database
GPG Key ID: 7385E7E464CB0A54
6 changed files with 595 additions and 187 deletions

View File

@ -1,96 +1,9 @@
package types package types
import ( import (
"context"
"sync" "sync"
"time"
) )
type ConnectivityGroup struct {
connections []*Connectivity
mu sync.Mutex
}
func NewConnectivityGroup(cons ...*Connectivity) *ConnectivityGroup {
return &ConnectivityGroup{
connections: cons,
}
}
func (g *ConnectivityGroup) Add(con *Connectivity) {
g.mu.Lock()
defer g.mu.Unlock()
g.connections = append(g.connections, con)
}
func (g *ConnectivityGroup) AnyDisconnected(ctx context.Context) bool {
g.mu.Lock()
conns := g.connections
g.mu.Unlock()
for _, conn := range conns {
select {
case <-ctx.Done():
return false
case <-conn.connectedC:
continue
case <-conn.disconnectedC:
return true
}
}
return false
}
func (g *ConnectivityGroup) waitAllAuthed(ctx context.Context, c chan struct{}, allTimeoutDuration time.Duration) {
g.mu.Lock()
conns := g.connections
g.mu.Unlock()
authedConns := make([]bool, len(conns))
allTimeout := time.After(allTimeoutDuration)
for {
for idx, con := range conns {
// if the connection is not authed, mark it as false
if !con.authed {
// authedConns[idx] = false
}
timeout := time.After(3 * time.Second)
select {
case <-ctx.Done():
return
case <-allTimeout:
return
case <-timeout:
continue
case <-con.AuthedC():
authedConns[idx] = true
}
}
if allTrue(authedConns) {
close(c)
return
}
}
}
// AllAuthedC returns a channel that will be closed when all connections are authenticated
// the returned channel will be closed when all connections are authenticated
// and the channel can only be used once (because we can't close a channel twice)
func (g *ConnectivityGroup) AllAuthedC(ctx context.Context, timeout time.Duration) <-chan struct{} {
c := make(chan struct{})
go g.waitAllAuthed(ctx, c, timeout)
return c
}
func allTrue(bools []bool) bool { func allTrue(bools []bool) bool {
for _, b := range bools { for _, b := range bools {
if !b { if !b {
@ -101,6 +14,7 @@ func allTrue(bools []bool) bool {
return true return true
} }
//go:generate callbackgen -type Connectivity
type Connectivity struct { type Connectivity struct {
authed bool authed bool
authedC chan struct{} authedC chan struct{}
@ -109,6 +23,11 @@ type Connectivity struct {
connectedC chan struct{} connectedC chan struct{}
disconnectedC chan struct{} disconnectedC chan struct{}
connectCallbacks []func()
disconnectCallbacks []func()
authCallbacks []func()
stream Stream
mu sync.Mutex mu sync.Mutex
} }
@ -141,31 +60,39 @@ func (c *Connectivity) IsAuthed() (authed bool) {
return authed return authed
} }
func (c *Connectivity) handleConnect() { func (c *Connectivity) setConnect() {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() if !c.connected {
c.connected = true c.connected = true
close(c.connectedC) close(c.connectedC)
c.disconnectedC = make(chan struct{}) c.disconnectedC = make(chan struct{})
}
c.mu.Unlock()
c.EmitConnect()
} }
func (c *Connectivity) handleDisconnect() { func (c *Connectivity) setDisconnect() {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() if c.connected {
c.connected = false c.connected = false
c.authed = false
c.authedC = make(chan struct{}) c.authedC = make(chan struct{})
c.connectedC = make(chan struct{}) c.connectedC = make(chan struct{})
close(c.disconnectedC) close(c.disconnectedC)
}
c.mu.Unlock()
c.EmitDisconnect()
} }
func (c *Connectivity) handleAuth() { func (c *Connectivity) setAuthed() {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() if !c.authed {
c.authed = true c.authed = true
close(c.authedC) close(c.authedC)
}
c.mu.Unlock()
c.EmitAuth()
} }
func (c *Connectivity) AuthedC() chan struct{} { func (c *Connectivity) AuthedC() chan struct{} {
@ -187,7 +114,8 @@ func (c *Connectivity) DisconnectedC() chan struct{} {
} }
func (c *Connectivity) Bind(stream Stream) { func (c *Connectivity) Bind(stream Stream) {
stream.OnConnect(c.handleConnect) stream.OnConnect(c.setConnect)
stream.OnDisconnect(c.handleDisconnect) stream.OnDisconnect(c.setDisconnect)
stream.OnAuth(c.handleAuth) stream.OnAuth(c.setAuthed)
c.stream = stream
} }

View File

@ -0,0 +1,35 @@
// Code generated by "callbackgen -type Connectivity"; DO NOT EDIT.
package types
import ()
func (c *Connectivity) OnConnect(cb func()) {
c.connectCallbacks = append(c.connectCallbacks, cb)
}
func (c *Connectivity) EmitConnect() {
for _, cb := range c.connectCallbacks {
cb()
}
}
func (c *Connectivity) OnDisconnect(cb func()) {
c.disconnectCallbacks = append(c.disconnectCallbacks, cb)
}
func (c *Connectivity) EmitDisconnect() {
for _, cb := range c.disconnectCallbacks {
cb()
}
}
func (c *Connectivity) OnAuth(cb func()) {
c.authCallbacks = append(c.authCallbacks, cb)
}
func (c *Connectivity) EmitAuth() {
for _, cb := range c.authCallbacks {
cb()
}
}

View File

@ -1,105 +1,39 @@
package types package types
import ( import (
"context"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
) )
func TestConnectivity(t *testing.T) { func TestConnectivity(t *testing.T) {
t.Run("general", func(t *testing.T) { t.Run("general", func(t *testing.T) {
conn1 := NewConnectivity() conn1 := NewConnectivity()
conn1.handleConnect() conn1.setConnect()
conn1.handleAuth() conn1.setAuthed()
conn1.handleDisconnect() conn1.setDisconnect()
}) })
t.Run("reconnect", func(t *testing.T) { t.Run("reconnect", func(t *testing.T) {
conn1 := NewConnectivity() conn1 := NewConnectivity()
conn1.handleConnect() conn1.setConnect()
conn1.handleAuth() conn1.setAuthed()
conn1.handleDisconnect() conn1.setDisconnect()
conn1.handleConnect() conn1.setConnect()
conn1.handleAuth() conn1.setAuthed()
conn1.handleDisconnect() conn1.setDisconnect()
}) })
t.Run("no-auth reconnect", func(t *testing.T) { t.Run("no-auth reconnect", func(t *testing.T) {
conn1 := NewConnectivity() conn1 := NewConnectivity()
conn1.handleConnect() conn1.setConnect()
conn1.handleDisconnect() conn1.setDisconnect()
conn1.handleConnect() conn1.setConnect()
conn1.handleDisconnect() conn1.setDisconnect()
}) })
} }
func TestConnectivityGroupAuthC(t *testing.T) {
timeout := 100 * time.Millisecond
delay := timeout * 2
ctx := context.Background()
conn1 := NewConnectivity()
conn2 := NewConnectivity()
group := NewConnectivityGroup(conn1, conn2)
allAuthedC := group.AllAuthedC(ctx, time.Second)
time.Sleep(delay)
conn1.handleConnect()
assert.True(t, waitSigChan(conn1.ConnectedC(), timeout))
conn1.handleAuth()
assert.True(t, waitSigChan(conn1.AuthedC(), timeout))
time.Sleep(delay)
conn2.handleConnect()
assert.True(t, waitSigChan(conn2.ConnectedC(), timeout))
conn2.handleAuth()
assert.True(t, waitSigChan(conn2.AuthedC(), timeout))
assert.True(t, waitSigChan(allAuthedC, timeout))
}
func TestConnectivityGroupReconnect(t *testing.T) {
timeout := 100 * time.Millisecond
delay := timeout * 2
ctx := context.Background()
conn1 := NewConnectivity()
conn2 := NewConnectivity()
group := NewConnectivityGroup(conn1, conn2)
time.Sleep(delay)
conn1.handleConnect()
conn1.handleAuth()
conn1authC := conn1.authedC
time.Sleep(delay)
conn2.handleConnect()
conn2.handleAuth()
assert.True(t, waitSigChan(group.AllAuthedC(ctx, time.Second), timeout), "all connections are authenticated")
assert.False(t, group.AnyDisconnected(ctx))
// this should re-allocate authedC
conn1.handleDisconnect()
assert.NotEqual(t, conn1authC, conn1.authedC)
assert.True(t, group.AnyDisconnected(ctx))
assert.False(t, waitSigChan(group.AllAuthedC(ctx, time.Second), timeout), "one connection should be un-authed")
time.Sleep(delay)
conn1.handleConnect()
conn1.handleAuth()
assert.True(t, waitSigChan(group.AllAuthedC(ctx, time.Second), timeout), "all connections are authenticated, again")
}
func waitSigChan(c <-chan struct{}, timeoutDuration time.Duration) bool { func waitSigChan(c <-chan struct{}, timeoutDuration time.Duration) bool {
select { select {
case <-time.After(timeoutDuration): case <-time.After(timeoutDuration):

View File

@ -0,0 +1,237 @@
package types
import (
"context"
"sync"
"time"
)
//go:generate callbackgen -type ConnectivityGroup
type ConnectivityGroup struct {
*Connectivity
connections []*Connectivity
mu sync.Mutex
authedC chan struct{}
states map[*Connectivity]ConnectivityState
sumState ConnectivityState
connectCallbacks []func()
disconnectCallbacks []func()
authCallbacks []func()
}
type ConnectivityState int
const (
ConnectivityStateDisconnected ConnectivityState = -1
ConnectivityStateUnknown ConnectivityState = 0
ConnectivityStateConnected ConnectivityState = 1
ConnectivityStateAuthed ConnectivityState = 2
)
func getConnState(con *Connectivity) ConnectivityState {
state := ConnectivityStateDisconnected
if con.IsAuthed() {
state = ConnectivityStateAuthed
} else if con.IsConnected() {
state = ConnectivityStateConnected
}
return state
}
func initConnStates(cons []*Connectivity) map[*Connectivity]ConnectivityState {
states := map[*Connectivity]ConnectivityState{}
for _, con := range cons {
state := ConnectivityStateDisconnected
if con.IsAuthed() {
state = ConnectivityStateAuthed
} else if con.IsConnected() {
state = ConnectivityStateConnected
}
states[con] = state
}
return states
}
func sumStates(states map[*Connectivity]ConnectivityState) ConnectivityState {
disconnected := 0
connected := 0
authed := 0
for _, state := range states {
switch state {
case ConnectivityStateDisconnected:
disconnected++
case ConnectivityStateConnected:
connected++
case ConnectivityStateAuthed:
authed++
}
}
numConn := len(states)
// if one of the connections is disconnected, the group is disconnected
if disconnected > 0 {
return ConnectivityStateDisconnected
} else if authed == numConn {
// if all connections are authed, the group is authed
return ConnectivityStateAuthed
} else if connected == numConn || (connected+authed) == numConn {
// if all connections are connected, the group is connected
return ConnectivityStateConnected
}
return ConnectivityStateUnknown
}
func NewConnectivityGroup(cons ...*Connectivity) *ConnectivityGroup {
states := initConnStates(cons)
sumState := sumStates(states)
g := &ConnectivityGroup{
Connectivity: NewConnectivity(),
connections: cons,
authedC: make(chan struct{}),
states: states,
sumState: sumState,
}
for _, con := range cons {
g.Add(con)
}
return g
}
func (g *ConnectivityGroup) setState(con *Connectivity, state ConnectivityState) {
g.mu.Lock()
g.states[con] = state
prevState := g.sumState
curState := sumStates(g.states)
g.sumState = curState
g.mu.Unlock()
// if the state is not changed, return
if prevState == curState {
return
}
g.setFlags(curState)
switch curState {
case ConnectivityStateAuthed:
g.EmitAuth()
case ConnectivityStateConnected:
g.EmitConnect()
case ConnectivityStateDisconnected:
g.EmitDisconnect()
}
}
func (g *ConnectivityGroup) setFlags(state ConnectivityState) {
switch state {
case ConnectivityStateAuthed:
g.Connectivity.setConnect()
g.Connectivity.setAuthed()
case ConnectivityStateConnected:
g.Connectivity.setConnect()
case ConnectivityStateDisconnected:
g.Connectivity.setDisconnect()
}
}
func (g *ConnectivityGroup) Add(con *Connectivity) {
g.mu.Lock()
g.connections = append(g.connections, con)
g.states[con] = getConnState(con)
g.sumState = sumStates(g.states)
g.setFlags(g.sumState)
g.mu.Unlock()
_con := con
con.OnDisconnect(func() {
g.setState(_con, ConnectivityStateDisconnected)
})
con.OnConnect(func() {
g.setState(_con, ConnectivityStateConnected)
})
con.OnAuth(func() {
g.setState(_con, ConnectivityStateAuthed)
})
}
func (g *ConnectivityGroup) AnyDisconnected(ctx context.Context) bool {
g.mu.Lock()
conns := g.connections
g.mu.Unlock()
for _, conn := range conns {
select {
case <-ctx.Done():
return false
case <-conn.connectedC:
continue
case <-conn.disconnectedC:
return true
}
}
return false
}
func (g *ConnectivityGroup) waitAllAuthed(ctx context.Context, allTimeoutDuration time.Duration) {
g.mu.Lock()
conns := g.connections
c := g.authedC
g.mu.Unlock()
authedConns := make([]bool, len(conns))
allTimeout := time.After(allTimeoutDuration)
for {
for idx, con := range conns {
// if the connection is not authed, mark it as false
if !con.authed {
// authedConns[idx] = false
}
timeout := time.After(3 * time.Second)
select {
case <-ctx.Done():
return
case <-allTimeout:
return
case <-timeout:
continue
case <-con.AuthedC():
authedConns[idx] = true
}
}
if allTrue(authedConns) {
close(c)
return
}
}
}
// AllAuthedC returns a channel that will be closed when all connections are authenticated
// the returned channel will be closed when all connections are authenticated
// and the channel can only be used once (because we can't close a channel twice)
func (g *ConnectivityGroup) AllAuthedC(ctx context.Context, timeout time.Duration) <-chan struct{} {
go g.waitAllAuthed(ctx, timeout)
return g.authedC
}

View File

@ -0,0 +1,35 @@
// Code generated by "callbackgen -type ConnectivityGroup"; DO NOT EDIT.
package types
import ()
func (g *ConnectivityGroup) OnConnect(cb func()) {
g.connectCallbacks = append(g.connectCallbacks, cb)
}
func (g *ConnectivityGroup) EmitConnect() {
for _, cb := range g.connectCallbacks {
cb()
}
}
func (g *ConnectivityGroup) OnDisconnect(cb func()) {
g.disconnectCallbacks = append(g.disconnectCallbacks, cb)
}
func (g *ConnectivityGroup) EmitDisconnect() {
for _, cb := range g.disconnectCallbacks {
cb()
}
}
func (g *ConnectivityGroup) OnAuth(cb func()) {
g.authCallbacks = append(g.authCallbacks, cb)
}
func (g *ConnectivityGroup) EmitAuth() {
for _, cb := range g.authCallbacks {
cb()
}
}

View File

@ -0,0 +1,239 @@
package types
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestConnectivityGroupAuthC(t *testing.T) {
timeout := 100 * time.Millisecond
delay := timeout * 2
ctx := context.Background()
conn1 := NewConnectivity()
conn2 := NewConnectivity()
group := NewConnectivityGroup(conn1, conn2)
allAuthedC := group.AllAuthedC(ctx, time.Second)
time.Sleep(delay)
conn1.setConnect()
assert.True(t, waitSigChan(conn1.ConnectedC(), timeout))
conn1.setAuthed()
assert.True(t, waitSigChan(conn1.AuthedC(), timeout))
time.Sleep(delay)
conn2.setConnect()
assert.True(t, waitSigChan(conn2.ConnectedC(), timeout))
conn2.setAuthed()
assert.True(t, waitSigChan(conn2.AuthedC(), timeout))
assert.True(t, waitSigChan(allAuthedC, timeout))
}
func TestConnectivityGroup(t *testing.T) {
ctx := context.Background()
_ = ctx
t.Run("connected", func(t *testing.T) {
conn1 := NewConnectivity()
conn1.setConnect()
conn2 := NewConnectivity()
conn2.setConnect()
group := NewConnectivityGroup(conn1, conn2)
assert.Equal(t, ConnectivityStateConnected, group.sumState)
assert.True(t, group.IsConnected())
assert.False(t, group.IsAuthed())
})
t.Run("only one connected", func(t *testing.T) {
conn1 := NewConnectivity()
conn1.setConnect()
conn2 := NewConnectivity()
conn2.setDisconnect()
group := NewConnectivityGroup(conn1, conn2)
assert.Equal(t, ConnectivityStateDisconnected, group.sumState)
assert.False(t, group.IsConnected())
assert.False(t, group.IsAuthed())
})
t.Run("only one authed", func(t *testing.T) {
conn1 := NewConnectivity()
conn1.setConnect()
conn2 := NewConnectivity()
conn2.setAuthed()
group := NewConnectivityGroup(conn1, conn2)
assert.Equal(t, ConnectivityStateConnected, group.sumState)
assert.True(t, group.IsConnected())
assert.False(t, group.IsAuthed())
})
t.Run("all authed", func(t *testing.T) {
conn1 := NewConnectivity()
conn1.setAuthed()
conn2 := NewConnectivity()
conn2.setAuthed()
group := NewConnectivityGroup(conn1, conn2)
assert.Equal(t, ConnectivityStateAuthed, group.sumState)
assert.True(t, group.IsConnected())
assert.True(t, group.IsAuthed())
})
t.Run("reconnect", func(t *testing.T) {
conn1 := NewConnectivity()
conn1.setDisconnect()
conn2 := NewConnectivity()
conn2.setDisconnect()
// callbackState is used to test the callback state
callbackState := ConnectivityStateDisconnected
group := NewConnectivityGroup(conn1, conn2)
group.OnConnect(func() {
callbackState = ConnectivityStateConnected
})
group.OnAuth(func() {
callbackState = ConnectivityStateAuthed
})
group.OnDisconnect(func() {
callbackState = ConnectivityStateDisconnected
})
assert.Equal(t, ConnectivityStateDisconnected, group.sumState)
assert.False(t, group.IsConnected())
assert.False(t, group.IsAuthed())
assert.Equal(t, ConnectivityStateDisconnected, callbackState)
t.Log("conn1 connected")
conn1.setConnect()
assert.Equal(t, ConnectivityStateDisconnected, group.sumState)
assert.False(t, group.IsConnected())
assert.False(t, group.IsAuthed())
assert.Equal(t, ConnectivityStateDisconnected, callbackState)
t.Log("conn2 connected")
conn2.setConnect()
assert.Equal(t, ConnectivityStateConnected, group.sumState)
assert.True(t, group.IsConnected())
assert.False(t, group.IsAuthed())
assert.Equal(t, ConnectivityStateConnected, callbackState)
t.Log("conn1 and conn2 authed")
conn1.setAuthed()
conn2.setAuthed()
assert.Equal(t, ConnectivityStateAuthed, group.sumState)
assert.True(t, group.IsConnected())
assert.True(t, group.IsAuthed())
assert.Equal(t, ConnectivityStateAuthed, callbackState)
t.Log("one connection gets disconnected should fallback to disconnected state")
conn2.setDisconnect()
assert.Equal(t, ConnectivityStateDisconnected, group.sumState)
assert.False(t, group.IsConnected())
assert.False(t, group.IsAuthed())
assert.Equal(t, ConnectivityStateDisconnected, callbackState)
t.Log("all connections get disconnected should fallback to disconnected state")
conn1.setDisconnect()
conn2.setDisconnect()
assert.Equal(t, ConnectivityStateDisconnected, group.sumState)
assert.False(t, group.IsConnected())
assert.False(t, group.IsAuthed())
assert.Equal(t, ConnectivityStateDisconnected, callbackState)
t.Log("all connections are connected again")
conn1.setConnect()
conn2.setConnect()
assert.Equal(t, ConnectivityStateConnected, group.sumState)
assert.True(t, group.IsConnected())
assert.False(t, group.IsAuthed())
assert.Equal(t, ConnectivityStateConnected, callbackState)
})
}
func Test_sumStates(t *testing.T) {
type args struct {
states map[*Connectivity]ConnectivityState
}
tests := []struct {
name string
args args
want ConnectivityState
}{
{
name: "all connected",
args: args{
states: map[*Connectivity]ConnectivityState{
NewConnectivity(): ConnectivityStateConnected,
NewConnectivity(): ConnectivityStateConnected,
NewConnectivity(): ConnectivityStateConnected,
},
},
want: ConnectivityStateConnected,
},
{
name: "only one connected should return disconnected",
args: args{
states: map[*Connectivity]ConnectivityState{
NewConnectivity(): ConnectivityStateConnected,
NewConnectivity(): ConnectivityStateDisconnected,
NewConnectivity(): ConnectivityStateDisconnected,
},
},
want: ConnectivityStateDisconnected,
},
{
name: "only one connected and others authed should return connected",
args: args{
states: map[*Connectivity]ConnectivityState{
NewConnectivity(): ConnectivityStateConnected,
NewConnectivity(): ConnectivityStateAuthed,
NewConnectivity(): ConnectivityStateAuthed,
},
},
want: ConnectivityStateConnected,
},
{
name: "all disconnected should return disconnected",
args: args{
states: map[*Connectivity]ConnectivityState{
NewConnectivity(): ConnectivityStateDisconnected,
NewConnectivity(): ConnectivityStateDisconnected,
NewConnectivity(): ConnectivityStateDisconnected,
},
},
want: ConnectivityStateDisconnected,
},
{
name: "one disconnected should return disconnected",
args: args{
states: map[*Connectivity]ConnectivityState{
NewConnectivity(): ConnectivityStateConnected,
NewConnectivity(): ConnectivityStateConnected,
NewConnectivity(): ConnectivityStateDisconnected,
},
},
want: ConnectivityStateDisconnected,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equalf(t, tt.want, sumStates(tt.args.states), "sumStates(%v)", tt.args.states)
})
}
}