optimizer: optimizeEx supports discrete parameters

This commit is contained in:
Raphanus Lo 2022-07-30 20:34:28 +08:00
parent 76d908e2bc
commit 09940ed3cd
3 changed files with 115 additions and 19 deletions

View File

@ -123,6 +123,7 @@ func (o *HyperparameterOptimizer) buildParamDomains() (map[string]string, []para
var domain paramDomain var domain paramDomain
switch selector.Type { switch selector.Type {
case selectorTypeRange, selectorTypeRangeFloat: case selectorTypeRange, selectorTypeRangeFloat:
if selector.Step.IsZero() {
domain = &floatRangeDomain{ domain = &floatRangeDomain{
paramDomainBase: paramDomainBase{ paramDomainBase: paramDomainBase{
label: selector.Label, label: selector.Label,
@ -131,7 +132,19 @@ func (o *HyperparameterOptimizer) buildParamDomains() (map[string]string, []para
min: selector.Min.Float64(), min: selector.Min.Float64(),
max: selector.Max.Float64(), max: selector.Max.Float64(),
} }
} else {
domain = &floatDiscreteRangeDomain{
paramDomainBase: paramDomainBase{
label: selector.Label,
path: selector.Path,
},
min: selector.Min.Float64(),
max: selector.Max.Float64(),
step: selector.Step.Float64(),
}
}
case selectorTypeRangeInt: case selectorTypeRangeInt:
if selector.Step.IsZero() {
domain = &intRangeDomain{ domain = &intRangeDomain{
paramDomainBase: paramDomainBase{ paramDomainBase: paramDomainBase{
label: selector.Label, label: selector.Label,
@ -140,6 +153,17 @@ func (o *HyperparameterOptimizer) buildParamDomains() (map[string]string, []para
min: selector.Min.Int(), min: selector.Min.Int(),
max: selector.Max.Int(), max: selector.Max.Int(),
} }
} else {
domain = &intStepRangeDomain{
paramDomainBase: paramDomainBase{
label: selector.Label,
path: selector.Path,
},
min: selector.Min.Int(),
max: selector.Max.Int(),
step: selector.Step.Int(),
}
}
case selectorTypeIterate, selectorTypeString: case selectorTypeIterate, selectorTypeString:
domain = &stringDomain{ domain = &stringDomain{
paramDomainBase: paramDomainBase{ paramDomainBase: paramDomainBase{

View File

@ -15,6 +15,15 @@ func TestBuildParamDomains(t *testing.T) {
max: expect.Max.Float64(), max: expect.Max.Float64(),
} }
} }
var floatDiscreteRangeDomainVerifier = func(domain paramDomain, expect SelectorConfig) bool {
concrete := domain.(*floatDiscreteRangeDomain)
return *concrete == floatDiscreteRangeDomain{
paramDomainBase: paramDomainBase{label: expect.Label, path: expect.Path},
min: expect.Min.Float64(),
max: expect.Max.Float64(),
step: expect.Step.Float64(),
}
}
var intRangeDomainVerifier = func(domain paramDomain, expect SelectorConfig) bool { var intRangeDomainVerifier = func(domain paramDomain, expect SelectorConfig) bool {
concrete := domain.(*intRangeDomain) concrete := domain.(*intRangeDomain)
return *concrete == intRangeDomain{ return *concrete == intRangeDomain{
@ -23,6 +32,15 @@ func TestBuildParamDomains(t *testing.T) {
max: expect.Max.Int(), max: expect.Max.Int(),
} }
} }
var intStepRangeDomainVerifier = func(domain paramDomain, expect SelectorConfig) bool {
concrete := domain.(*intStepRangeDomain)
return *concrete == intStepRangeDomain{
paramDomainBase: paramDomainBase{label: expect.Label, path: expect.Path},
min: expect.Min.Int(),
max: expect.Max.Int(),
step: expect.Step.Int(),
}
}
var stringDomainVerifier = func(domain paramDomain, expect SelectorConfig) bool { var stringDomainVerifier = func(domain paramDomain, expect SelectorConfig) bool {
concrete := domain.(*stringDomain) concrete := domain.(*stringDomain)
expectBase := paramDomainBase{label: expect.Label, path: expect.Path} expectBase := paramDomainBase{label: expect.Label, path: expect.Path}
@ -56,8 +74,8 @@ func TestBuildParamDomains(t *testing.T) {
Label: "range label", Label: "range label",
Path: "range path", Path: "range path",
Values: []string{"ignore", "ignore"}, Values: []string{"ignore", "ignore"},
Min: fixedpoint.NewFromFloat(0.0), Min: fixedpoint.NewFromFloat(7.0),
Max: fixedpoint.NewFromFloat(0.0), Max: fixedpoint.NewFromFloat(80.0),
Step: fixedpoint.NewFromFloat(0.0), Step: fixedpoint.NewFromFloat(0.0),
}, },
verify: floatRangeDomainVerifier, verify: floatRangeDomainVerifier,
@ -67,11 +85,22 @@ func TestBuildParamDomains(t *testing.T) {
Label: "rangeFloat label", Label: "rangeFloat label",
Path: "rangeFloat path", Path: "rangeFloat path",
Values: []string{"ignore", "ignore"}, Values: []string{"ignore", "ignore"},
Min: fixedpoint.NewFromFloat(0.0), Min: fixedpoint.NewFromFloat(6.0),
Max: fixedpoint.NewFromFloat(0.0), Max: fixedpoint.NewFromFloat(10.0),
Step: fixedpoint.NewFromFloat(0.0), Step: fixedpoint.NewFromFloat(0.0),
}, },
verify: floatRangeDomainVerifier, verify: floatRangeDomainVerifier,
}, {
config: SelectorConfig{
Type: selectorTypeRangeFloat,
Label: "rangeDiscreteFloat label",
Path: "rangeDiscreteFloat path",
Values: []string{"ignore", "ignore"},
Min: fixedpoint.NewFromFloat(6.0),
Max: fixedpoint.NewFromFloat(10.0),
Step: fixedpoint.NewFromFloat(2.0),
},
verify: floatDiscreteRangeDomainVerifier,
}, { }, {
config: SelectorConfig{ config: SelectorConfig{
Type: selectorTypeRangeInt, Type: selectorTypeRangeInt,
@ -80,9 +109,20 @@ func TestBuildParamDomains(t *testing.T) {
Values: []string{"ignore", "ignore"}, Values: []string{"ignore", "ignore"},
Min: fixedpoint.NewFromInt(3), Min: fixedpoint.NewFromInt(3),
Max: fixedpoint.NewFromInt(100), Max: fixedpoint.NewFromInt(100),
Step: fixedpoint.NewFromInt(66), Step: fixedpoint.NewFromInt(0),
}, },
verify: intRangeDomainVerifier, verify: intRangeDomainVerifier,
}, {
config: SelectorConfig{
Type: selectorTypeRangeInt,
Label: "rangeInt label",
Path: "rangeInt path",
Values: []string{"ignore", "ignore"},
Min: fixedpoint.NewFromInt(3),
Max: fixedpoint.NewFromInt(100),
Step: fixedpoint.NewFromInt(7),
},
verify: intStepRangeDomainVerifier,
}, { }, {
config: SelectorConfig{ config: SelectorConfig{
Type: selectorTypeIterate, Type: selectorTypeIterate,

View File

@ -30,6 +30,22 @@ func (d *intRangeDomain) buildPatch(trial *goptuna.Trial) (jsonpatch.Patch, erro
return jsonpatch.DecodePatch(jsonOp) return jsonpatch.DecodePatch(jsonOp)
} }
type intStepRangeDomain struct {
paramDomainBase
min int
max int
step int
}
func (d *intStepRangeDomain) buildPatch(trial *goptuna.Trial) (jsonpatch.Patch, error) {
val, err := trial.SuggestStepInt(d.label, d.min, d.max, d.step)
if err != nil {
return nil, err
}
jsonOp := []byte(reformatJson(fmt.Sprintf(`[{"op": "replace", "path": "%s", "value": %v }]`, d.path, val)))
return jsonpatch.DecodePatch(jsonOp)
}
type floatRangeDomain struct { type floatRangeDomain struct {
paramDomainBase paramDomainBase
min float64 min float64
@ -45,6 +61,22 @@ func (d *floatRangeDomain) buildPatch(trial *goptuna.Trial) (jsonpatch.Patch, er
return jsonpatch.DecodePatch(jsonOp) return jsonpatch.DecodePatch(jsonOp)
} }
type floatDiscreteRangeDomain struct {
paramDomainBase
min float64
max float64
step float64
}
func (d *floatDiscreteRangeDomain) buildPatch(trial *goptuna.Trial) (jsonpatch.Patch, error) {
val, err := trial.SuggestDiscreteFloat(d.label, d.min, d.max, d.step)
if err != nil {
return nil, err
}
jsonOp := []byte(reformatJson(fmt.Sprintf(`[{"op": "replace", "path": "%s", "value": %v }]`, d.path, val)))
return jsonpatch.DecodePatch(jsonOp)
}
type stringDomain struct { type stringDomain struct {
paramDomainBase paramDomainBase
options []string options []string