optimizer: workaround for data race in TPE optimization

This commit is contained in:
Raphanus Lo 2022-07-30 09:57:15 +08:00
parent ae3eaaaeb3
commit 76d908e2bc

View File

@ -12,6 +12,7 @@ import (
"github.com/sirupsen/logrus"
"golang.org/x/sync/errgroup"
"math"
"sync"
)
const (
@ -76,6 +77,10 @@ func buildHyperparameterOptimizeTrialResults(study *goptuna.Study) []*Hyperparam
type HyperparameterOptimizer struct {
StudyName string
Config *Config
// Workaround for goptuna/tpe parameter suggestion. Remove this after fixed.
// ref: https://github.com/c-bata/goptuna/issues/236
paramSuggestionLock sync.Mutex
}
func (o *HyperparameterOptimizer) buildStudy(trialFinishChan chan goptuna.FrozenTrial) (*goptuna.Study, error) {
@ -172,15 +177,23 @@ func (o *HyperparameterOptimizer) buildObjective(executor Executor, configJson [
}
return func(trial goptuna.Trial) (float64, error) {
var trialConfig = configJson
for _, domain := range paramDomains {
if patch, err := domain.buildPatch(&trial); err != nil {
return 0.0, err
} else if patchedConfig, err := patch.ApplyIndent(trialConfig, " "); err != nil {
return 0.0, err
} else {
trialConfig = patchedConfig
trialConfig, err := func(trialConfig []byte) ([]byte, error) {
o.paramSuggestionLock.Lock()
defer o.paramSuggestionLock.Unlock()
for _, domain := range paramDomains {
if patch, err := domain.buildPatch(&trial); err != nil {
return nil, err
} else if patchedConfig, err := patch.ApplyIndent(trialConfig, " "); err != nil {
return nil, err
} else {
trialConfig = patchedConfig
}
}
return trialConfig, nil
}(configJson)
if err != nil {
return 0.0, err
}
summary, err := executor.Execute(trialConfig)