Merge pull request #8461 from freqtrade/feat/hyperopt_progressbar

hyperopt progressbar -> rich
This commit is contained in:
Matthias 2023-04-13 20:00:27 +02:00 committed by GitHub
commit 3c64c6b034
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 68 additions and 58 deletions

View File

@ -1,24 +1,11 @@
import logging
import sys
from logging import Formatter
from logging.handlers import BufferingHandler, RotatingFileHandler, SysLogHandler
from logging.handlers import RotatingFileHandler, SysLogHandler
from freqtrade.constants import Config
from freqtrade.exceptions import OperationalException
class FTBufferingHandler(BufferingHandler):
def flush(self):
"""
Override Flush behaviour - we keep half of the configured capacity
otherwise, we have moments with "empty" logs.
"""
self.acquire()
try:
# Keep half of the records in buffer.
self.buffer = self.buffer[-int(self.capacity / 2):]
finally:
self.release()
from freqtrade.loggers.buffering_handler import FTBufferingHandler
from freqtrade.loggers.std_err_stream_handler import FTStdErrStreamHandler
logger = logging.getLogger(__name__)
@ -69,7 +56,7 @@ def setup_logging_pre() -> None:
logging.basicConfig(
level=logging.INFO,
format=LOGFORMAT,
handlers=[logging.StreamHandler(sys.stderr), bufferHandler]
handlers=[FTStdErrStreamHandler(), bufferHandler]
)

View File

@ -0,0 +1,15 @@
from logging.handlers import BufferingHandler
class FTBufferingHandler(BufferingHandler):
def flush(self):
"""
Override Flush behaviour - we keep half of the configured capacity
otherwise, we have moments with "empty" logs.
"""
self.acquire()
try:
# Keep half of the records in buffer.
self.buffer = self.buffer[-int(self.capacity / 2):]
finally:
self.release()

View File

@ -0,0 +1,26 @@
import sys
from logging import Handler
class FTStdErrStreamHandler(Handler):
def flush(self):
"""
Override Flush behaviour - we keep half of the configured capacity
otherwise, we have moments with "empty" logs.
"""
self.acquire()
try:
sys.stderr.flush()
finally:
self.release()
def emit(self, record):
try:
msg = self.format(record)
# Don't keep a reference to stderr - this can be problematic with progressbars.
sys.stderr.write(msg + '\n')
self.flush()
except RecursionError:
raise
except Exception:
self.handleError(record)

View File

@ -13,13 +13,13 @@ from math import ceil
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import progressbar
import rapidjson
from colorama import Fore, Style
from colorama import init as colorama_init
from joblib import Parallel, cpu_count, delayed, dump, load, wrap_non_picklable_objects
from joblib.externals import cloudpickle
from pandas import DataFrame
from rich.progress import (BarColumn, MofNCompleteColumn, Progress, TaskProgressColumn, TextColumn,
TimeElapsedColumn, TimeRemainingColumn)
from freqtrade.constants import DATETIME_PRINT_FORMAT, FTHYPT_FILEVERSION, LAST_BT_RESULT_FN, Config
from freqtrade.data.converter import trim_dataframes
@ -44,8 +44,6 @@ with warnings.catch_warnings():
from skopt import Optimizer
from skopt.space import Dimension
progressbar.streams.wrap_stderr()
progressbar.streams.wrap_stdout()
logger = logging.getLogger(__name__)
@ -520,29 +518,6 @@ class Hyperopt:
else:
return self.opt.ask(n_points=n_points), [False for _ in range(n_points)]
def get_progressbar_widgets(self):
if self.print_colorized:
widgets = [
' [Epoch ', progressbar.Counter(), ' of ', str(self.total_epochs),
' (', progressbar.Percentage(), ')] ',
progressbar.Bar(marker=progressbar.AnimatedMarker(
fill='\N{FULL BLOCK}',
fill_wrap=Fore.GREEN + '{}' + Fore.RESET,
marker_wrap=Style.BRIGHT + '{}' + Style.RESET_ALL,
)),
' [', progressbar.ETA(), ', ', progressbar.Timer(), ']',
]
else:
widgets = [
' [Epoch ', progressbar.Counter(), ' of ', str(self.total_epochs),
' (', progressbar.Percentage(), ')] ',
progressbar.Bar(marker=progressbar.AnimatedMarker(
fill='\N{FULL BLOCK}',
)),
' [', progressbar.ETA(), ', ', progressbar.Timer(), ']',
]
return widgets
def evaluate_result(self, val: Dict[str, Any], current: int, is_random: bool):
"""
Evaluate results returned from generate_optimizer
@ -602,11 +577,19 @@ class Hyperopt:
logger.info(f'Effective number of parallel workers used: {jobs}')
# Define progressbar
widgets = self.get_progressbar_widgets()
with progressbar.ProgressBar(
max_value=self.total_epochs, redirect_stdout=False, redirect_stderr=False,
widgets=widgets
with Progress(
TextColumn("[progress.description]{task.description}"),
BarColumn(bar_width=None),
MofNCompleteColumn(),
TaskProgressColumn(),
"",
TimeElapsedColumn(),
"",
TimeRemainingColumn(),
expand=True,
) as pbar:
task = pbar.add_task("Epochs", total=self.total_epochs)
start = 0
if self.analyze_per_epoch:
@ -616,7 +599,7 @@ class Hyperopt:
f_val0 = self.generate_optimizer(asked[0])
self.opt.tell(asked, [f_val0['loss']])
self.evaluate_result(f_val0, 1, is_random[0])
pbar.update(1)
pbar.update(task, advance=1)
start += 1
evals = ceil((self.total_epochs - start) / jobs)
@ -630,14 +613,12 @@ class Hyperopt:
f_val = self.run_optimizer_parallel(parallel, asked)
self.opt.tell(asked, [v['loss'] for v in f_val])
# Calculate progressbar outputs
for j, val in enumerate(f_val):
# Use human-friendly indexes here (starting from 1)
current = i * jobs + j + 1 + start
self.evaluate_result(val, current, is_random[j])
pbar.update(current)
pbar.update(task, advance=1)
except KeyboardInterrupt:
print('User interrupted..')

View File

@ -6,4 +6,3 @@ scipy==1.10.1
scikit-learn==1.1.3
scikit-optimize==0.9.0
filelock==3.11.0
progressbar2==4.2.0

View File

@ -20,6 +20,7 @@ jinja2==3.1.2
tables==3.8.0
blosc==1.11.1
joblib==1.2.0
rich==13.3.3
pyarrow==11.0.0; platform_machine != 'armv7l'
# find first, C search in arrays

View File

@ -8,7 +8,6 @@ hyperopt = [
'scikit-learn',
'scikit-optimize>=0.7.0',
'filelock',
'progressbar2',
]
freqai = [
@ -82,6 +81,7 @@ setup(
'numpy',
'pandas',
'joblib>=1.2.0',
'rich',
'pyarrow; platform_machine != "armv7l"',
'fastapi',
'pydantic>=1.8.0',

View File

@ -23,7 +23,8 @@ from freqtrade.configuration.load_config import (load_config_file, load_file, lo
from freqtrade.constants import DEFAULT_DB_DRYRUN_URL, DEFAULT_DB_PROD_URL, ENV_VAR_PREFIX
from freqtrade.enums import RunMode
from freqtrade.exceptions import OperationalException
from freqtrade.loggers import FTBufferingHandler, _set_loggers, setup_logging, setup_logging_pre
from freqtrade.loggers import (FTBufferingHandler, FTStdErrStreamHandler, _set_loggers,
setup_logging, setup_logging_pre)
from tests.conftest import (CURRENT_TEST_STRATEGY, log_has, log_has_re,
patched_configuration_load_config_file)
@ -658,7 +659,7 @@ def test_set_loggers_syslog():
setup_logging(config)
assert len(logger.handlers) == 3
assert [x for x in logger.handlers if type(x) == logging.handlers.SysLogHandler]
assert [x for x in logger.handlers if type(x) == logging.StreamHandler]
assert [x for x in logger.handlers if type(x) == FTStdErrStreamHandler]
assert [x for x in logger.handlers if type(x) == FTBufferingHandler]
# setting up logging again should NOT cause the loggers to be added a second time.
setup_logging(config)
@ -681,7 +682,7 @@ def test_set_loggers_Filehandler(tmpdir):
setup_logging(config)
assert len(logger.handlers) == 3
assert [x for x in logger.handlers if type(x) == logging.handlers.RotatingFileHandler]
assert [x for x in logger.handlers if type(x) == logging.StreamHandler]
assert [x for x in logger.handlers if type(x) == FTStdErrStreamHandler]
assert [x for x in logger.handlers if type(x) == FTBufferingHandler]
# setting up logging again should NOT cause the loggers to be added a second time.
setup_logging(config)
@ -706,7 +707,7 @@ def test_set_loggers_journald(mocker):
setup_logging(config)
assert len(logger.handlers) == 3
assert [x for x in logger.handlers if type(x).__name__ == "JournaldLogHandler"]
assert [x for x in logger.handlers if type(x) == logging.StreamHandler]
assert [x for x in logger.handlers if type(x) == FTStdErrStreamHandler]
# reset handlers to not break pytest
logger.handlers = orig_handlers