ruff format: update persistence

This commit is contained in:
Matthias 2024-05-12 16:48:11 +02:00
parent 5783a44c86
commit cebbe0121e
9 changed files with 807 additions and 665 deletions

View File

@ -1,4 +1,3 @@
from sqlalchemy.orm import DeclarativeBase, Session, scoped_session

View File

@ -23,16 +23,17 @@ class _CustomData(ModelBase):
- One trade can have many metadata entries
- One metadata entry can only be associated with one Trade
"""
__tablename__ = 'trade_custom_data'
__tablename__ = "trade_custom_data"
__allow_unmapped__ = True
session: ClassVar[SessionType]
# Uniqueness should be ensured over pair, order_id
# its likely that order_id is unique per Pair on some exchanges.
__table_args__ = (UniqueConstraint('ft_trade_id', 'cd_key', name="_trade_id_cd_key"),)
__table_args__ = (UniqueConstraint("ft_trade_id", "cd_key", name="_trade_id_cd_key"),)
id = mapped_column(Integer, primary_key=True)
ft_trade_id = mapped_column(Integer, ForeignKey('trades.id'), index=True)
ft_trade_id = mapped_column(Integer, ForeignKey("trades.id"), index=True)
trade = relationship("Trade", back_populates="custom_data")
@ -46,17 +47,22 @@ class _CustomData(ModelBase):
value: Any = None
def __repr__(self):
create_time = (self.created_at.strftime(DATETIME_PRINT_FORMAT)
if self.created_at is not None else None)
update_time = (self.updated_at.strftime(DATETIME_PRINT_FORMAT)
if self.updated_at is not None else None)
return (f'CustomData(id={self.id}, key={self.cd_key}, type={self.cd_type}, ' +
f'value={self.cd_value}, trade_id={self.ft_trade_id}, created={create_time}, ' +
f'updated={update_time})')
create_time = (
self.created_at.strftime(DATETIME_PRINT_FORMAT) if self.created_at is not None else None
)
update_time = (
self.updated_at.strftime(DATETIME_PRINT_FORMAT) if self.updated_at is not None else None
)
return (
f"CustomData(id={self.id}, key={self.cd_key}, type={self.cd_type}, "
+ f"value={self.cd_value}, trade_id={self.ft_trade_id}, created={create_time}, "
+ f"updated={update_time})"
)
@classmethod
def query_cd(cls, key: Optional[str] = None,
trade_id: Optional[int] = None) -> Sequence['_CustomData']:
def query_cd(
cls, key: Optional[str] = None, trade_id: Optional[int] = None
) -> Sequence["_CustomData"]:
"""
Get all CustomData, if trade_id is not specified
return will be for generic values not tied to a trade
@ -80,17 +86,17 @@ class CustomDataWrapper:
use_db = True
custom_data: List[_CustomData] = []
unserialized_types = ['bool', 'float', 'int', 'str']
unserialized_types = ["bool", "float", "int", "str"]
@staticmethod
def _convert_custom_data(data: _CustomData) -> _CustomData:
if data.cd_type in CustomDataWrapper.unserialized_types:
data.value = data.cd_value
if data.cd_type == 'bool':
data.value = data.cd_value.lower() == 'true'
elif data.cd_type == 'int':
if data.cd_type == "bool":
data.value = data.cd_value.lower() == "true"
elif data.cd_type == "int":
data.value = int(data.cd_value)
elif data.cd_type == 'float':
elif data.cd_type == "float":
data.value = float(data.cd_value)
else:
data.value = json.loads(data.cd_value)
@ -111,31 +117,32 @@ class CustomDataWrapper:
@staticmethod
def get_custom_data(*, trade_id: int, key: Optional[str] = None) -> List[_CustomData]:
if CustomDataWrapper.use_db:
filters = [
_CustomData.ft_trade_id == trade_id,
]
if key is not None:
filters.append(_CustomData.cd_key.ilike(key))
filtered_custom_data = _CustomData.session.scalars(select(_CustomData).filter(
*filters)).all()
filtered_custom_data = _CustomData.session.scalars(
select(_CustomData).filter(*filters)
).all()
else:
filtered_custom_data = [
data_entry for data_entry in CustomDataWrapper.custom_data
data_entry
for data_entry in CustomDataWrapper.custom_data
if (data_entry.ft_trade_id == trade_id)
]
if key is not None:
filtered_custom_data = [
data_entry for data_entry in filtered_custom_data
data_entry
for data_entry in filtered_custom_data
if (data_entry.cd_key.casefold() == key.casefold())
]
return [CustomDataWrapper._convert_custom_data(d) for d in filtered_custom_data]
@staticmethod
def set_custom_data(trade_id: int, key: str, value: Any) -> None:
value_type = type(value).__name__
if value_type not in CustomDataWrapper.unserialized_types:

View File

@ -12,22 +12,23 @@ ValueTypes = Union[str, datetime, float, int]
class ValueTypesEnum(str, Enum):
STRING = 'str'
DATETIME = 'datetime'
FLOAT = 'float'
INT = 'int'
STRING = "str"
DATETIME = "datetime"
FLOAT = "float"
INT = "int"
class KeyStoreKeys(str, Enum):
BOT_START_TIME = 'bot_start_time'
STARTUP_TIME = 'startup_time'
BOT_START_TIME = "bot_start_time"
STARTUP_TIME = "startup_time"
class _KeyValueStoreModel(ModelBase):
"""
Pair Locks database model.
"""
__tablename__ = 'KeyValueStore'
__tablename__ = "KeyValueStore"
session: ClassVar[SessionType]
id: Mapped[int] = mapped_column(primary_key=True)
@ -56,8 +57,11 @@ class KeyValueStore:
:param key: Key to store the value for - can be used in get-value to retrieve the key
:param value: Value to store - can be str, datetime, float or int
"""
kv = _KeyValueStoreModel.session.query(_KeyValueStoreModel).filter(
_KeyValueStoreModel.key == key).first()
kv = (
_KeyValueStoreModel.session.query(_KeyValueStoreModel)
.filter(_KeyValueStoreModel.key == key)
.first()
)
if kv is None:
kv = _KeyValueStoreModel(key=key)
if isinstance(value, str):
@ -73,7 +77,7 @@ class KeyValueStore:
kv.value_type = ValueTypesEnum.INT
kv.int_value = value
else:
raise ValueError(f'Unknown value type {kv.value_type}')
raise ValueError(f"Unknown value type {kv.value_type}")
_KeyValueStoreModel.session.add(kv)
_KeyValueStoreModel.session.commit()
@ -83,8 +87,11 @@ class KeyValueStore:
Delete the value for the given key.
:param key: Key to delete the value for
"""
kv = _KeyValueStoreModel.session.query(_KeyValueStoreModel).filter(
_KeyValueStoreModel.key == key).first()
kv = (
_KeyValueStoreModel.session.query(_KeyValueStoreModel)
.filter(_KeyValueStoreModel.key == key)
.first()
)
if kv is not None:
_KeyValueStoreModel.session.delete(kv)
_KeyValueStoreModel.session.commit()
@ -95,8 +102,11 @@ class KeyValueStore:
Get the value for the given key.
:param key: Key to get the value for
"""
kv = _KeyValueStoreModel.session.query(_KeyValueStoreModel).filter(
_KeyValueStoreModel.key == key).first()
kv = (
_KeyValueStoreModel.session.query(_KeyValueStoreModel)
.filter(_KeyValueStoreModel.key == key)
.first()
)
if kv is None:
return None
if kv.value_type == ValueTypesEnum.STRING:
@ -108,7 +118,7 @@ class KeyValueStore:
if kv.value_type == ValueTypesEnum.INT:
return kv.int_value
# This should never happen unless someone messed with the database manually
raise ValueError(f'Unknown value type {kv.value_type}') # pragma: no cover
raise ValueError(f"Unknown value type {kv.value_type}") # pragma: no cover
@staticmethod
def get_string_value(key: KeyStoreKeys) -> Optional[str]:
@ -116,9 +126,14 @@ class KeyValueStore:
Get the value for the given key.
:param key: Key to get the value for
"""
kv = _KeyValueStoreModel.session.query(_KeyValueStoreModel).filter(
_KeyValueStoreModel.key == key,
_KeyValueStoreModel.value_type == ValueTypesEnum.STRING).first()
kv = (
_KeyValueStoreModel.session.query(_KeyValueStoreModel)
.filter(
_KeyValueStoreModel.key == key,
_KeyValueStoreModel.value_type == ValueTypesEnum.STRING,
)
.first()
)
if kv is None:
return None
return kv.string_value
@ -129,9 +144,14 @@ class KeyValueStore:
Get the value for the given key.
:param key: Key to get the value for
"""
kv = _KeyValueStoreModel.session.query(_KeyValueStoreModel).filter(
_KeyValueStoreModel.key == key,
_KeyValueStoreModel.value_type == ValueTypesEnum.DATETIME).first()
kv = (
_KeyValueStoreModel.session.query(_KeyValueStoreModel)
.filter(
_KeyValueStoreModel.key == key,
_KeyValueStoreModel.value_type == ValueTypesEnum.DATETIME,
)
.first()
)
if kv is None or kv.datetime_value is None:
return None
return kv.datetime_value.replace(tzinfo=timezone.utc)
@ -142,9 +162,14 @@ class KeyValueStore:
Get the value for the given key.
:param key: Key to get the value for
"""
kv = _KeyValueStoreModel.session.query(_KeyValueStoreModel).filter(
_KeyValueStoreModel.key == key,
_KeyValueStoreModel.value_type == ValueTypesEnum.FLOAT).first()
kv = (
_KeyValueStoreModel.session.query(_KeyValueStoreModel)
.filter(
_KeyValueStoreModel.key == key,
_KeyValueStoreModel.value_type == ValueTypesEnum.FLOAT,
)
.first()
)
if kv is None:
return None
return kv.float_value
@ -155,9 +180,13 @@ class KeyValueStore:
Get the value for the given key.
:param key: Key to get the value for
"""
kv = _KeyValueStoreModel.session.query(_KeyValueStoreModel).filter(
_KeyValueStoreModel.key == key,
_KeyValueStoreModel.value_type == ValueTypesEnum.INT).first()
kv = (
_KeyValueStoreModel.session.query(_KeyValueStoreModel)
.filter(
_KeyValueStoreModel.key == key, _KeyValueStoreModel.value_type == ValueTypesEnum.INT
)
.first()
)
if kv is None:
return None
return kv.int_value
@ -168,12 +197,13 @@ def set_startup_time():
sets bot_start_time to the first trade open date - or "now" on new databases.
sets startup_time to "now"
"""
st = KeyValueStore.get_value('bot_start_time')
st = KeyValueStore.get_value("bot_start_time")
if st is None:
from freqtrade.persistence import Trade
t = Trade.session.query(Trade).order_by(Trade.open_date.asc()).first()
if t is not None:
KeyValueStore.store_value('bot_start_time', t.open_date_utc)
KeyValueStore.store_value("bot_start_time", t.open_date_utc)
else:
KeyValueStore.store_value('bot_start_time', datetime.now(timezone.utc))
KeyValueStore.store_value('startup_time', datetime.now(timezone.utc))
KeyValueStore.store_value("bot_start_time", datetime.now(timezone.utc))
KeyValueStore.store_value("startup_time", datetime.now(timezone.utc))

View File

@ -25,8 +25,8 @@ def get_column_def(columns: List, column: str, default: str) -> str:
def get_backup_name(tabs: List[str], backup_prefix: str):
table_back_name = backup_prefix
for i, table_back_name in enumerate(tabs):
table_back_name = f'{backup_prefix}{i}'
logger.debug(f'trying {table_back_name}')
table_back_name = f"{backup_prefix}{i}"
logger.debug(f"trying {table_back_name}")
return table_back_name
@ -35,21 +35,22 @@ def get_last_sequence_ids(engine, trade_back_name: str, order_back_name: str):
order_id: Optional[int] = None
trade_id: Optional[int] = None
if engine.name == 'postgresql':
if engine.name == "postgresql":
with engine.begin() as connection:
trade_id = connection.execute(text("select nextval('trades_id_seq')")).fetchone()[0]
order_id = connection.execute(text("select nextval('orders_id_seq')")).fetchone()[0]
with engine.begin() as connection:
connection.execute(text(
f"ALTER SEQUENCE orders_id_seq rename to {order_back_name}_id_seq_bak"))
connection.execute(text(
f"ALTER SEQUENCE trades_id_seq rename to {trade_back_name}_id_seq_bak"))
connection.execute(
text(f"ALTER SEQUENCE orders_id_seq rename to {order_back_name}_id_seq_bak")
)
connection.execute(
text(f"ALTER SEQUENCE trades_id_seq rename to {trade_back_name}_id_seq_bak")
)
return order_id, trade_id
def set_sequence_ids(engine, order_id, trade_id, pairlock_id=None):
if engine.name == 'postgresql':
if engine.name == "postgresql":
with engine.begin() as connection:
if order_id:
connection.execute(text(f"ALTER SEQUENCE orders_id_seq RESTART WITH {order_id}"))
@ -57,84 +58,95 @@ def set_sequence_ids(engine, order_id, trade_id, pairlock_id=None):
connection.execute(text(f"ALTER SEQUENCE trades_id_seq RESTART WITH {trade_id}"))
if pairlock_id:
connection.execute(
text(f"ALTER SEQUENCE pairlocks_id_seq RESTART WITH {pairlock_id}"))
text(f"ALTER SEQUENCE pairlocks_id_seq RESTART WITH {pairlock_id}")
)
def drop_index_on_table(engine, inspector, table_bak_name):
with engine.begin() as connection:
# drop indexes on backup table in new session
for index in inspector.get_indexes(table_bak_name):
if engine.name == 'mysql':
if engine.name == "mysql":
connection.execute(text(f"drop index {index['name']} on {table_bak_name}"))
else:
connection.execute(text(f"drop index {index['name']}"))
def migrate_trades_and_orders_table(
decl_base, inspector, engine,
trade_back_name: str, cols: List,
order_back_name: str, cols_order: List):
base_currency = get_column_def(cols, 'base_currency', 'null')
stake_currency = get_column_def(cols, 'stake_currency', 'null')
fee_open = get_column_def(cols, 'fee_open', 'fee')
fee_open_cost = get_column_def(cols, 'fee_open_cost', 'null')
fee_open_currency = get_column_def(cols, 'fee_open_currency', 'null')
fee_close = get_column_def(cols, 'fee_close', 'fee')
fee_close_cost = get_column_def(cols, 'fee_close_cost', 'null')
fee_close_currency = get_column_def(cols, 'fee_close_currency', 'null')
open_rate_requested = get_column_def(cols, 'open_rate_requested', 'null')
close_rate_requested = get_column_def(cols, 'close_rate_requested', 'null')
stop_loss = get_column_def(cols, 'stop_loss', '0.0')
stop_loss_pct = get_column_def(cols, 'stop_loss_pct', 'null')
initial_stop_loss = get_column_def(cols, 'initial_stop_loss', '0.0')
initial_stop_loss_pct = get_column_def(cols, 'initial_stop_loss_pct', 'null')
decl_base,
inspector,
engine,
trade_back_name: str,
cols: List,
order_back_name: str,
cols_order: List,
):
base_currency = get_column_def(cols, "base_currency", "null")
stake_currency = get_column_def(cols, "stake_currency", "null")
fee_open = get_column_def(cols, "fee_open", "fee")
fee_open_cost = get_column_def(cols, "fee_open_cost", "null")
fee_open_currency = get_column_def(cols, "fee_open_currency", "null")
fee_close = get_column_def(cols, "fee_close", "fee")
fee_close_cost = get_column_def(cols, "fee_close_cost", "null")
fee_close_currency = get_column_def(cols, "fee_close_currency", "null")
open_rate_requested = get_column_def(cols, "open_rate_requested", "null")
close_rate_requested = get_column_def(cols, "close_rate_requested", "null")
stop_loss = get_column_def(cols, "stop_loss", "0.0")
stop_loss_pct = get_column_def(cols, "stop_loss_pct", "null")
initial_stop_loss = get_column_def(cols, "initial_stop_loss", "0.0")
initial_stop_loss_pct = get_column_def(cols, "initial_stop_loss_pct", "null")
is_stop_loss_trailing = get_column_def(
cols, 'is_stop_loss_trailing',
f'coalesce({stop_loss_pct}, 0.0) <> coalesce({initial_stop_loss_pct}, 0.0)')
max_rate = get_column_def(cols, 'max_rate', '0.0')
min_rate = get_column_def(cols, 'min_rate', 'null')
exit_reason = get_column_def(cols, 'sell_reason', get_column_def(cols, 'exit_reason', 'null'))
strategy = get_column_def(cols, 'strategy', 'null')
enter_tag = get_column_def(cols, 'buy_tag', get_column_def(cols, 'enter_tag', 'null'))
realized_profit = get_column_def(cols, 'realized_profit', '0.0')
cols,
"is_stop_loss_trailing",
f"coalesce({stop_loss_pct}, 0.0) <> coalesce({initial_stop_loss_pct}, 0.0)",
)
max_rate = get_column_def(cols, "max_rate", "0.0")
min_rate = get_column_def(cols, "min_rate", "null")
exit_reason = get_column_def(cols, "sell_reason", get_column_def(cols, "exit_reason", "null"))
strategy = get_column_def(cols, "strategy", "null")
enter_tag = get_column_def(cols, "buy_tag", get_column_def(cols, "enter_tag", "null"))
realized_profit = get_column_def(cols, "realized_profit", "0.0")
trading_mode = get_column_def(cols, 'trading_mode', 'null')
trading_mode = get_column_def(cols, "trading_mode", "null")
# Leverage Properties
leverage = get_column_def(cols, 'leverage', '1.0')
liquidation_price = get_column_def(cols, 'liquidation_price',
get_column_def(cols, 'isolated_liq', 'null'))
leverage = get_column_def(cols, "leverage", "1.0")
liquidation_price = get_column_def(
cols, "liquidation_price", get_column_def(cols, "isolated_liq", "null")
)
# sqlite does not support literals for booleans
if engine.name == 'postgresql':
is_short = get_column_def(cols, 'is_short', 'false')
if engine.name == "postgresql":
is_short = get_column_def(cols, "is_short", "false")
else:
is_short = get_column_def(cols, 'is_short', '0')
is_short = get_column_def(cols, "is_short", "0")
# Futures Properties
interest_rate = get_column_def(cols, 'interest_rate', '0.0')
funding_fees = get_column_def(cols, 'funding_fees', '0.0')
funding_fee_running = get_column_def(cols, 'funding_fee_running', 'null')
max_stake_amount = get_column_def(cols, 'max_stake_amount', 'stake_amount')
interest_rate = get_column_def(cols, "interest_rate", "0.0")
funding_fees = get_column_def(cols, "funding_fees", "0.0")
funding_fee_running = get_column_def(cols, "funding_fee_running", "null")
max_stake_amount = get_column_def(cols, "max_stake_amount", "stake_amount")
# If ticker-interval existed use that, else null.
if has_column(cols, 'ticker_interval'):
timeframe = get_column_def(cols, 'timeframe', 'ticker_interval')
if has_column(cols, "ticker_interval"):
timeframe = get_column_def(cols, "timeframe", "ticker_interval")
else:
timeframe = get_column_def(cols, 'timeframe', 'null')
timeframe = get_column_def(cols, "timeframe", "null")
open_trade_value = get_column_def(cols, 'open_trade_value',
f'amount * open_rate * (1 + {fee_open})')
open_trade_value = get_column_def(
cols, "open_trade_value", f"amount * open_rate * (1 + {fee_open})"
)
close_profit_abs = get_column_def(
cols, 'close_profit_abs',
f"(amount * close_rate * (1 - {fee_close})) - {open_trade_value}")
exit_order_status = get_column_def(cols, 'exit_order_status',
get_column_def(cols, 'sell_order_status', 'null'))
amount_requested = get_column_def(cols, 'amount_requested', 'amount')
cols, "close_profit_abs", f"(amount * close_rate * (1 - {fee_close})) - {open_trade_value}"
)
exit_order_status = get_column_def(
cols, "exit_order_status", get_column_def(cols, "sell_order_status", "null")
)
amount_requested = get_column_def(cols, "amount_requested", "amount")
amount_precision = get_column_def(cols, 'amount_precision', 'null')
price_precision = get_column_def(cols, 'price_precision', 'null')
precision_mode = get_column_def(cols, 'precision_mode', 'null')
contract_size = get_column_def(cols, 'contract_size', 'null')
amount_precision = get_column_def(cols, "amount_precision", "null")
price_precision = get_column_def(cols, "price_precision", "null")
precision_mode = get_column_def(cols, "precision_mode", "null")
contract_size = get_column_def(cols, "contract_size", "null")
# Schema migration necessary
with engine.begin() as connection:
@ -151,7 +163,8 @@ def migrate_trades_and_orders_table(
# Copy data back - following the correct schema
with engine.begin() as connection:
connection.execute(text(f"""insert into trades
connection.execute(
text(f"""insert into trades
(id, exchange, pair, base_currency, stake_currency, is_open,
fee_open, fee_open_cost, fee_open_currency,
fee_close, fee_close_cost, fee_close_currency, open_rate,
@ -196,7 +209,8 @@ def migrate_trades_and_orders_table(
{precision_mode} precision_mode, {contract_size} contract_size,
{max_stake_amount} max_stake_amount
from {trade_back_name}
"""))
""")
)
migrate_orders_table(engine, order_back_name, cols_order)
set_sequence_ids(engine, order_id, trade_id)
@ -212,19 +226,19 @@ def drop_orders_table(engine, table_back_name: str):
def migrate_orders_table(engine, table_back_name: str, cols_order: List):
ft_fee_base = get_column_def(cols_order, 'ft_fee_base', 'null')
average = get_column_def(cols_order, 'average', 'null')
stop_price = get_column_def(cols_order, 'stop_price', 'null')
funding_fee = get_column_def(cols_order, 'funding_fee', '0.0')
ft_amount = get_column_def(cols_order, 'ft_amount', 'coalesce(amount, 0.0)')
ft_price = get_column_def(cols_order, 'ft_price', 'coalesce(price, 0.0)')
ft_cancel_reason = get_column_def(cols_order, 'ft_cancel_reason', 'null')
ft_order_tag = get_column_def(cols_order, 'ft_order_tag', 'null')
ft_fee_base = get_column_def(cols_order, "ft_fee_base", "null")
average = get_column_def(cols_order, "average", "null")
stop_price = get_column_def(cols_order, "stop_price", "null")
funding_fee = get_column_def(cols_order, "funding_fee", "0.0")
ft_amount = get_column_def(cols_order, "ft_amount", "coalesce(amount, 0.0)")
ft_price = get_column_def(cols_order, "ft_price", "coalesce(price, 0.0)")
ft_cancel_reason = get_column_def(cols_order, "ft_cancel_reason", "null")
ft_order_tag = get_column_def(cols_order, "ft_order_tag", "null")
# sqlite does not support literals for booleans
with engine.begin() as connection:
connection.execute(text(f"""
connection.execute(
text(f"""
insert into orders (id, ft_trade_id, ft_order_side, ft_pair, ft_is_open, order_id,
status, symbol, order_type, side, price, amount, filled, average, remaining, cost,
stop_price, order_date, order_filled_date, order_update_date, ft_fee_base, funding_fee,
@ -237,36 +251,36 @@ def migrate_orders_table(engine, table_back_name: str, cols_order: List):
{ft_amount} ft_amount, {ft_price} ft_price, {ft_cancel_reason} ft_cancel_reason,
{ft_order_tag} ft_order_tag
from {table_back_name}
"""))
""")
)
def migrate_pairlocks_table(
decl_base, inspector, engine,
pairlock_back_name: str, cols: List):
def migrate_pairlocks_table(decl_base, inspector, engine, pairlock_back_name: str, cols: List):
# Schema migration necessary
with engine.begin() as connection:
connection.execute(text(f"alter table pairlocks rename to {pairlock_back_name}"))
drop_index_on_table(engine, inspector, pairlock_back_name)
side = get_column_def(cols, 'side', "'*'")
side = get_column_def(cols, "side", "'*'")
# let SQLAlchemy create the schema as required
decl_base.metadata.create_all(engine)
# Copy data back - following the correct schema
with engine.begin() as connection:
connection.execute(text(f"""insert into pairlocks
connection.execute(
text(f"""insert into pairlocks
(id, pair, side, reason, lock_time,
lock_end_time, active)
select id, pair, {side} side, reason, lock_time,
lock_end_time, active
from {pairlock_back_name}
"""))
""")
)
def set_sqlite_to_wal(engine):
if engine.name == 'sqlite' and str(engine.url) != 'sqlite://':
if engine.name == "sqlite" and str(engine.url) != "sqlite://":
# Set Mode to
with engine.begin() as connection:
connection.execute(text("PRAGMA journal_mode=wal"))
@ -274,7 +288,6 @@ def set_sqlite_to_wal(engine):
def fix_old_dry_orders(engine):
with engine.begin() as connection:
# Update current dry-run Orders where
# - stoploss order is Open (will be replaced eventually)
# 2nd query:
@ -283,26 +296,28 @@ def fix_old_dry_orders(engine):
# - current Order trade_id not equal to current Trade.id
# - current Order not stoploss
stmt = update(Order).where(
Order.ft_is_open.is_(True),
Order.ft_order_side == 'stoploss',
Order.order_id.like('dry%'),
).values(ft_is_open=False)
stmt = (
update(Order)
.where(
Order.ft_is_open.is_(True),
Order.ft_order_side == "stoploss",
Order.order_id.like("dry%"),
)
.values(ft_is_open=False)
)
connection.execute(stmt)
# Close dry-run orders for closed trades.
stmt = update(Order).where(
Order.ft_is_open.is_(True),
Order.ft_trade_id.not_in(
select(
Trade.id
).where(Trade.is_open.is_(True))
),
Order.ft_order_side != 'stoploss',
Order.order_id.like('dry%')
).values(ft_is_open=False)
stmt = (
update(Order)
.where(
Order.ft_is_open.is_(True),
Order.ft_trade_id.not_in(select(Trade.id).where(Trade.is_open.is_(True))),
Order.ft_order_side != "stoploss",
Order.order_id.like("dry%"),
)
.values(ft_is_open=False)
)
connection.execute(stmt)
@ -312,15 +327,15 @@ def check_migrate(engine, decl_base, previous_tables) -> None:
"""
inspector = inspect(engine)
cols_trades = inspector.get_columns('trades')
cols_orders = inspector.get_columns('orders')
cols_pairlocks = inspector.get_columns('pairlocks')
tabs = get_table_names_for_table(inspector, 'trades')
table_back_name = get_backup_name(tabs, 'trades_bak')
order_tabs = get_table_names_for_table(inspector, 'orders')
order_table_bak_name = get_backup_name(order_tabs, 'orders_bak')
pairlock_tabs = get_table_names_for_table(inspector, 'pairlocks')
pairlock_table_bak_name = get_backup_name(pairlock_tabs, 'pairlocks_bak')
cols_trades = inspector.get_columns("trades")
cols_orders = inspector.get_columns("orders")
cols_pairlocks = inspector.get_columns("pairlocks")
tabs = get_table_names_for_table(inspector, "trades")
table_back_name = get_backup_name(tabs, "trades_bak")
order_tabs = get_table_names_for_table(inspector, "orders")
order_table_bak_name = get_backup_name(order_tabs, "orders_bak")
pairlock_tabs = get_table_names_for_table(inspector, "pairlocks")
pairlock_table_bak_name = get_backup_name(pairlock_tabs, "pairlocks_bak")
# Check if migration necessary
# Migrates both trades and orders table!
@ -328,27 +343,37 @@ def check_migrate(engine, decl_base, previous_tables) -> None:
# or not has_column(cols_orders, 'funding_fee')):
migrating = False
# if not has_column(cols_trades, 'funding_fee_running'):
if not has_column(cols_orders, 'ft_order_tag'):
if not has_column(cols_orders, "ft_order_tag"):
migrating = True
logger.info(f"Running database migration for trades - "
f"backup: {table_back_name}, {order_table_bak_name}")
logger.info(
f"Running database migration for trades - "
f"backup: {table_back_name}, {order_table_bak_name}"
)
migrate_trades_and_orders_table(
decl_base, inspector, engine, table_back_name, cols_trades,
order_table_bak_name, cols_orders)
decl_base,
inspector,
engine,
table_back_name,
cols_trades,
order_table_bak_name,
cols_orders,
)
if not has_column(cols_pairlocks, 'side'):
if not has_column(cols_pairlocks, "side"):
migrating = True
logger.info(f"Running database migration for pairlocks - "
f"backup: {pairlock_table_bak_name}")
logger.info(
f"Running database migration for pairlocks - " f"backup: {pairlock_table_bak_name}"
)
migrate_pairlocks_table(
decl_base, inspector, engine, pairlock_table_bak_name, cols_pairlocks
)
if 'orders' not in previous_tables and 'trades' in previous_tables:
if "orders" not in previous_tables and "trades" in previous_tables:
raise OperationalException(
"Your database seems to be very old. "
"Please update to freqtrade 2022.3 to migrate this database or "
"start with a fresh database.")
"start with a fresh database."
)
set_sqlite_to_wal(engine)
fix_old_dry_orders(engine)

View File

@ -1,6 +1,7 @@
"""
This module contains the class to persist trades into SQLite
"""
import logging
import threading
from contextvars import ContextVar
@ -23,7 +24,7 @@ from freqtrade.persistence.trade_model import Order, Trade
logger = logging.getLogger(__name__)
REQUEST_ID_CTX_KEY: Final[str] = 'request_id'
REQUEST_ID_CTX_KEY: Final[str] = "request_id"
_request_id_ctx_var: ContextVar[Optional[str]] = ContextVar(REQUEST_ID_CTX_KEY, default=None)
@ -39,7 +40,7 @@ def get_request_or_thread_id() -> Optional[str]:
return id
_SQL_DOCS_URL = 'http://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls'
_SQL_DOCS_URL = "http://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls"
def init_db(db_url: str) -> None:
@ -52,35 +53,44 @@ def init_db(db_url: str) -> None:
"""
kwargs: Dict[str, Any] = {}
if db_url == 'sqlite:///':
if db_url == "sqlite:///":
raise OperationalException(
f'Bad db-url {db_url}. For in-memory database, please use `sqlite://`.')
if db_url == 'sqlite://':
kwargs.update({
'poolclass': StaticPool,
})
f"Bad db-url {db_url}. For in-memory database, please use `sqlite://`."
)
if db_url == "sqlite://":
kwargs.update(
{
"poolclass": StaticPool,
}
)
# Take care of thread ownership
if db_url.startswith('sqlite://'):
kwargs.update({
'connect_args': {'check_same_thread': False},
})
if db_url.startswith("sqlite://"):
kwargs.update(
{
"connect_args": {"check_same_thread": False},
}
)
try:
engine = create_engine(db_url, future=True, **kwargs)
except NoSuchModuleError:
raise OperationalException(f"Given value for db_url: '{db_url}' "
f"is no valid database URL! (See {_SQL_DOCS_URL})")
raise OperationalException(
f"Given value for db_url: '{db_url}' "
f"is no valid database URL! (See {_SQL_DOCS_URL})"
)
# https://docs.sqlalchemy.org/en/13/orm/contextual.html#thread-local-scope
# Scoped sessions proxy requests to the appropriate thread-local session.
# Since we also use fastAPI, we need to make it aware of the request id, too
Trade.session = scoped_session(sessionmaker(
bind=engine, autoflush=False), scopefunc=get_request_or_thread_id)
Trade.session = scoped_session(
sessionmaker(bind=engine, autoflush=False), scopefunc=get_request_or_thread_id
)
Order.session = Trade.session
PairLock.session = Trade.session
_KeyValueStoreModel.session = Trade.session
_CustomData.session = scoped_session(sessionmaker(bind=engine, autoflush=True),
scopefunc=get_request_or_thread_id)
_CustomData.session = scoped_session(
sessionmaker(bind=engine, autoflush=True), scopefunc=get_request_or_thread_id
)
previous_tables = inspect(engine).get_table_names()
ModelBase.metadata.create_all(engine)

View File

@ -12,7 +12,8 @@ class PairLock(ModelBase):
"""
Pair Locks database model.
"""
__tablename__ = 'pairlocks'
__tablename__ = "pairlocks"
session: ClassVar[SessionType]
id: Mapped[int] = mapped_column(primary_key=True)
@ -32,43 +33,48 @@ class PairLock(ModelBase):
lock_time = self.lock_time.strftime(DATETIME_PRINT_FORMAT)
lock_end_time = self.lock_end_time.strftime(DATETIME_PRINT_FORMAT)
return (
f'PairLock(id={self.id}, pair={self.pair}, side={self.side}, lock_time={lock_time}, '
f'lock_end_time={lock_end_time}, reason={self.reason}, active={self.active})')
f"PairLock(id={self.id}, pair={self.pair}, side={self.side}, lock_time={lock_time}, "
f"lock_end_time={lock_end_time}, reason={self.reason}, active={self.active})"
)
@staticmethod
def query_pair_locks(
pair: Optional[str], now: datetime, side: str = '*') -> ScalarResult['PairLock']:
pair: Optional[str], now: datetime, side: str = "*"
) -> ScalarResult["PairLock"]:
"""
Get all currently active locks for this pair
:param pair: Pair to check for. Returns all current locks if pair is empty
:param now: Datetime object (generated via datetime.now(timezone.utc)).
"""
filters = [PairLock.lock_end_time > now,
# Only active locks
PairLock.active.is_(True), ]
filters = [
PairLock.lock_end_time > now,
# Only active locks
PairLock.active.is_(True),
]
if pair:
filters.append(PairLock.pair == pair)
if side != '*':
filters.append(or_(PairLock.side == side, PairLock.side == '*'))
if side != "*":
filters.append(or_(PairLock.side == side, PairLock.side == "*"))
else:
filters.append(PairLock.side == '*')
filters.append(PairLock.side == "*")
return PairLock.session.scalars(select(PairLock).filter(*filters))
@staticmethod
def get_all_locks() -> ScalarResult['PairLock']:
def get_all_locks() -> ScalarResult["PairLock"]:
return PairLock.session.scalars(select(PairLock))
def to_json(self) -> Dict[str, Any]:
return {
'id': self.id,
'pair': self.pair,
'lock_time': self.lock_time.strftime(DATETIME_PRINT_FORMAT),
'lock_timestamp': int(self.lock_time.replace(tzinfo=timezone.utc).timestamp() * 1000),
'lock_end_time': self.lock_end_time.strftime(DATETIME_PRINT_FORMAT),
'lock_end_timestamp': int(self.lock_end_time.replace(tzinfo=timezone.utc
).timestamp() * 1000),
'reason': self.reason,
'side': self.side,
'active': self.active,
"id": self.id,
"pair": self.pair,
"lock_time": self.lock_time.strftime(DATETIME_PRINT_FORMAT),
"lock_timestamp": int(self.lock_time.replace(tzinfo=timezone.utc).timestamp() * 1000),
"lock_end_time": self.lock_end_time.strftime(DATETIME_PRINT_FORMAT),
"lock_end_timestamp": int(
self.lock_end_time.replace(tzinfo=timezone.utc).timestamp() * 1000
),
"reason": self.reason,
"side": self.side,
"active": self.active,
}

View File

@ -21,7 +21,7 @@ class PairLocks:
use_db = True
locks: List[PairLock] = []
timeframe: str = ''
timeframe: str = ""
@staticmethod
def reset_locks() -> None:
@ -32,8 +32,14 @@ class PairLocks:
PairLocks.locks = []
@staticmethod
def lock_pair(pair: str, until: datetime, reason: Optional[str] = None, *,
now: Optional[datetime] = None, side: str = '*') -> PairLock:
def lock_pair(
pair: str,
until: datetime,
reason: Optional[str] = None,
*,
now: Optional[datetime] = None,
side: str = "*",
) -> PairLock:
"""
Create PairLock from now to "until".
Uses database by default, unless PairLocks.use_db is set to False,
@ -50,7 +56,7 @@ class PairLocks:
lock_end_time=timeframe_to_next_date(PairLocks.timeframe, until),
reason=reason,
side=side,
active=True
active=True,
)
if PairLocks.use_db:
PairLock.session.add(lock)
@ -60,8 +66,9 @@ class PairLocks:
return lock
@staticmethod
def get_pair_locks(pair: Optional[str], now: Optional[datetime] = None,
side: str = '*') -> Sequence[PairLock]:
def get_pair_locks(
pair: Optional[str], now: Optional[datetime] = None, side: str = "*"
) -> Sequence[PairLock]:
"""
Get all currently active locks for this pair
:param pair: Pair to check for. Returns all current locks if pair is empty
@ -74,17 +81,22 @@ class PairLocks:
if PairLocks.use_db:
return PairLock.query_pair_locks(pair, now, side).all()
else:
locks = [lock for lock in PairLocks.locks if (
lock.lock_end_time >= now
and lock.active is True
and (pair is None or lock.pair == pair)
and (lock.side == '*' or lock.side == side)
)]
locks = [
lock
for lock in PairLocks.locks
if (
lock.lock_end_time >= now
and lock.active is True
and (pair is None or lock.pair == pair)
and (lock.side == "*" or lock.side == side)
)
]
return locks
@staticmethod
def get_pair_longest_lock(
pair: str, now: Optional[datetime] = None, side: str = '*') -> Optional[PairLock]:
pair: str, now: Optional[datetime] = None, side: str = "*"
) -> Optional[PairLock]:
"""
Get the lock that expires the latest for the pair given.
"""
@ -93,7 +105,7 @@ class PairLocks:
return locks[0] if locks else None
@staticmethod
def unlock_pair(pair: str, now: Optional[datetime] = None, side: str = '*') -> None:
def unlock_pair(pair: str, now: Optional[datetime] = None, side: str = "*") -> None:
"""
Release all locks for this pair.
:param pair: Pair to unlock
@ -124,10 +136,11 @@ class PairLocks:
if PairLocks.use_db:
# used in live modes
logger.info(f"Releasing all locks with reason '{reason}':")
filters = [PairLock.lock_end_time > now,
PairLock.active.is_(True),
PairLock.reason == reason
]
filters = [
PairLock.lock_end_time > now,
PairLock.active.is_(True),
PairLock.reason == reason,
]
locks = PairLock.session.scalars(select(PairLock).filter(*filters)).all()
for lock in locks:
logger.info(f"Releasing lock for {lock.pair} with reason '{reason}'.")
@ -141,7 +154,7 @@ class PairLocks:
lock.active = False
@staticmethod
def is_global_lock(now: Optional[datetime] = None, side: str = '*') -> bool:
def is_global_lock(now: Optional[datetime] = None, side: str = "*") -> bool:
"""
:param now: Datetime object (generated via datetime.now(timezone.utc)).
defaults to datetime.now(timezone.utc)
@ -149,10 +162,10 @@ class PairLocks:
if not now:
now = datetime.now(timezone.utc)
return len(PairLocks.get_pair_locks('*', now, side)) > 0
return len(PairLocks.get_pair_locks("*", now, side)) > 0
@staticmethod
def is_pair_locked(pair: str, now: Optional[datetime] = None, side: str = '*') -> bool:
def is_pair_locked(pair: str, now: Optional[datetime] = None, side: str = "*") -> bool:
"""
:param pair: Pair to check for
:param now: Datetime object (generated via datetime.now(timezone.utc)).
@ -161,9 +174,8 @@ class PairLocks:
if not now:
now = datetime.now(timezone.utc)
return (
len(PairLocks.get_pair_locks(pair, now, side)) > 0
or PairLocks.is_global_lock(now, side)
return len(PairLocks.get_pair_locks(pair, now, side)) > 0 or PairLocks.is_global_lock(
now, side
)
@staticmethod

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,3 @@
from freqtrade.persistence.custom_data import CustomDataWrapper
from freqtrade.persistence.pairlock_middleware import PairLocks
from freqtrade.persistence.trade_model import Trade
@ -20,13 +19,13 @@ def enable_database_use() -> None:
Cleanup function to restore database usage.
"""
PairLocks.use_db = True
PairLocks.timeframe = ''
PairLocks.timeframe = ""
Trade.use_db = True
CustomDataWrapper.use_db = True
class FtNoDBContext:
def __init__(self, timeframe: str = ''):
def __init__(self, timeframe: str = ""):
self.timeframe = timeframe
def __enter__(self):