first commit
This commit is contained in:
commit
6194adcc21
114
common/balance.go
Normal file
114
common/balance.go
Normal file
|
@ -0,0 +1,114 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
. "git.qtrade.icu/coin-quant/trademodel"
|
||||
"github.com/shopspring/decimal"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNoBalance = errors.New("no balance")
|
||||
Zero = decimal.NewFromInt(0)
|
||||
)
|
||||
|
||||
type VBalance struct {
|
||||
total decimal.Decimal
|
||||
prevRoundTotal decimal.Decimal
|
||||
position decimal.Decimal
|
||||
feeTotal decimal.Decimal
|
||||
// 开仓的总价值
|
||||
longCost decimal.Decimal
|
||||
shortCost decimal.Decimal
|
||||
fee decimal.Decimal
|
||||
prevFee decimal.Decimal
|
||||
}
|
||||
|
||||
func NewVBalance() *VBalance {
|
||||
b := new(VBalance)
|
||||
b.total = decimal.NewFromFloat(100000)
|
||||
b.prevRoundTotal = b.total
|
||||
b.fee = decimal.NewFromFloat(0.00075)
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *VBalance) Set(total float64) {
|
||||
b.total = decimal.NewFromFloat(total)
|
||||
b.prevRoundTotal = b.total
|
||||
}
|
||||
|
||||
func (b *VBalance) SetFee(fee float64) {
|
||||
b.fee = decimal.NewFromFloat(fee)
|
||||
}
|
||||
|
||||
func (b *VBalance) Pos() (pos float64) {
|
||||
pos, _ = b.position.Float64()
|
||||
return
|
||||
}
|
||||
|
||||
func (b *VBalance) Get() (total float64) {
|
||||
// return b.total + b.costTotal
|
||||
total, _ = b.total.Float64()
|
||||
return
|
||||
}
|
||||
|
||||
func (b *VBalance) GetFeeTotal() (fee float64) {
|
||||
fee, _ = b.feeTotal.Float64()
|
||||
return
|
||||
}
|
||||
|
||||
func (b *VBalance) AvgOpenPrice() (price float64) {
|
||||
switch b.position.Sign() {
|
||||
case -1:
|
||||
price, _ = b.shortCost.Div(b.position.Abs()).Float64()
|
||||
case 0:
|
||||
return
|
||||
case 1:
|
||||
price, _ = b.longCost.Div(b.position.Abs()).Float64()
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (b *VBalance) AddTrade(tr Trade) (profit, onceFee float64, err error) {
|
||||
amount := decimal.NewFromFloat(tr.Amount).Abs()
|
||||
// 仓位价值
|
||||
cost := amount.Mul(decimal.NewFromFloat(tr.Price)).Abs()
|
||||
fee := cost.Mul(b.fee)
|
||||
onceFee, _ = fee.Float64()
|
||||
costAll, _ := cost.Add(fee).Float64()
|
||||
if tr.Action.IsOpen() && costAll >= b.Get() {
|
||||
err = ErrNoBalance
|
||||
return
|
||||
}
|
||||
// close/stop just return if no position
|
||||
if b.position.Equal(Zero) && !tr.Action.IsOpen() {
|
||||
return
|
||||
}
|
||||
if tr.Action.IsLong() {
|
||||
b.position = b.position.Add(amount)
|
||||
b.longCost = b.longCost.Add(cost)
|
||||
} else {
|
||||
b.position = b.position.Sub(amount)
|
||||
b.shortCost = b.shortCost.Add(cost)
|
||||
}
|
||||
isPositionZero := b.position.Equal(Zero)
|
||||
if tr.Action.IsOpen() && !isPositionZero {
|
||||
b.total = b.total.Sub(cost).Sub(fee)
|
||||
}
|
||||
b.feeTotal = b.feeTotal.Add(fee)
|
||||
// 计算盈利
|
||||
if isPositionZero {
|
||||
totalFee := fee.Add(b.prevFee)
|
||||
prof := b.shortCost.Sub(b.longCost).Sub(totalFee)
|
||||
b.total = b.prevRoundTotal.Add(prof)
|
||||
profit, _ = prof.Float64()
|
||||
b.longCost = decimal.NewFromInt(0)
|
||||
b.shortCost = decimal.NewFromInt(0)
|
||||
b.prevRoundTotal = b.total
|
||||
b.prevFee = decimal.Zero
|
||||
} else {
|
||||
b.prevFee = b.prevFee.Add(fee)
|
||||
}
|
||||
return
|
||||
}
|
122
common/balance_lever.go
Normal file
122
common/balance_lever.go
Normal file
|
@ -0,0 +1,122 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
. "git.qtrade.icu/coin-quant/trademodel"
|
||||
"github.com/shopspring/decimal"
|
||||
)
|
||||
|
||||
var (
|
||||
numOne = decimal.NewFromInt(1)
|
||||
)
|
||||
|
||||
type LeverBalance struct {
|
||||
vBalance *VBalance
|
||||
total decimal.Decimal
|
||||
|
||||
// 开仓的总价值
|
||||
lever decimal.Decimal
|
||||
}
|
||||
|
||||
func NewLeverBalance() *LeverBalance {
|
||||
lb := new(LeverBalance)
|
||||
lb.vBalance = NewVBalance()
|
||||
lb.lever = decimal.NewFromFloat(1)
|
||||
return lb
|
||||
}
|
||||
|
||||
func (b *LeverBalance) Set(total float64) {
|
||||
b.total = decimal.NewFromFloat(total)
|
||||
vTotal, _ := b.total.Mul(b.lever).Float64()
|
||||
b.vBalance.Set(vTotal)
|
||||
}
|
||||
|
||||
func (b *LeverBalance) SetFee(fee float64) {
|
||||
b.vBalance.SetFee(fee)
|
||||
}
|
||||
|
||||
func (b *LeverBalance) SetLever(lever float64) {
|
||||
b.lever = decimal.NewFromFloat(lever)
|
||||
vTotal, _ := b.total.Mul(b.lever).Float64()
|
||||
b.vBalance.Set(vTotal)
|
||||
}
|
||||
|
||||
func (b *LeverBalance) Pos() (pos float64) {
|
||||
return b.vBalance.Pos()
|
||||
}
|
||||
|
||||
// func (b *LeverBalance) LiquidationPrice() (price float64, valid bool) {
|
||||
// pos, _ := b.position.Float64()
|
||||
// if pos == 0 {
|
||||
// return
|
||||
// }
|
||||
// valid = true
|
||||
// if pos > 0 {
|
||||
// price, _ = b.openPrice.Sub(b.openPrice.Div(b.lever)).Float64()
|
||||
// } else {
|
||||
// price, _ = b.openPrice.Add(b.openPrice.Div(b.lever)).Float64()
|
||||
// }
|
||||
// return
|
||||
// }
|
||||
|
||||
func (b *LeverBalance) CheckLiquidation(price float64) (liqPrice float64, isLiq bool) {
|
||||
openPrice := decimal.NewFromFloat(b.vBalance.AvgOpenPrice())
|
||||
fee := b.vBalance.fee
|
||||
switch b.vBalance.position.Sign() {
|
||||
// <0
|
||||
case -1:
|
||||
// liqPrice + liqPrice * fee = openPrice + openPrice/lever
|
||||
// liqPrice *(1 + fee) = openPrice * (1 + 1/lever)
|
||||
// liqPrice = (openPrice * (1 + 1/lever))/(1-fee)
|
||||
liqPrice, _ = openPrice.Add(openPrice.Div(b.lever)).Div(numOne.Add(fee)).Float64()
|
||||
if price >= liqPrice {
|
||||
isLiq = true
|
||||
}
|
||||
// =0
|
||||
case 0:
|
||||
return
|
||||
// >0
|
||||
case 1:
|
||||
// liqPrice - liqPrice * fee = openPrice - openPrice/lever
|
||||
// (1-fee) * liqPrice = openPrice * (1 - 1/lever)
|
||||
// liqPrice = (openPrice * (1 - 1/lever))/(1-fee)
|
||||
liqPrice, _ = openPrice.Sub(openPrice.Div(b.lever)).Div(numOne.Sub(fee)).Float64()
|
||||
if price <= liqPrice {
|
||||
isLiq = true
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (b *LeverBalance) Get() (total float64) {
|
||||
total, _ = b.total.Float64()
|
||||
return
|
||||
}
|
||||
|
||||
func (b *LeverBalance) GetFeeTotal() float64 {
|
||||
return b.vBalance.GetFeeTotal()
|
||||
}
|
||||
|
||||
func (b *LeverBalance) AddTrade(tr Trade) (profit, onceFee float64, err error) {
|
||||
if tr.Action.IsOpen() {
|
||||
// check balance enough when open order
|
||||
amount := decimal.NewFromFloat(tr.Amount).Abs()
|
||||
cost := amount.Mul(decimal.NewFromFloat(tr.Price)).Abs()
|
||||
fee := cost.Mul(b.vBalance.fee)
|
||||
onceCost := cost.Div(b.lever).Add(fee)
|
||||
if b.total.LessThan(onceCost) {
|
||||
err = ErrNoBalance
|
||||
return
|
||||
}
|
||||
} else {
|
||||
liqPrice, isLiq := b.CheckLiquidation(tr.Price)
|
||||
if isLiq {
|
||||
tr.Price = liqPrice
|
||||
}
|
||||
}
|
||||
profit, onceFee, err = b.vBalance.AddTrade(tr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
b.total = b.total.Add(decimal.NewFromFloat(profit)).Sub(decimal.NewFromFloat(onceFee))
|
||||
return
|
||||
}
|
61
common/balance_lever_test.go
Normal file
61
common/balance_lever_test.go
Normal file
|
@ -0,0 +1,61 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "git.qtrade.icu/coin-quant/trademodel"
|
||||
)
|
||||
|
||||
func TestCheckLiquidationLong(t *testing.T) {
|
||||
lb := NewLeverBalance()
|
||||
lb.Set(100)
|
||||
lb.SetFee(0.0002)
|
||||
lb.SetLever(10)
|
||||
_, _, err := lb.AddTrade(Trade{Action: OpenLong, Price: 100, Amount: 9})
|
||||
if err != nil {
|
||||
t.Fatal("Liq lever AddTrade failed:" + err.Error())
|
||||
}
|
||||
liqPrice, isLiq := lb.CheckLiquidation(90.1)
|
||||
if isLiq {
|
||||
t.Fatal("Liq cal too large")
|
||||
}
|
||||
t.Log(liqPrice, isLiq)
|
||||
liqPrice, isLiq = lb.CheckLiquidation(90)
|
||||
if !isLiq {
|
||||
t.Fatal("Liq cal too small")
|
||||
}
|
||||
t.Log(liqPrice, isLiq)
|
||||
}
|
||||
|
||||
func TestCheckLiquidationShort(t *testing.T) {
|
||||
lb := NewLeverBalance()
|
||||
lb.Set(100)
|
||||
lb.SetFee(0.0002)
|
||||
lb.SetLever(10)
|
||||
_, _, err := lb.AddTrade(Trade{Action: OpenShort, Price: 100, Amount: 9})
|
||||
if err != nil {
|
||||
t.Fatal("Liq lever AddTrade failed:" + err.Error())
|
||||
}
|
||||
liqPrice, isLiq := lb.CheckLiquidation(109)
|
||||
if isLiq {
|
||||
t.Fatal("Liq cal too small")
|
||||
}
|
||||
t.Log(liqPrice, isLiq)
|
||||
liqPrice, isLiq = lb.CheckLiquidation(109.99)
|
||||
if !isLiq {
|
||||
t.Fatal("Liq cal too large")
|
||||
}
|
||||
t.Log(liqPrice, isLiq)
|
||||
}
|
||||
|
||||
func TestCheckLeverBalance(t *testing.T) {
|
||||
lb := NewLeverBalance()
|
||||
lb.Set(100)
|
||||
lb.SetFee(0.0002)
|
||||
lb.SetLever(10)
|
||||
_, _, err := lb.AddTrade(Trade{Action: OpenLong, Price: 100, Amount: 10})
|
||||
if err == nil {
|
||||
t.Fatal("Liq not work")
|
||||
}
|
||||
t.Log(err.Error())
|
||||
}
|
254
common/balance_test.go
Normal file
254
common/balance_test.go
Normal file
|
@ -0,0 +1,254 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
. "git.qtrade.icu/coin-quant/trademodel"
|
||||
"github.com/shopspring/decimal"
|
||||
)
|
||||
|
||||
func calFee(fee decimal.Decimal, trades ...Trade) float64 {
|
||||
var cost decimal.Decimal
|
||||
for _, v := range trades {
|
||||
dec := decimal.NewFromFloat(v.Price).Mul(decimal.NewFromFloat(v.Amount))
|
||||
cost = cost.Add(dec)
|
||||
}
|
||||
realFee, _ := cost.Mul(fee).Float64()
|
||||
return realFee
|
||||
}
|
||||
|
||||
func TestLong(t *testing.T) {
|
||||
tm := time.Now()
|
||||
openTrade := Trade{
|
||||
ID: "1",
|
||||
Action: OpenLong,
|
||||
Time: tm,
|
||||
Price: 100,
|
||||
Amount: 1,
|
||||
}
|
||||
closeTrade := Trade{
|
||||
ID: "2",
|
||||
Action: CloseLong,
|
||||
Time: tm.Add(time.Second),
|
||||
Price: 110,
|
||||
Amount: 1,
|
||||
}
|
||||
stopTrade := Trade{
|
||||
ID: "3",
|
||||
Action: StopLong,
|
||||
Time: tm.Add(time.Second * 2),
|
||||
Price: 90,
|
||||
Amount: 1,
|
||||
}
|
||||
|
||||
b := NewVBalance()
|
||||
b.Set(1000)
|
||||
b.AddTrade(openTrade)
|
||||
b.AddTrade(closeTrade)
|
||||
fee := calFee(b.fee, openTrade, closeTrade)
|
||||
if b.Get() != 1010-fee {
|
||||
t.Fatal("balance close error:", b.Get(), 1010-fee)
|
||||
}
|
||||
b.Set(1000)
|
||||
b.AddTrade(openTrade)
|
||||
b.AddTrade(stopTrade)
|
||||
|
||||
fee = calFee(b.fee, openTrade, stopTrade)
|
||||
if b.Get() != 990-fee {
|
||||
t.Fatal("balance stop error:", b.Get())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiLong(t *testing.T) {
|
||||
tm := time.Now()
|
||||
openTrade := Trade{
|
||||
ID: "1",
|
||||
Action: OpenLong,
|
||||
Time: tm,
|
||||
Price: 100,
|
||||
Amount: 1,
|
||||
}
|
||||
openTrade2 := Trade{
|
||||
ID: "1",
|
||||
Action: OpenLong,
|
||||
Time: tm,
|
||||
Price: 105,
|
||||
Amount: 1,
|
||||
}
|
||||
closeTrade := Trade{
|
||||
ID: "2",
|
||||
Action: CloseLong,
|
||||
Time: tm.Add(time.Second),
|
||||
Price: 110,
|
||||
Amount: 2,
|
||||
}
|
||||
stopTrade := Trade{
|
||||
ID: "3",
|
||||
Action: StopLong,
|
||||
Time: tm.Add(time.Second * 2),
|
||||
Price: 90,
|
||||
Amount: 2,
|
||||
}
|
||||
|
||||
b := NewVBalance()
|
||||
b.Set(1000)
|
||||
b.AddTrade(openTrade)
|
||||
b.AddTrade(openTrade2)
|
||||
b.AddTrade(closeTrade)
|
||||
fee := calFee(b.fee, openTrade, openTrade2, closeTrade)
|
||||
if b.Get() != 1015-fee {
|
||||
t.Fatal("balance close error:", b.Get(), fee)
|
||||
}
|
||||
b.Set(1000)
|
||||
b.AddTrade(openTrade)
|
||||
b.AddTrade(openTrade2)
|
||||
b.AddTrade(stopTrade)
|
||||
fee = calFee(b.fee, openTrade, openTrade2, stopTrade)
|
||||
if b.Get() != 975-fee {
|
||||
t.Fatal("balance stop error:", b.Get())
|
||||
}
|
||||
}
|
||||
|
||||
func TestShort(t *testing.T) {
|
||||
tm := time.Now()
|
||||
openTrade := Trade{
|
||||
ID: "1",
|
||||
Action: OpenShort,
|
||||
Time: tm,
|
||||
Price: 110,
|
||||
Amount: 1,
|
||||
}
|
||||
closeTrade := Trade{
|
||||
ID: "2",
|
||||
Action: CloseShort,
|
||||
Time: tm.Add(time.Second),
|
||||
Price: 100,
|
||||
Amount: 1,
|
||||
}
|
||||
stopTrade := Trade{
|
||||
ID: "3",
|
||||
Action: StopShort,
|
||||
Time: tm.Add(time.Second * 2),
|
||||
Price: 120,
|
||||
Amount: 1,
|
||||
}
|
||||
|
||||
b := NewVBalance()
|
||||
b.Set(1000)
|
||||
b.AddTrade(openTrade)
|
||||
b.AddTrade(closeTrade)
|
||||
fee := calFee(b.fee, openTrade, closeTrade)
|
||||
if b.Get() != 1010-fee {
|
||||
t.Fatal("balance close error:", b.Get(), 1010-fee)
|
||||
}
|
||||
b.Set(1000)
|
||||
b.AddTrade(openTrade)
|
||||
b.AddTrade(stopTrade)
|
||||
fee = calFee(b.fee, openTrade, stopTrade)
|
||||
if b.Get() != 990-fee {
|
||||
t.Fatal("balance stop error:", b.Get())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiShort(t *testing.T) {
|
||||
tm := time.Now()
|
||||
openTrade := Trade{
|
||||
ID: "1",
|
||||
Action: OpenShort,
|
||||
Time: tm,
|
||||
Price: 110,
|
||||
Amount: 1,
|
||||
}
|
||||
openTrade2 := Trade{
|
||||
ID: "1",
|
||||
Action: OpenShort,
|
||||
Time: tm,
|
||||
Price: 120,
|
||||
Amount: 1,
|
||||
}
|
||||
closeTrade := Trade{
|
||||
ID: "2",
|
||||
Action: CloseShort,
|
||||
Time: tm.Add(time.Second),
|
||||
Price: 100,
|
||||
Amount: 2,
|
||||
}
|
||||
stopTrade := Trade{
|
||||
ID: "3",
|
||||
Action: StopShort,
|
||||
Time: tm.Add(time.Second * 2),
|
||||
Price: 130,
|
||||
Amount: 2,
|
||||
}
|
||||
|
||||
b := NewVBalance()
|
||||
b.Set(1000)
|
||||
b.AddTrade(openTrade)
|
||||
b.AddTrade(openTrade2)
|
||||
b.AddTrade(closeTrade)
|
||||
fee := calFee(b.fee, openTrade, openTrade2, closeTrade)
|
||||
if b.Get() != 1030-fee {
|
||||
t.Fatal("balance close error:", b.Get())
|
||||
}
|
||||
b.Set(1000)
|
||||
b.AddTrade(openTrade)
|
||||
b.AddTrade(openTrade2)
|
||||
b.AddTrade(stopTrade)
|
||||
fee = calFee(b.fee, openTrade, openTrade2, stopTrade)
|
||||
if b.Get() != 970-fee {
|
||||
t.Fatal("balance stop error:", b.Get())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAvgPriceLong(t *testing.T) {
|
||||
tm := time.Now()
|
||||
openTrade := Trade{
|
||||
ID: "1",
|
||||
Action: OpenLong,
|
||||
Time: tm,
|
||||
Price: 110,
|
||||
Amount: 1,
|
||||
}
|
||||
openTrade2 := Trade{
|
||||
ID: "1",
|
||||
Action: OpenLong,
|
||||
Time: tm,
|
||||
Price: 120,
|
||||
Amount: 3,
|
||||
}
|
||||
|
||||
b := NewVBalance()
|
||||
b.Set(1000)
|
||||
b.AddTrade(openTrade)
|
||||
b.AddTrade(openTrade2)
|
||||
if b.AvgOpenPrice() != 117.5 {
|
||||
t.Fatalf("cal avg failed: %f", b.AvgOpenPrice())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAvgPriceShort(t *testing.T) {
|
||||
tm := time.Now()
|
||||
openTrade := Trade{
|
||||
ID: "1",
|
||||
Action: OpenShort,
|
||||
Time: tm,
|
||||
Price: 110,
|
||||
Amount: 1,
|
||||
}
|
||||
openTrade2 := Trade{
|
||||
ID: "1",
|
||||
Action: OpenShort,
|
||||
Time: tm,
|
||||
Price: 120,
|
||||
Amount: 3,
|
||||
}
|
||||
|
||||
b := NewVBalance()
|
||||
b.Set(1000)
|
||||
b.AddTrade(openTrade)
|
||||
b.AddTrade(openTrade2)
|
||||
if b.AvgOpenPrice() != 117.5 {
|
||||
t.Fatalf("cal avg failed: %f", b.AvgOpenPrice())
|
||||
}
|
||||
}
|
84
common/binsize.go
Normal file
84
common/binsize.go
Normal file
|
@ -0,0 +1,84 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultBinSizes = "1m, 5m, 15m, 30m, 1h, 4h, 1d"
|
||||
)
|
||||
|
||||
var (
|
||||
Day = time.Hour * 24
|
||||
Week = time.Hour * 24 * 7
|
||||
)
|
||||
|
||||
// ParseBinStrs parse binsizes to strs
|
||||
func ParseBinStrs(str string) (strs []string) {
|
||||
bins := strings.Split(str, ",")
|
||||
var temp string
|
||||
for _, v := range bins {
|
||||
temp = strings.Trim(v, " ")
|
||||
strs = append(strs, temp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ParseBinSizes parse binsizes
|
||||
func ParseBinSizes(str string) (durations []time.Duration, err error) {
|
||||
strs := ParseBinStrs(str)
|
||||
var t time.Duration
|
||||
for _, v := range strs {
|
||||
t, err = GetBinSizeDuration(v)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
durations = append(durations, t)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GetBinSizeDuration get duration of the binsize
|
||||
func GetBinSizeDuration(str string) (duration time.Duration, err error) {
|
||||
if len(str) == 0 {
|
||||
err = errors.New("binsize is empty")
|
||||
return
|
||||
}
|
||||
n, err := strconv.ParseInt(str, 10, 64)
|
||||
if err == nil {
|
||||
duration = time.Duration(n) * time.Minute
|
||||
return
|
||||
}
|
||||
err = nil
|
||||
char := str[len(str)-1]
|
||||
switch char {
|
||||
case 's', 'S':
|
||||
duration = time.Second
|
||||
case 'm':
|
||||
duration = time.Minute
|
||||
case 'h':
|
||||
duration = time.Hour
|
||||
case 'd', 'D':
|
||||
duration = Day
|
||||
case 'w', 'W':
|
||||
duration = Week
|
||||
default:
|
||||
err = fmt.Errorf("unsupport binsize: %s", str)
|
||||
return
|
||||
}
|
||||
if len(str) == 1 {
|
||||
return
|
||||
}
|
||||
value := str[0 : len(str)-1]
|
||||
n, err = strconv.ParseInt(value, 10, 64)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("parse binsize error:%s", err.Error())
|
||||
return
|
||||
}
|
||||
duration = time.Duration(n) * duration
|
||||
return
|
||||
}
|
39
common/binsize_test.go
Normal file
39
common/binsize_test.go
Normal file
|
@ -0,0 +1,39 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGetBinSizeDuration(t *testing.T) {
|
||||
testMap := map[string]time.Duration{
|
||||
"1s": time.Second,
|
||||
"5s": 5 * time.Second,
|
||||
"m": time.Minute,
|
||||
"1m": time.Minute,
|
||||
"5m": 5 * time.Minute,
|
||||
"15m": 15 * time.Minute,
|
||||
"30m": 30 * time.Minute,
|
||||
"1h": time.Hour,
|
||||
"4h": 4 * time.Hour,
|
||||
"6h": 6 * time.Hour,
|
||||
"1d": Day,
|
||||
"7d": Week,
|
||||
"1w": Week,
|
||||
"1": time.Minute,
|
||||
"15": 15 * time.Minute,
|
||||
"60": time.Hour,
|
||||
}
|
||||
|
||||
var temp time.Duration
|
||||
var err error
|
||||
for k, v := range testMap {
|
||||
temp, err = GetBinSizeDuration(k)
|
||||
if err != nil {
|
||||
t.Fatalf("parse %s failed:%s", k, err.Error())
|
||||
}
|
||||
if temp != v {
|
||||
t.Fatalf("GetBinSizeDuration failed:%s %s", temp, v)
|
||||
}
|
||||
}
|
||||
}
|
23
common/browser.go
Normal file
23
common/browser.go
Normal file
|
@ -0,0 +1,23 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
func OpenURL(strURL string) error {
|
||||
var cmd string
|
||||
var args []string
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
cmd = "cmd"
|
||||
args = []string{"/c", "start"}
|
||||
case "darwin":
|
||||
cmd = "open"
|
||||
default: // "linux", "freebsd", "openbsd", "netbsd"
|
||||
cmd = "xdg-open"
|
||||
}
|
||||
args = append(args, strURL)
|
||||
return exec.Command(cmd, args...).Start()
|
||||
}
|
129
common/engine.go
Normal file
129
common/engine.go
Normal file
|
@ -0,0 +1,129 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
. "git.qtrade.icu/coin-quant/trademodel"
|
||||
"github.com/bitly/go-simplejson"
|
||||
)
|
||||
|
||||
type CandleFn func(candle *Candle)
|
||||
|
||||
type Entry struct {
|
||||
Value interface{}
|
||||
Label string
|
||||
}
|
||||
|
||||
type Param struct {
|
||||
Name string
|
||||
Type string
|
||||
Label string
|
||||
Info string
|
||||
DefValue interface{}
|
||||
Enums []Entry
|
||||
ptr interface{}
|
||||
}
|
||||
|
||||
func StringParam(name, label, info, defValue string, ptr *string, enums ...Entry) Param {
|
||||
*ptr = defValue
|
||||
return Param{Name: name, Type: "string", Label: label, Info: info, DefValue: defValue, Enums: enums, ptr: ptr}
|
||||
}
|
||||
|
||||
func IntParam(name, label, info string, defValue int, ptr *int, enums ...Entry) Param {
|
||||
*ptr = defValue
|
||||
return Param{Name: name, Type: "int", Label: label, Info: info, DefValue: defValue, Enums: enums, ptr: ptr}
|
||||
}
|
||||
|
||||
func FloatParam(name, label, info string, defValue float64, ptr *float64, enums ...Entry) Param {
|
||||
*ptr = defValue
|
||||
return Param{Name: name, Type: "float", Label: label, Info: info, DefValue: defValue, Enums: enums, ptr: ptr}
|
||||
}
|
||||
|
||||
func BoolParam(name, label, info string, defValue bool, ptr *bool, enums ...Entry) Param {
|
||||
*ptr = defValue
|
||||
return Param{Name: name, Type: "bool", Label: label, Info: info, DefValue: defValue, Enums: enums, ptr: ptr}
|
||||
}
|
||||
|
||||
func ParseParams(str string, params []Param) (data ParamData, err error) {
|
||||
data = make(ParamData)
|
||||
sj := simplejson.New()
|
||||
err = sj.UnmarshalJSON([]byte(str))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var temp *simplejson.Json
|
||||
var ok, boolV bool
|
||||
var strV string
|
||||
var intV int
|
||||
var floatV float64
|
||||
for _, v := range params {
|
||||
if v.ptr == nil {
|
||||
data[v.Name] = sj.Get(v.Name).Interface()
|
||||
return
|
||||
}
|
||||
temp, ok = sj.CheckGet(v.Name)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
switch ptr := v.ptr.(type) {
|
||||
case *string:
|
||||
strV, err = temp.String()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
*ptr = strV
|
||||
data[v.Name] = strV
|
||||
case *float64:
|
||||
floatV, err = temp.Float64()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
*ptr = floatV
|
||||
data[v.Name] = floatV
|
||||
case *int:
|
||||
intV, err = temp.Int()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
*ptr = intV
|
||||
data[v.Name] = intV
|
||||
case *bool:
|
||||
boolV, err = temp.Bool()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
*ptr = boolV
|
||||
data[v.Name] = boolV
|
||||
default:
|
||||
err = fmt.Errorf("unsupport value type: %##v", ptr)
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type ParamData map[string]interface{}
|
||||
|
||||
func (d ParamData) GetString(key, defaultValue string) string {
|
||||
v, ok := d[key]
|
||||
if !ok {
|
||||
return defaultValue
|
||||
}
|
||||
ret := v.(string)
|
||||
if ret == "" {
|
||||
return defaultValue
|
||||
}
|
||||
return ret
|
||||
}
|
||||
func (d ParamData) GetFloat(key string, defaultValue float64) float64 {
|
||||
v, ok := d[key]
|
||||
if !ok {
|
||||
return defaultValue
|
||||
}
|
||||
ret := v.(float64)
|
||||
if ret == 0 {
|
||||
return defaultValue
|
||||
}
|
||||
return ret
|
||||
}
|
37
common/engine_test.go
Normal file
37
common/engine_test.go
Normal file
|
@ -0,0 +1,37 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParam(t *testing.T) {
|
||||
var str1, str2 string
|
||||
var int1, int2 int
|
||||
var float1, float2 float64
|
||||
params := []Param{
|
||||
StringParam("str1", "str test", "just a simple string", "a", &str1),
|
||||
StringParam("str2", "str test", "enum string", "B", &str2,
|
||||
Entry{Label: "A", Value: "A"},
|
||||
Entry{Label: "B", Value: "B"}),
|
||||
IntParam("int1", "int1 test", "just a simple int", 1, &int1),
|
||||
IntParam("int2", "int2 test", "enum int", 1, &int2,
|
||||
Entry{Label: "A", Value: 1},
|
||||
Entry{Label: "B", Value: 2}),
|
||||
FloatParam("float1", "float1 test", "just a simple int", 1, &float1),
|
||||
FloatParam("float2", "float2 test", "enum float", 1, &float2,
|
||||
Entry{Label: "A", Value: 1.0},
|
||||
Entry{Label: "B", Value: 2.0}),
|
||||
}
|
||||
|
||||
str := `{"str1": "str1", "str2":"A", "int1": 10, "int2": 1, "float1": 3, "float2": 2.0}`
|
||||
rets, err := ParseParams(str, params)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
if str1 != "str1" || str2 != "A" || int1 != 10 || int2 != 1 || float1 != 3 || float2 != 2.0 {
|
||||
t.Fatal("value not match", str1, str2, int1, int2, float1, float2)
|
||||
}
|
||||
if rets["str1"] != str1 || rets["str2"] != str2 || rets["int1"] != int1 || rets["int2"] != int2 || rets["float1"] != float1 || rets["float2"] != float2 {
|
||||
t.Fatal("value not match", str1, str2, int1, int2, float1, float2, rets)
|
||||
}
|
||||
}
|
47
common/float.go
Normal file
47
common/float.go
Normal file
|
@ -0,0 +1,47 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
)
|
||||
|
||||
// FloatMul return a*b
|
||||
func FloatMul(a, b float64) float64 {
|
||||
aDec := decimal.NewFromFloat(a)
|
||||
bDec := decimal.NewFromFloat(b)
|
||||
ret, _ := aDec.Mul(bDec).Float64()
|
||||
return ret
|
||||
}
|
||||
|
||||
// FloatAdd return a*b
|
||||
func FloatAdd(a, b float64) float64 {
|
||||
aDec := decimal.NewFromFloat(a)
|
||||
bDec := decimal.NewFromFloat(b)
|
||||
ret, _ := aDec.Add(bDec).Float64()
|
||||
return ret
|
||||
}
|
||||
|
||||
// FloatSub return a-b
|
||||
func FloatSub(a, b float64) float64 {
|
||||
aDec := decimal.NewFromFloat(a)
|
||||
bDec := decimal.NewFromFloat(b)
|
||||
ret, _ := aDec.Sub(bDec).Float64()
|
||||
return ret
|
||||
}
|
||||
|
||||
// FloatDiv return a/b
|
||||
func FloatDiv(a, b float64) float64 {
|
||||
aDec := decimal.NewFromFloat(a)
|
||||
bDec := decimal.NewFromFloat(b)
|
||||
ret, _ := aDec.Div(bDec).Float64()
|
||||
return ret
|
||||
}
|
||||
|
||||
// FormatFloat format float with precision
|
||||
func FormatFloat(n float64, precision int) float64 {
|
||||
str := fmt.Sprintf("%df", precision)
|
||||
n2, _ := strconv.ParseFloat(fmt.Sprintf("%."+str, n), 64)
|
||||
return n2
|
||||
}
|
153
common/merge.go
Normal file
153
common/merge.go
Normal file
|
@ -0,0 +1,153 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
. "git.qtrade.icu/coin-quant/trademodel"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// MergeKlineChan merge kline data
|
||||
func MergeKlineChan(klines chan []interface{}, srcDuration, dstDuration time.Duration) (rets chan []interface{}) {
|
||||
rets = make(chan []interface{}, len(klines))
|
||||
go func() {
|
||||
km := NewKlineMerge(srcDuration, dstDuration)
|
||||
var temp interface{}
|
||||
for v := range klines {
|
||||
tempDatas := []interface{}{}
|
||||
for _, d := range v {
|
||||
temp = km.Update(d)
|
||||
if temp != nil {
|
||||
tempDatas = append(tempDatas, temp)
|
||||
}
|
||||
}
|
||||
if len(tempDatas) != 0 {
|
||||
rets <- tempDatas
|
||||
}
|
||||
}
|
||||
close(rets)
|
||||
}()
|
||||
return
|
||||
}
|
||||
|
||||
// KlineMerge merge kline to new duration
|
||||
type KlineMerge struct {
|
||||
src int64 // src kline seconds
|
||||
dst int64 // dst kline seconds
|
||||
ratio int // dst/src kline ration
|
||||
cache CandleList // kline cache
|
||||
bFirst bool
|
||||
nextStart int64
|
||||
}
|
||||
|
||||
// NewKlineMergeStr new KlineMerge with string duration
|
||||
func NewKlineMergeStr(src, dst string) *KlineMerge {
|
||||
srcDur, err := time.ParseDuration(src)
|
||||
if err != nil {
|
||||
log.Errorf("NewKlineMergeStr parse src %s error: %s", src, err.Error())
|
||||
return nil
|
||||
}
|
||||
dstDur, err := time.ParseDuration(dst)
|
||||
if err != nil {
|
||||
log.Errorf("NewKlineMergeStr parse dst %s error: %s", dst, err.Error())
|
||||
return nil
|
||||
}
|
||||
return NewKlineMerge(srcDur, dstDur)
|
||||
}
|
||||
|
||||
// NewKlineMerge merge kline constructor
|
||||
func NewKlineMerge(src, dst time.Duration) *KlineMerge {
|
||||
km := new(KlineMerge)
|
||||
km.src = int64(src / time.Second)
|
||||
km.dst = int64(dst / time.Second)
|
||||
km.ratio = int(dst / src)
|
||||
km.bFirst = true
|
||||
return km
|
||||
}
|
||||
|
||||
// IsFirst is first time
|
||||
func (km *KlineMerge) IsFirst() bool {
|
||||
return km.bFirst
|
||||
}
|
||||
|
||||
// NeedMerge is kline need merge
|
||||
func (km *KlineMerge) NeedMerge() bool {
|
||||
return km.ratio != 1
|
||||
}
|
||||
|
||||
// GetSrc return kline source duration secs
|
||||
func (km *KlineMerge) GetSrc() int64 {
|
||||
return km.src
|
||||
}
|
||||
|
||||
// GetSrcDuration get kline source duration
|
||||
func (km *KlineMerge) GetSrcDuration() time.Duration {
|
||||
return time.Duration(km.src) * time.Second
|
||||
}
|
||||
|
||||
// GetDstDuration get kline dst duration
|
||||
func (km *KlineMerge) GetDstDuration() time.Duration {
|
||||
return time.Duration(km.dst) * time.Second
|
||||
}
|
||||
|
||||
// GetDst return kline dst duration secs
|
||||
func (km *KlineMerge) GetDst() int64 {
|
||||
return km.dst
|
||||
}
|
||||
|
||||
// Update update candle, and return new kline candle
|
||||
// return nil if no new kline candle
|
||||
func (km *KlineMerge) Update(data interface{}) (ret interface{}) {
|
||||
// return if no need to merge
|
||||
if km.ratio == 1 {
|
||||
ret = data
|
||||
return
|
||||
}
|
||||
candle, ok := data.(*Candle)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("KlineMerge data type error:%#v", data))
|
||||
return
|
||||
}
|
||||
n := len(km.cache)
|
||||
if n > 0 && candle.Start <= km.cache[n-1].Start {
|
||||
return
|
||||
}
|
||||
if km.bFirst && candle.Start%km.dst != 0 {
|
||||
return
|
||||
}
|
||||
km.bFirst = false
|
||||
var bNew bool
|
||||
if candle.Start >= km.nextStart {
|
||||
km.nextStart = (candle.Start/km.dst + 1) * km.dst
|
||||
if n != 0 {
|
||||
ret = km.cache.Merge()
|
||||
km.cache = CandleList{}
|
||||
bNew = true
|
||||
}
|
||||
}
|
||||
// add current candle to cache
|
||||
index := int(candle.Start%km.dst)/int(km.src) + 1
|
||||
km.cache = append(km.cache, candle)
|
||||
if bNew || index != km.ratio {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
// reset cache after kline merged
|
||||
km.cache = CandleList{}
|
||||
}()
|
||||
// cache length not match,just skip
|
||||
if len(km.cache) != km.ratio {
|
||||
log.Warnf("cache length not match,real:%d, want:%d", len(km.cache), km.ratio)
|
||||
// return
|
||||
}
|
||||
ret = km.cache.Merge()
|
||||
return
|
||||
}
|
||||
|
||||
func (km *KlineMerge) GetUnFinished() (ret interface{}) {
|
||||
if len(km.cache) == 0 {
|
||||
return nil
|
||||
}
|
||||
return km.cache.Merge()
|
||||
}
|
76
common/merge_test.go
Normal file
76
common/merge_test.go
Normal file
|
@ -0,0 +1,76 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
. "git.qtrade.icu/coin-quant/trademodel"
|
||||
)
|
||||
|
||||
func getTestData(source, dst time.Duration) (candles CandleList) {
|
||||
nSourceSec := int64(source / time.Second)
|
||||
nStart := nSourceSec * (time.Now().Add(0-source*20).Unix() / nSourceSec)
|
||||
candle := Candle{
|
||||
Start: nStart,
|
||||
Open: 100,
|
||||
High: 200,
|
||||
Low: 50,
|
||||
Close: 110,
|
||||
Turnover: 1,
|
||||
Volume: 1,
|
||||
Trades: 10,
|
||||
}
|
||||
for i := 0; i != 10; i++ {
|
||||
temp := candle
|
||||
temp.Open = candle.Open + float64(i)
|
||||
temp.High = candle.High + float64(i)
|
||||
temp.Low = candle.Low + float64(i)
|
||||
temp.Close = candle.Close + float64(i)
|
||||
temp.Turnover = candle.Turnover + float64(i)
|
||||
temp.Volume = candle.Volume + float64(i)
|
||||
temp.Trades = candle.Trades + int64(i)
|
||||
candles = append(candles, &temp)
|
||||
candle.Start += nSourceSec
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func TestMergeKline(t *testing.T) {
|
||||
source := time.Minute * 5
|
||||
dst := time.Minute * 15
|
||||
candles := getTestData(source, dst)
|
||||
m := NewKlineMerge(source, dst)
|
||||
for _, v := range candles {
|
||||
t.Log("candle:", v)
|
||||
}
|
||||
var ret interface{}
|
||||
for _, v := range candles {
|
||||
ret = m.Update(v)
|
||||
if ret == nil {
|
||||
continue
|
||||
}
|
||||
t.Log("ret:", ret)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeKlineChan(t *testing.T) {
|
||||
klines := make(chan []interface{}, 10)
|
||||
source := time.Minute * 5
|
||||
dst := time.Minute * 15
|
||||
candles := getTestData(source, dst)
|
||||
go func() {
|
||||
datas := make([]interface{}, len(candles))
|
||||
for k, v := range candles {
|
||||
t.Log(v)
|
||||
datas[k] = v
|
||||
}
|
||||
klines <- datas
|
||||
close(klines)
|
||||
}()
|
||||
ret := MergeKlineChan(klines, source, dst)
|
||||
for v := range ret {
|
||||
for _, d := range v {
|
||||
t.Log("ret:", d)
|
||||
}
|
||||
}
|
||||
}
|
69
common/path.go
Normal file
69
common/path.go
Normal file
|
@ -0,0 +1,69 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
var (
|
||||
pkgRegexp = regexp.MustCompile(`^package \w+\n$`)
|
||||
)
|
||||
|
||||
// GetExecDir return exec dir
|
||||
func GetExecDir() string {
|
||||
dir, _ := os.Executable()
|
||||
exPath := filepath.Dir(dir)
|
||||
return exPath
|
||||
}
|
||||
|
||||
func CopyWithMainPkg(dst, src string) (err error) {
|
||||
fSrc, err := os.Open(src)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("open %s file failed:%w", src, err)
|
||||
return
|
||||
}
|
||||
defer fSrc.Close()
|
||||
fDst, err := os.Create(dst)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("create %s file failed:%w", dst, err)
|
||||
return
|
||||
}
|
||||
defer fDst.Close()
|
||||
r := bufio.NewReader(fSrc)
|
||||
var line string
|
||||
for err == nil {
|
||||
line, err = r.ReadString('\n')
|
||||
if err != nil && err != io.EOF {
|
||||
break
|
||||
}
|
||||
if pkgRegexp.MatchString(line) {
|
||||
line = "package main"
|
||||
}
|
||||
fDst.Write([]byte(line))
|
||||
}
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func Copy(dst, src string) (err error) {
|
||||
fSrc, err := os.Open(src)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("open %s file failed:%w", src, err)
|
||||
return
|
||||
}
|
||||
defer fSrc.Close()
|
||||
fDst, err := os.Create(dst)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("create %s file failed:%w", dst, err)
|
||||
return
|
||||
}
|
||||
defer fDst.Close()
|
||||
_, err = io.Copy(fDst, fSrc)
|
||||
return
|
||||
}
|
34
engine/interface.go
Normal file
34
engine/interface.go
Normal file
|
@ -0,0 +1,34 @@
|
|||
package engine
|
||||
|
||||
import (
|
||||
"git.qtrade.icu/coin-quant/base/common"
|
||||
"git.qtrade.icu/coin-quant/indicator"
|
||||
"git.qtrade.icu/coin-quant/trademodel"
|
||||
)
|
||||
|
||||
const (
|
||||
StatusRunning = 0
|
||||
StatusSuccess = 1
|
||||
StatusFail = -1
|
||||
)
|
||||
|
||||
type Engine interface {
|
||||
OpenLong(price, amount float64) string
|
||||
CloseLong(price, amount float64) string
|
||||
OpenShort(price, amount float64) string
|
||||
CloseShort(price, amount float64) string
|
||||
StopLong(price, amount float64) string
|
||||
StopShort(price, amount float64) string
|
||||
CancelOrder(string)
|
||||
CancelAllOrder()
|
||||
DoOrder(typ trademodel.TradeType, price, amount float64) string
|
||||
AddIndicator(name string, params ...int) (ind indicator.CommonIndicator)
|
||||
Position() (pos, price float64)
|
||||
Balance() float64
|
||||
Log(v ...interface{})
|
||||
Watch(watchType string)
|
||||
SendNotify(title, content, contentType string)
|
||||
Merge(src, dst string, fn common.CandleFn)
|
||||
SetBalance(balance float64)
|
||||
UpdateStatus(status int, msg string)
|
||||
}
|
67
fsm/fsm.go
Normal file
67
fsm/fsm.go
Normal file
|
@ -0,0 +1,67 @@
|
|||
package fsm
|
||||
|
||||
import "fmt"
|
||||
|
||||
type Rule struct {
|
||||
Name string
|
||||
Src []string
|
||||
Dst string
|
||||
}
|
||||
|
||||
type EventDesc struct {
|
||||
Name string
|
||||
Src string
|
||||
Dst string
|
||||
Args []interface{}
|
||||
}
|
||||
|
||||
type Callback func(event EventDesc)
|
||||
|
||||
type FSM struct {
|
||||
state string
|
||||
rules []Rule
|
||||
callbacks map[string]Callback
|
||||
}
|
||||
|
||||
func NewFSM(initia string, rules []Rule) *FSM {
|
||||
f := new(FSM)
|
||||
f.callbacks = make(map[string]Callback)
|
||||
f.state = initia
|
||||
f.rules = rules
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *FSM) SetCallback(state string, cb Callback) {
|
||||
f.callbacks[state] = cb
|
||||
}
|
||||
|
||||
func (f *FSM) Event(event string, args ...interface{}) (err error) {
|
||||
var src, dst string
|
||||
Out:
|
||||
for _, v := range f.rules {
|
||||
if v.Name == event {
|
||||
for _, s := range v.Src {
|
||||
if s == f.state {
|
||||
dst = v.Dst
|
||||
src = s
|
||||
break Out
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if dst == "" {
|
||||
err = fmt.Errorf("current state: %s,skip event:%s", f.state, event)
|
||||
return
|
||||
}
|
||||
f.state = dst
|
||||
cb, ok := f.callbacks[dst]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
cb(EventDesc{Name: event, Src: src, Dst: dst, Args: args})
|
||||
return
|
||||
}
|
||||
|
||||
func (f *FSM) Current() string {
|
||||
return f.state
|
||||
}
|
30
fsm/fsm_test.go
Normal file
30
fsm/fsm_test.go
Normal file
|
@ -0,0 +1,30 @@
|
|||
package fsm
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFSM(t *testing.T) {
|
||||
rules := []Rule{
|
||||
{"openOrder", []string{"init"}, "open"},
|
||||
{"stopOrder", []string{"open", "openMore"}, "init"},
|
||||
{"closeOrder", []string{"open", "openMore"}, "init"},
|
||||
{"openOrder", []string{"open"}, "openMore"},
|
||||
}
|
||||
args := []interface{}{"Hello", 1, 2, 3}
|
||||
cb := func(event EventDesc) {
|
||||
assert.Equal(t, event, EventDesc{Name: "openOrder", Src: "init", Dst: "open", Args: args}, "state error")
|
||||
t.Log(event.Name, event.Src, event.Dst, event.Args)
|
||||
}
|
||||
fsm := NewFSM("init", rules)
|
||||
fsm.SetCallback("open", cb)
|
||||
assert.Equal(t, fsm.Current(), "init", "init state error")
|
||||
fsm.Event("openOrder", args...)
|
||||
assert.Equal(t, fsm.Current(), "open", "open state error")
|
||||
err := fsm.Event("test")
|
||||
assert.NotNil(t, err, "error check failed")
|
||||
fsm.Event("closeOrder")
|
||||
assert.Equal(t, fsm.Current(), "init", "close state error")
|
||||
}
|
19
go.mod
Normal file
19
go.mod
Normal file
|
@ -0,0 +1,19 @@
|
|||
module git.qtrade.icu/coin-quant/base
|
||||
|
||||
go 1.22.0
|
||||
|
||||
require (
|
||||
git.qtrade.icu/coin-quant/indicator v0.0.0-20240625151736-c23020eee562
|
||||
git.qtrade.icu/coin-quant/trademodel v0.0.0-20240625151548-cef4b6fc28b9
|
||||
github.com/bitly/go-simplejson v0.5.1
|
||||
github.com/shopspring/decimal v1.4.0
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
github.com/stretchr/testify v1.9.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
26
go.sum
Normal file
26
go.sum
Normal file
|
@ -0,0 +1,26 @@
|
|||
git.qtrade.icu/coin-quant/indicator v0.0.0-20240625151736-c23020eee562 h1:oA06Mq/hJtzJ6k7ZW6kd3RY9EBDLVCEPDFADiP9JwIk=
|
||||
git.qtrade.icu/coin-quant/indicator v0.0.0-20240625151736-c23020eee562/go.mod h1:x1+rqPrwJqPLETFdMQGhzp71Z3ZxAlNFExGVOhk+IT0=
|
||||
git.qtrade.icu/coin-quant/trademodel v0.0.0-20240625151548-cef4b6fc28b9 h1:9T1u+MzfbG9jZU1wzDtmBoOwN1m/fRX0iX7NbLwAHgU=
|
||||
git.qtrade.icu/coin-quant/trademodel v0.0.0-20240625151548-cef4b6fc28b9/go.mod h1:SZnI+IqcRlKVcDSS++NIgthZX4GG1OU4UG+RDrSOD34=
|
||||
github.com/bitly/go-simplejson v0.5.1 h1:xgwPbetQScXt1gh9BmoJ6j9JMr3TElvuIyjR8pgdoow=
|
||||
github.com/bitly/go-simplejson v0.5.1/go.mod h1:YOPVLzCfwK14b4Sff3oP1AmGhI9T9Vsg84etUnlyp+Q=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
|
||||
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 h1:0A+M6Uqn+Eje4kHMK80dtF3JCXC4ykBgQG4Fe06QRhQ=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
Loading…
Reference in New Issue
Block a user