diff --git a/pkg/optimizer/hpoptimizer.go b/pkg/optimizer/hpoptimizer.go index d2858e6e5..ea4494349 100644 --- a/pkg/optimizer/hpoptimizer.go +++ b/pkg/optimizer/hpoptimizer.go @@ -123,22 +123,46 @@ func (o *HyperparameterOptimizer) buildParamDomains() (map[string]string, []para var domain paramDomain switch selector.Type { case selectorTypeRange, selectorTypeRangeFloat: - domain = &floatRangeDomain{ - paramDomainBase: paramDomainBase{ - label: selector.Label, - path: selector.Path, - }, - min: selector.Min.Float64(), - max: selector.Max.Float64(), + if selector.Step.IsZero() { + domain = &floatRangeDomain{ + paramDomainBase: paramDomainBase{ + label: selector.Label, + path: selector.Path, + }, + min: selector.Min.Float64(), + max: selector.Max.Float64(), + } + } else { + domain = &floatDiscreteRangeDomain{ + paramDomainBase: paramDomainBase{ + label: selector.Label, + path: selector.Path, + }, + min: selector.Min.Float64(), + max: selector.Max.Float64(), + step: selector.Step.Float64(), + } } case selectorTypeRangeInt: - domain = &intRangeDomain{ - paramDomainBase: paramDomainBase{ - label: selector.Label, - path: selector.Path, - }, - min: selector.Min.Int(), - max: selector.Max.Int(), + if selector.Step.IsZero() { + domain = &intRangeDomain{ + paramDomainBase: paramDomainBase{ + label: selector.Label, + path: selector.Path, + }, + min: selector.Min.Int(), + max: selector.Max.Int(), + } + } else { + domain = &intStepRangeDomain{ + paramDomainBase: paramDomainBase{ + label: selector.Label, + path: selector.Path, + }, + min: selector.Min.Int(), + max: selector.Max.Int(), + step: selector.Step.Int(), + } } case selectorTypeIterate, selectorTypeString: domain = &stringDomain{ diff --git a/pkg/optimizer/hpoptimizer_test.go b/pkg/optimizer/hpoptimizer_test.go index 490240f05..2107946b3 100644 --- a/pkg/optimizer/hpoptimizer_test.go +++ b/pkg/optimizer/hpoptimizer_test.go @@ -15,6 +15,15 @@ func TestBuildParamDomains(t *testing.T) { max: expect.Max.Float64(), } } + var floatDiscreteRangeDomainVerifier = func(domain paramDomain, expect SelectorConfig) bool { + concrete := domain.(*floatDiscreteRangeDomain) + return *concrete == floatDiscreteRangeDomain{ + paramDomainBase: paramDomainBase{label: expect.Label, path: expect.Path}, + min: expect.Min.Float64(), + max: expect.Max.Float64(), + step: expect.Step.Float64(), + } + } var intRangeDomainVerifier = func(domain paramDomain, expect SelectorConfig) bool { concrete := domain.(*intRangeDomain) return *concrete == intRangeDomain{ @@ -23,6 +32,15 @@ func TestBuildParamDomains(t *testing.T) { max: expect.Max.Int(), } } + var intStepRangeDomainVerifier = func(domain paramDomain, expect SelectorConfig) bool { + concrete := domain.(*intStepRangeDomain) + return *concrete == intStepRangeDomain{ + paramDomainBase: paramDomainBase{label: expect.Label, path: expect.Path}, + min: expect.Min.Int(), + max: expect.Max.Int(), + step: expect.Step.Int(), + } + } var stringDomainVerifier = func(domain paramDomain, expect SelectorConfig) bool { concrete := domain.(*stringDomain) expectBase := paramDomainBase{label: expect.Label, path: expect.Path} @@ -56,8 +74,8 @@ func TestBuildParamDomains(t *testing.T) { Label: "range label", Path: "range path", Values: []string{"ignore", "ignore"}, - Min: fixedpoint.NewFromFloat(0.0), - Max: fixedpoint.NewFromFloat(0.0), + Min: fixedpoint.NewFromFloat(7.0), + Max: fixedpoint.NewFromFloat(80.0), Step: fixedpoint.NewFromFloat(0.0), }, verify: floatRangeDomainVerifier, @@ -67,11 +85,22 @@ func TestBuildParamDomains(t *testing.T) { Label: "rangeFloat label", Path: "rangeFloat path", Values: []string{"ignore", "ignore"}, - Min: fixedpoint.NewFromFloat(0.0), - Max: fixedpoint.NewFromFloat(0.0), + Min: fixedpoint.NewFromFloat(6.0), + Max: fixedpoint.NewFromFloat(10.0), Step: fixedpoint.NewFromFloat(0.0), }, verify: floatRangeDomainVerifier, + }, { + config: SelectorConfig{ + Type: selectorTypeRangeFloat, + Label: "rangeDiscreteFloat label", + Path: "rangeDiscreteFloat path", + Values: []string{"ignore", "ignore"}, + Min: fixedpoint.NewFromFloat(6.0), + Max: fixedpoint.NewFromFloat(10.0), + Step: fixedpoint.NewFromFloat(2.0), + }, + verify: floatDiscreteRangeDomainVerifier, }, { config: SelectorConfig{ Type: selectorTypeRangeInt, @@ -80,9 +109,20 @@ func TestBuildParamDomains(t *testing.T) { Values: []string{"ignore", "ignore"}, Min: fixedpoint.NewFromInt(3), Max: fixedpoint.NewFromInt(100), - Step: fixedpoint.NewFromInt(66), + Step: fixedpoint.NewFromInt(0), }, verify: intRangeDomainVerifier, + }, { + config: SelectorConfig{ + Type: selectorTypeRangeInt, + Label: "rangeInt label", + Path: "rangeInt path", + Values: []string{"ignore", "ignore"}, + Min: fixedpoint.NewFromInt(3), + Max: fixedpoint.NewFromInt(100), + Step: fixedpoint.NewFromInt(7), + }, + verify: intStepRangeDomainVerifier, }, { config: SelectorConfig{ Type: selectorTypeIterate, diff --git a/pkg/optimizer/hyperparam.go b/pkg/optimizer/hyperparam.go index 31b699ea1..ae6222894 100644 --- a/pkg/optimizer/hyperparam.go +++ b/pkg/optimizer/hyperparam.go @@ -30,6 +30,22 @@ func (d *intRangeDomain) buildPatch(trial *goptuna.Trial) (jsonpatch.Patch, erro return jsonpatch.DecodePatch(jsonOp) } +type intStepRangeDomain struct { + paramDomainBase + min int + max int + step int +} + +func (d *intStepRangeDomain) buildPatch(trial *goptuna.Trial) (jsonpatch.Patch, error) { + val, err := trial.SuggestStepInt(d.label, d.min, d.max, d.step) + if err != nil { + return nil, err + } + jsonOp := []byte(reformatJson(fmt.Sprintf(`[{"op": "replace", "path": "%s", "value": %v }]`, d.path, val))) + return jsonpatch.DecodePatch(jsonOp) +} + type floatRangeDomain struct { paramDomainBase min float64 @@ -45,6 +61,22 @@ func (d *floatRangeDomain) buildPatch(trial *goptuna.Trial) (jsonpatch.Patch, er return jsonpatch.DecodePatch(jsonOp) } +type floatDiscreteRangeDomain struct { + paramDomainBase + min float64 + max float64 + step float64 +} + +func (d *floatDiscreteRangeDomain) buildPatch(trial *goptuna.Trial) (jsonpatch.Patch, error) { + val, err := trial.SuggestDiscreteFloat(d.label, d.min, d.max, d.step) + if err != nil { + return nil, err + } + jsonOp := []byte(reformatJson(fmt.Sprintf(`[{"op": "replace", "path": "%s", "value": %v }]`, d.path, val))) + return jsonpatch.DecodePatch(jsonOp) +} + type stringDomain struct { paramDomainBase options []string