mirror of
https://github.com/c9s/bbgo.git
synced 2024-11-10 09:11:55 +00:00
optimizer: optimizeEx supports discrete parameters
This commit is contained in:
parent
76d908e2bc
commit
09940ed3cd
|
@ -123,6 +123,7 @@ func (o *HyperparameterOptimizer) buildParamDomains() (map[string]string, []para
|
||||||
var domain paramDomain
|
var domain paramDomain
|
||||||
switch selector.Type {
|
switch selector.Type {
|
||||||
case selectorTypeRange, selectorTypeRangeFloat:
|
case selectorTypeRange, selectorTypeRangeFloat:
|
||||||
|
if selector.Step.IsZero() {
|
||||||
domain = &floatRangeDomain{
|
domain = &floatRangeDomain{
|
||||||
paramDomainBase: paramDomainBase{
|
paramDomainBase: paramDomainBase{
|
||||||
label: selector.Label,
|
label: selector.Label,
|
||||||
|
@ -131,7 +132,19 @@ func (o *HyperparameterOptimizer) buildParamDomains() (map[string]string, []para
|
||||||
min: selector.Min.Float64(),
|
min: selector.Min.Float64(),
|
||||||
max: selector.Max.Float64(),
|
max: selector.Max.Float64(),
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
domain = &floatDiscreteRangeDomain{
|
||||||
|
paramDomainBase: paramDomainBase{
|
||||||
|
label: selector.Label,
|
||||||
|
path: selector.Path,
|
||||||
|
},
|
||||||
|
min: selector.Min.Float64(),
|
||||||
|
max: selector.Max.Float64(),
|
||||||
|
step: selector.Step.Float64(),
|
||||||
|
}
|
||||||
|
}
|
||||||
case selectorTypeRangeInt:
|
case selectorTypeRangeInt:
|
||||||
|
if selector.Step.IsZero() {
|
||||||
domain = &intRangeDomain{
|
domain = &intRangeDomain{
|
||||||
paramDomainBase: paramDomainBase{
|
paramDomainBase: paramDomainBase{
|
||||||
label: selector.Label,
|
label: selector.Label,
|
||||||
|
@ -140,6 +153,17 @@ func (o *HyperparameterOptimizer) buildParamDomains() (map[string]string, []para
|
||||||
min: selector.Min.Int(),
|
min: selector.Min.Int(),
|
||||||
max: selector.Max.Int(),
|
max: selector.Max.Int(),
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
domain = &intStepRangeDomain{
|
||||||
|
paramDomainBase: paramDomainBase{
|
||||||
|
label: selector.Label,
|
||||||
|
path: selector.Path,
|
||||||
|
},
|
||||||
|
min: selector.Min.Int(),
|
||||||
|
max: selector.Max.Int(),
|
||||||
|
step: selector.Step.Int(),
|
||||||
|
}
|
||||||
|
}
|
||||||
case selectorTypeIterate, selectorTypeString:
|
case selectorTypeIterate, selectorTypeString:
|
||||||
domain = &stringDomain{
|
domain = &stringDomain{
|
||||||
paramDomainBase: paramDomainBase{
|
paramDomainBase: paramDomainBase{
|
||||||
|
|
|
@ -15,6 +15,15 @@ func TestBuildParamDomains(t *testing.T) {
|
||||||
max: expect.Max.Float64(),
|
max: expect.Max.Float64(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
var floatDiscreteRangeDomainVerifier = func(domain paramDomain, expect SelectorConfig) bool {
|
||||||
|
concrete := domain.(*floatDiscreteRangeDomain)
|
||||||
|
return *concrete == floatDiscreteRangeDomain{
|
||||||
|
paramDomainBase: paramDomainBase{label: expect.Label, path: expect.Path},
|
||||||
|
min: expect.Min.Float64(),
|
||||||
|
max: expect.Max.Float64(),
|
||||||
|
step: expect.Step.Float64(),
|
||||||
|
}
|
||||||
|
}
|
||||||
var intRangeDomainVerifier = func(domain paramDomain, expect SelectorConfig) bool {
|
var intRangeDomainVerifier = func(domain paramDomain, expect SelectorConfig) bool {
|
||||||
concrete := domain.(*intRangeDomain)
|
concrete := domain.(*intRangeDomain)
|
||||||
return *concrete == intRangeDomain{
|
return *concrete == intRangeDomain{
|
||||||
|
@ -23,6 +32,15 @@ func TestBuildParamDomains(t *testing.T) {
|
||||||
max: expect.Max.Int(),
|
max: expect.Max.Int(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
var intStepRangeDomainVerifier = func(domain paramDomain, expect SelectorConfig) bool {
|
||||||
|
concrete := domain.(*intStepRangeDomain)
|
||||||
|
return *concrete == intStepRangeDomain{
|
||||||
|
paramDomainBase: paramDomainBase{label: expect.Label, path: expect.Path},
|
||||||
|
min: expect.Min.Int(),
|
||||||
|
max: expect.Max.Int(),
|
||||||
|
step: expect.Step.Int(),
|
||||||
|
}
|
||||||
|
}
|
||||||
var stringDomainVerifier = func(domain paramDomain, expect SelectorConfig) bool {
|
var stringDomainVerifier = func(domain paramDomain, expect SelectorConfig) bool {
|
||||||
concrete := domain.(*stringDomain)
|
concrete := domain.(*stringDomain)
|
||||||
expectBase := paramDomainBase{label: expect.Label, path: expect.Path}
|
expectBase := paramDomainBase{label: expect.Label, path: expect.Path}
|
||||||
|
@ -56,8 +74,8 @@ func TestBuildParamDomains(t *testing.T) {
|
||||||
Label: "range label",
|
Label: "range label",
|
||||||
Path: "range path",
|
Path: "range path",
|
||||||
Values: []string{"ignore", "ignore"},
|
Values: []string{"ignore", "ignore"},
|
||||||
Min: fixedpoint.NewFromFloat(0.0),
|
Min: fixedpoint.NewFromFloat(7.0),
|
||||||
Max: fixedpoint.NewFromFloat(0.0),
|
Max: fixedpoint.NewFromFloat(80.0),
|
||||||
Step: fixedpoint.NewFromFloat(0.0),
|
Step: fixedpoint.NewFromFloat(0.0),
|
||||||
},
|
},
|
||||||
verify: floatRangeDomainVerifier,
|
verify: floatRangeDomainVerifier,
|
||||||
|
@ -67,11 +85,22 @@ func TestBuildParamDomains(t *testing.T) {
|
||||||
Label: "rangeFloat label",
|
Label: "rangeFloat label",
|
||||||
Path: "rangeFloat path",
|
Path: "rangeFloat path",
|
||||||
Values: []string{"ignore", "ignore"},
|
Values: []string{"ignore", "ignore"},
|
||||||
Min: fixedpoint.NewFromFloat(0.0),
|
Min: fixedpoint.NewFromFloat(6.0),
|
||||||
Max: fixedpoint.NewFromFloat(0.0),
|
Max: fixedpoint.NewFromFloat(10.0),
|
||||||
Step: fixedpoint.NewFromFloat(0.0),
|
Step: fixedpoint.NewFromFloat(0.0),
|
||||||
},
|
},
|
||||||
verify: floatRangeDomainVerifier,
|
verify: floatRangeDomainVerifier,
|
||||||
|
}, {
|
||||||
|
config: SelectorConfig{
|
||||||
|
Type: selectorTypeRangeFloat,
|
||||||
|
Label: "rangeDiscreteFloat label",
|
||||||
|
Path: "rangeDiscreteFloat path",
|
||||||
|
Values: []string{"ignore", "ignore"},
|
||||||
|
Min: fixedpoint.NewFromFloat(6.0),
|
||||||
|
Max: fixedpoint.NewFromFloat(10.0),
|
||||||
|
Step: fixedpoint.NewFromFloat(2.0),
|
||||||
|
},
|
||||||
|
verify: floatDiscreteRangeDomainVerifier,
|
||||||
}, {
|
}, {
|
||||||
config: SelectorConfig{
|
config: SelectorConfig{
|
||||||
Type: selectorTypeRangeInt,
|
Type: selectorTypeRangeInt,
|
||||||
|
@ -80,9 +109,20 @@ func TestBuildParamDomains(t *testing.T) {
|
||||||
Values: []string{"ignore", "ignore"},
|
Values: []string{"ignore", "ignore"},
|
||||||
Min: fixedpoint.NewFromInt(3),
|
Min: fixedpoint.NewFromInt(3),
|
||||||
Max: fixedpoint.NewFromInt(100),
|
Max: fixedpoint.NewFromInt(100),
|
||||||
Step: fixedpoint.NewFromInt(66),
|
Step: fixedpoint.NewFromInt(0),
|
||||||
},
|
},
|
||||||
verify: intRangeDomainVerifier,
|
verify: intRangeDomainVerifier,
|
||||||
|
}, {
|
||||||
|
config: SelectorConfig{
|
||||||
|
Type: selectorTypeRangeInt,
|
||||||
|
Label: "rangeInt label",
|
||||||
|
Path: "rangeInt path",
|
||||||
|
Values: []string{"ignore", "ignore"},
|
||||||
|
Min: fixedpoint.NewFromInt(3),
|
||||||
|
Max: fixedpoint.NewFromInt(100),
|
||||||
|
Step: fixedpoint.NewFromInt(7),
|
||||||
|
},
|
||||||
|
verify: intStepRangeDomainVerifier,
|
||||||
}, {
|
}, {
|
||||||
config: SelectorConfig{
|
config: SelectorConfig{
|
||||||
Type: selectorTypeIterate,
|
Type: selectorTypeIterate,
|
||||||
|
|
|
@ -30,6 +30,22 @@ func (d *intRangeDomain) buildPatch(trial *goptuna.Trial) (jsonpatch.Patch, erro
|
||||||
return jsonpatch.DecodePatch(jsonOp)
|
return jsonpatch.DecodePatch(jsonOp)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type intStepRangeDomain struct {
|
||||||
|
paramDomainBase
|
||||||
|
min int
|
||||||
|
max int
|
||||||
|
step int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *intStepRangeDomain) buildPatch(trial *goptuna.Trial) (jsonpatch.Patch, error) {
|
||||||
|
val, err := trial.SuggestStepInt(d.label, d.min, d.max, d.step)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
jsonOp := []byte(reformatJson(fmt.Sprintf(`[{"op": "replace", "path": "%s", "value": %v }]`, d.path, val)))
|
||||||
|
return jsonpatch.DecodePatch(jsonOp)
|
||||||
|
}
|
||||||
|
|
||||||
type floatRangeDomain struct {
|
type floatRangeDomain struct {
|
||||||
paramDomainBase
|
paramDomainBase
|
||||||
min float64
|
min float64
|
||||||
|
@ -45,6 +61,22 @@ func (d *floatRangeDomain) buildPatch(trial *goptuna.Trial) (jsonpatch.Patch, er
|
||||||
return jsonpatch.DecodePatch(jsonOp)
|
return jsonpatch.DecodePatch(jsonOp)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type floatDiscreteRangeDomain struct {
|
||||||
|
paramDomainBase
|
||||||
|
min float64
|
||||||
|
max float64
|
||||||
|
step float64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *floatDiscreteRangeDomain) buildPatch(trial *goptuna.Trial) (jsonpatch.Patch, error) {
|
||||||
|
val, err := trial.SuggestDiscreteFloat(d.label, d.min, d.max, d.step)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
jsonOp := []byte(reformatJson(fmt.Sprintf(`[{"op": "replace", "path": "%s", "value": %v }]`, d.path, val)))
|
||||||
|
return jsonpatch.DecodePatch(jsonOp)
|
||||||
|
}
|
||||||
|
|
||||||
type stringDomain struct {
|
type stringDomain struct {
|
||||||
paramDomainBase
|
paramDomainBase
|
||||||
options []string
|
options []string
|
||||||
|
|
Loading…
Reference in New Issue
Block a user