mirror of
https://github.com/c9s/bbgo.git
synced 2024-11-10 01:01:56 +00:00
optimizer: optimizeEx supports discrete parameters
This commit is contained in:
parent
76d908e2bc
commit
09940ed3cd
|
@ -123,6 +123,7 @@ func (o *HyperparameterOptimizer) buildParamDomains() (map[string]string, []para
|
|||
var domain paramDomain
|
||||
switch selector.Type {
|
||||
case selectorTypeRange, selectorTypeRangeFloat:
|
||||
if selector.Step.IsZero() {
|
||||
domain = &floatRangeDomain{
|
||||
paramDomainBase: paramDomainBase{
|
||||
label: selector.Label,
|
||||
|
@ -131,7 +132,19 @@ func (o *HyperparameterOptimizer) buildParamDomains() (map[string]string, []para
|
|||
min: selector.Min.Float64(),
|
||||
max: selector.Max.Float64(),
|
||||
}
|
||||
} else {
|
||||
domain = &floatDiscreteRangeDomain{
|
||||
paramDomainBase: paramDomainBase{
|
||||
label: selector.Label,
|
||||
path: selector.Path,
|
||||
},
|
||||
min: selector.Min.Float64(),
|
||||
max: selector.Max.Float64(),
|
||||
step: selector.Step.Float64(),
|
||||
}
|
||||
}
|
||||
case selectorTypeRangeInt:
|
||||
if selector.Step.IsZero() {
|
||||
domain = &intRangeDomain{
|
||||
paramDomainBase: paramDomainBase{
|
||||
label: selector.Label,
|
||||
|
@ -140,6 +153,17 @@ func (o *HyperparameterOptimizer) buildParamDomains() (map[string]string, []para
|
|||
min: selector.Min.Int(),
|
||||
max: selector.Max.Int(),
|
||||
}
|
||||
} else {
|
||||
domain = &intStepRangeDomain{
|
||||
paramDomainBase: paramDomainBase{
|
||||
label: selector.Label,
|
||||
path: selector.Path,
|
||||
},
|
||||
min: selector.Min.Int(),
|
||||
max: selector.Max.Int(),
|
||||
step: selector.Step.Int(),
|
||||
}
|
||||
}
|
||||
case selectorTypeIterate, selectorTypeString:
|
||||
domain = &stringDomain{
|
||||
paramDomainBase: paramDomainBase{
|
||||
|
|
|
@ -15,6 +15,15 @@ func TestBuildParamDomains(t *testing.T) {
|
|||
max: expect.Max.Float64(),
|
||||
}
|
||||
}
|
||||
var floatDiscreteRangeDomainVerifier = func(domain paramDomain, expect SelectorConfig) bool {
|
||||
concrete := domain.(*floatDiscreteRangeDomain)
|
||||
return *concrete == floatDiscreteRangeDomain{
|
||||
paramDomainBase: paramDomainBase{label: expect.Label, path: expect.Path},
|
||||
min: expect.Min.Float64(),
|
||||
max: expect.Max.Float64(),
|
||||
step: expect.Step.Float64(),
|
||||
}
|
||||
}
|
||||
var intRangeDomainVerifier = func(domain paramDomain, expect SelectorConfig) bool {
|
||||
concrete := domain.(*intRangeDomain)
|
||||
return *concrete == intRangeDomain{
|
||||
|
@ -23,6 +32,15 @@ func TestBuildParamDomains(t *testing.T) {
|
|||
max: expect.Max.Int(),
|
||||
}
|
||||
}
|
||||
var intStepRangeDomainVerifier = func(domain paramDomain, expect SelectorConfig) bool {
|
||||
concrete := domain.(*intStepRangeDomain)
|
||||
return *concrete == intStepRangeDomain{
|
||||
paramDomainBase: paramDomainBase{label: expect.Label, path: expect.Path},
|
||||
min: expect.Min.Int(),
|
||||
max: expect.Max.Int(),
|
||||
step: expect.Step.Int(),
|
||||
}
|
||||
}
|
||||
var stringDomainVerifier = func(domain paramDomain, expect SelectorConfig) bool {
|
||||
concrete := domain.(*stringDomain)
|
||||
expectBase := paramDomainBase{label: expect.Label, path: expect.Path}
|
||||
|
@ -56,8 +74,8 @@ func TestBuildParamDomains(t *testing.T) {
|
|||
Label: "range label",
|
||||
Path: "range path",
|
||||
Values: []string{"ignore", "ignore"},
|
||||
Min: fixedpoint.NewFromFloat(0.0),
|
||||
Max: fixedpoint.NewFromFloat(0.0),
|
||||
Min: fixedpoint.NewFromFloat(7.0),
|
||||
Max: fixedpoint.NewFromFloat(80.0),
|
||||
Step: fixedpoint.NewFromFloat(0.0),
|
||||
},
|
||||
verify: floatRangeDomainVerifier,
|
||||
|
@ -67,11 +85,22 @@ func TestBuildParamDomains(t *testing.T) {
|
|||
Label: "rangeFloat label",
|
||||
Path: "rangeFloat path",
|
||||
Values: []string{"ignore", "ignore"},
|
||||
Min: fixedpoint.NewFromFloat(0.0),
|
||||
Max: fixedpoint.NewFromFloat(0.0),
|
||||
Min: fixedpoint.NewFromFloat(6.0),
|
||||
Max: fixedpoint.NewFromFloat(10.0),
|
||||
Step: fixedpoint.NewFromFloat(0.0),
|
||||
},
|
||||
verify: floatRangeDomainVerifier,
|
||||
}, {
|
||||
config: SelectorConfig{
|
||||
Type: selectorTypeRangeFloat,
|
||||
Label: "rangeDiscreteFloat label",
|
||||
Path: "rangeDiscreteFloat path",
|
||||
Values: []string{"ignore", "ignore"},
|
||||
Min: fixedpoint.NewFromFloat(6.0),
|
||||
Max: fixedpoint.NewFromFloat(10.0),
|
||||
Step: fixedpoint.NewFromFloat(2.0),
|
||||
},
|
||||
verify: floatDiscreteRangeDomainVerifier,
|
||||
}, {
|
||||
config: SelectorConfig{
|
||||
Type: selectorTypeRangeInt,
|
||||
|
@ -80,9 +109,20 @@ func TestBuildParamDomains(t *testing.T) {
|
|||
Values: []string{"ignore", "ignore"},
|
||||
Min: fixedpoint.NewFromInt(3),
|
||||
Max: fixedpoint.NewFromInt(100),
|
||||
Step: fixedpoint.NewFromInt(66),
|
||||
Step: fixedpoint.NewFromInt(0),
|
||||
},
|
||||
verify: intRangeDomainVerifier,
|
||||
}, {
|
||||
config: SelectorConfig{
|
||||
Type: selectorTypeRangeInt,
|
||||
Label: "rangeInt label",
|
||||
Path: "rangeInt path",
|
||||
Values: []string{"ignore", "ignore"},
|
||||
Min: fixedpoint.NewFromInt(3),
|
||||
Max: fixedpoint.NewFromInt(100),
|
||||
Step: fixedpoint.NewFromInt(7),
|
||||
},
|
||||
verify: intStepRangeDomainVerifier,
|
||||
}, {
|
||||
config: SelectorConfig{
|
||||
Type: selectorTypeIterate,
|
||||
|
|
|
@ -30,6 +30,22 @@ func (d *intRangeDomain) buildPatch(trial *goptuna.Trial) (jsonpatch.Patch, erro
|
|||
return jsonpatch.DecodePatch(jsonOp)
|
||||
}
|
||||
|
||||
type intStepRangeDomain struct {
|
||||
paramDomainBase
|
||||
min int
|
||||
max int
|
||||
step int
|
||||
}
|
||||
|
||||
func (d *intStepRangeDomain) buildPatch(trial *goptuna.Trial) (jsonpatch.Patch, error) {
|
||||
val, err := trial.SuggestStepInt(d.label, d.min, d.max, d.step)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jsonOp := []byte(reformatJson(fmt.Sprintf(`[{"op": "replace", "path": "%s", "value": %v }]`, d.path, val)))
|
||||
return jsonpatch.DecodePatch(jsonOp)
|
||||
}
|
||||
|
||||
type floatRangeDomain struct {
|
||||
paramDomainBase
|
||||
min float64
|
||||
|
@ -45,6 +61,22 @@ func (d *floatRangeDomain) buildPatch(trial *goptuna.Trial) (jsonpatch.Patch, er
|
|||
return jsonpatch.DecodePatch(jsonOp)
|
||||
}
|
||||
|
||||
type floatDiscreteRangeDomain struct {
|
||||
paramDomainBase
|
||||
min float64
|
||||
max float64
|
||||
step float64
|
||||
}
|
||||
|
||||
func (d *floatDiscreteRangeDomain) buildPatch(trial *goptuna.Trial) (jsonpatch.Patch, error) {
|
||||
val, err := trial.SuggestDiscreteFloat(d.label, d.min, d.max, d.step)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jsonOp := []byte(reformatJson(fmt.Sprintf(`[{"op": "replace", "path": "%s", "value": %v }]`, d.path, val)))
|
||||
return jsonpatch.DecodePatch(jsonOp)
|
||||
}
|
||||
|
||||
type stringDomain struct {
|
||||
paramDomainBase
|
||||
options []string
|
||||
|
|
Loading…
Reference in New Issue
Block a user