From 83cdd4e1a47e0224acf53a155c8c57ddfcd8a6c2 Mon Sep 17 00:00:00 2001 From: Edwin Date: Wed, 6 Sep 2023 12:18:17 +0800 Subject: [PATCH] pkg/exchange: update add reconnect and resubscribe func for stream --- pkg/types/stream.go | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/pkg/types/stream.go b/pkg/types/stream.go index 668a12856..5b481d0d9 100644 --- a/pkg/types/stream.go +++ b/pkg/types/stream.go @@ -27,14 +27,26 @@ var defaultDialer = &websocket.Dialer{ type Stream interface { StandardStreamEventHub + // Subscribe subscribes the specific channel, but not connect to the server. Subscribe(channel Channel, symbol string, options SubscribeOptions) GetSubscriptions() []Subscription + // Resubscribe used to update or renew existing subscriptions. It will reconnect to the server. + Resubscribe(func(oldSubs []Subscription) (newSubs []Subscription, err error)) error + // SetPublicOnly connects to public or private SetPublicOnly() GetPublicOnly() bool + + // Connect connects to websocket server Connect(ctx context.Context) error + Reconnect() Close() error } +type Unsubscriber interface { + // Unsubscribe unsubscribes the all subscriptions. + Unsubscribe() +} + type EndpointCreator func(ctx context.Context) (string, error) type Parser func(message []byte) (interface{}, error) @@ -76,6 +88,10 @@ type StandardStream struct { Subscriptions []Subscription + // subLock is used for locking Subscriptions fields. + // When changing these field values, be sure to call subLock + subLock sync.Mutex + startCallbacks []func() connectCallbacks []func() @@ -290,10 +306,34 @@ func (s *StandardStream) ping(ctx context.Context, conn *websocket.Conn, cancel } func (s *StandardStream) GetSubscriptions() []Subscription { + s.subLock.Lock() + defer s.subLock.Unlock() + return s.Subscriptions } +// Resubscribe synchronizes the new subscriptions based on the provided function. +// The fn function takes the old subscriptions as input and returns the new subscriptions that will replace the old ones +// in the struct then Reconnect. +// This method is thread-safe. +func (s *StandardStream) Resubscribe(fn func(old []Subscription) (new []Subscription, err error)) error { + s.subLock.Lock() + defer s.subLock.Unlock() + + var err error + subs, err := fn(s.Subscriptions) + if err != nil { + return err + } + s.Subscriptions = subs + s.Reconnect() + return nil +} + func (s *StandardStream) Subscribe(channel Channel, symbol string, options SubscribeOptions) { + s.subLock.Lock() + defer s.subLock.Unlock() + s.Subscriptions = append(s.Subscriptions, Subscription{ Channel: channel, Symbol: symbol,