From f46ca57bb2593c50808b1c52970cd63ace1ffe61 Mon Sep 17 00:00:00 2001 From: Edwin Date: Fri, 17 Nov 2023 16:01:15 +0800 Subject: [PATCH] pkg/types: refactor exchange name --- pkg/types/exchange.go | 66 +++++++++++++++++++++++-------------------- 1 file changed, 35 insertions(+), 31 deletions(-) diff --git a/pkg/types/exchange.go b/pkg/types/exchange.go index 06491224f..b67a9b2b2 100644 --- a/pkg/types/exchange.go +++ b/pkg/types/exchange.go @@ -15,30 +15,6 @@ const DateFormat = "2006-01-02" type ExchangeName string -func (n *ExchangeName) Value() (driver.Value, error) { - return n.String(), nil -} - -func (n *ExchangeName) UnmarshalJSON(data []byte) error { - var s string - if err := json.Unmarshal(data, &s); err != nil { - return err - } - - switch s { - case "binance", "bitget", "bybit", "max", "okex", "kucoin": - *n = ExchangeName(s) - return nil - - } - - return fmt.Errorf("unknown or unsupported exchange name: %s, valid names are: binance, bitget, bybit, max, okex, kucoin", s) -} - -func (n ExchangeName) String() string { - return string(n) -} - const ( ExchangeMax ExchangeName = "max" ExchangeBinance ExchangeName = "binance" @@ -59,15 +35,43 @@ var SupportedExchanges = []ExchangeName{ // note: we are not using "backtest" } -func ValidExchangeName(a string) (ExchangeName, error) { - aa := strings.ToLower(a) - for _, n := range SupportedExchanges { - if string(n) == aa { - return n, nil - } +func (n *ExchangeName) Value() (driver.Value, error) { + return n.String(), nil +} + +func (n *ExchangeName) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err } - return "", fmt.Errorf("invalid exchange name: %s", a) + *n = ExchangeName(s) + if !n.IsValid() { + return fmt.Errorf("%s is an invalid exchange name", s) + } + + return nil +} + +func (n ExchangeName) IsValid() bool { + switch n { + case ExchangeBinance, ExchangeBitget, ExchangeBybit, ExchangeMax, ExchangeOKEx, ExchangeKucoin: + return true + } + return false +} + +func (n ExchangeName) String() string { + return string(n) +} + +func ValidExchangeName(a string) (ExchangeName, error) { + exName := ExchangeName(strings.ToLower(a)) + if !exName.IsValid() { + return "", fmt.Errorf("invalid exchange name: %s", a) + } + + return exName, nil } type ExchangeMinimal interface {