changed to ast_comments, added tests for comments.

This commit is contained in:
hippocritical 2023-02-17 21:01:09 +01:00
parent 69a63975c1
commit 06edc5c044
3 changed files with 78 additions and 39 deletions

View File

@ -1,10 +1,12 @@
import logging import logging
import os
import time import time
from typing import Any, Dict from typing import Any, Dict
from freqtrade.configuration import setup_utils_configuration from freqtrade.configuration import setup_utils_configuration
from freqtrade.enums import RunMode from freqtrade.enums import RunMode
from freqtrade.resolvers import StrategyResolver from freqtrade.resolvers import StrategyResolver
from freqtrade.strategy.strategyupdater import StrategyUpdater
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -17,24 +19,37 @@ def start_strategy_update(args: Dict[str, Any]) -> None:
:return: None :return: None
""" """
# Import here to avoid loading backtesting module when it's not used
from freqtrade.strategy.strategyupdater import StrategyUpdater
config = setup_utils_configuration(args, RunMode.UTIL_NO_EXCHANGE) config = setup_utils_configuration(args, RunMode.UTIL_NO_EXCHANGE)
strategy_objs = StrategyResolver.search_all_objects( strategy_objs = StrategyResolver.search_all_objects(
config, enum_failed=False, recursive=config.get('recursive_strategy_search', False)) config, enum_failed=False, recursive=config.get('recursive_strategy_search', False))
filtered_strategy_objs = [] filtered_strategy_objs = []
if hasattr(args, "strategy_list"):
for args_strategy in args['strategy_list']: for args_strategy in args['strategy_list']:
for strategy_obj in strategy_objs: for strategy_obj in strategy_objs:
if strategy_obj['name'] == args_strategy and strategy_obj not in filtered_strategy_objs: if (strategy_obj['name'] == args_strategy
and strategy_obj not in filtered_strategy_objs):
filtered_strategy_objs.append(strategy_obj) filtered_strategy_objs.append(strategy_obj)
break break
for filtered_strategy_obj in filtered_strategy_objs: for filtered_strategy_obj in filtered_strategy_objs:
start_conversion(filtered_strategy_obj, config)
else:
processed_locations = set()
for strategy_obj in strategy_objs:
if strategy_obj['location'] not in processed_locations:
processed_locations.add(strategy_obj['location'])
start_conversion(strategy_obj, config)
def start_conversion(strategy_obj, config):
# try:
print(f"Conversion of {os.path.basename(strategy_obj['location'])} started.")
instance_strategy_updater = StrategyUpdater() instance_strategy_updater = StrategyUpdater()
start = time.perf_counter() start = time.perf_counter()
instance_strategy_updater.start(config, filtered_strategy_obj) instance_strategy_updater.start(config, strategy_obj)
elapsed = time.perf_counter() - start elapsed = time.perf_counter() - start
print(f"Conversion of {filtered_strategy_obj['name']} took {elapsed:.1f} seconds.") print(f"Conversion of {os.path.basename(strategy_obj['location'])} took {elapsed:.1f} seconds.")
# except:
# pass

View File

@ -1,9 +1,8 @@
import ast
import os import os
import shutil import shutil
from pathlib import Path from pathlib import Path
import astor import ast_comments
class StrategyUpdater: class StrategyUpdater:
@ -76,7 +75,7 @@ class StrategyUpdater:
# define the function to update the code # define the function to update the code
def update_code(self, code): def update_code(self, code):
# parse the code into an AST # parse the code into an AST
tree = ast.parse(code) tree = ast_comments.parse(code)
# use the AST to update the code # use the AST to update the code
updated_code = self.modify_ast(tree) updated_code = self.modify_ast(tree)
@ -90,38 +89,43 @@ class StrategyUpdater:
NameUpdater().visit(tree) NameUpdater().visit(tree)
# first fix the comments, so it understands "\n" properly inside multi line comments. # first fix the comments, so it understands "\n" properly inside multi line comments.
ast.fix_missing_locations(tree) ast_comments.fix_missing_locations(tree)
ast.increment_lineno(tree, n=1) ast_comments.increment_lineno(tree, n=1)
# generate the new code from the updated AST # generate the new code from the updated AST
# without indent {} parameters would just be written straight one after the other. # without indent {} parameters would just be written straight one after the other.
return astor.to_source(tree)
# ast_comments would be amazing since this is the only solution that carries over comments,
# but it does currently not have an unparse function, hopefully in the future ... !
# return ast_comments.unparse(tree)
return ast_comments.unparse(tree)
# Here we go through each respective node, slice, elt, key ... to replace outdated entries. # Here we go through each respective node, slice, elt, key ... to replace outdated entries.
class NameUpdater(ast.NodeTransformer): class NameUpdater(ast_comments.NodeTransformer):
def generic_visit(self, node): def generic_visit(self, node):
# space is not yet transferred from buy/sell to entry/exit and thereby has to be skipped. # space is not yet transferred from buy/sell to entry/exit and thereby has to be skipped.
if isinstance(node, ast.keyword): if isinstance(node, ast_comments.keyword):
if node.arg == "space": if node.arg == "space":
return node return node
# from here on this is the original function. # from here on this is the original function.
for field, old_value in ast.iter_fields(node): for field, old_value in ast_comments.iter_fields(node):
if isinstance(old_value, list): if isinstance(old_value, list):
new_values = [] new_values = []
for value in old_value: for value in old_value:
if isinstance(value, ast.AST): if isinstance(value, ast_comments.AST):
value = self.visit(value) value = self.visit(value)
if value is None: if value is None:
continue continue
elif not isinstance(value, ast.AST): elif not isinstance(value, ast_comments.AST):
new_values.extend(value) new_values.extend(value)
continue continue
new_values.append(value) new_values.append(value)
old_value[:] = new_values old_value[:] = new_values
elif isinstance(old_value, ast.AST): elif isinstance(old_value, ast_comments.AST):
new_node = self.visit(old_value) new_node = self.visit(old_value)
if new_node is None: if new_node is None:
delattr(node, field) delattr(node, field)
@ -163,8 +167,8 @@ class NameUpdater(ast.NodeTransformer):
# node.module = "freqtrade.strategy" # node.module = "freqtrade.strategy"
return node return node
def visit_If(self, node: ast.If): def visit_If(self, node: ast_comments.If):
for child in ast.iter_child_nodes(node): for child in ast_comments.iter_child_nodes(node):
self.visit(child) self.visit(child)
return node return node
@ -175,7 +179,7 @@ class NameUpdater(ast.NodeTransformer):
def visit_Attribute(self, node): def visit_Attribute(self, node):
if ( if (
isinstance(node.value, ast.Name) isinstance(node.value, ast_comments.Name)
and node.value.id == 'trades' and node.value.id == 'trades'
and node.attr == 'nr_of_successful_buys' and node.attr == 'nr_of_successful_buys'
): ):
@ -184,33 +188,33 @@ class NameUpdater(ast.NodeTransformer):
def visit_ClassDef(self, node): def visit_ClassDef(self, node):
# check if the class is derived from IStrategy # check if the class is derived from IStrategy
if any(isinstance(base, ast.Name) and if any(isinstance(base, ast_comments.Name) and
base.id == 'IStrategy' for base in node.bases): base.id == 'IStrategy' for base in node.bases):
# check if the INTERFACE_VERSION variable exists # check if the INTERFACE_VERSION variable exists
has_interface_version = any( has_interface_version = any(
isinstance(child, ast.Assign) and isinstance(child, ast_comments.Assign) and
isinstance(child.targets[0], ast.Name) and isinstance(child.targets[0], ast_comments.Name) and
child.targets[0].id == 'INTERFACE_VERSION' child.targets[0].id == 'INTERFACE_VERSION'
for child in node.body for child in node.body
) )
# if the INTERFACE_VERSION variable does not exist, add it as the first child # if the INTERFACE_VERSION variable does not exist, add it as the first child
if not has_interface_version: if not has_interface_version:
node.body.insert(0, ast.parse('INTERFACE_VERSION = 3').body[0]) node.body.insert(0, ast_comments.parse('INTERFACE_VERSION = 3').body[0])
# otherwise, update its value to 3 # otherwise, update its value to 3
else: else:
for child in node.body: for child in node.body:
if ( if (
isinstance(child, ast.Assign) isinstance(child, ast_comments.Assign)
and isinstance(child.targets[0], ast.Name) and isinstance(child.targets[0], ast_comments.Name)
and child.targets[0].id == 'INTERFACE_VERSION' and child.targets[0].id == 'INTERFACE_VERSION'
): ):
child.value = ast.parse('3').body[0].value child.value = ast_comments.parse('3').body[0].value
self.generic_visit(node) self.generic_visit(node)
return node return node
def visit_Subscript(self, node): def visit_Subscript(self, node):
if isinstance(node.slice, ast.Constant): if isinstance(node.slice, ast_comments.Constant):
if node.slice.value in StrategyUpdater.rename_dict: if node.slice.value in StrategyUpdater.rename_dict:
# Replace the slice attributes with the values from rename_dict # Replace the slice attributes with the values from rename_dict
node.slice.value = StrategyUpdater.rename_dict[node.slice.value] node.slice.value = StrategyUpdater.rename_dict[node.slice.value]
@ -232,12 +236,12 @@ class NameUpdater(ast.NodeTransformer):
# sub function again needed since the structure itself is highly flexible ... # sub function again needed since the structure itself is highly flexible ...
def visit_elt(self, elt): def visit_elt(self, elt):
if isinstance(elt, ast.Constant) and elt.value in StrategyUpdater.rename_dict: if isinstance(elt, ast_comments.Constant) and elt.value in StrategyUpdater.rename_dict:
elt.value = StrategyUpdater.rename_dict[elt.value] elt.value = StrategyUpdater.rename_dict[elt.value]
if hasattr(elt, "elts"): if hasattr(elt, "elts"):
self.visit_elts(elt.elts) self.visit_elts(elt.elts)
if hasattr(elt, "args"): if hasattr(elt, "args"):
if isinstance(elt.args, ast.arguments): if isinstance(elt.args, ast_comments.arguments):
self.visit_elts(elt.args) self.visit_elts(elt.args)
else: else:
for arg in elt.args: for arg in elt.args:

View File

@ -65,7 +65,23 @@ sell_reason == 'sell_signal'
sell_reason == 'force_sell' sell_reason == 'force_sell'
sell_reason == 'emergency_sell' sell_reason == 'emergency_sell'
""") """)
modified_code9 = instance_strategy_updater.update_code("""
# This is the 1st comment
import talib.abstract as ta
# This is the 2nd comment
import freqtrade.vendor.qtpylib.indicators as qtpylib
class someStrategy(IStrategy):
# This is the 3rd comment
# This attribute will be overridden if the config file contains "minimal_roi"
minimal_roi = {
"0": 0.50
}
# This is the 4th comment
stoploss = -0.1
""")
# currently still missing: # currently still missing:
# Webhook terminology, Telegram notification settings, Strategy/Config settings # Webhook terminology, Telegram notification settings, Strategy/Config settings
@ -108,3 +124,7 @@ sell_reason == 'emergency_sell'
assert "exit_reason" in modified_code8 assert "exit_reason" in modified_code8
assert "force_exit" in modified_code8 assert "force_exit" in modified_code8
assert "emergency_exit" in modified_code8 assert "emergency_exit" in modified_code8
assert "This is the 1st comment" in modified_code9
assert "This is the 2nd comment" in modified_code9
assert "This is the 3rd comment" in modified_code9