ruff format: update persistence

This commit is contained in:
Matthias 2024-05-12 16:48:11 +02:00
parent 5783a44c86
commit cebbe0121e
9 changed files with 807 additions and 665 deletions

View File

@ -1,4 +1,3 @@
from sqlalchemy.orm import DeclarativeBase, Session, scoped_session from sqlalchemy.orm import DeclarativeBase, Session, scoped_session

View File

@ -23,16 +23,17 @@ class _CustomData(ModelBase):
- One trade can have many metadata entries - One trade can have many metadata entries
- One metadata entry can only be associated with one Trade - One metadata entry can only be associated with one Trade
""" """
__tablename__ = 'trade_custom_data'
__tablename__ = "trade_custom_data"
__allow_unmapped__ = True __allow_unmapped__ = True
session: ClassVar[SessionType] session: ClassVar[SessionType]
# Uniqueness should be ensured over pair, order_id # Uniqueness should be ensured over pair, order_id
# its likely that order_id is unique per Pair on some exchanges. # its likely that order_id is unique per Pair on some exchanges.
__table_args__ = (UniqueConstraint('ft_trade_id', 'cd_key', name="_trade_id_cd_key"),) __table_args__ = (UniqueConstraint("ft_trade_id", "cd_key", name="_trade_id_cd_key"),)
id = mapped_column(Integer, primary_key=True) id = mapped_column(Integer, primary_key=True)
ft_trade_id = mapped_column(Integer, ForeignKey('trades.id'), index=True) ft_trade_id = mapped_column(Integer, ForeignKey("trades.id"), index=True)
trade = relationship("Trade", back_populates="custom_data") trade = relationship("Trade", back_populates="custom_data")
@ -46,17 +47,22 @@ class _CustomData(ModelBase):
value: Any = None value: Any = None
def __repr__(self): def __repr__(self):
create_time = (self.created_at.strftime(DATETIME_PRINT_FORMAT) create_time = (
if self.created_at is not None else None) self.created_at.strftime(DATETIME_PRINT_FORMAT) if self.created_at is not None else None
update_time = (self.updated_at.strftime(DATETIME_PRINT_FORMAT) )
if self.updated_at is not None else None) update_time = (
return (f'CustomData(id={self.id}, key={self.cd_key}, type={self.cd_type}, ' + self.updated_at.strftime(DATETIME_PRINT_FORMAT) if self.updated_at is not None else None
f'value={self.cd_value}, trade_id={self.ft_trade_id}, created={create_time}, ' + )
f'updated={update_time})') return (
f"CustomData(id={self.id}, key={self.cd_key}, type={self.cd_type}, "
+ f"value={self.cd_value}, trade_id={self.ft_trade_id}, created={create_time}, "
+ f"updated={update_time})"
)
@classmethod @classmethod
def query_cd(cls, key: Optional[str] = None, def query_cd(
trade_id: Optional[int] = None) -> Sequence['_CustomData']: cls, key: Optional[str] = None, trade_id: Optional[int] = None
) -> Sequence["_CustomData"]:
""" """
Get all CustomData, if trade_id is not specified Get all CustomData, if trade_id is not specified
return will be for generic values not tied to a trade return will be for generic values not tied to a trade
@ -80,17 +86,17 @@ class CustomDataWrapper:
use_db = True use_db = True
custom_data: List[_CustomData] = [] custom_data: List[_CustomData] = []
unserialized_types = ['bool', 'float', 'int', 'str'] unserialized_types = ["bool", "float", "int", "str"]
@staticmethod @staticmethod
def _convert_custom_data(data: _CustomData) -> _CustomData: def _convert_custom_data(data: _CustomData) -> _CustomData:
if data.cd_type in CustomDataWrapper.unserialized_types: if data.cd_type in CustomDataWrapper.unserialized_types:
data.value = data.cd_value data.value = data.cd_value
if data.cd_type == 'bool': if data.cd_type == "bool":
data.value = data.cd_value.lower() == 'true' data.value = data.cd_value.lower() == "true"
elif data.cd_type == 'int': elif data.cd_type == "int":
data.value = int(data.cd_value) data.value = int(data.cd_value)
elif data.cd_type == 'float': elif data.cd_type == "float":
data.value = float(data.cd_value) data.value = float(data.cd_value)
else: else:
data.value = json.loads(data.cd_value) data.value = json.loads(data.cd_value)
@ -111,31 +117,32 @@ class CustomDataWrapper:
@staticmethod @staticmethod
def get_custom_data(*, trade_id: int, key: Optional[str] = None) -> List[_CustomData]: def get_custom_data(*, trade_id: int, key: Optional[str] = None) -> List[_CustomData]:
if CustomDataWrapper.use_db: if CustomDataWrapper.use_db:
filters = [ filters = [
_CustomData.ft_trade_id == trade_id, _CustomData.ft_trade_id == trade_id,
] ]
if key is not None: if key is not None:
filters.append(_CustomData.cd_key.ilike(key)) filters.append(_CustomData.cd_key.ilike(key))
filtered_custom_data = _CustomData.session.scalars(select(_CustomData).filter( filtered_custom_data = _CustomData.session.scalars(
*filters)).all() select(_CustomData).filter(*filters)
).all()
else: else:
filtered_custom_data = [ filtered_custom_data = [
data_entry for data_entry in CustomDataWrapper.custom_data data_entry
for data_entry in CustomDataWrapper.custom_data
if (data_entry.ft_trade_id == trade_id) if (data_entry.ft_trade_id == trade_id)
] ]
if key is not None: if key is not None:
filtered_custom_data = [ filtered_custom_data = [
data_entry for data_entry in filtered_custom_data data_entry
for data_entry in filtered_custom_data
if (data_entry.cd_key.casefold() == key.casefold()) if (data_entry.cd_key.casefold() == key.casefold())
] ]
return [CustomDataWrapper._convert_custom_data(d) for d in filtered_custom_data] return [CustomDataWrapper._convert_custom_data(d) for d in filtered_custom_data]
@staticmethod @staticmethod
def set_custom_data(trade_id: int, key: str, value: Any) -> None: def set_custom_data(trade_id: int, key: str, value: Any) -> None:
value_type = type(value).__name__ value_type = type(value).__name__
if value_type not in CustomDataWrapper.unserialized_types: if value_type not in CustomDataWrapper.unserialized_types:

View File

@ -12,22 +12,23 @@ ValueTypes = Union[str, datetime, float, int]
class ValueTypesEnum(str, Enum): class ValueTypesEnum(str, Enum):
STRING = 'str' STRING = "str"
DATETIME = 'datetime' DATETIME = "datetime"
FLOAT = 'float' FLOAT = "float"
INT = 'int' INT = "int"
class KeyStoreKeys(str, Enum): class KeyStoreKeys(str, Enum):
BOT_START_TIME = 'bot_start_time' BOT_START_TIME = "bot_start_time"
STARTUP_TIME = 'startup_time' STARTUP_TIME = "startup_time"
class _KeyValueStoreModel(ModelBase): class _KeyValueStoreModel(ModelBase):
""" """
Pair Locks database model. Pair Locks database model.
""" """
__tablename__ = 'KeyValueStore'
__tablename__ = "KeyValueStore"
session: ClassVar[SessionType] session: ClassVar[SessionType]
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[int] = mapped_column(primary_key=True)
@ -56,8 +57,11 @@ class KeyValueStore:
:param key: Key to store the value for - can be used in get-value to retrieve the key :param key: Key to store the value for - can be used in get-value to retrieve the key
:param value: Value to store - can be str, datetime, float or int :param value: Value to store - can be str, datetime, float or int
""" """
kv = _KeyValueStoreModel.session.query(_KeyValueStoreModel).filter( kv = (
_KeyValueStoreModel.key == key).first() _KeyValueStoreModel.session.query(_KeyValueStoreModel)
.filter(_KeyValueStoreModel.key == key)
.first()
)
if kv is None: if kv is None:
kv = _KeyValueStoreModel(key=key) kv = _KeyValueStoreModel(key=key)
if isinstance(value, str): if isinstance(value, str):
@ -73,7 +77,7 @@ class KeyValueStore:
kv.value_type = ValueTypesEnum.INT kv.value_type = ValueTypesEnum.INT
kv.int_value = value kv.int_value = value
else: else:
raise ValueError(f'Unknown value type {kv.value_type}') raise ValueError(f"Unknown value type {kv.value_type}")
_KeyValueStoreModel.session.add(kv) _KeyValueStoreModel.session.add(kv)
_KeyValueStoreModel.session.commit() _KeyValueStoreModel.session.commit()
@ -83,8 +87,11 @@ class KeyValueStore:
Delete the value for the given key. Delete the value for the given key.
:param key: Key to delete the value for :param key: Key to delete the value for
""" """
kv = _KeyValueStoreModel.session.query(_KeyValueStoreModel).filter( kv = (
_KeyValueStoreModel.key == key).first() _KeyValueStoreModel.session.query(_KeyValueStoreModel)
.filter(_KeyValueStoreModel.key == key)
.first()
)
if kv is not None: if kv is not None:
_KeyValueStoreModel.session.delete(kv) _KeyValueStoreModel.session.delete(kv)
_KeyValueStoreModel.session.commit() _KeyValueStoreModel.session.commit()
@ -95,8 +102,11 @@ class KeyValueStore:
Get the value for the given key. Get the value for the given key.
:param key: Key to get the value for :param key: Key to get the value for
""" """
kv = _KeyValueStoreModel.session.query(_KeyValueStoreModel).filter( kv = (
_KeyValueStoreModel.key == key).first() _KeyValueStoreModel.session.query(_KeyValueStoreModel)
.filter(_KeyValueStoreModel.key == key)
.first()
)
if kv is None: if kv is None:
return None return None
if kv.value_type == ValueTypesEnum.STRING: if kv.value_type == ValueTypesEnum.STRING:
@ -108,7 +118,7 @@ class KeyValueStore:
if kv.value_type == ValueTypesEnum.INT: if kv.value_type == ValueTypesEnum.INT:
return kv.int_value return kv.int_value
# This should never happen unless someone messed with the database manually # This should never happen unless someone messed with the database manually
raise ValueError(f'Unknown value type {kv.value_type}') # pragma: no cover raise ValueError(f"Unknown value type {kv.value_type}") # pragma: no cover
@staticmethod @staticmethod
def get_string_value(key: KeyStoreKeys) -> Optional[str]: def get_string_value(key: KeyStoreKeys) -> Optional[str]:
@ -116,9 +126,14 @@ class KeyValueStore:
Get the value for the given key. Get the value for the given key.
:param key: Key to get the value for :param key: Key to get the value for
""" """
kv = _KeyValueStoreModel.session.query(_KeyValueStoreModel).filter( kv = (
_KeyValueStoreModel.key == key, _KeyValueStoreModel.session.query(_KeyValueStoreModel)
_KeyValueStoreModel.value_type == ValueTypesEnum.STRING).first() .filter(
_KeyValueStoreModel.key == key,
_KeyValueStoreModel.value_type == ValueTypesEnum.STRING,
)
.first()
)
if kv is None: if kv is None:
return None return None
return kv.string_value return kv.string_value
@ -129,9 +144,14 @@ class KeyValueStore:
Get the value for the given key. Get the value for the given key.
:param key: Key to get the value for :param key: Key to get the value for
""" """
kv = _KeyValueStoreModel.session.query(_KeyValueStoreModel).filter( kv = (
_KeyValueStoreModel.key == key, _KeyValueStoreModel.session.query(_KeyValueStoreModel)
_KeyValueStoreModel.value_type == ValueTypesEnum.DATETIME).first() .filter(
_KeyValueStoreModel.key == key,
_KeyValueStoreModel.value_type == ValueTypesEnum.DATETIME,
)
.first()
)
if kv is None or kv.datetime_value is None: if kv is None or kv.datetime_value is None:
return None return None
return kv.datetime_value.replace(tzinfo=timezone.utc) return kv.datetime_value.replace(tzinfo=timezone.utc)
@ -142,9 +162,14 @@ class KeyValueStore:
Get the value for the given key. Get the value for the given key.
:param key: Key to get the value for :param key: Key to get the value for
""" """
kv = _KeyValueStoreModel.session.query(_KeyValueStoreModel).filter( kv = (
_KeyValueStoreModel.key == key, _KeyValueStoreModel.session.query(_KeyValueStoreModel)
_KeyValueStoreModel.value_type == ValueTypesEnum.FLOAT).first() .filter(
_KeyValueStoreModel.key == key,
_KeyValueStoreModel.value_type == ValueTypesEnum.FLOAT,
)
.first()
)
if kv is None: if kv is None:
return None return None
return kv.float_value return kv.float_value
@ -155,9 +180,13 @@ class KeyValueStore:
Get the value for the given key. Get the value for the given key.
:param key: Key to get the value for :param key: Key to get the value for
""" """
kv = _KeyValueStoreModel.session.query(_KeyValueStoreModel).filter( kv = (
_KeyValueStoreModel.key == key, _KeyValueStoreModel.session.query(_KeyValueStoreModel)
_KeyValueStoreModel.value_type == ValueTypesEnum.INT).first() .filter(
_KeyValueStoreModel.key == key, _KeyValueStoreModel.value_type == ValueTypesEnum.INT
)
.first()
)
if kv is None: if kv is None:
return None return None
return kv.int_value return kv.int_value
@ -168,12 +197,13 @@ def set_startup_time():
sets bot_start_time to the first trade open date - or "now" on new databases. sets bot_start_time to the first trade open date - or "now" on new databases.
sets startup_time to "now" sets startup_time to "now"
""" """
st = KeyValueStore.get_value('bot_start_time') st = KeyValueStore.get_value("bot_start_time")
if st is None: if st is None:
from freqtrade.persistence import Trade from freqtrade.persistence import Trade
t = Trade.session.query(Trade).order_by(Trade.open_date.asc()).first() t = Trade.session.query(Trade).order_by(Trade.open_date.asc()).first()
if t is not None: if t is not None:
KeyValueStore.store_value('bot_start_time', t.open_date_utc) KeyValueStore.store_value("bot_start_time", t.open_date_utc)
else: else:
KeyValueStore.store_value('bot_start_time', datetime.now(timezone.utc)) KeyValueStore.store_value("bot_start_time", datetime.now(timezone.utc))
KeyValueStore.store_value('startup_time', datetime.now(timezone.utc)) KeyValueStore.store_value("startup_time", datetime.now(timezone.utc))

View File

@ -25,8 +25,8 @@ def get_column_def(columns: List, column: str, default: str) -> str:
def get_backup_name(tabs: List[str], backup_prefix: str): def get_backup_name(tabs: List[str], backup_prefix: str):
table_back_name = backup_prefix table_back_name = backup_prefix
for i, table_back_name in enumerate(tabs): for i, table_back_name in enumerate(tabs):
table_back_name = f'{backup_prefix}{i}' table_back_name = f"{backup_prefix}{i}"
logger.debug(f'trying {table_back_name}') logger.debug(f"trying {table_back_name}")
return table_back_name return table_back_name
@ -35,21 +35,22 @@ def get_last_sequence_ids(engine, trade_back_name: str, order_back_name: str):
order_id: Optional[int] = None order_id: Optional[int] = None
trade_id: Optional[int] = None trade_id: Optional[int] = None
if engine.name == 'postgresql': if engine.name == "postgresql":
with engine.begin() as connection: with engine.begin() as connection:
trade_id = connection.execute(text("select nextval('trades_id_seq')")).fetchone()[0] trade_id = connection.execute(text("select nextval('trades_id_seq')")).fetchone()[0]
order_id = connection.execute(text("select nextval('orders_id_seq')")).fetchone()[0] order_id = connection.execute(text("select nextval('orders_id_seq')")).fetchone()[0]
with engine.begin() as connection: with engine.begin() as connection:
connection.execute(text( connection.execute(
f"ALTER SEQUENCE orders_id_seq rename to {order_back_name}_id_seq_bak")) text(f"ALTER SEQUENCE orders_id_seq rename to {order_back_name}_id_seq_bak")
connection.execute(text( )
f"ALTER SEQUENCE trades_id_seq rename to {trade_back_name}_id_seq_bak")) connection.execute(
text(f"ALTER SEQUENCE trades_id_seq rename to {trade_back_name}_id_seq_bak")
)
return order_id, trade_id return order_id, trade_id
def set_sequence_ids(engine, order_id, trade_id, pairlock_id=None): def set_sequence_ids(engine, order_id, trade_id, pairlock_id=None):
if engine.name == "postgresql":
if engine.name == 'postgresql':
with engine.begin() as connection: with engine.begin() as connection:
if order_id: if order_id:
connection.execute(text(f"ALTER SEQUENCE orders_id_seq RESTART WITH {order_id}")) connection.execute(text(f"ALTER SEQUENCE orders_id_seq RESTART WITH {order_id}"))
@ -57,84 +58,95 @@ def set_sequence_ids(engine, order_id, trade_id, pairlock_id=None):
connection.execute(text(f"ALTER SEQUENCE trades_id_seq RESTART WITH {trade_id}")) connection.execute(text(f"ALTER SEQUENCE trades_id_seq RESTART WITH {trade_id}"))
if pairlock_id: if pairlock_id:
connection.execute( connection.execute(
text(f"ALTER SEQUENCE pairlocks_id_seq RESTART WITH {pairlock_id}")) text(f"ALTER SEQUENCE pairlocks_id_seq RESTART WITH {pairlock_id}")
)
def drop_index_on_table(engine, inspector, table_bak_name): def drop_index_on_table(engine, inspector, table_bak_name):
with engine.begin() as connection: with engine.begin() as connection:
# drop indexes on backup table in new session # drop indexes on backup table in new session
for index in inspector.get_indexes(table_bak_name): for index in inspector.get_indexes(table_bak_name):
if engine.name == 'mysql': if engine.name == "mysql":
connection.execute(text(f"drop index {index['name']} on {table_bak_name}")) connection.execute(text(f"drop index {index['name']} on {table_bak_name}"))
else: else:
connection.execute(text(f"drop index {index['name']}")) connection.execute(text(f"drop index {index['name']}"))
def migrate_trades_and_orders_table( def migrate_trades_and_orders_table(
decl_base, inspector, engine, decl_base,
trade_back_name: str, cols: List, inspector,
order_back_name: str, cols_order: List): engine,
base_currency = get_column_def(cols, 'base_currency', 'null') trade_back_name: str,
stake_currency = get_column_def(cols, 'stake_currency', 'null') cols: List,
fee_open = get_column_def(cols, 'fee_open', 'fee') order_back_name: str,
fee_open_cost = get_column_def(cols, 'fee_open_cost', 'null') cols_order: List,
fee_open_currency = get_column_def(cols, 'fee_open_currency', 'null') ):
fee_close = get_column_def(cols, 'fee_close', 'fee') base_currency = get_column_def(cols, "base_currency", "null")
fee_close_cost = get_column_def(cols, 'fee_close_cost', 'null') stake_currency = get_column_def(cols, "stake_currency", "null")
fee_close_currency = get_column_def(cols, 'fee_close_currency', 'null') fee_open = get_column_def(cols, "fee_open", "fee")
open_rate_requested = get_column_def(cols, 'open_rate_requested', 'null') fee_open_cost = get_column_def(cols, "fee_open_cost", "null")
close_rate_requested = get_column_def(cols, 'close_rate_requested', 'null') fee_open_currency = get_column_def(cols, "fee_open_currency", "null")
stop_loss = get_column_def(cols, 'stop_loss', '0.0') fee_close = get_column_def(cols, "fee_close", "fee")
stop_loss_pct = get_column_def(cols, 'stop_loss_pct', 'null') fee_close_cost = get_column_def(cols, "fee_close_cost", "null")
initial_stop_loss = get_column_def(cols, 'initial_stop_loss', '0.0') fee_close_currency = get_column_def(cols, "fee_close_currency", "null")
initial_stop_loss_pct = get_column_def(cols, 'initial_stop_loss_pct', 'null') open_rate_requested = get_column_def(cols, "open_rate_requested", "null")
close_rate_requested = get_column_def(cols, "close_rate_requested", "null")
stop_loss = get_column_def(cols, "stop_loss", "0.0")
stop_loss_pct = get_column_def(cols, "stop_loss_pct", "null")
initial_stop_loss = get_column_def(cols, "initial_stop_loss", "0.0")
initial_stop_loss_pct = get_column_def(cols, "initial_stop_loss_pct", "null")
is_stop_loss_trailing = get_column_def( is_stop_loss_trailing = get_column_def(
cols, 'is_stop_loss_trailing', cols,
f'coalesce({stop_loss_pct}, 0.0) <> coalesce({initial_stop_loss_pct}, 0.0)') "is_stop_loss_trailing",
max_rate = get_column_def(cols, 'max_rate', '0.0') f"coalesce({stop_loss_pct}, 0.0) <> coalesce({initial_stop_loss_pct}, 0.0)",
min_rate = get_column_def(cols, 'min_rate', 'null') )
exit_reason = get_column_def(cols, 'sell_reason', get_column_def(cols, 'exit_reason', 'null')) max_rate = get_column_def(cols, "max_rate", "0.0")
strategy = get_column_def(cols, 'strategy', 'null') min_rate = get_column_def(cols, "min_rate", "null")
enter_tag = get_column_def(cols, 'buy_tag', get_column_def(cols, 'enter_tag', 'null')) exit_reason = get_column_def(cols, "sell_reason", get_column_def(cols, "exit_reason", "null"))
realized_profit = get_column_def(cols, 'realized_profit', '0.0') strategy = get_column_def(cols, "strategy", "null")
enter_tag = get_column_def(cols, "buy_tag", get_column_def(cols, "enter_tag", "null"))
realized_profit = get_column_def(cols, "realized_profit", "0.0")
trading_mode = get_column_def(cols, 'trading_mode', 'null') trading_mode = get_column_def(cols, "trading_mode", "null")
# Leverage Properties # Leverage Properties
leverage = get_column_def(cols, 'leverage', '1.0') leverage = get_column_def(cols, "leverage", "1.0")
liquidation_price = get_column_def(cols, 'liquidation_price', liquidation_price = get_column_def(
get_column_def(cols, 'isolated_liq', 'null')) cols, "liquidation_price", get_column_def(cols, "isolated_liq", "null")
)
# sqlite does not support literals for booleans # sqlite does not support literals for booleans
if engine.name == 'postgresql': if engine.name == "postgresql":
is_short = get_column_def(cols, 'is_short', 'false') is_short = get_column_def(cols, "is_short", "false")
else: else:
is_short = get_column_def(cols, 'is_short', '0') is_short = get_column_def(cols, "is_short", "0")
# Futures Properties # Futures Properties
interest_rate = get_column_def(cols, 'interest_rate', '0.0') interest_rate = get_column_def(cols, "interest_rate", "0.0")
funding_fees = get_column_def(cols, 'funding_fees', '0.0') funding_fees = get_column_def(cols, "funding_fees", "0.0")
funding_fee_running = get_column_def(cols, 'funding_fee_running', 'null') funding_fee_running = get_column_def(cols, "funding_fee_running", "null")
max_stake_amount = get_column_def(cols, 'max_stake_amount', 'stake_amount') max_stake_amount = get_column_def(cols, "max_stake_amount", "stake_amount")
# If ticker-interval existed use that, else null. # If ticker-interval existed use that, else null.
if has_column(cols, 'ticker_interval'): if has_column(cols, "ticker_interval"):
timeframe = get_column_def(cols, 'timeframe', 'ticker_interval') timeframe = get_column_def(cols, "timeframe", "ticker_interval")
else: else:
timeframe = get_column_def(cols, 'timeframe', 'null') timeframe = get_column_def(cols, "timeframe", "null")
open_trade_value = get_column_def(cols, 'open_trade_value', open_trade_value = get_column_def(
f'amount * open_rate * (1 + {fee_open})') cols, "open_trade_value", f"amount * open_rate * (1 + {fee_open})"
)
close_profit_abs = get_column_def( close_profit_abs = get_column_def(
cols, 'close_profit_abs', cols, "close_profit_abs", f"(amount * close_rate * (1 - {fee_close})) - {open_trade_value}"
f"(amount * close_rate * (1 - {fee_close})) - {open_trade_value}") )
exit_order_status = get_column_def(cols, 'exit_order_status', exit_order_status = get_column_def(
get_column_def(cols, 'sell_order_status', 'null')) cols, "exit_order_status", get_column_def(cols, "sell_order_status", "null")
amount_requested = get_column_def(cols, 'amount_requested', 'amount') )
amount_requested = get_column_def(cols, "amount_requested", "amount")
amount_precision = get_column_def(cols, 'amount_precision', 'null') amount_precision = get_column_def(cols, "amount_precision", "null")
price_precision = get_column_def(cols, 'price_precision', 'null') price_precision = get_column_def(cols, "price_precision", "null")
precision_mode = get_column_def(cols, 'precision_mode', 'null') precision_mode = get_column_def(cols, "precision_mode", "null")
contract_size = get_column_def(cols, 'contract_size', 'null') contract_size = get_column_def(cols, "contract_size", "null")
# Schema migration necessary # Schema migration necessary
with engine.begin() as connection: with engine.begin() as connection:
@ -151,7 +163,8 @@ def migrate_trades_and_orders_table(
# Copy data back - following the correct schema # Copy data back - following the correct schema
with engine.begin() as connection: with engine.begin() as connection:
connection.execute(text(f"""insert into trades connection.execute(
text(f"""insert into trades
(id, exchange, pair, base_currency, stake_currency, is_open, (id, exchange, pair, base_currency, stake_currency, is_open,
fee_open, fee_open_cost, fee_open_currency, fee_open, fee_open_cost, fee_open_currency,
fee_close, fee_close_cost, fee_close_currency, open_rate, fee_close, fee_close_cost, fee_close_currency, open_rate,
@ -196,7 +209,8 @@ def migrate_trades_and_orders_table(
{precision_mode} precision_mode, {contract_size} contract_size, {precision_mode} precision_mode, {contract_size} contract_size,
{max_stake_amount} max_stake_amount {max_stake_amount} max_stake_amount
from {trade_back_name} from {trade_back_name}
""")) """)
)
migrate_orders_table(engine, order_back_name, cols_order) migrate_orders_table(engine, order_back_name, cols_order)
set_sequence_ids(engine, order_id, trade_id) set_sequence_ids(engine, order_id, trade_id)
@ -212,19 +226,19 @@ def drop_orders_table(engine, table_back_name: str):
def migrate_orders_table(engine, table_back_name: str, cols_order: List): def migrate_orders_table(engine, table_back_name: str, cols_order: List):
ft_fee_base = get_column_def(cols_order, "ft_fee_base", "null")
ft_fee_base = get_column_def(cols_order, 'ft_fee_base', 'null') average = get_column_def(cols_order, "average", "null")
average = get_column_def(cols_order, 'average', 'null') stop_price = get_column_def(cols_order, "stop_price", "null")
stop_price = get_column_def(cols_order, 'stop_price', 'null') funding_fee = get_column_def(cols_order, "funding_fee", "0.0")
funding_fee = get_column_def(cols_order, 'funding_fee', '0.0') ft_amount = get_column_def(cols_order, "ft_amount", "coalesce(amount, 0.0)")
ft_amount = get_column_def(cols_order, 'ft_amount', 'coalesce(amount, 0.0)') ft_price = get_column_def(cols_order, "ft_price", "coalesce(price, 0.0)")
ft_price = get_column_def(cols_order, 'ft_price', 'coalesce(price, 0.0)') ft_cancel_reason = get_column_def(cols_order, "ft_cancel_reason", "null")
ft_cancel_reason = get_column_def(cols_order, 'ft_cancel_reason', 'null') ft_order_tag = get_column_def(cols_order, "ft_order_tag", "null")
ft_order_tag = get_column_def(cols_order, 'ft_order_tag', 'null')
# sqlite does not support literals for booleans # sqlite does not support literals for booleans
with engine.begin() as connection: with engine.begin() as connection:
connection.execute(text(f""" connection.execute(
text(f"""
insert into orders (id, ft_trade_id, ft_order_side, ft_pair, ft_is_open, order_id, insert into orders (id, ft_trade_id, ft_order_side, ft_pair, ft_is_open, order_id,
status, symbol, order_type, side, price, amount, filled, average, remaining, cost, status, symbol, order_type, side, price, amount, filled, average, remaining, cost,
stop_price, order_date, order_filled_date, order_update_date, ft_fee_base, funding_fee, stop_price, order_date, order_filled_date, order_update_date, ft_fee_base, funding_fee,
@ -237,36 +251,36 @@ def migrate_orders_table(engine, table_back_name: str, cols_order: List):
{ft_amount} ft_amount, {ft_price} ft_price, {ft_cancel_reason} ft_cancel_reason, {ft_amount} ft_amount, {ft_price} ft_price, {ft_cancel_reason} ft_cancel_reason,
{ft_order_tag} ft_order_tag {ft_order_tag} ft_order_tag
from {table_back_name} from {table_back_name}
""")) """)
)
def migrate_pairlocks_table( def migrate_pairlocks_table(decl_base, inspector, engine, pairlock_back_name: str, cols: List):
decl_base, inspector, engine,
pairlock_back_name: str, cols: List):
# Schema migration necessary # Schema migration necessary
with engine.begin() as connection: with engine.begin() as connection:
connection.execute(text(f"alter table pairlocks rename to {pairlock_back_name}")) connection.execute(text(f"alter table pairlocks rename to {pairlock_back_name}"))
drop_index_on_table(engine, inspector, pairlock_back_name) drop_index_on_table(engine, inspector, pairlock_back_name)
side = get_column_def(cols, 'side', "'*'") side = get_column_def(cols, "side", "'*'")
# let SQLAlchemy create the schema as required # let SQLAlchemy create the schema as required
decl_base.metadata.create_all(engine) decl_base.metadata.create_all(engine)
# Copy data back - following the correct schema # Copy data back - following the correct schema
with engine.begin() as connection: with engine.begin() as connection:
connection.execute(text(f"""insert into pairlocks connection.execute(
text(f"""insert into pairlocks
(id, pair, side, reason, lock_time, (id, pair, side, reason, lock_time,
lock_end_time, active) lock_end_time, active)
select id, pair, {side} side, reason, lock_time, select id, pair, {side} side, reason, lock_time,
lock_end_time, active lock_end_time, active
from {pairlock_back_name} from {pairlock_back_name}
""")) """)
)
def set_sqlite_to_wal(engine): def set_sqlite_to_wal(engine):
if engine.name == 'sqlite' and str(engine.url) != 'sqlite://': if engine.name == "sqlite" and str(engine.url) != "sqlite://":
# Set Mode to # Set Mode to
with engine.begin() as connection: with engine.begin() as connection:
connection.execute(text("PRAGMA journal_mode=wal")) connection.execute(text("PRAGMA journal_mode=wal"))
@ -274,7 +288,6 @@ def set_sqlite_to_wal(engine):
def fix_old_dry_orders(engine): def fix_old_dry_orders(engine):
with engine.begin() as connection: with engine.begin() as connection:
# Update current dry-run Orders where # Update current dry-run Orders where
# - stoploss order is Open (will be replaced eventually) # - stoploss order is Open (will be replaced eventually)
# 2nd query: # 2nd query:
@ -283,26 +296,28 @@ def fix_old_dry_orders(engine):
# - current Order trade_id not equal to current Trade.id # - current Order trade_id not equal to current Trade.id
# - current Order not stoploss # - current Order not stoploss
stmt = update(Order).where( stmt = (
Order.ft_is_open.is_(True), update(Order)
Order.ft_order_side == 'stoploss', .where(
Order.order_id.like('dry%'), Order.ft_is_open.is_(True),
Order.ft_order_side == "stoploss",
).values(ft_is_open=False) Order.order_id.like("dry%"),
)
.values(ft_is_open=False)
)
connection.execute(stmt) connection.execute(stmt)
# Close dry-run orders for closed trades. # Close dry-run orders for closed trades.
stmt = update(Order).where( stmt = (
Order.ft_is_open.is_(True), update(Order)
Order.ft_trade_id.not_in( .where(
select( Order.ft_is_open.is_(True),
Trade.id Order.ft_trade_id.not_in(select(Trade.id).where(Trade.is_open.is_(True))),
).where(Trade.is_open.is_(True)) Order.ft_order_side != "stoploss",
), Order.order_id.like("dry%"),
Order.ft_order_side != 'stoploss', )
Order.order_id.like('dry%') .values(ft_is_open=False)
)
).values(ft_is_open=False)
connection.execute(stmt) connection.execute(stmt)
@ -312,15 +327,15 @@ def check_migrate(engine, decl_base, previous_tables) -> None:
""" """
inspector = inspect(engine) inspector = inspect(engine)
cols_trades = inspector.get_columns('trades') cols_trades = inspector.get_columns("trades")
cols_orders = inspector.get_columns('orders') cols_orders = inspector.get_columns("orders")
cols_pairlocks = inspector.get_columns('pairlocks') cols_pairlocks = inspector.get_columns("pairlocks")
tabs = get_table_names_for_table(inspector, 'trades') tabs = get_table_names_for_table(inspector, "trades")
table_back_name = get_backup_name(tabs, 'trades_bak') table_back_name = get_backup_name(tabs, "trades_bak")
order_tabs = get_table_names_for_table(inspector, 'orders') order_tabs = get_table_names_for_table(inspector, "orders")
order_table_bak_name = get_backup_name(order_tabs, 'orders_bak') order_table_bak_name = get_backup_name(order_tabs, "orders_bak")
pairlock_tabs = get_table_names_for_table(inspector, 'pairlocks') pairlock_tabs = get_table_names_for_table(inspector, "pairlocks")
pairlock_table_bak_name = get_backup_name(pairlock_tabs, 'pairlocks_bak') pairlock_table_bak_name = get_backup_name(pairlock_tabs, "pairlocks_bak")
# Check if migration necessary # Check if migration necessary
# Migrates both trades and orders table! # Migrates both trades and orders table!
@ -328,27 +343,37 @@ def check_migrate(engine, decl_base, previous_tables) -> None:
# or not has_column(cols_orders, 'funding_fee')): # or not has_column(cols_orders, 'funding_fee')):
migrating = False migrating = False
# if not has_column(cols_trades, 'funding_fee_running'): # if not has_column(cols_trades, 'funding_fee_running'):
if not has_column(cols_orders, 'ft_order_tag'): if not has_column(cols_orders, "ft_order_tag"):
migrating = True migrating = True
logger.info(f"Running database migration for trades - " logger.info(
f"backup: {table_back_name}, {order_table_bak_name}") f"Running database migration for trades - "
f"backup: {table_back_name}, {order_table_bak_name}"
)
migrate_trades_and_orders_table( migrate_trades_and_orders_table(
decl_base, inspector, engine, table_back_name, cols_trades, decl_base,
order_table_bak_name, cols_orders) inspector,
engine,
table_back_name,
cols_trades,
order_table_bak_name,
cols_orders,
)
if not has_column(cols_pairlocks, 'side'): if not has_column(cols_pairlocks, "side"):
migrating = True migrating = True
logger.info(f"Running database migration for pairlocks - " logger.info(
f"backup: {pairlock_table_bak_name}") f"Running database migration for pairlocks - " f"backup: {pairlock_table_bak_name}"
)
migrate_pairlocks_table( migrate_pairlocks_table(
decl_base, inspector, engine, pairlock_table_bak_name, cols_pairlocks decl_base, inspector, engine, pairlock_table_bak_name, cols_pairlocks
) )
if 'orders' not in previous_tables and 'trades' in previous_tables: if "orders" not in previous_tables and "trades" in previous_tables:
raise OperationalException( raise OperationalException(
"Your database seems to be very old. " "Your database seems to be very old. "
"Please update to freqtrade 2022.3 to migrate this database or " "Please update to freqtrade 2022.3 to migrate this database or "
"start with a fresh database.") "start with a fresh database."
)
set_sqlite_to_wal(engine) set_sqlite_to_wal(engine)
fix_old_dry_orders(engine) fix_old_dry_orders(engine)

View File

@ -1,6 +1,7 @@
""" """
This module contains the class to persist trades into SQLite This module contains the class to persist trades into SQLite
""" """
import logging import logging
import threading import threading
from contextvars import ContextVar from contextvars import ContextVar
@ -23,7 +24,7 @@ from freqtrade.persistence.trade_model import Order, Trade
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
REQUEST_ID_CTX_KEY: Final[str] = 'request_id' REQUEST_ID_CTX_KEY: Final[str] = "request_id"
_request_id_ctx_var: ContextVar[Optional[str]] = ContextVar(REQUEST_ID_CTX_KEY, default=None) _request_id_ctx_var: ContextVar[Optional[str]] = ContextVar(REQUEST_ID_CTX_KEY, default=None)
@ -39,7 +40,7 @@ def get_request_or_thread_id() -> Optional[str]:
return id return id
_SQL_DOCS_URL = 'http://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls' _SQL_DOCS_URL = "http://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls"
def init_db(db_url: str) -> None: def init_db(db_url: str) -> None:
@ -52,35 +53,44 @@ def init_db(db_url: str) -> None:
""" """
kwargs: Dict[str, Any] = {} kwargs: Dict[str, Any] = {}
if db_url == 'sqlite:///': if db_url == "sqlite:///":
raise OperationalException( raise OperationalException(
f'Bad db-url {db_url}. For in-memory database, please use `sqlite://`.') f"Bad db-url {db_url}. For in-memory database, please use `sqlite://`."
if db_url == 'sqlite://': )
kwargs.update({ if db_url == "sqlite://":
'poolclass': StaticPool, kwargs.update(
}) {
"poolclass": StaticPool,
}
)
# Take care of thread ownership # Take care of thread ownership
if db_url.startswith('sqlite://'): if db_url.startswith("sqlite://"):
kwargs.update({ kwargs.update(
'connect_args': {'check_same_thread': False}, {
}) "connect_args": {"check_same_thread": False},
}
)
try: try:
engine = create_engine(db_url, future=True, **kwargs) engine = create_engine(db_url, future=True, **kwargs)
except NoSuchModuleError: except NoSuchModuleError:
raise OperationalException(f"Given value for db_url: '{db_url}' " raise OperationalException(
f"is no valid database URL! (See {_SQL_DOCS_URL})") f"Given value for db_url: '{db_url}' "
f"is no valid database URL! (See {_SQL_DOCS_URL})"
)
# https://docs.sqlalchemy.org/en/13/orm/contextual.html#thread-local-scope # https://docs.sqlalchemy.org/en/13/orm/contextual.html#thread-local-scope
# Scoped sessions proxy requests to the appropriate thread-local session. # Scoped sessions proxy requests to the appropriate thread-local session.
# Since we also use fastAPI, we need to make it aware of the request id, too # Since we also use fastAPI, we need to make it aware of the request id, too
Trade.session = scoped_session(sessionmaker( Trade.session = scoped_session(
bind=engine, autoflush=False), scopefunc=get_request_or_thread_id) sessionmaker(bind=engine, autoflush=False), scopefunc=get_request_or_thread_id
)
Order.session = Trade.session Order.session = Trade.session
PairLock.session = Trade.session PairLock.session = Trade.session
_KeyValueStoreModel.session = Trade.session _KeyValueStoreModel.session = Trade.session
_CustomData.session = scoped_session(sessionmaker(bind=engine, autoflush=True), _CustomData.session = scoped_session(
scopefunc=get_request_or_thread_id) sessionmaker(bind=engine, autoflush=True), scopefunc=get_request_or_thread_id
)
previous_tables = inspect(engine).get_table_names() previous_tables = inspect(engine).get_table_names()
ModelBase.metadata.create_all(engine) ModelBase.metadata.create_all(engine)

View File

@ -12,7 +12,8 @@ class PairLock(ModelBase):
""" """
Pair Locks database model. Pair Locks database model.
""" """
__tablename__ = 'pairlocks'
__tablename__ = "pairlocks"
session: ClassVar[SessionType] session: ClassVar[SessionType]
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[int] = mapped_column(primary_key=True)
@ -32,43 +33,48 @@ class PairLock(ModelBase):
lock_time = self.lock_time.strftime(DATETIME_PRINT_FORMAT) lock_time = self.lock_time.strftime(DATETIME_PRINT_FORMAT)
lock_end_time = self.lock_end_time.strftime(DATETIME_PRINT_FORMAT) lock_end_time = self.lock_end_time.strftime(DATETIME_PRINT_FORMAT)
return ( return (
f'PairLock(id={self.id}, pair={self.pair}, side={self.side}, lock_time={lock_time}, ' f"PairLock(id={self.id}, pair={self.pair}, side={self.side}, lock_time={lock_time}, "
f'lock_end_time={lock_end_time}, reason={self.reason}, active={self.active})') f"lock_end_time={lock_end_time}, reason={self.reason}, active={self.active})"
)
@staticmethod @staticmethod
def query_pair_locks( def query_pair_locks(
pair: Optional[str], now: datetime, side: str = '*') -> ScalarResult['PairLock']: pair: Optional[str], now: datetime, side: str = "*"
) -> ScalarResult["PairLock"]:
""" """
Get all currently active locks for this pair Get all currently active locks for this pair
:param pair: Pair to check for. Returns all current locks if pair is empty :param pair: Pair to check for. Returns all current locks if pair is empty
:param now: Datetime object (generated via datetime.now(timezone.utc)). :param now: Datetime object (generated via datetime.now(timezone.utc)).
""" """
filters = [PairLock.lock_end_time > now, filters = [
# Only active locks PairLock.lock_end_time > now,
PairLock.active.is_(True), ] # Only active locks
PairLock.active.is_(True),
]
if pair: if pair:
filters.append(PairLock.pair == pair) filters.append(PairLock.pair == pair)
if side != '*': if side != "*":
filters.append(or_(PairLock.side == side, PairLock.side == '*')) filters.append(or_(PairLock.side == side, PairLock.side == "*"))
else: else:
filters.append(PairLock.side == '*') filters.append(PairLock.side == "*")
return PairLock.session.scalars(select(PairLock).filter(*filters)) return PairLock.session.scalars(select(PairLock).filter(*filters))
@staticmethod @staticmethod
def get_all_locks() -> ScalarResult['PairLock']: def get_all_locks() -> ScalarResult["PairLock"]:
return PairLock.session.scalars(select(PairLock)) return PairLock.session.scalars(select(PairLock))
def to_json(self) -> Dict[str, Any]: def to_json(self) -> Dict[str, Any]:
return { return {
'id': self.id, "id": self.id,
'pair': self.pair, "pair": self.pair,
'lock_time': self.lock_time.strftime(DATETIME_PRINT_FORMAT), "lock_time": self.lock_time.strftime(DATETIME_PRINT_FORMAT),
'lock_timestamp': int(self.lock_time.replace(tzinfo=timezone.utc).timestamp() * 1000), "lock_timestamp": int(self.lock_time.replace(tzinfo=timezone.utc).timestamp() * 1000),
'lock_end_time': self.lock_end_time.strftime(DATETIME_PRINT_FORMAT), "lock_end_time": self.lock_end_time.strftime(DATETIME_PRINT_FORMAT),
'lock_end_timestamp': int(self.lock_end_time.replace(tzinfo=timezone.utc "lock_end_timestamp": int(
).timestamp() * 1000), self.lock_end_time.replace(tzinfo=timezone.utc).timestamp() * 1000
'reason': self.reason, ),
'side': self.side, "reason": self.reason,
'active': self.active, "side": self.side,
"active": self.active,
} }

View File

@ -21,7 +21,7 @@ class PairLocks:
use_db = True use_db = True
locks: List[PairLock] = [] locks: List[PairLock] = []
timeframe: str = '' timeframe: str = ""
@staticmethod @staticmethod
def reset_locks() -> None: def reset_locks() -> None:
@ -32,8 +32,14 @@ class PairLocks:
PairLocks.locks = [] PairLocks.locks = []
@staticmethod @staticmethod
def lock_pair(pair: str, until: datetime, reason: Optional[str] = None, *, def lock_pair(
now: Optional[datetime] = None, side: str = '*') -> PairLock: pair: str,
until: datetime,
reason: Optional[str] = None,
*,
now: Optional[datetime] = None,
side: str = "*",
) -> PairLock:
""" """
Create PairLock from now to "until". Create PairLock from now to "until".
Uses database by default, unless PairLocks.use_db is set to False, Uses database by default, unless PairLocks.use_db is set to False,
@ -50,7 +56,7 @@ class PairLocks:
lock_end_time=timeframe_to_next_date(PairLocks.timeframe, until), lock_end_time=timeframe_to_next_date(PairLocks.timeframe, until),
reason=reason, reason=reason,
side=side, side=side,
active=True active=True,
) )
if PairLocks.use_db: if PairLocks.use_db:
PairLock.session.add(lock) PairLock.session.add(lock)
@ -60,8 +66,9 @@ class PairLocks:
return lock return lock
@staticmethod @staticmethod
def get_pair_locks(pair: Optional[str], now: Optional[datetime] = None, def get_pair_locks(
side: str = '*') -> Sequence[PairLock]: pair: Optional[str], now: Optional[datetime] = None, side: str = "*"
) -> Sequence[PairLock]:
""" """
Get all currently active locks for this pair Get all currently active locks for this pair
:param pair: Pair to check for. Returns all current locks if pair is empty :param pair: Pair to check for. Returns all current locks if pair is empty
@ -74,17 +81,22 @@ class PairLocks:
if PairLocks.use_db: if PairLocks.use_db:
return PairLock.query_pair_locks(pair, now, side).all() return PairLock.query_pair_locks(pair, now, side).all()
else: else:
locks = [lock for lock in PairLocks.locks if ( locks = [
lock.lock_end_time >= now lock
and lock.active is True for lock in PairLocks.locks
and (pair is None or lock.pair == pair) if (
and (lock.side == '*' or lock.side == side) lock.lock_end_time >= now
)] and lock.active is True
and (pair is None or lock.pair == pair)
and (lock.side == "*" or lock.side == side)
)
]
return locks return locks
@staticmethod @staticmethod
def get_pair_longest_lock( def get_pair_longest_lock(
pair: str, now: Optional[datetime] = None, side: str = '*') -> Optional[PairLock]: pair: str, now: Optional[datetime] = None, side: str = "*"
) -> Optional[PairLock]:
""" """
Get the lock that expires the latest for the pair given. Get the lock that expires the latest for the pair given.
""" """
@ -93,7 +105,7 @@ class PairLocks:
return locks[0] if locks else None return locks[0] if locks else None
@staticmethod @staticmethod
def unlock_pair(pair: str, now: Optional[datetime] = None, side: str = '*') -> None: def unlock_pair(pair: str, now: Optional[datetime] = None, side: str = "*") -> None:
""" """
Release all locks for this pair. Release all locks for this pair.
:param pair: Pair to unlock :param pair: Pair to unlock
@ -124,10 +136,11 @@ class PairLocks:
if PairLocks.use_db: if PairLocks.use_db:
# used in live modes # used in live modes
logger.info(f"Releasing all locks with reason '{reason}':") logger.info(f"Releasing all locks with reason '{reason}':")
filters = [PairLock.lock_end_time > now, filters = [
PairLock.active.is_(True), PairLock.lock_end_time > now,
PairLock.reason == reason PairLock.active.is_(True),
] PairLock.reason == reason,
]
locks = PairLock.session.scalars(select(PairLock).filter(*filters)).all() locks = PairLock.session.scalars(select(PairLock).filter(*filters)).all()
for lock in locks: for lock in locks:
logger.info(f"Releasing lock for {lock.pair} with reason '{reason}'.") logger.info(f"Releasing lock for {lock.pair} with reason '{reason}'.")
@ -141,7 +154,7 @@ class PairLocks:
lock.active = False lock.active = False
@staticmethod @staticmethod
def is_global_lock(now: Optional[datetime] = None, side: str = '*') -> bool: def is_global_lock(now: Optional[datetime] = None, side: str = "*") -> bool:
""" """
:param now: Datetime object (generated via datetime.now(timezone.utc)). :param now: Datetime object (generated via datetime.now(timezone.utc)).
defaults to datetime.now(timezone.utc) defaults to datetime.now(timezone.utc)
@ -149,10 +162,10 @@ class PairLocks:
if not now: if not now:
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
return len(PairLocks.get_pair_locks('*', now, side)) > 0 return len(PairLocks.get_pair_locks("*", now, side)) > 0
@staticmethod @staticmethod
def is_pair_locked(pair: str, now: Optional[datetime] = None, side: str = '*') -> bool: def is_pair_locked(pair: str, now: Optional[datetime] = None, side: str = "*") -> bool:
""" """
:param pair: Pair to check for :param pair: Pair to check for
:param now: Datetime object (generated via datetime.now(timezone.utc)). :param now: Datetime object (generated via datetime.now(timezone.utc)).
@ -161,9 +174,8 @@ class PairLocks:
if not now: if not now:
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
return ( return len(PairLocks.get_pair_locks(pair, now, side)) > 0 or PairLocks.is_global_lock(
len(PairLocks.get_pair_locks(pair, now, side)) > 0 now, side
or PairLocks.is_global_lock(now, side)
) )
@staticmethod @staticmethod

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,3 @@
from freqtrade.persistence.custom_data import CustomDataWrapper from freqtrade.persistence.custom_data import CustomDataWrapper
from freqtrade.persistence.pairlock_middleware import PairLocks from freqtrade.persistence.pairlock_middleware import PairLocks
from freqtrade.persistence.trade_model import Trade from freqtrade.persistence.trade_model import Trade
@ -20,13 +19,13 @@ def enable_database_use() -> None:
Cleanup function to restore database usage. Cleanup function to restore database usage.
""" """
PairLocks.use_db = True PairLocks.use_db = True
PairLocks.timeframe = '' PairLocks.timeframe = ""
Trade.use_db = True Trade.use_db = True
CustomDataWrapper.use_db = True CustomDataWrapper.use_db = True
class FtNoDBContext: class FtNoDBContext:
def __init__(self, timeframe: str = ''): def __init__(self, timeframe: str = ""):
self.timeframe = timeframe self.timeframe = timeframe
def __enter__(self): def __enter__(self):