202 lines
6.2 KiB
Go
202 lines
6.2 KiB
Go
package optimizer
|
|
|
|
import (
|
|
"git.qtrade.icu/lychiyu/qbtrade/pkg/fixedpoint"
|
|
"reflect"
|
|
"testing"
|
|
)
|
|
|
|
func TestBuildParamDomains(t *testing.T) {
|
|
var floatRangeDomainVerifier = func(domain paramDomain, expect SelectorConfig) bool {
|
|
concrete := domain.(*floatRangeDomain)
|
|
return *concrete == floatRangeDomain{
|
|
paramDomainBase: paramDomainBase{label: expect.Label, path: expect.Path},
|
|
min: expect.Min.Float64(),
|
|
max: expect.Max.Float64(),
|
|
}
|
|
}
|
|
var floatDiscreteRangeDomainVerifier = func(domain paramDomain, expect SelectorConfig) bool {
|
|
concrete := domain.(*floatDiscreteRangeDomain)
|
|
return *concrete == floatDiscreteRangeDomain{
|
|
paramDomainBase: paramDomainBase{label: expect.Label, path: expect.Path},
|
|
min: expect.Min.Float64(),
|
|
max: expect.Max.Float64(),
|
|
step: expect.Step.Float64(),
|
|
}
|
|
}
|
|
var intRangeDomainVerifier = func(domain paramDomain, expect SelectorConfig) bool {
|
|
concrete := domain.(*intRangeDomain)
|
|
return *concrete == intRangeDomain{
|
|
paramDomainBase: paramDomainBase{label: expect.Label, path: expect.Path},
|
|
min: expect.Min.Int(),
|
|
max: expect.Max.Int(),
|
|
}
|
|
}
|
|
var intStepRangeDomainVerifier = func(domain paramDomain, expect SelectorConfig) bool {
|
|
concrete := domain.(*intStepRangeDomain)
|
|
return *concrete == intStepRangeDomain{
|
|
paramDomainBase: paramDomainBase{label: expect.Label, path: expect.Path},
|
|
min: expect.Min.Int(),
|
|
max: expect.Max.Int(),
|
|
step: expect.Step.Int(),
|
|
}
|
|
}
|
|
var stringDomainVerifier = func(domain paramDomain, expect SelectorConfig) bool {
|
|
concrete := domain.(*stringDomain)
|
|
expectBase := paramDomainBase{label: expect.Label, path: expect.Path}
|
|
if concrete.paramDomainBase != expectBase {
|
|
return false
|
|
}
|
|
if len(concrete.options) != len(expect.Values) {
|
|
return false
|
|
}
|
|
for i, item := range concrete.options {
|
|
if item != expect.Values[i] {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
var boolDomainVerifier = func(domain paramDomain, expect SelectorConfig) bool {
|
|
concrete := domain.(*boolDomain)
|
|
return *concrete == boolDomain{
|
|
paramDomainBase: paramDomainBase{label: expect.Label, path: expect.Path},
|
|
}
|
|
}
|
|
|
|
tests := []struct {
|
|
config SelectorConfig
|
|
verify func(domain paramDomain, expect SelectorConfig) bool
|
|
}{
|
|
{
|
|
config: SelectorConfig{
|
|
Type: selectorTypeRange,
|
|
Label: "range label",
|
|
Path: "range path",
|
|
Values: []string{"ignore", "ignore"},
|
|
Min: fixedpoint.NewFromFloat(7.0),
|
|
Max: fixedpoint.NewFromFloat(80.0),
|
|
Step: fixedpoint.NewFromFloat(0.0),
|
|
},
|
|
verify: floatRangeDomainVerifier,
|
|
}, {
|
|
config: SelectorConfig{
|
|
Type: selectorTypeRangeFloat,
|
|
Label: "rangeFloat label",
|
|
Path: "rangeFloat path",
|
|
Values: []string{"ignore", "ignore"},
|
|
Min: fixedpoint.NewFromFloat(6.0),
|
|
Max: fixedpoint.NewFromFloat(10.0),
|
|
Step: fixedpoint.NewFromFloat(0.0),
|
|
},
|
|
verify: floatRangeDomainVerifier,
|
|
}, {
|
|
config: SelectorConfig{
|
|
Type: selectorTypeRangeFloat,
|
|
Label: "rangeDiscreteFloat label",
|
|
Path: "rangeDiscreteFloat path",
|
|
Values: []string{"ignore", "ignore"},
|
|
Min: fixedpoint.NewFromFloat(6.0),
|
|
Max: fixedpoint.NewFromFloat(10.0),
|
|
Step: fixedpoint.NewFromFloat(2.0),
|
|
},
|
|
verify: floatDiscreteRangeDomainVerifier,
|
|
}, {
|
|
config: SelectorConfig{
|
|
Type: selectorTypeRangeInt,
|
|
Label: "rangeInt label",
|
|
Path: "rangeInt path",
|
|
Values: []string{"ignore", "ignore"},
|
|
Min: fixedpoint.NewFromInt(3),
|
|
Max: fixedpoint.NewFromInt(100),
|
|
Step: fixedpoint.NewFromInt(0),
|
|
},
|
|
verify: intRangeDomainVerifier,
|
|
}, {
|
|
config: SelectorConfig{
|
|
Type: selectorTypeRangeInt,
|
|
Label: "rangeInt label",
|
|
Path: "rangeInt path",
|
|
Values: []string{"ignore", "ignore"},
|
|
Min: fixedpoint.NewFromInt(3),
|
|
Max: fixedpoint.NewFromInt(100),
|
|
Step: fixedpoint.NewFromInt(7),
|
|
},
|
|
verify: intStepRangeDomainVerifier,
|
|
}, {
|
|
config: SelectorConfig{
|
|
Type: selectorTypeIterate,
|
|
Label: "iterate label",
|
|
Path: "iterate path",
|
|
Values: nil,
|
|
Min: fixedpoint.NewFromInt(0),
|
|
Max: fixedpoint.NewFromInt(-8),
|
|
Step: fixedpoint.NewFromInt(-1),
|
|
},
|
|
verify: stringDomainVerifier,
|
|
}, {
|
|
config: SelectorConfig{
|
|
Type: selectorTypeString,
|
|
Label: "string label",
|
|
Path: "string path",
|
|
Values: []string{"option1", "option2", "option3"},
|
|
Min: fixedpoint.NewFromInt(0),
|
|
Max: fixedpoint.NewFromInt(-8),
|
|
Step: fixedpoint.NewFromInt(-1),
|
|
},
|
|
verify: stringDomainVerifier,
|
|
}, {
|
|
config: SelectorConfig{
|
|
Type: selectorTypeBool,
|
|
Label: "bool label",
|
|
Path: "bool path",
|
|
Values: []string{"ignore"},
|
|
Min: fixedpoint.NewFromInt(99),
|
|
Max: fixedpoint.NewFromInt(1064),
|
|
Step: fixedpoint.NewFromInt(-89),
|
|
},
|
|
verify: boolDomainVerifier,
|
|
}, {
|
|
config: SelectorConfig{
|
|
Type: "unknown type",
|
|
Label: "unknown label",
|
|
Path: "unknown path",
|
|
Values: []string{"unknown option"},
|
|
Min: fixedpoint.NewFromInt(99),
|
|
Max: fixedpoint.NewFromFloat(1064),
|
|
Step: fixedpoint.NewFromInt(0),
|
|
},
|
|
verify: nil,
|
|
},
|
|
}
|
|
|
|
selectors := make([]SelectorConfig, len(tests))
|
|
expectLabelPaths := make(map[string]string)
|
|
verifiers := make([]func(domain paramDomain) bool, 0, len(tests))
|
|
for i, testItem := range tests {
|
|
itemConfig, itemVerify := testItem.config, testItem.verify
|
|
selectors[i] = itemConfig
|
|
if itemVerify != nil {
|
|
expectLabelPaths[testItem.config.Label] = testItem.config.Path
|
|
verifiers = append(verifiers, func(domain paramDomain) bool {
|
|
return itemVerify(domain, itemConfig)
|
|
})
|
|
}
|
|
}
|
|
optimizer := &HyperparameterOptimizer{Config: &Config{Matrix: selectors}}
|
|
exactLabelPaths, exactParamDomains := optimizer.buildParamDomains()
|
|
|
|
if !reflect.DeepEqual(exactLabelPaths, expectLabelPaths) {
|
|
t.Errorf("expectLabelPaths=%v, exactLabelPaths=%v", expectLabelPaths, exactLabelPaths)
|
|
}
|
|
if len(exactParamDomains) != len(verifiers) {
|
|
t.Errorf("expect %d param domains, got %d", len(verifiers), len(exactParamDomains))
|
|
}
|
|
for i, verifier := range verifiers {
|
|
pd := exactParamDomains[i]
|
|
if !verifier(pd) {
|
|
t.Errorf("unexpect param domain at #%d: %#v", i, pd)
|
|
}
|
|
}
|
|
}
|