Combine custom_data classes to one file

This commit is contained in:
Matthias 2024-02-12 20:14:37 +01:00
parent 7fd70b82fa
commit b7904b8e80
5 changed files with 118 additions and 123 deletions

View File

@ -1,6 +1,6 @@
# flake8: noqa: F401
from freqtrade.persistence.custom_data_middleware import CustomDataWrapper
from freqtrade.persistence.custom_data import CustomDataWrapper
from freqtrade.persistence.key_value_store import KeyStoreKeys, KeyValueStore
from freqtrade.persistence.models import init_db
from freqtrade.persistence.pairlock_middleware import PairLocks

View File

@ -1,5 +1,7 @@
import json
import logging
from datetime import datetime
from typing import ClassVar, Optional, Sequence
from typing import Any, ClassVar, List, Optional, Sequence
from sqlalchemy import DateTime, ForeignKey, Integer, String, Text, UniqueConstraint, select
from sqlalchemy.orm import Mapped, mapped_column, relationship
@ -9,6 +11,9 @@ from freqtrade.persistence.base import ModelBase, SessionType
from freqtrade.util import dt_now
logger = logging.getLogger(__name__)
class CustomData(ModelBase):
"""
CustomData database model
@ -60,3 +65,107 @@ class CustomData(ModelBase):
filters.append(CustomData.cd_key.ilike(key))
return CustomData.session.scalars(select(CustomData).filter(*filters)).all()
class CustomDataWrapper:
"""
CustomData middleware class
Abstracts the database layer away so it becomes optional - which will be necessary to support
backtesting and hyperopt in the future.
"""
use_db = True
custom_data: List[CustomData] = []
unserialized_types = ['bool', 'float', 'int', 'str']
@staticmethod
def reset_custom_data() -> None:
"""
Resets all key-value pairs. Only active for backtesting mode.
"""
if not CustomDataWrapper.use_db:
CustomDataWrapper.custom_data = []
@staticmethod
def get_custom_data(key: Optional[str] = None,
trade_id: Optional[int] = None) -> CustomData:
if trade_id is None:
trade_id = 0
if CustomDataWrapper.use_db:
filtered_custom_data = []
for data_entry in CustomData.query_cd(trade_id=trade_id, key=key):
if data_entry.cd_type not in CustomDataWrapper.unserialized_types:
data_entry.cd_value = json.loads(data_entry.cd_value)
filtered_custom_data.append(data_entry)
return filtered_custom_data
else:
filtered_custom_data = [
data_entry for data_entry in CustomDataWrapper.custom_data
if (data_entry.ft_trade_id == trade_id)
]
if key is not None:
filtered_custom_data = [
data_entry for data_entry in filtered_custom_data
if (data_entry.cd_key.casefold() == key.casefold())
]
return filtered_custom_data
@staticmethod
def set_custom_data(key: str, value: Any, trade_id: Optional[int] = None) -> None:
value_type = type(value).__name__
value_db = None
if value_type not in CustomDataWrapper.unserialized_types:
try:
value_db = json.dumps(value)
except TypeError as e:
logger.warning(f"could not serialize {key} value due to {e}")
else:
value_db = str(value)
if trade_id is None:
trade_id = 0
custom_data = CustomDataWrapper.get_custom_data(key=key, trade_id=trade_id)
if custom_data:
data_entry = custom_data[0]
data_entry.cd_value = value_db
data_entry.updated_at = dt_now()
else:
data_entry = CustomData(
ft_trade_id=trade_id,
cd_key=key,
cd_type=value_type,
cd_value=value_db,
created_at=dt_now()
)
if CustomDataWrapper.use_db and value_db is not None:
data_entry.cd_value = value_db
CustomData.session.add(data_entry)
CustomData.session.commit()
elif not CustomDataWrapper.use_db:
cd_index = -1
for index, data_entry in enumerate(CustomDataWrapper.custom_data):
if data_entry.ft_trade_id == trade_id and data_entry.cd_key == key:
cd_index = index
break
if cd_index >= 0:
data_entry.cd_type = value_type
data_entry.cd_value = value_db
data_entry.updated_at = dt_now()
CustomDataWrapper.custom_data[cd_index] = data_entry
else:
CustomDataWrapper.custom_data.append(data_entry)
@staticmethod
def get_all_custom_data() -> List[CustomData]:
if CustomDataWrapper.use_db:
return list(CustomData.query_cd())
else:
return CustomDataWrapper.custom_data

View File

@ -1,113 +0,0 @@
import json
import logging
from typing import Any, List, Optional
from freqtrade.persistence.custom_data import CustomData
from freqtrade.util import dt_now
logger = logging.getLogger(__name__)
class CustomDataWrapper:
"""
CustomData middleware class
Abstracts the database layer away so it becomes optional - which will be necessary to support
backtesting and hyperopt in the future.
"""
use_db = True
custom_data: List[CustomData] = []
unserialized_types = ['bool', 'float', 'int', 'str']
@staticmethod
def reset_custom_data() -> None:
"""
Resets all key-value pairs. Only active for backtesting mode.
"""
if not CustomDataWrapper.use_db:
CustomDataWrapper.custom_data = []
@staticmethod
def get_custom_data(key: Optional[str] = None,
trade_id: Optional[int] = None) -> CustomData:
if trade_id is None:
trade_id = 0
if CustomDataWrapper.use_db:
filtered_custom_data = []
for data_entry in CustomData.query_cd(trade_id=trade_id, key=key):
if data_entry.cd_type not in CustomDataWrapper.unserialized_types:
data_entry.cd_value = json.loads(data_entry.cd_value)
filtered_custom_data.append(data_entry)
return filtered_custom_data
else:
filtered_custom_data = [
data_entry for data_entry in CustomDataWrapper.custom_data
if (data_entry.ft_trade_id == trade_id)
]
if key is not None:
filtered_custom_data = [
data_entry for data_entry in filtered_custom_data
if (data_entry.cd_key.casefold() == key.casefold())
]
return filtered_custom_data
@staticmethod
def set_custom_data(key: str, value: Any, trade_id: Optional[int] = None) -> None:
value_type = type(value).__name__
value_db = None
if value_type not in CustomDataWrapper.unserialized_types:
try:
value_db = json.dumps(value)
except TypeError as e:
logger.warning(f"could not serialize {key} value due to {e}")
else:
value_db = str(value)
if trade_id is None:
trade_id = 0
custom_data = CustomDataWrapper.get_custom_data(key=key, trade_id=trade_id)
if custom_data:
data_entry = custom_data[0]
data_entry.cd_value = value_db
data_entry.updated_at = dt_now()
else:
data_entry = CustomData(
ft_trade_id=trade_id,
cd_key=key,
cd_type=value_type,
cd_value=value_db,
created_at=dt_now()
)
if CustomDataWrapper.use_db and value_db is not None:
data_entry.cd_value = value_db
CustomData.session.add(data_entry)
CustomData.session.commit()
elif not CustomDataWrapper.use_db:
cd_index = -1
for index, data_entry in enumerate(CustomDataWrapper.custom_data):
if data_entry.ft_trade_id == trade_id and data_entry.cd_key == key:
cd_index = index
break
if cd_index >= 0:
data_entry.cd_type = value_type
data_entry.cd_value = value_db
data_entry.updated_at = dt_now()
CustomDataWrapper.custom_data[cd_index] = data_entry
else:
CustomDataWrapper.custom_data.append(data_entry)
@staticmethod
def get_all_custom_data() -> List[CustomData]:
if CustomDataWrapper.use_db:
return list(CustomData.query_cd())
else:
return CustomDataWrapper.custom_data

View File

@ -23,8 +23,7 @@ from freqtrade.exchange import (ROUND_DOWN, ROUND_UP, amount_to_contract_precisi
from freqtrade.leverage import interest
from freqtrade.misc import safe_value_fallback
from freqtrade.persistence.base import ModelBase, SessionType
from freqtrade.persistence.custom_data import CustomData
from freqtrade.persistence.custom_data_middleware import CustomDataWrapper
from freqtrade.persistence.custom_data import CustomData, CustomDataWrapper
from freqtrade.util import FtPrecise, dt_from_ts, dt_now, dt_ts
@ -345,7 +344,7 @@ class LocalTrade:
id: int = 0
orders: List[Order] = []
custom_data: List[CustomData] = []
custom_data: List[_CustomData] = []
exchange: str = ''
pair: str = ''
@ -1209,7 +1208,7 @@ class LocalTrade:
def set_custom_data(self, key: str, value: Any) -> None:
CustomDataWrapper.set_custom_data(key=key, value=value, trade_id=self.id)
def get_custom_data(self, key: Optional[str]) -> List[CustomData]:
def get_custom_data(self, key: Optional[str]) -> List[_CustomData]:
return CustomDataWrapper.get_custom_data(key=key, trade_id=self.id)
@property
@ -1467,7 +1466,7 @@ class Trade(ModelBase, LocalTrade):
orders: Mapped[List[Order]] = relationship(
"Order", order_by="Order.id", cascade="all, delete-orphan", lazy="selectin",
innerjoin=True) # type: ignore
custom_data: Mapped[List[CustomData]] = relationship(
custom_data: Mapped[List[_CustomData]] = relationship(
"CustomData", order_by="CustomData.id", cascade="all, delete-orphan",
lazy="raise") # type: ignore
@ -1574,9 +1573,9 @@ class Trade(ModelBase, LocalTrade):
Order.session.delete(order)
for entry in self.custom_data:
CustomData.session.delete(entry)
_CustomData.session.delete(entry)
CustomData.session.commit()
_CustomData.session.commit()
Trade.session.delete(self)
Trade.commit()

View File

@ -1,5 +1,5 @@
from freqtrade.persistence.custom_data_middleware import CustomDataWrapper
from freqtrade.persistence.custom_data import CustomDataWrapper
from freqtrade.persistence.pairlock_middleware import PairLocks
from freqtrade.persistence.trade_model import Trade