From ebab50f30c87d810ba7b719e1f21dec941c6dc1c Mon Sep 17 00:00:00 2001 From: c9s Date: Thu, 14 Nov 2024 22:19:07 +0800 Subject: [PATCH] types: refactor, redesign connectivity --- pkg/types/connectivity.go | 144 ++++---------- pkg/types/connectivity_callbacks.go | 35 ++++ pkg/types/connectivity_test.go | 92 ++------- pkg/types/connectivitygroup.go | 237 ++++++++++++++++++++++ pkg/types/connectivitygroup_callbacks.go | 35 ++++ pkg/types/connectivitygroup_test.go | 239 +++++++++++++++++++++++ 6 files changed, 595 insertions(+), 187 deletions(-) create mode 100644 pkg/types/connectivity_callbacks.go create mode 100644 pkg/types/connectivitygroup.go create mode 100644 pkg/types/connectivitygroup_callbacks.go create mode 100644 pkg/types/connectivitygroup_test.go diff --git a/pkg/types/connectivity.go b/pkg/types/connectivity.go index 047a3e041..93640dc59 100644 --- a/pkg/types/connectivity.go +++ b/pkg/types/connectivity.go @@ -1,96 +1,9 @@ package types import ( - "context" "sync" - "time" ) -type ConnectivityGroup struct { - connections []*Connectivity - mu sync.Mutex -} - -func NewConnectivityGroup(cons ...*Connectivity) *ConnectivityGroup { - return &ConnectivityGroup{ - connections: cons, - } -} - -func (g *ConnectivityGroup) Add(con *Connectivity) { - g.mu.Lock() - defer g.mu.Unlock() - - g.connections = append(g.connections, con) -} - -func (g *ConnectivityGroup) AnyDisconnected(ctx context.Context) bool { - g.mu.Lock() - conns := g.connections - g.mu.Unlock() - - for _, conn := range conns { - select { - case <-ctx.Done(): - return false - - case <-conn.connectedC: - continue - - case <-conn.disconnectedC: - return true - } - } - - return false -} - -func (g *ConnectivityGroup) waitAllAuthed(ctx context.Context, c chan struct{}, allTimeoutDuration time.Duration) { - g.mu.Lock() - conns := g.connections - g.mu.Unlock() - - authedConns := make([]bool, len(conns)) - allTimeout := time.After(allTimeoutDuration) - for { - for idx, con := range conns { - // if the connection is not authed, mark it as false - if !con.authed { - // authedConns[idx] = false - } - - timeout := time.After(3 * time.Second) - select { - case <-ctx.Done(): - return - - case <-allTimeout: - return - - case <-timeout: - continue - - case <-con.AuthedC(): - authedConns[idx] = true - } - } - - if allTrue(authedConns) { - close(c) - return - } - } -} - -// AllAuthedC returns a channel that will be closed when all connections are authenticated -// the returned channel will be closed when all connections are authenticated -// and the channel can only be used once (because we can't close a channel twice) -func (g *ConnectivityGroup) AllAuthedC(ctx context.Context, timeout time.Duration) <-chan struct{} { - c := make(chan struct{}) - go g.waitAllAuthed(ctx, c, timeout) - return c -} - func allTrue(bools []bool) bool { for _, b := range bools { if !b { @@ -101,6 +14,7 @@ func allTrue(bools []bool) bool { return true } +//go:generate callbackgen -type Connectivity type Connectivity struct { authed bool authedC chan struct{} @@ -109,7 +23,12 @@ type Connectivity struct { connectedC chan struct{} disconnectedC chan struct{} - mu sync.Mutex + connectCallbacks []func() + disconnectCallbacks []func() + authCallbacks []func() + + stream Stream + mu sync.Mutex } func NewConnectivity() *Connectivity { @@ -141,31 +60,39 @@ func (c *Connectivity) IsAuthed() (authed bool) { return authed } -func (c *Connectivity) handleConnect() { +func (c *Connectivity) setConnect() { c.mu.Lock() - defer c.mu.Unlock() - - c.connected = true - close(c.connectedC) - c.disconnectedC = make(chan struct{}) + if !c.connected { + c.connected = true + close(c.connectedC) + c.disconnectedC = make(chan struct{}) + } + c.mu.Unlock() + c.EmitConnect() } -func (c *Connectivity) handleDisconnect() { +func (c *Connectivity) setDisconnect() { c.mu.Lock() - defer c.mu.Unlock() - - c.connected = false - c.authedC = make(chan struct{}) - c.connectedC = make(chan struct{}) - close(c.disconnectedC) + if c.connected { + c.connected = false + c.authed = false + c.authedC = make(chan struct{}) + c.connectedC = make(chan struct{}) + close(c.disconnectedC) + } + c.mu.Unlock() + c.EmitDisconnect() } -func (c *Connectivity) handleAuth() { +func (c *Connectivity) setAuthed() { c.mu.Lock() - defer c.mu.Unlock() + if !c.authed { + c.authed = true + close(c.authedC) + } + c.mu.Unlock() - c.authed = true - close(c.authedC) + c.EmitAuth() } func (c *Connectivity) AuthedC() chan struct{} { @@ -187,7 +114,8 @@ func (c *Connectivity) DisconnectedC() chan struct{} { } func (c *Connectivity) Bind(stream Stream) { - stream.OnConnect(c.handleConnect) - stream.OnDisconnect(c.handleDisconnect) - stream.OnAuth(c.handleAuth) + stream.OnConnect(c.setConnect) + stream.OnDisconnect(c.setDisconnect) + stream.OnAuth(c.setAuthed) + c.stream = stream } diff --git a/pkg/types/connectivity_callbacks.go b/pkg/types/connectivity_callbacks.go new file mode 100644 index 000000000..bb6ab24e3 --- /dev/null +++ b/pkg/types/connectivity_callbacks.go @@ -0,0 +1,35 @@ +// Code generated by "callbackgen -type Connectivity"; DO NOT EDIT. + +package types + +import () + +func (c *Connectivity) OnConnect(cb func()) { + c.connectCallbacks = append(c.connectCallbacks, cb) +} + +func (c *Connectivity) EmitConnect() { + for _, cb := range c.connectCallbacks { + cb() + } +} + +func (c *Connectivity) OnDisconnect(cb func()) { + c.disconnectCallbacks = append(c.disconnectCallbacks, cb) +} + +func (c *Connectivity) EmitDisconnect() { + for _, cb := range c.disconnectCallbacks { + cb() + } +} + +func (c *Connectivity) OnAuth(cb func()) { + c.authCallbacks = append(c.authCallbacks, cb) +} + +func (c *Connectivity) EmitAuth() { + for _, cb := range c.authCallbacks { + cb() + } +} diff --git a/pkg/types/connectivity_test.go b/pkg/types/connectivity_test.go index 7a9849863..5681f1506 100644 --- a/pkg/types/connectivity_test.go +++ b/pkg/types/connectivity_test.go @@ -1,105 +1,39 @@ package types import ( - "context" "testing" "time" - - "github.com/stretchr/testify/assert" ) func TestConnectivity(t *testing.T) { t.Run("general", func(t *testing.T) { conn1 := NewConnectivity() - conn1.handleConnect() - conn1.handleAuth() - conn1.handleDisconnect() + conn1.setConnect() + conn1.setAuthed() + conn1.setDisconnect() }) t.Run("reconnect", func(t *testing.T) { conn1 := NewConnectivity() - conn1.handleConnect() - conn1.handleAuth() - conn1.handleDisconnect() + conn1.setConnect() + conn1.setAuthed() + conn1.setDisconnect() - conn1.handleConnect() - conn1.handleAuth() - conn1.handleDisconnect() + conn1.setConnect() + conn1.setAuthed() + conn1.setDisconnect() }) t.Run("no-auth reconnect", func(t *testing.T) { conn1 := NewConnectivity() - conn1.handleConnect() - conn1.handleDisconnect() + conn1.setConnect() + conn1.setDisconnect() - conn1.handleConnect() - conn1.handleDisconnect() + conn1.setConnect() + conn1.setDisconnect() }) } -func TestConnectivityGroupAuthC(t *testing.T) { - timeout := 100 * time.Millisecond - delay := timeout * 2 - - ctx := context.Background() - conn1 := NewConnectivity() - conn2 := NewConnectivity() - group := NewConnectivityGroup(conn1, conn2) - allAuthedC := group.AllAuthedC(ctx, time.Second) - - time.Sleep(delay) - conn1.handleConnect() - assert.True(t, waitSigChan(conn1.ConnectedC(), timeout)) - conn1.handleAuth() - assert.True(t, waitSigChan(conn1.AuthedC(), timeout)) - - time.Sleep(delay) - conn2.handleConnect() - assert.True(t, waitSigChan(conn2.ConnectedC(), timeout)) - - conn2.handleAuth() - assert.True(t, waitSigChan(conn2.AuthedC(), timeout)) - - assert.True(t, waitSigChan(allAuthedC, timeout)) -} - -func TestConnectivityGroupReconnect(t *testing.T) { - timeout := 100 * time.Millisecond - delay := timeout * 2 - - ctx := context.Background() - conn1 := NewConnectivity() - conn2 := NewConnectivity() - group := NewConnectivityGroup(conn1, conn2) - - time.Sleep(delay) - conn1.handleConnect() - conn1.handleAuth() - conn1authC := conn1.authedC - - time.Sleep(delay) - conn2.handleConnect() - conn2.handleAuth() - - assert.True(t, waitSigChan(group.AllAuthedC(ctx, time.Second), timeout), "all connections are authenticated") - - assert.False(t, group.AnyDisconnected(ctx)) - - // this should re-allocate authedC - conn1.handleDisconnect() - assert.NotEqual(t, conn1authC, conn1.authedC) - - assert.True(t, group.AnyDisconnected(ctx)) - - assert.False(t, waitSigChan(group.AllAuthedC(ctx, time.Second), timeout), "one connection should be un-authed") - - time.Sleep(delay) - - conn1.handleConnect() - conn1.handleAuth() - assert.True(t, waitSigChan(group.AllAuthedC(ctx, time.Second), timeout), "all connections are authenticated, again") -} - func waitSigChan(c <-chan struct{}, timeoutDuration time.Duration) bool { select { case <-time.After(timeoutDuration): diff --git a/pkg/types/connectivitygroup.go b/pkg/types/connectivitygroup.go new file mode 100644 index 000000000..0b500e92d --- /dev/null +++ b/pkg/types/connectivitygroup.go @@ -0,0 +1,237 @@ +package types + +import ( + "context" + "sync" + "time" +) + +//go:generate callbackgen -type ConnectivityGroup +type ConnectivityGroup struct { + *Connectivity + + connections []*Connectivity + mu sync.Mutex + authedC chan struct{} + + states map[*Connectivity]ConnectivityState + sumState ConnectivityState + + connectCallbacks []func() + disconnectCallbacks []func() + authCallbacks []func() +} + +type ConnectivityState int + +const ( + ConnectivityStateDisconnected ConnectivityState = -1 + ConnectivityStateUnknown ConnectivityState = 0 + ConnectivityStateConnected ConnectivityState = 1 + ConnectivityStateAuthed ConnectivityState = 2 +) + +func getConnState(con *Connectivity) ConnectivityState { + state := ConnectivityStateDisconnected + + if con.IsAuthed() { + state = ConnectivityStateAuthed + } else if con.IsConnected() { + state = ConnectivityStateConnected + } + + return state +} + +func initConnStates(cons []*Connectivity) map[*Connectivity]ConnectivityState { + states := map[*Connectivity]ConnectivityState{} + + for _, con := range cons { + state := ConnectivityStateDisconnected + + if con.IsAuthed() { + state = ConnectivityStateAuthed + } else if con.IsConnected() { + state = ConnectivityStateConnected + } + + states[con] = state + } + + return states +} + +func sumStates(states map[*Connectivity]ConnectivityState) ConnectivityState { + disconnected := 0 + connected := 0 + authed := 0 + + for _, state := range states { + switch state { + case ConnectivityStateDisconnected: + disconnected++ + case ConnectivityStateConnected: + connected++ + case ConnectivityStateAuthed: + authed++ + } + } + + numConn := len(states) + + // if one of the connections is disconnected, the group is disconnected + if disconnected > 0 { + return ConnectivityStateDisconnected + } else if authed == numConn { + // if all connections are authed, the group is authed + return ConnectivityStateAuthed + } else if connected == numConn || (connected+authed) == numConn { + // if all connections are connected, the group is connected + return ConnectivityStateConnected + } + + return ConnectivityStateUnknown +} + +func NewConnectivityGroup(cons ...*Connectivity) *ConnectivityGroup { + states := initConnStates(cons) + sumState := sumStates(states) + g := &ConnectivityGroup{ + Connectivity: NewConnectivity(), + connections: cons, + authedC: make(chan struct{}), + states: states, + sumState: sumState, + } + + for _, con := range cons { + g.Add(con) + } + + return g +} + +func (g *ConnectivityGroup) setState(con *Connectivity, state ConnectivityState) { + g.mu.Lock() + g.states[con] = state + prevState := g.sumState + curState := sumStates(g.states) + g.sumState = curState + g.mu.Unlock() + + // if the state is not changed, return + if prevState == curState { + return + } + + g.setFlags(curState) + switch curState { + case ConnectivityStateAuthed: + g.EmitAuth() + case ConnectivityStateConnected: + g.EmitConnect() + case ConnectivityStateDisconnected: + g.EmitDisconnect() + } +} + +func (g *ConnectivityGroup) setFlags(state ConnectivityState) { + switch state { + case ConnectivityStateAuthed: + g.Connectivity.setConnect() + g.Connectivity.setAuthed() + case ConnectivityStateConnected: + g.Connectivity.setConnect() + case ConnectivityStateDisconnected: + g.Connectivity.setDisconnect() + } +} + +func (g *ConnectivityGroup) Add(con *Connectivity) { + g.mu.Lock() + g.connections = append(g.connections, con) + g.states[con] = getConnState(con) + g.sumState = sumStates(g.states) + g.setFlags(g.sumState) + g.mu.Unlock() + + _con := con + con.OnDisconnect(func() { + g.setState(_con, ConnectivityStateDisconnected) + }) + + con.OnConnect(func() { + g.setState(_con, ConnectivityStateConnected) + }) + + con.OnAuth(func() { + g.setState(_con, ConnectivityStateAuthed) + }) +} + +func (g *ConnectivityGroup) AnyDisconnected(ctx context.Context) bool { + g.mu.Lock() + conns := g.connections + g.mu.Unlock() + + for _, conn := range conns { + select { + case <-ctx.Done(): + return false + + case <-conn.connectedC: + continue + + case <-conn.disconnectedC: + return true + } + } + + return false +} + +func (g *ConnectivityGroup) waitAllAuthed(ctx context.Context, allTimeoutDuration time.Duration) { + g.mu.Lock() + conns := g.connections + c := g.authedC + g.mu.Unlock() + + authedConns := make([]bool, len(conns)) + allTimeout := time.After(allTimeoutDuration) + for { + for idx, con := range conns { + // if the connection is not authed, mark it as false + if !con.authed { + // authedConns[idx] = false + } + + timeout := time.After(3 * time.Second) + select { + case <-ctx.Done(): + return + + case <-allTimeout: + return + + case <-timeout: + continue + + case <-con.AuthedC(): + authedConns[idx] = true + } + } + + if allTrue(authedConns) { + close(c) + return + } + } +} + +// AllAuthedC returns a channel that will be closed when all connections are authenticated +// the returned channel will be closed when all connections are authenticated +// and the channel can only be used once (because we can't close a channel twice) +func (g *ConnectivityGroup) AllAuthedC(ctx context.Context, timeout time.Duration) <-chan struct{} { + go g.waitAllAuthed(ctx, timeout) + return g.authedC +} diff --git a/pkg/types/connectivitygroup_callbacks.go b/pkg/types/connectivitygroup_callbacks.go new file mode 100644 index 000000000..c5d5f3bf4 --- /dev/null +++ b/pkg/types/connectivitygroup_callbacks.go @@ -0,0 +1,35 @@ +// Code generated by "callbackgen -type ConnectivityGroup"; DO NOT EDIT. + +package types + +import () + +func (g *ConnectivityGroup) OnConnect(cb func()) { + g.connectCallbacks = append(g.connectCallbacks, cb) +} + +func (g *ConnectivityGroup) EmitConnect() { + for _, cb := range g.connectCallbacks { + cb() + } +} + +func (g *ConnectivityGroup) OnDisconnect(cb func()) { + g.disconnectCallbacks = append(g.disconnectCallbacks, cb) +} + +func (g *ConnectivityGroup) EmitDisconnect() { + for _, cb := range g.disconnectCallbacks { + cb() + } +} + +func (g *ConnectivityGroup) OnAuth(cb func()) { + g.authCallbacks = append(g.authCallbacks, cb) +} + +func (g *ConnectivityGroup) EmitAuth() { + for _, cb := range g.authCallbacks { + cb() + } +} diff --git a/pkg/types/connectivitygroup_test.go b/pkg/types/connectivitygroup_test.go new file mode 100644 index 000000000..b75aa6524 --- /dev/null +++ b/pkg/types/connectivitygroup_test.go @@ -0,0 +1,239 @@ +package types + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestConnectivityGroupAuthC(t *testing.T) { + timeout := 100 * time.Millisecond + delay := timeout * 2 + + ctx := context.Background() + conn1 := NewConnectivity() + conn2 := NewConnectivity() + group := NewConnectivityGroup(conn1, conn2) + allAuthedC := group.AllAuthedC(ctx, time.Second) + + time.Sleep(delay) + conn1.setConnect() + assert.True(t, waitSigChan(conn1.ConnectedC(), timeout)) + conn1.setAuthed() + assert.True(t, waitSigChan(conn1.AuthedC(), timeout)) + + time.Sleep(delay) + conn2.setConnect() + assert.True(t, waitSigChan(conn2.ConnectedC(), timeout)) + + conn2.setAuthed() + assert.True(t, waitSigChan(conn2.AuthedC(), timeout)) + + assert.True(t, waitSigChan(allAuthedC, timeout)) +} + +func TestConnectivityGroup(t *testing.T) { + ctx := context.Background() + _ = ctx + + t.Run("connected", func(t *testing.T) { + conn1 := NewConnectivity() + conn1.setConnect() + + conn2 := NewConnectivity() + conn2.setConnect() + + group := NewConnectivityGroup(conn1, conn2) + assert.Equal(t, ConnectivityStateConnected, group.sumState) + assert.True(t, group.IsConnected()) + assert.False(t, group.IsAuthed()) + }) + + t.Run("only one connected", func(t *testing.T) { + conn1 := NewConnectivity() + conn1.setConnect() + + conn2 := NewConnectivity() + conn2.setDisconnect() + + group := NewConnectivityGroup(conn1, conn2) + assert.Equal(t, ConnectivityStateDisconnected, group.sumState) + assert.False(t, group.IsConnected()) + assert.False(t, group.IsAuthed()) + }) + + t.Run("only one authed", func(t *testing.T) { + conn1 := NewConnectivity() + conn1.setConnect() + + conn2 := NewConnectivity() + conn2.setAuthed() + + group := NewConnectivityGroup(conn1, conn2) + assert.Equal(t, ConnectivityStateConnected, group.sumState) + assert.True(t, group.IsConnected()) + assert.False(t, group.IsAuthed()) + }) + + t.Run("all authed", func(t *testing.T) { + conn1 := NewConnectivity() + conn1.setAuthed() + + conn2 := NewConnectivity() + conn2.setAuthed() + + group := NewConnectivityGroup(conn1, conn2) + assert.Equal(t, ConnectivityStateAuthed, group.sumState) + assert.True(t, group.IsConnected()) + assert.True(t, group.IsAuthed()) + }) + + t.Run("reconnect", func(t *testing.T) { + conn1 := NewConnectivity() + conn1.setDisconnect() + conn2 := NewConnectivity() + conn2.setDisconnect() + + // callbackState is used to test the callback state + callbackState := ConnectivityStateDisconnected + + group := NewConnectivityGroup(conn1, conn2) + group.OnConnect(func() { + callbackState = ConnectivityStateConnected + }) + group.OnAuth(func() { + callbackState = ConnectivityStateAuthed + }) + group.OnDisconnect(func() { + callbackState = ConnectivityStateDisconnected + }) + + assert.Equal(t, ConnectivityStateDisconnected, group.sumState) + assert.False(t, group.IsConnected()) + assert.False(t, group.IsAuthed()) + assert.Equal(t, ConnectivityStateDisconnected, callbackState) + + t.Log("conn1 connected") + conn1.setConnect() + assert.Equal(t, ConnectivityStateDisconnected, group.sumState) + assert.False(t, group.IsConnected()) + assert.False(t, group.IsAuthed()) + assert.Equal(t, ConnectivityStateDisconnected, callbackState) + + t.Log("conn2 connected") + conn2.setConnect() + assert.Equal(t, ConnectivityStateConnected, group.sumState) + assert.True(t, group.IsConnected()) + assert.False(t, group.IsAuthed()) + assert.Equal(t, ConnectivityStateConnected, callbackState) + + t.Log("conn1 and conn2 authed") + conn1.setAuthed() + conn2.setAuthed() + + assert.Equal(t, ConnectivityStateAuthed, group.sumState) + assert.True(t, group.IsConnected()) + assert.True(t, group.IsAuthed()) + assert.Equal(t, ConnectivityStateAuthed, callbackState) + + t.Log("one connection gets disconnected should fallback to disconnected state") + conn2.setDisconnect() + + assert.Equal(t, ConnectivityStateDisconnected, group.sumState) + assert.False(t, group.IsConnected()) + assert.False(t, group.IsAuthed()) + assert.Equal(t, ConnectivityStateDisconnected, callbackState) + + t.Log("all connections get disconnected should fallback to disconnected state") + conn1.setDisconnect() + conn2.setDisconnect() + + assert.Equal(t, ConnectivityStateDisconnected, group.sumState) + assert.False(t, group.IsConnected()) + assert.False(t, group.IsAuthed()) + assert.Equal(t, ConnectivityStateDisconnected, callbackState) + + t.Log("all connections are connected again") + conn1.setConnect() + conn2.setConnect() + + assert.Equal(t, ConnectivityStateConnected, group.sumState) + assert.True(t, group.IsConnected()) + assert.False(t, group.IsAuthed()) + assert.Equal(t, ConnectivityStateConnected, callbackState) + }) +} + +func Test_sumStates(t *testing.T) { + type args struct { + states map[*Connectivity]ConnectivityState + } + tests := []struct { + name string + args args + want ConnectivityState + }{ + { + name: "all connected", + args: args{ + states: map[*Connectivity]ConnectivityState{ + NewConnectivity(): ConnectivityStateConnected, + NewConnectivity(): ConnectivityStateConnected, + NewConnectivity(): ConnectivityStateConnected, + }, + }, + want: ConnectivityStateConnected, + }, + { + name: "only one connected should return disconnected", + args: args{ + states: map[*Connectivity]ConnectivityState{ + NewConnectivity(): ConnectivityStateConnected, + NewConnectivity(): ConnectivityStateDisconnected, + NewConnectivity(): ConnectivityStateDisconnected, + }, + }, + want: ConnectivityStateDisconnected, + }, + { + name: "only one connected and others authed should return connected", + args: args{ + states: map[*Connectivity]ConnectivityState{ + NewConnectivity(): ConnectivityStateConnected, + NewConnectivity(): ConnectivityStateAuthed, + NewConnectivity(): ConnectivityStateAuthed, + }, + }, + want: ConnectivityStateConnected, + }, + { + name: "all disconnected should return disconnected", + args: args{ + states: map[*Connectivity]ConnectivityState{ + NewConnectivity(): ConnectivityStateDisconnected, + NewConnectivity(): ConnectivityStateDisconnected, + NewConnectivity(): ConnectivityStateDisconnected, + }, + }, + want: ConnectivityStateDisconnected, + }, + { + name: "one disconnected should return disconnected", + args: args{ + states: map[*Connectivity]ConnectivityState{ + NewConnectivity(): ConnectivityStateConnected, + NewConnectivity(): ConnectivityStateConnected, + NewConnectivity(): ConnectivityStateDisconnected, + }, + }, + want: ConnectivityStateDisconnected, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, sumStates(tt.args.states), "sumStates(%v)", tt.args.states) + }) + } +}