optimizer: optimizeEx supports discrete parameters

This commit is contained in:
Raphanus Lo 2022-07-30 20:34:28 +08:00
parent 76d908e2bc
commit 09940ed3cd
3 changed files with 115 additions and 19 deletions

View File

@ -123,22 +123,46 @@ func (o *HyperparameterOptimizer) buildParamDomains() (map[string]string, []para
var domain paramDomain
switch selector.Type {
case selectorTypeRange, selectorTypeRangeFloat:
domain = &floatRangeDomain{
paramDomainBase: paramDomainBase{
label: selector.Label,
path: selector.Path,
},
min: selector.Min.Float64(),
max: selector.Max.Float64(),
if selector.Step.IsZero() {
domain = &floatRangeDomain{
paramDomainBase: paramDomainBase{
label: selector.Label,
path: selector.Path,
},
min: selector.Min.Float64(),
max: selector.Max.Float64(),
}
} else {
domain = &floatDiscreteRangeDomain{
paramDomainBase: paramDomainBase{
label: selector.Label,
path: selector.Path,
},
min: selector.Min.Float64(),
max: selector.Max.Float64(),
step: selector.Step.Float64(),
}
}
case selectorTypeRangeInt:
domain = &intRangeDomain{
paramDomainBase: paramDomainBase{
label: selector.Label,
path: selector.Path,
},
min: selector.Min.Int(),
max: selector.Max.Int(),
if selector.Step.IsZero() {
domain = &intRangeDomain{
paramDomainBase: paramDomainBase{
label: selector.Label,
path: selector.Path,
},
min: selector.Min.Int(),
max: selector.Max.Int(),
}
} else {
domain = &intStepRangeDomain{
paramDomainBase: paramDomainBase{
label: selector.Label,
path: selector.Path,
},
min: selector.Min.Int(),
max: selector.Max.Int(),
step: selector.Step.Int(),
}
}
case selectorTypeIterate, selectorTypeString:
domain = &stringDomain{

View File

@ -15,6 +15,15 @@ func TestBuildParamDomains(t *testing.T) {
max: expect.Max.Float64(),
}
}
var floatDiscreteRangeDomainVerifier = func(domain paramDomain, expect SelectorConfig) bool {
concrete := domain.(*floatDiscreteRangeDomain)
return *concrete == floatDiscreteRangeDomain{
paramDomainBase: paramDomainBase{label: expect.Label, path: expect.Path},
min: expect.Min.Float64(),
max: expect.Max.Float64(),
step: expect.Step.Float64(),
}
}
var intRangeDomainVerifier = func(domain paramDomain, expect SelectorConfig) bool {
concrete := domain.(*intRangeDomain)
return *concrete == intRangeDomain{
@ -23,6 +32,15 @@ func TestBuildParamDomains(t *testing.T) {
max: expect.Max.Int(),
}
}
var intStepRangeDomainVerifier = func(domain paramDomain, expect SelectorConfig) bool {
concrete := domain.(*intStepRangeDomain)
return *concrete == intStepRangeDomain{
paramDomainBase: paramDomainBase{label: expect.Label, path: expect.Path},
min: expect.Min.Int(),
max: expect.Max.Int(),
step: expect.Step.Int(),
}
}
var stringDomainVerifier = func(domain paramDomain, expect SelectorConfig) bool {
concrete := domain.(*stringDomain)
expectBase := paramDomainBase{label: expect.Label, path: expect.Path}
@ -56,8 +74,8 @@ func TestBuildParamDomains(t *testing.T) {
Label: "range label",
Path: "range path",
Values: []string{"ignore", "ignore"},
Min: fixedpoint.NewFromFloat(0.0),
Max: fixedpoint.NewFromFloat(0.0),
Min: fixedpoint.NewFromFloat(7.0),
Max: fixedpoint.NewFromFloat(80.0),
Step: fixedpoint.NewFromFloat(0.0),
},
verify: floatRangeDomainVerifier,
@ -67,11 +85,22 @@ func TestBuildParamDomains(t *testing.T) {
Label: "rangeFloat label",
Path: "rangeFloat path",
Values: []string{"ignore", "ignore"},
Min: fixedpoint.NewFromFloat(0.0),
Max: fixedpoint.NewFromFloat(0.0),
Min: fixedpoint.NewFromFloat(6.0),
Max: fixedpoint.NewFromFloat(10.0),
Step: fixedpoint.NewFromFloat(0.0),
},
verify: floatRangeDomainVerifier,
}, {
config: SelectorConfig{
Type: selectorTypeRangeFloat,
Label: "rangeDiscreteFloat label",
Path: "rangeDiscreteFloat path",
Values: []string{"ignore", "ignore"},
Min: fixedpoint.NewFromFloat(6.0),
Max: fixedpoint.NewFromFloat(10.0),
Step: fixedpoint.NewFromFloat(2.0),
},
verify: floatDiscreteRangeDomainVerifier,
}, {
config: SelectorConfig{
Type: selectorTypeRangeInt,
@ -80,9 +109,20 @@ func TestBuildParamDomains(t *testing.T) {
Values: []string{"ignore", "ignore"},
Min: fixedpoint.NewFromInt(3),
Max: fixedpoint.NewFromInt(100),
Step: fixedpoint.NewFromInt(66),
Step: fixedpoint.NewFromInt(0),
},
verify: intRangeDomainVerifier,
}, {
config: SelectorConfig{
Type: selectorTypeRangeInt,
Label: "rangeInt label",
Path: "rangeInt path",
Values: []string{"ignore", "ignore"},
Min: fixedpoint.NewFromInt(3),
Max: fixedpoint.NewFromInt(100),
Step: fixedpoint.NewFromInt(7),
},
verify: intStepRangeDomainVerifier,
}, {
config: SelectorConfig{
Type: selectorTypeIterate,

View File

@ -30,6 +30,22 @@ func (d *intRangeDomain) buildPatch(trial *goptuna.Trial) (jsonpatch.Patch, erro
return jsonpatch.DecodePatch(jsonOp)
}
type intStepRangeDomain struct {
paramDomainBase
min int
max int
step int
}
func (d *intStepRangeDomain) buildPatch(trial *goptuna.Trial) (jsonpatch.Patch, error) {
val, err := trial.SuggestStepInt(d.label, d.min, d.max, d.step)
if err != nil {
return nil, err
}
jsonOp := []byte(reformatJson(fmt.Sprintf(`[{"op": "replace", "path": "%s", "value": %v }]`, d.path, val)))
return jsonpatch.DecodePatch(jsonOp)
}
type floatRangeDomain struct {
paramDomainBase
min float64
@ -45,6 +61,22 @@ func (d *floatRangeDomain) buildPatch(trial *goptuna.Trial) (jsonpatch.Patch, er
return jsonpatch.DecodePatch(jsonOp)
}
type floatDiscreteRangeDomain struct {
paramDomainBase
min float64
max float64
step float64
}
func (d *floatDiscreteRangeDomain) buildPatch(trial *goptuna.Trial) (jsonpatch.Patch, error) {
val, err := trial.SuggestDiscreteFloat(d.label, d.min, d.max, d.step)
if err != nil {
return nil, err
}
jsonOp := []byte(reformatJson(fmt.Sprintf(`[{"op": "replace", "path": "%s", "value": %v }]`, d.path, val)))
return jsonpatch.DecodePatch(jsonOp)
}
type stringDomain struct {
paramDomainBase
options []string