diff --git a/pkg/types/connectivitygroup.go b/pkg/types/connectivitygroup.go index 7cc177381..4761c1606 100644 --- a/pkg/types/connectivitygroup.go +++ b/pkg/types/connectivitygroup.go @@ -173,28 +173,7 @@ func (g *ConnectivityGroup) Add(con *Connectivity) { }) } -func (g *ConnectivityGroup) AnyDisconnected(ctx context.Context) bool { - g.mu.Lock() - conns := g.connections - g.mu.Unlock() - - for _, conn := range conns { - select { - case <-ctx.Done(): - return false - - case <-conn.connectedC: - continue - - case <-conn.disconnectedC: - return true - } - } - - return false -} - -func (g *ConnectivityGroup) waitAllAuthed(ctx context.Context, c chan struct{}) { +func (g *ConnectivityGroup) waitForState(ctx context.Context, c chan struct{}, expected ConnectivityState) { for { select { case <-ctx.Done(): @@ -202,7 +181,7 @@ func (g *ConnectivityGroup) waitAllAuthed(ctx context.Context, c chan struct{}) default: state := g.GetState() - if state == ConnectivityStateAuthed { + if state == expected { close(c) return } @@ -217,6 +196,6 @@ func (g *ConnectivityGroup) waitAllAuthed(ctx context.Context, c chan struct{}) // and the channel can only be used once (because we can't close a channel twice) func (g *ConnectivityGroup) AllAuthedC(ctx context.Context) <-chan struct{} { c := make(chan struct{}) - go g.waitAllAuthed(ctx, c) + go g.waitForState(ctx, c, ConnectivityStateAuthed) return c }