Merge pull request #1177 from zenixls2/fix/indicator_minor_fix

fix: some types and operations in SeriesExtended are not supported
This commit is contained in:
Zenix 2023-05-24 21:03:01 +09:00 committed by GitHub
commit 8aa1f668d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 96 additions and 43 deletions

View File

@ -383,27 +383,8 @@ type AddSeriesResult struct {
// Add two series, result[i] = a[i] + b[i]
func Add(a interface{}, b interface{}) SeriesExtend {
var aa Series
var bb Series
switch tp := a.(type) {
case float64:
aa = NumberSeries(tp)
case Series:
aa = tp
default:
panic("input should be either *Series or float64")
}
switch tp := b.(type) {
case float64:
bb = NumberSeries(tp)
case Series:
bb = tp
default:
panic("input should be either *Series or float64")
}
aa := switchIface(a)
bb := switchIface(b)
return NewSeries(&AddSeriesResult{aa, bb})
}
@ -473,7 +454,7 @@ func switchIface(b interface{}) Series {
return tp
default:
fmt.Println(reflect.TypeOf(b))
panic("input should be either *Series or float64")
panic("input should be either *Series or numbers")
}
}
@ -515,28 +496,9 @@ var _ Series = &DivSeriesResult{}
// Multiple two series, result[i] = a[i] * b[i]
func Mul(a interface{}, b interface{}) SeriesExtend {
var aa Series
var bb Series
switch tp := a.(type) {
case float64:
aa = NumberSeries(tp)
case Series:
aa = tp
default:
panic("input should be either Series or float64")
}
switch tp := b.(type) {
case float64:
bb = NumberSeries(tp)
case Series:
bb = tp
default:
panic("input should be either Series or float64")
}
aa := switchIface(a)
bb := switchIface(b)
return NewSeries(&MulSeriesResult{aa, bb})
}
type MulSeriesResult struct {
@ -577,6 +539,18 @@ func Dot(a interface{}, b interface{}, limit ...int) float64 {
case float64:
aaf = tp
isaf = true
case int32:
aaf = float64(tp)
isaf = true
case int64:
aaf = float64(tp)
isaf = true
case float32:
aaf = float64(tp)
isaf = true
case int:
aaf = float64(tp)
isaf = true
case Series:
aas = tp
isaf = false
@ -587,6 +561,18 @@ func Dot(a interface{}, b interface{}, limit ...int) float64 {
case float64:
bbf = tp
isbf = true
case int32:
aaf = float64(tp)
isaf = true
case int64:
aaf = float64(tp)
isaf = true
case float32:
aaf = float64(tp)
isaf = true
case int:
aaf = float64(tp)
isaf = true
case Series:
bbs = tp
isbf = false

View File

@ -2,6 +2,7 @@ package types
import (
//"os"
"math"
"testing"
"time"
@ -12,6 +13,14 @@ import (
"github.com/c9s/bbgo/pkg/datatype/floats"
)
func TestQueue(t *testing.T) {
zeroq := NewQueue(0)
assert.Equal(t, zeroq.Last(), 0.)
assert.Equal(t, zeroq.Index(0), 0.)
zeroq.Update(1.)
assert.Equal(t, zeroq.Length(), 0)
}
func TestFloat(t *testing.T) {
var a Series = Minus(3., 2.)
assert.Equal(t, a.Last(), 1.)
@ -123,6 +132,64 @@ func TestSigmoid(t *testing.T) {
}
}
func TestHighLowest(t *testing.T) {
a := floats.Slice{3.0, 1.0, 2.1}
assert.Equal(t, 3.0, Highest(&a, 4))
assert.Equal(t, 1.0, Lowest(&a, 4))
}
func TestAdd(t *testing.T) {
var a NumberSeries = 3.0
var b NumberSeries = 2.0
out := Add(&a, &b)
assert.Equal(t, out.Last(), 5.0)
assert.Equal(t, out.Index(0), 5.0)
assert.Equal(t, out.Length(), math.MaxInt32)
}
func TestDiv(t *testing.T) {
a := floats.Slice{3.0, 1.0, 2.0}
b := NumberSeries(2.0)
out := Div(&a, &b)
assert.Equal(t, out.Last(), 1.0)
assert.Equal(t, out.Length(), 3)
assert.Equal(t, out.Index(1), 0.5)
}
func TestMul(t *testing.T) {
a := floats.Slice{3.0, 1.0, 2.0}
b := NumberSeries(2.0)
out := Mul(&a, &b)
assert.Equal(t, out.Last(), 4.0)
assert.Equal(t, out.Length(), 3)
assert.Equal(t, out.Index(1), 2.0)
}
func TestArray(t *testing.T) {
a := floats.Slice{3.0, 1.0, 2.0}
out := Array(&a, 1)
assert.Equal(t, len(out), 1)
out = Array(&a, 4)
assert.Equal(t, len(out), 3)
}
func TestSwitchInterface(t *testing.T) {
var a int = 1
var af float64 = 1.0
var b int32 = 2
var bf float64 = 2.0
var c int64 = 3
var cf float64 = 3.0
var d float32 = 4.0
var df float64 = 4.0
var e float64 = 5.0
assert.Equal(t, switchIface(a).Last(), af)
assert.Equal(t, switchIface(b).Last(), bf)
assert.Equal(t, switchIface(c).Last(), cf)
assert.Equal(t, switchIface(d).Last(), df)
assert.Equal(t, switchIface(e).Last(), e)
}
// from https://en.wikipedia.org/wiki/Logistic_regression
func TestLogisticRegression(t *testing.T) {
a := []floats.Slice{{0.5, 0.75, 1., 1.25, 1.5, 1.75, 1.75, 2.0, 2.25, 2.5, 2.75, 3., 3.25, 3.5, 4., 4.25, 4.5, 4.75, 5., 5.5}}