fix: some types in SeriesExtended are not supported

This commit is contained in:
zenix 2023-05-24 11:31:28 +09:00
parent 14849afe4e
commit 508f42663d
2 changed files with 96 additions and 43 deletions

View File

@ -383,27 +383,8 @@ type AddSeriesResult struct {
// Add two series, result[i] = a[i] + b[i] // Add two series, result[i] = a[i] + b[i]
func Add(a interface{}, b interface{}) SeriesExtend { func Add(a interface{}, b interface{}) SeriesExtend {
var aa Series aa := switchIface(a)
var bb Series bb := switchIface(b)
switch tp := a.(type) {
case float64:
aa = NumberSeries(tp)
case Series:
aa = tp
default:
panic("input should be either *Series or float64")
}
switch tp := b.(type) {
case float64:
bb = NumberSeries(tp)
case Series:
bb = tp
default:
panic("input should be either *Series or float64")
}
return NewSeries(&AddSeriesResult{aa, bb}) return NewSeries(&AddSeriesResult{aa, bb})
} }
@ -473,7 +454,7 @@ func switchIface(b interface{}) Series {
return tp return tp
default: default:
fmt.Println(reflect.TypeOf(b)) fmt.Println(reflect.TypeOf(b))
panic("input should be either *Series or float64") panic("input should be either *Series or numbers")
} }
} }
@ -515,28 +496,9 @@ var _ Series = &DivSeriesResult{}
// Multiple two series, result[i] = a[i] * b[i] // Multiple two series, result[i] = a[i] * b[i]
func Mul(a interface{}, b interface{}) SeriesExtend { func Mul(a interface{}, b interface{}) SeriesExtend {
var aa Series aa := switchIface(a)
var bb Series bb := switchIface(b)
switch tp := a.(type) {
case float64:
aa = NumberSeries(tp)
case Series:
aa = tp
default:
panic("input should be either Series or float64")
}
switch tp := b.(type) {
case float64:
bb = NumberSeries(tp)
case Series:
bb = tp
default:
panic("input should be either Series or float64")
}
return NewSeries(&MulSeriesResult{aa, bb}) return NewSeries(&MulSeriesResult{aa, bb})
} }
type MulSeriesResult struct { type MulSeriesResult struct {
@ -577,6 +539,18 @@ func Dot(a interface{}, b interface{}, limit ...int) float64 {
case float64: case float64:
aaf = tp aaf = tp
isaf = true isaf = true
case int32:
aaf = float64(tp)
isaf = true
case int64:
aaf = float64(tp)
isaf = true
case float32:
aaf = float64(tp)
isaf = true
case int:
aaf = float64(tp)
isaf = true
case Series: case Series:
aas = tp aas = tp
isaf = false isaf = false
@ -587,6 +561,18 @@ func Dot(a interface{}, b interface{}, limit ...int) float64 {
case float64: case float64:
bbf = tp bbf = tp
isbf = true isbf = true
case int32:
aaf = float64(tp)
isaf = true
case int64:
aaf = float64(tp)
isaf = true
case float32:
aaf = float64(tp)
isaf = true
case int:
aaf = float64(tp)
isaf = true
case Series: case Series:
bbs = tp bbs = tp
isbf = false isbf = false

View File

@ -2,6 +2,7 @@ package types
import ( import (
//"os" //"os"
"math"
"testing" "testing"
"time" "time"
@ -12,6 +13,14 @@ import (
"github.com/c9s/bbgo/pkg/datatype/floats" "github.com/c9s/bbgo/pkg/datatype/floats"
) )
func TestQueue(t *testing.T) {
zeroq := NewQueue(0)
assert.Equal(t, zeroq.Last(), 0.)
assert.Equal(t, zeroq.Index(0), 0.)
zeroq.Update(1.)
assert.Equal(t, zeroq.Length(), 0)
}
func TestFloat(t *testing.T) { func TestFloat(t *testing.T) {
var a Series = Minus(3., 2.) var a Series = Minus(3., 2.)
assert.Equal(t, a.Last(), 1.) assert.Equal(t, a.Last(), 1.)
@ -123,6 +132,64 @@ func TestSigmoid(t *testing.T) {
} }
} }
func TestHighLowest(t *testing.T) {
a := floats.Slice{3.0, 1.0, 2.1}
assert.Equal(t, 3.0, Highest(&a, 4))
assert.Equal(t, 1.0, Lowest(&a, 4))
}
func TestAdd(t *testing.T) {
var a NumberSeries = 3.0
var b NumberSeries = 2.0
out := Add(&a, &b)
assert.Equal(t, out.Last(), 5.0)
assert.Equal(t, out.Index(0), 5.0)
assert.Equal(t, out.Length(), math.MaxInt32)
}
func TestDiv(t *testing.T) {
a := floats.Slice{3.0, 1.0, 2.0}
b := NumberSeries(2.0)
out := Div(&a, &b)
assert.Equal(t, out.Last(), 1.0)
assert.Equal(t, out.Length(), 3)
assert.Equal(t, out.Index(1), 0.5)
}
func TestMul(t *testing.T) {
a := floats.Slice{3.0, 1.0, 2.0}
b := NumberSeries(2.0)
out := Mul(&a, &b)
assert.Equal(t, out.Last(), 4.0)
assert.Equal(t, out.Length(), 3)
assert.Equal(t, out.Index(1), 2.0)
}
func TestArray(t *testing.T) {
a := floats.Slice{3.0, 1.0, 2.0}
out := Array(&a, 1)
assert.Equal(t, len(out), 1)
out = Array(&a, 4)
assert.Equal(t, len(out), 3)
}
func TestSwitchInterface(t *testing.T) {
var a int = 1
var af float64 = 1.0
var b int32 = 2
var bf float64 = 2.0
var c int64 = 3
var cf float64 = 3.0
var d float32 = 4.0
var df float64 = 4.0
var e float64 = 5.0
assert.Equal(t, switchIface(a).Last(), af)
assert.Equal(t, switchIface(b).Last(), bf)
assert.Equal(t, switchIface(c).Last(), cf)
assert.Equal(t, switchIface(d).Last(), df)
assert.Equal(t, switchIface(e).Last(), e)
}
// from https://en.wikipedia.org/wiki/Logistic_regression // from https://en.wikipedia.org/wiki/Logistic_regression
func TestLogisticRegression(t *testing.T) { func TestLogisticRegression(t *testing.T) {
a := []floats.Slice{{0.5, 0.75, 1., 1.25, 1.5, 1.75, 1.75, 2.0, 2.25, 2.5, 2.75, 3., 3.25, 3.5, 4., 4.25, 4.5, 4.75, 5., 5.5}} a := []floats.Slice{{0.5, 0.75, 1., 1.25, 1.5, 1.75, 1.75, 2.0, 2.25, 2.5, 2.75, 3., 3.25, 3.5, 4., 4.25, 4.5, 4.75, 5., 5.5}}