diff --git a/pkg/service/trade.go b/pkg/service/trade.go index 16a64d076..26b221ed4 100644 --- a/pkg/service/trade.go +++ b/pkg/service/trade.go @@ -2,6 +2,7 @@ package service import ( "context" + "fmt" "strconv" "strings" "time" @@ -32,7 +33,11 @@ type QueryTradesOptions struct { // ASC or DESC Ordering string - Limit uint64 + + // OrderByColumn is the column name to order by + // Currently we only support traded_at and gid column. + OrderByColumn string + Limit uint64 } type TradingVolume struct { @@ -304,12 +309,29 @@ func (s *TradeService) Query(options QueryTradesOptions) ([]types.Trade, error) sel = sel.Where(sq.Eq{"exchange": options.Sessions}) } - if options.Ordering != "" { - sel = sel.OrderBy("traded_at " + options.Ordering) - } else { - sel = sel.OrderBy("traded_at ASC") + var orderByColumn string + switch options.OrderByColumn { + case "": + orderByColumn = "traded_at" + case "traded_at", "gid": + orderByColumn = options.OrderByColumn + default: + return nil, fmt.Errorf("invalid order by column: %s", options.OrderByColumn) } + var ordering string + + switch strings.ToUpper(options.Ordering) { + case "": + ordering = "ASC" + case "ASC", "DESC": + ordering = strings.ToUpper(options.Ordering) + default: + return nil, fmt.Errorf("invalid ordering: %s", options.Ordering) + } + + sel = sel.OrderBy(orderByColumn + " " + ordering) + if options.Limit > 0 { sel = sel.Limit(options.Limit) } diff --git a/pkg/service/trade_test.go b/pkg/service/trade_test.go index d24e0b16f..d0e9c97ad 100644 --- a/pkg/service/trade_test.go +++ b/pkg/service/trade_test.go @@ -1,9 +1,11 @@ package service import ( + "database/sql" "testing" "time" + "github.com/DATA-DOG/go-sqlmock" "github.com/jmoiron/sqlx" "github.com/stretchr/testify/assert" @@ -87,3 +89,36 @@ func Test_queryTradesSQL(t *testing.T) { })) }) } + +func TestTradeService_Query(t *testing.T) { + db, mock, err := sqlmock.New() + if !assert.NoError(t, err) { + return + } + defer db.Close() + + sqlxDB := sqlx.NewDb(db, "mysql") + defer sqlxDB.Close() + + s := NewTradeService(sqlxDB) + + _, err = s.Query(QueryTradesOptions{Ordering: "test_ordering"}) + assert.Error(t, err) + assert.Equal(t, "invalid ordering: test_ordering", err.Error()) + + _, err = s.Query(QueryTradesOptions{OrderByColumn: "invalid_column"}) + assert.Error(t, err) + assert.Equal(t, "invalid order by column: invalid_column", err.Error()) + + mock.ExpectQuery("SELECT \\* FROM trades WHERE gid > \\? ORDER BY gid ASC").WithArgs(1234).WillReturnError(sql.ErrNoRows) + _, err = s.Query(QueryTradesOptions{LastGID: 1234, Ordering: "ASC", OrderByColumn: "gid"}) + assert.Equal(t, sql.ErrNoRows, err) + + mock.ExpectQuery("SELECT \\* FROM trades ORDER BY gid DESC").WillReturnError(sql.ErrNoRows) + _, err = s.Query(QueryTradesOptions{Ordering: "DESC", OrderByColumn: "gid"}) + assert.Equal(t, sql.ErrNoRows, err) + + mock.ExpectQuery("SELECT \\* FROM trades ORDER BY traded_at ASC").WillReturnError(sql.ErrNoRows) + _, err = s.Query(QueryTradesOptions{Ordering: "ASC", OrderByColumn: "traded_at"}) + assert.Equal(t, sql.ErrNoRows, err) +}