Merge pull request #8461 from freqtrade/feat/hyperopt_progressbar

hyperopt progressbar -> rich
This commit is contained in:
Matthias 2023-04-13 20:00:27 +02:00 committed by GitHub
commit 3c64c6b034
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 68 additions and 58 deletions

View File

@ -1,24 +1,11 @@
import logging import logging
import sys
from logging import Formatter from logging import Formatter
from logging.handlers import BufferingHandler, RotatingFileHandler, SysLogHandler from logging.handlers import RotatingFileHandler, SysLogHandler
from freqtrade.constants import Config from freqtrade.constants import Config
from freqtrade.exceptions import OperationalException from freqtrade.exceptions import OperationalException
from freqtrade.loggers.buffering_handler import FTBufferingHandler
from freqtrade.loggers.std_err_stream_handler import FTStdErrStreamHandler
class FTBufferingHandler(BufferingHandler):
def flush(self):
"""
Override Flush behaviour - we keep half of the configured capacity
otherwise, we have moments with "empty" logs.
"""
self.acquire()
try:
# Keep half of the records in buffer.
self.buffer = self.buffer[-int(self.capacity / 2):]
finally:
self.release()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -69,7 +56,7 @@ def setup_logging_pre() -> None:
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
format=LOGFORMAT, format=LOGFORMAT,
handlers=[logging.StreamHandler(sys.stderr), bufferHandler] handlers=[FTStdErrStreamHandler(), bufferHandler]
) )

View File

@ -0,0 +1,15 @@
from logging.handlers import BufferingHandler
class FTBufferingHandler(BufferingHandler):
def flush(self):
"""
Override Flush behaviour - we keep half of the configured capacity
otherwise, we have moments with "empty" logs.
"""
self.acquire()
try:
# Keep half of the records in buffer.
self.buffer = self.buffer[-int(self.capacity / 2):]
finally:
self.release()

View File

@ -0,0 +1,26 @@
import sys
from logging import Handler
class FTStdErrStreamHandler(Handler):
def flush(self):
"""
Override Flush behaviour - we keep half of the configured capacity
otherwise, we have moments with "empty" logs.
"""
self.acquire()
try:
sys.stderr.flush()
finally:
self.release()
def emit(self, record):
try:
msg = self.format(record)
# Don't keep a reference to stderr - this can be problematic with progressbars.
sys.stderr.write(msg + '\n')
self.flush()
except RecursionError:
raise
except Exception:
self.handleError(record)

View File

@ -13,13 +13,13 @@ from math import ceil
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import progressbar
import rapidjson import rapidjson
from colorama import Fore, Style
from colorama import init as colorama_init from colorama import init as colorama_init
from joblib import Parallel, cpu_count, delayed, dump, load, wrap_non_picklable_objects from joblib import Parallel, cpu_count, delayed, dump, load, wrap_non_picklable_objects
from joblib.externals import cloudpickle from joblib.externals import cloudpickle
from pandas import DataFrame from pandas import DataFrame
from rich.progress import (BarColumn, MofNCompleteColumn, Progress, TaskProgressColumn, TextColumn,
TimeElapsedColumn, TimeRemainingColumn)
from freqtrade.constants import DATETIME_PRINT_FORMAT, FTHYPT_FILEVERSION, LAST_BT_RESULT_FN, Config from freqtrade.constants import DATETIME_PRINT_FORMAT, FTHYPT_FILEVERSION, LAST_BT_RESULT_FN, Config
from freqtrade.data.converter import trim_dataframes from freqtrade.data.converter import trim_dataframes
@ -44,8 +44,6 @@ with warnings.catch_warnings():
from skopt import Optimizer from skopt import Optimizer
from skopt.space import Dimension from skopt.space import Dimension
progressbar.streams.wrap_stderr()
progressbar.streams.wrap_stdout()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -520,29 +518,6 @@ class Hyperopt:
else: else:
return self.opt.ask(n_points=n_points), [False for _ in range(n_points)] return self.opt.ask(n_points=n_points), [False for _ in range(n_points)]
def get_progressbar_widgets(self):
if self.print_colorized:
widgets = [
' [Epoch ', progressbar.Counter(), ' of ', str(self.total_epochs),
' (', progressbar.Percentage(), ')] ',
progressbar.Bar(marker=progressbar.AnimatedMarker(
fill='\N{FULL BLOCK}',
fill_wrap=Fore.GREEN + '{}' + Fore.RESET,
marker_wrap=Style.BRIGHT + '{}' + Style.RESET_ALL,
)),
' [', progressbar.ETA(), ', ', progressbar.Timer(), ']',
]
else:
widgets = [
' [Epoch ', progressbar.Counter(), ' of ', str(self.total_epochs),
' (', progressbar.Percentage(), ')] ',
progressbar.Bar(marker=progressbar.AnimatedMarker(
fill='\N{FULL BLOCK}',
)),
' [', progressbar.ETA(), ', ', progressbar.Timer(), ']',
]
return widgets
def evaluate_result(self, val: Dict[str, Any], current: int, is_random: bool): def evaluate_result(self, val: Dict[str, Any], current: int, is_random: bool):
""" """
Evaluate results returned from generate_optimizer Evaluate results returned from generate_optimizer
@ -602,11 +577,19 @@ class Hyperopt:
logger.info(f'Effective number of parallel workers used: {jobs}') logger.info(f'Effective number of parallel workers used: {jobs}')
# Define progressbar # Define progressbar
widgets = self.get_progressbar_widgets() with Progress(
with progressbar.ProgressBar( TextColumn("[progress.description]{task.description}"),
max_value=self.total_epochs, redirect_stdout=False, redirect_stderr=False, BarColumn(bar_width=None),
widgets=widgets MofNCompleteColumn(),
TaskProgressColumn(),
"",
TimeElapsedColumn(),
"",
TimeRemainingColumn(),
expand=True,
) as pbar: ) as pbar:
task = pbar.add_task("Epochs", total=self.total_epochs)
start = 0 start = 0
if self.analyze_per_epoch: if self.analyze_per_epoch:
@ -616,7 +599,7 @@ class Hyperopt:
f_val0 = self.generate_optimizer(asked[0]) f_val0 = self.generate_optimizer(asked[0])
self.opt.tell(asked, [f_val0['loss']]) self.opt.tell(asked, [f_val0['loss']])
self.evaluate_result(f_val0, 1, is_random[0]) self.evaluate_result(f_val0, 1, is_random[0])
pbar.update(1) pbar.update(task, advance=1)
start += 1 start += 1
evals = ceil((self.total_epochs - start) / jobs) evals = ceil((self.total_epochs - start) / jobs)
@ -630,14 +613,12 @@ class Hyperopt:
f_val = self.run_optimizer_parallel(parallel, asked) f_val = self.run_optimizer_parallel(parallel, asked)
self.opt.tell(asked, [v['loss'] for v in f_val]) self.opt.tell(asked, [v['loss'] for v in f_val])
# Calculate progressbar outputs
for j, val in enumerate(f_val): for j, val in enumerate(f_val):
# Use human-friendly indexes here (starting from 1) # Use human-friendly indexes here (starting from 1)
current = i * jobs + j + 1 + start current = i * jobs + j + 1 + start
self.evaluate_result(val, current, is_random[j]) self.evaluate_result(val, current, is_random[j])
pbar.update(task, advance=1)
pbar.update(current)
except KeyboardInterrupt: except KeyboardInterrupt:
print('User interrupted..') print('User interrupted..')

View File

@ -6,4 +6,3 @@ scipy==1.10.1
scikit-learn==1.1.3 scikit-learn==1.1.3
scikit-optimize==0.9.0 scikit-optimize==0.9.0
filelock==3.11.0 filelock==3.11.0
progressbar2==4.2.0

View File

@ -20,6 +20,7 @@ jinja2==3.1.2
tables==3.8.0 tables==3.8.0
blosc==1.11.1 blosc==1.11.1
joblib==1.2.0 joblib==1.2.0
rich==13.3.3
pyarrow==11.0.0; platform_machine != 'armv7l' pyarrow==11.0.0; platform_machine != 'armv7l'
# find first, C search in arrays # find first, C search in arrays

View File

@ -8,7 +8,6 @@ hyperopt = [
'scikit-learn', 'scikit-learn',
'scikit-optimize>=0.7.0', 'scikit-optimize>=0.7.0',
'filelock', 'filelock',
'progressbar2',
] ]
freqai = [ freqai = [
@ -82,6 +81,7 @@ setup(
'numpy', 'numpy',
'pandas', 'pandas',
'joblib>=1.2.0', 'joblib>=1.2.0',
'rich',
'pyarrow; platform_machine != "armv7l"', 'pyarrow; platform_machine != "armv7l"',
'fastapi', 'fastapi',
'pydantic>=1.8.0', 'pydantic>=1.8.0',

View File

@ -23,7 +23,8 @@ from freqtrade.configuration.load_config import (load_config_file, load_file, lo
from freqtrade.constants import DEFAULT_DB_DRYRUN_URL, DEFAULT_DB_PROD_URL, ENV_VAR_PREFIX from freqtrade.constants import DEFAULT_DB_DRYRUN_URL, DEFAULT_DB_PROD_URL, ENV_VAR_PREFIX
from freqtrade.enums import RunMode from freqtrade.enums import RunMode
from freqtrade.exceptions import OperationalException from freqtrade.exceptions import OperationalException
from freqtrade.loggers import FTBufferingHandler, _set_loggers, setup_logging, setup_logging_pre from freqtrade.loggers import (FTBufferingHandler, FTStdErrStreamHandler, _set_loggers,
setup_logging, setup_logging_pre)
from tests.conftest import (CURRENT_TEST_STRATEGY, log_has, log_has_re, from tests.conftest import (CURRENT_TEST_STRATEGY, log_has, log_has_re,
patched_configuration_load_config_file) patched_configuration_load_config_file)
@ -658,7 +659,7 @@ def test_set_loggers_syslog():
setup_logging(config) setup_logging(config)
assert len(logger.handlers) == 3 assert len(logger.handlers) == 3
assert [x for x in logger.handlers if type(x) == logging.handlers.SysLogHandler] assert [x for x in logger.handlers if type(x) == logging.handlers.SysLogHandler]
assert [x for x in logger.handlers if type(x) == logging.StreamHandler] assert [x for x in logger.handlers if type(x) == FTStdErrStreamHandler]
assert [x for x in logger.handlers if type(x) == FTBufferingHandler] assert [x for x in logger.handlers if type(x) == FTBufferingHandler]
# setting up logging again should NOT cause the loggers to be added a second time. # setting up logging again should NOT cause the loggers to be added a second time.
setup_logging(config) setup_logging(config)
@ -681,7 +682,7 @@ def test_set_loggers_Filehandler(tmpdir):
setup_logging(config) setup_logging(config)
assert len(logger.handlers) == 3 assert len(logger.handlers) == 3
assert [x for x in logger.handlers if type(x) == logging.handlers.RotatingFileHandler] assert [x for x in logger.handlers if type(x) == logging.handlers.RotatingFileHandler]
assert [x for x in logger.handlers if type(x) == logging.StreamHandler] assert [x for x in logger.handlers if type(x) == FTStdErrStreamHandler]
assert [x for x in logger.handlers if type(x) == FTBufferingHandler] assert [x for x in logger.handlers if type(x) == FTBufferingHandler]
# setting up logging again should NOT cause the loggers to be added a second time. # setting up logging again should NOT cause the loggers to be added a second time.
setup_logging(config) setup_logging(config)
@ -706,7 +707,7 @@ def test_set_loggers_journald(mocker):
setup_logging(config) setup_logging(config)
assert len(logger.handlers) == 3 assert len(logger.handlers) == 3
assert [x for x in logger.handlers if type(x).__name__ == "JournaldLogHandler"] assert [x for x in logger.handlers if type(x).__name__ == "JournaldLogHandler"]
assert [x for x in logger.handlers if type(x) == logging.StreamHandler] assert [x for x in logger.handlers if type(x) == FTStdErrStreamHandler]
# reset handlers to not break pytest # reset handlers to not break pytest
logger.handlers = orig_handlers logger.handlers = orig_handlers