qbtrade/pkg/optimizer/hpoptimizer_test.go

202 lines
6.2 KiB
Go
Raw Normal View History

2024-06-27 14:42:38 +00:00
package optimizer
import (
"git.qtrade.icu/lychiyu/qbtrade/pkg/fixedpoint"
"reflect"
"testing"
)
func TestBuildParamDomains(t *testing.T) {
var floatRangeDomainVerifier = func(domain paramDomain, expect SelectorConfig) bool {
concrete := domain.(*floatRangeDomain)
return *concrete == floatRangeDomain{
paramDomainBase: paramDomainBase{label: expect.Label, path: expect.Path},
min: expect.Min.Float64(),
max: expect.Max.Float64(),
}
}
var floatDiscreteRangeDomainVerifier = func(domain paramDomain, expect SelectorConfig) bool {
concrete := domain.(*floatDiscreteRangeDomain)
return *concrete == floatDiscreteRangeDomain{
paramDomainBase: paramDomainBase{label: expect.Label, path: expect.Path},
min: expect.Min.Float64(),
max: expect.Max.Float64(),
step: expect.Step.Float64(),
}
}
var intRangeDomainVerifier = func(domain paramDomain, expect SelectorConfig) bool {
concrete := domain.(*intRangeDomain)
return *concrete == intRangeDomain{
paramDomainBase: paramDomainBase{label: expect.Label, path: expect.Path},
min: expect.Min.Int(),
max: expect.Max.Int(),
}
}
var intStepRangeDomainVerifier = func(domain paramDomain, expect SelectorConfig) bool {
concrete := domain.(*intStepRangeDomain)
return *concrete == intStepRangeDomain{
paramDomainBase: paramDomainBase{label: expect.Label, path: expect.Path},
min: expect.Min.Int(),
max: expect.Max.Int(),
step: expect.Step.Int(),
}
}
var stringDomainVerifier = func(domain paramDomain, expect SelectorConfig) bool {
concrete := domain.(*stringDomain)
expectBase := paramDomainBase{label: expect.Label, path: expect.Path}
if concrete.paramDomainBase != expectBase {
return false
}
if len(concrete.options) != len(expect.Values) {
return false
}
for i, item := range concrete.options {
if item != expect.Values[i] {
return false
}
}
return true
}
var boolDomainVerifier = func(domain paramDomain, expect SelectorConfig) bool {
concrete := domain.(*boolDomain)
return *concrete == boolDomain{
paramDomainBase: paramDomainBase{label: expect.Label, path: expect.Path},
}
}
tests := []struct {
config SelectorConfig
verify func(domain paramDomain, expect SelectorConfig) bool
}{
{
config: SelectorConfig{
Type: selectorTypeRange,
Label: "range label",
Path: "range path",
Values: []string{"ignore", "ignore"},
Min: fixedpoint.NewFromFloat(7.0),
Max: fixedpoint.NewFromFloat(80.0),
Step: fixedpoint.NewFromFloat(0.0),
},
verify: floatRangeDomainVerifier,
}, {
config: SelectorConfig{
Type: selectorTypeRangeFloat,
Label: "rangeFloat label",
Path: "rangeFloat path",
Values: []string{"ignore", "ignore"},
Min: fixedpoint.NewFromFloat(6.0),
Max: fixedpoint.NewFromFloat(10.0),
Step: fixedpoint.NewFromFloat(0.0),
},
verify: floatRangeDomainVerifier,
}, {
config: SelectorConfig{
Type: selectorTypeRangeFloat,
Label: "rangeDiscreteFloat label",
Path: "rangeDiscreteFloat path",
Values: []string{"ignore", "ignore"},
Min: fixedpoint.NewFromFloat(6.0),
Max: fixedpoint.NewFromFloat(10.0),
Step: fixedpoint.NewFromFloat(2.0),
},
verify: floatDiscreteRangeDomainVerifier,
}, {
config: SelectorConfig{
Type: selectorTypeRangeInt,
Label: "rangeInt label",
Path: "rangeInt path",
Values: []string{"ignore", "ignore"},
Min: fixedpoint.NewFromInt(3),
Max: fixedpoint.NewFromInt(100),
Step: fixedpoint.NewFromInt(0),
},
verify: intRangeDomainVerifier,
}, {
config: SelectorConfig{
Type: selectorTypeRangeInt,
Label: "rangeInt label",
Path: "rangeInt path",
Values: []string{"ignore", "ignore"},
Min: fixedpoint.NewFromInt(3),
Max: fixedpoint.NewFromInt(100),
Step: fixedpoint.NewFromInt(7),
},
verify: intStepRangeDomainVerifier,
}, {
config: SelectorConfig{
Type: selectorTypeIterate,
Label: "iterate label",
Path: "iterate path",
Values: nil,
Min: fixedpoint.NewFromInt(0),
Max: fixedpoint.NewFromInt(-8),
Step: fixedpoint.NewFromInt(-1),
},
verify: stringDomainVerifier,
}, {
config: SelectorConfig{
Type: selectorTypeString,
Label: "string label",
Path: "string path",
Values: []string{"option1", "option2", "option3"},
Min: fixedpoint.NewFromInt(0),
Max: fixedpoint.NewFromInt(-8),
Step: fixedpoint.NewFromInt(-1),
},
verify: stringDomainVerifier,
}, {
config: SelectorConfig{
Type: selectorTypeBool,
Label: "bool label",
Path: "bool path",
Values: []string{"ignore"},
Min: fixedpoint.NewFromInt(99),
Max: fixedpoint.NewFromInt(1064),
Step: fixedpoint.NewFromInt(-89),
},
verify: boolDomainVerifier,
}, {
config: SelectorConfig{
Type: "unknown type",
Label: "unknown label",
Path: "unknown path",
Values: []string{"unknown option"},
Min: fixedpoint.NewFromInt(99),
Max: fixedpoint.NewFromFloat(1064),
Step: fixedpoint.NewFromInt(0),
},
verify: nil,
},
}
selectors := make([]SelectorConfig, len(tests))
expectLabelPaths := make(map[string]string)
verifiers := make([]func(domain paramDomain) bool, 0, len(tests))
for i, testItem := range tests {
itemConfig, itemVerify := testItem.config, testItem.verify
selectors[i] = itemConfig
if itemVerify != nil {
expectLabelPaths[testItem.config.Label] = testItem.config.Path
verifiers = append(verifiers, func(domain paramDomain) bool {
return itemVerify(domain, itemConfig)
})
}
}
optimizer := &HyperparameterOptimizer{Config: &Config{Matrix: selectors}}
exactLabelPaths, exactParamDomains := optimizer.buildParamDomains()
if !reflect.DeepEqual(exactLabelPaths, expectLabelPaths) {
t.Errorf("expectLabelPaths=%v, exactLabelPaths=%v", expectLabelPaths, exactLabelPaths)
}
if len(exactParamDomains) != len(verifiers) {
t.Errorf("expect %d param domains, got %d", len(verifiers), len(exactParamDomains))
}
for i, verifier := range verifiers {
pd := exactParamDomains[i]
if !verifier(pd) {
t.Errorf("unexpect param domain at #%d: %#v", i, pd)
}
}
}