From b7904b8e805b57a7cd6b588c203ee981229b0b5e Mon Sep 17 00:00:00 2001 From: Matthias Date: Mon, 12 Feb 2024 20:14:37 +0100 Subject: [PATCH] Combine custom_data classes to one file --- freqtrade/persistence/__init__.py | 2 +- freqtrade/persistence/custom_data.py | 111 ++++++++++++++++- .../persistence/custom_data_middleware.py | 113 ------------------ freqtrade/persistence/trade_model.py | 13 +- freqtrade/persistence/usedb_context.py | 2 +- 5 files changed, 118 insertions(+), 123 deletions(-) delete mode 100644 freqtrade/persistence/custom_data_middleware.py diff --git a/freqtrade/persistence/__init__.py b/freqtrade/persistence/__init__.py index 5926f2ad3..d5584c22c 100644 --- a/freqtrade/persistence/__init__.py +++ b/freqtrade/persistence/__init__.py @@ -1,6 +1,6 @@ # flake8: noqa: F401 -from freqtrade.persistence.custom_data_middleware import CustomDataWrapper +from freqtrade.persistence.custom_data import CustomDataWrapper from freqtrade.persistence.key_value_store import KeyStoreKeys, KeyValueStore from freqtrade.persistence.models import init_db from freqtrade.persistence.pairlock_middleware import PairLocks diff --git a/freqtrade/persistence/custom_data.py b/freqtrade/persistence/custom_data.py index cfe150967..bf6056278 100644 --- a/freqtrade/persistence/custom_data.py +++ b/freqtrade/persistence/custom_data.py @@ -1,5 +1,7 @@ +import json +import logging from datetime import datetime -from typing import ClassVar, Optional, Sequence +from typing import Any, ClassVar, List, Optional, Sequence from sqlalchemy import DateTime, ForeignKey, Integer, String, Text, UniqueConstraint, select from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -9,6 +11,9 @@ from freqtrade.persistence.base import ModelBase, SessionType from freqtrade.util import dt_now +logger = logging.getLogger(__name__) + + class CustomData(ModelBase): """ CustomData database model @@ -60,3 +65,107 @@ class CustomData(ModelBase): filters.append(CustomData.cd_key.ilike(key)) return CustomData.session.scalars(select(CustomData).filter(*filters)).all() + + +class CustomDataWrapper: + """ + CustomData middleware class + Abstracts the database layer away so it becomes optional - which will be necessary to support + backtesting and hyperopt in the future. + """ + + use_db = True + custom_data: List[CustomData] = [] + unserialized_types = ['bool', 'float', 'int', 'str'] + + @staticmethod + def reset_custom_data() -> None: + """ + Resets all key-value pairs. Only active for backtesting mode. + """ + if not CustomDataWrapper.use_db: + CustomDataWrapper.custom_data = [] + + @staticmethod + def get_custom_data(key: Optional[str] = None, + trade_id: Optional[int] = None) -> CustomData: + if trade_id is None: + trade_id = 0 + + if CustomDataWrapper.use_db: + filtered_custom_data = [] + for data_entry in CustomData.query_cd(trade_id=trade_id, key=key): + if data_entry.cd_type not in CustomDataWrapper.unserialized_types: + data_entry.cd_value = json.loads(data_entry.cd_value) + filtered_custom_data.append(data_entry) + return filtered_custom_data + else: + filtered_custom_data = [ + data_entry for data_entry in CustomDataWrapper.custom_data + if (data_entry.ft_trade_id == trade_id) + ] + if key is not None: + filtered_custom_data = [ + data_entry for data_entry in filtered_custom_data + if (data_entry.cd_key.casefold() == key.casefold()) + ] + return filtered_custom_data + + @staticmethod + def set_custom_data(key: str, value: Any, trade_id: Optional[int] = None) -> None: + + value_type = type(value).__name__ + value_db = None + + if value_type not in CustomDataWrapper.unserialized_types: + try: + value_db = json.dumps(value) + except TypeError as e: + logger.warning(f"could not serialize {key} value due to {e}") + else: + value_db = str(value) + + if trade_id is None: + trade_id = 0 + + custom_data = CustomDataWrapper.get_custom_data(key=key, trade_id=trade_id) + if custom_data: + data_entry = custom_data[0] + data_entry.cd_value = value_db + data_entry.updated_at = dt_now() + else: + data_entry = CustomData( + ft_trade_id=trade_id, + cd_key=key, + cd_type=value_type, + cd_value=value_db, + created_at=dt_now() + ) + + if CustomDataWrapper.use_db and value_db is not None: + data_entry.cd_value = value_db + CustomData.session.add(data_entry) + CustomData.session.commit() + elif not CustomDataWrapper.use_db: + cd_index = -1 + for index, data_entry in enumerate(CustomDataWrapper.custom_data): + if data_entry.ft_trade_id == trade_id and data_entry.cd_key == key: + cd_index = index + break + + if cd_index >= 0: + data_entry.cd_type = value_type + data_entry.cd_value = value_db + data_entry.updated_at = dt_now() + + CustomDataWrapper.custom_data[cd_index] = data_entry + else: + CustomDataWrapper.custom_data.append(data_entry) + + @staticmethod + def get_all_custom_data() -> List[CustomData]: + + if CustomDataWrapper.use_db: + return list(CustomData.query_cd()) + else: + return CustomDataWrapper.custom_data diff --git a/freqtrade/persistence/custom_data_middleware.py b/freqtrade/persistence/custom_data_middleware.py deleted file mode 100644 index acc65606b..000000000 --- a/freqtrade/persistence/custom_data_middleware.py +++ /dev/null @@ -1,113 +0,0 @@ -import json -import logging -from typing import Any, List, Optional - -from freqtrade.persistence.custom_data import CustomData -from freqtrade.util import dt_now - - -logger = logging.getLogger(__name__) - - -class CustomDataWrapper: - """ - CustomData middleware class - Abstracts the database layer away so it becomes optional - which will be necessary to support - backtesting and hyperopt in the future. - """ - - use_db = True - custom_data: List[CustomData] = [] - unserialized_types = ['bool', 'float', 'int', 'str'] - - @staticmethod - def reset_custom_data() -> None: - """ - Resets all key-value pairs. Only active for backtesting mode. - """ - if not CustomDataWrapper.use_db: - CustomDataWrapper.custom_data = [] - - @staticmethod - def get_custom_data(key: Optional[str] = None, - trade_id: Optional[int] = None) -> CustomData: - if trade_id is None: - trade_id = 0 - - if CustomDataWrapper.use_db: - filtered_custom_data = [] - for data_entry in CustomData.query_cd(trade_id=trade_id, key=key): - if data_entry.cd_type not in CustomDataWrapper.unserialized_types: - data_entry.cd_value = json.loads(data_entry.cd_value) - filtered_custom_data.append(data_entry) - return filtered_custom_data - else: - filtered_custom_data = [ - data_entry for data_entry in CustomDataWrapper.custom_data - if (data_entry.ft_trade_id == trade_id) - ] - if key is not None: - filtered_custom_data = [ - data_entry for data_entry in filtered_custom_data - if (data_entry.cd_key.casefold() == key.casefold()) - ] - return filtered_custom_data - - @staticmethod - def set_custom_data(key: str, value: Any, trade_id: Optional[int] = None) -> None: - - value_type = type(value).__name__ - value_db = None - - if value_type not in CustomDataWrapper.unserialized_types: - try: - value_db = json.dumps(value) - except TypeError as e: - logger.warning(f"could not serialize {key} value due to {e}") - else: - value_db = str(value) - - if trade_id is None: - trade_id = 0 - - custom_data = CustomDataWrapper.get_custom_data(key=key, trade_id=trade_id) - if custom_data: - data_entry = custom_data[0] - data_entry.cd_value = value_db - data_entry.updated_at = dt_now() - else: - data_entry = CustomData( - ft_trade_id=trade_id, - cd_key=key, - cd_type=value_type, - cd_value=value_db, - created_at=dt_now() - ) - - if CustomDataWrapper.use_db and value_db is not None: - data_entry.cd_value = value_db - CustomData.session.add(data_entry) - CustomData.session.commit() - elif not CustomDataWrapper.use_db: - cd_index = -1 - for index, data_entry in enumerate(CustomDataWrapper.custom_data): - if data_entry.ft_trade_id == trade_id and data_entry.cd_key == key: - cd_index = index - break - - if cd_index >= 0: - data_entry.cd_type = value_type - data_entry.cd_value = value_db - data_entry.updated_at = dt_now() - - CustomDataWrapper.custom_data[cd_index] = data_entry - else: - CustomDataWrapper.custom_data.append(data_entry) - - @staticmethod - def get_all_custom_data() -> List[CustomData]: - - if CustomDataWrapper.use_db: - return list(CustomData.query_cd()) - else: - return CustomDataWrapper.custom_data diff --git a/freqtrade/persistence/trade_model.py b/freqtrade/persistence/trade_model.py index ea03e2c29..7487c72b3 100644 --- a/freqtrade/persistence/trade_model.py +++ b/freqtrade/persistence/trade_model.py @@ -23,8 +23,7 @@ from freqtrade.exchange import (ROUND_DOWN, ROUND_UP, amount_to_contract_precisi from freqtrade.leverage import interest from freqtrade.misc import safe_value_fallback from freqtrade.persistence.base import ModelBase, SessionType -from freqtrade.persistence.custom_data import CustomData -from freqtrade.persistence.custom_data_middleware import CustomDataWrapper +from freqtrade.persistence.custom_data import CustomData, CustomDataWrapper from freqtrade.util import FtPrecise, dt_from_ts, dt_now, dt_ts @@ -345,7 +344,7 @@ class LocalTrade: id: int = 0 orders: List[Order] = [] - custom_data: List[CustomData] = [] + custom_data: List[_CustomData] = [] exchange: str = '' pair: str = '' @@ -1209,7 +1208,7 @@ class LocalTrade: def set_custom_data(self, key: str, value: Any) -> None: CustomDataWrapper.set_custom_data(key=key, value=value, trade_id=self.id) - def get_custom_data(self, key: Optional[str]) -> List[CustomData]: + def get_custom_data(self, key: Optional[str]) -> List[_CustomData]: return CustomDataWrapper.get_custom_data(key=key, trade_id=self.id) @property @@ -1467,7 +1466,7 @@ class Trade(ModelBase, LocalTrade): orders: Mapped[List[Order]] = relationship( "Order", order_by="Order.id", cascade="all, delete-orphan", lazy="selectin", innerjoin=True) # type: ignore - custom_data: Mapped[List[CustomData]] = relationship( + custom_data: Mapped[List[_CustomData]] = relationship( "CustomData", order_by="CustomData.id", cascade="all, delete-orphan", lazy="raise") # type: ignore @@ -1574,9 +1573,9 @@ class Trade(ModelBase, LocalTrade): Order.session.delete(order) for entry in self.custom_data: - CustomData.session.delete(entry) + _CustomData.session.delete(entry) - CustomData.session.commit() + _CustomData.session.commit() Trade.session.delete(self) Trade.commit() diff --git a/freqtrade/persistence/usedb_context.py b/freqtrade/persistence/usedb_context.py index 193f7021d..732f0b0f8 100644 --- a/freqtrade/persistence/usedb_context.py +++ b/freqtrade/persistence/usedb_context.py @@ -1,5 +1,5 @@ -from freqtrade.persistence.custom_data_middleware import CustomDataWrapper +from freqtrade.persistence.custom_data import CustomDataWrapper from freqtrade.persistence.pairlock_middleware import PairLocks from freqtrade.persistence.trade_model import Trade