don't overwrite is_random

this should fix issue #6746
This commit is contained in:
Italo 2022-06-09 20:06:23 +01:00
parent dd32127014
commit dce9fdd0e4

View File

@ -429,18 +429,19 @@ class Hyperopt:
return new_list return new_list
i = 0 i = 0
asked_non_tried: List[List[Any]] = [] asked_non_tried: List[List[Any]] = []
is_random: List[bool] = [] is_random_non_tried: List[bool] = []
while i < 5 and len(asked_non_tried) < n_points: while i < 5 and len(asked_non_tried) < n_points:
if i < 3: if i < 3:
self.opt.cache_ = {} self.opt.cache_ = {}
asked = unique_list(self.opt.ask(n_points=n_points * 5)) asked = unique_list(self.opt.ask(n_points=n_points * 5))
is_random = [False for _ in range(len(asked))] is_random = [False for _ in range(len(asked))]
else: else:
asked = unique_list(self.opt.space.rvs(n_samples=n_points * 5)) asked = unique_list(self.opt.space.rvs(
n_samples=n_points * 5, random_state=self.random_state + i))
is_random = [True for _ in range(len(asked))] is_random = [True for _ in range(len(asked))]
is_random += [rand for x, rand in zip(asked, is_random) is_random_non_tried += [rand for x, rand in zip(asked, is_random)
if x not in self.opt.Xi if x not in self.opt.Xi
and x not in asked_non_tried] and x not in asked_non_tried]
asked_non_tried += [x for x in asked asked_non_tried += [x for x in asked
if x not in self.opt.Xi if x not in self.opt.Xi
and x not in asked_non_tried] and x not in asked_non_tried]
@ -449,7 +450,7 @@ class Hyperopt:
if asked_non_tried: if asked_non_tried:
return ( return (
asked_non_tried[:min(len(asked_non_tried), n_points)], asked_non_tried[:min(len(asked_non_tried), n_points)],
is_random[:min(len(asked_non_tried), n_points)] is_random_non_tried[:min(len(asked_non_tried), n_points)]
) )
else: else:
return self.opt.ask(n_points=n_points), [False for _ in range(n_points)] return self.opt.ask(n_points=n_points), [False for _ in range(n_points)]