From fb35a2b79f94cff933de4e15f817958df3bc53bf Mon Sep 17 00:00:00 2001 From: c9s Date: Fri, 27 Sep 2024 13:31:37 +0800 Subject: [PATCH] types: add AnyDisconnected() method on ConnectivityGroup --- pkg/types/connectivity.go | 28 +++++++++++++++++++++++++--- pkg/types/connectivity_test.go | 9 +++++++-- 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/pkg/types/connectivity.go b/pkg/types/connectivity.go index 1745ba866..fdb1e88a9 100644 --- a/pkg/types/connectivity.go +++ b/pkg/types/connectivity.go @@ -24,14 +24,36 @@ func (g *ConnectivityGroup) Add(con *Connectivity) { g.connections = append(g.connections, con) } +func (g *ConnectivityGroup) AnyDisconnected(ctx context.Context) bool { + g.mu.Lock() + conns := g.connections + g.mu.Unlock() + + for _, conn := range conns { + select { + case <-ctx.Done(): + return false + + case <-conn.connectedC: + continue + + case <-conn.disconnectedC: + return true + } + } + + return false +} + func (g *ConnectivityGroup) waitAllAuthed(ctx context.Context, c chan struct{}, allTimeoutDuration time.Duration) { g.mu.Lock() - defer g.mu.Unlock() + conns := g.connections + g.mu.Unlock() - authedConns := make([]bool, len(g.connections)) + authedConns := make([]bool, len(conns)) allTimeout := time.After(allTimeoutDuration) for { - for idx, con := range g.connections { + for idx, con := range conns { // if the connection is not authed, mark it as false if !con.authed { // authedConns[idx] = false diff --git a/pkg/types/connectivity_test.go b/pkg/types/connectivity_test.go index 5d36ede37..7a9849863 100644 --- a/pkg/types/connectivity_test.go +++ b/pkg/types/connectivity_test.go @@ -39,6 +39,7 @@ func TestConnectivity(t *testing.T) { func TestConnectivityGroupAuthC(t *testing.T) { timeout := 100 * time.Millisecond + delay := timeout * 2 ctx := context.Background() conn1 := NewConnectivity() @@ -46,13 +47,13 @@ func TestConnectivityGroupAuthC(t *testing.T) { group := NewConnectivityGroup(conn1, conn2) allAuthedC := group.AllAuthedC(ctx, time.Second) - time.Sleep(timeout) + time.Sleep(delay) conn1.handleConnect() assert.True(t, waitSigChan(conn1.ConnectedC(), timeout)) conn1.handleAuth() assert.True(t, waitSigChan(conn1.AuthedC(), timeout)) - time.Sleep(timeout) + time.Sleep(delay) conn2.handleConnect() assert.True(t, waitSigChan(conn2.ConnectedC(), timeout)) @@ -82,10 +83,14 @@ func TestConnectivityGroupReconnect(t *testing.T) { assert.True(t, waitSigChan(group.AllAuthedC(ctx, time.Second), timeout), "all connections are authenticated") + assert.False(t, group.AnyDisconnected(ctx)) + // this should re-allocate authedC conn1.handleDisconnect() assert.NotEqual(t, conn1authC, conn1.authedC) + assert.True(t, group.AnyDisconnected(ctx)) + assert.False(t, waitSigChan(group.AllAuthedC(ctx, time.Second), timeout), "one connection should be un-authed") time.Sleep(delay)