Extract get_trades function

This commit is contained in:
Matthias 2019-10-29 15:01:10 +01:00
parent 01efebc42f
commit 26a5800a7f
2 changed files with 24 additions and 9 deletions

View File

@ -106,7 +106,7 @@ def load_trades_from_db(db_url: str) -> pd.DataFrame:
t.stop_loss, t.initial_stop_loss, t.stop_loss, t.initial_stop_loss,
t.strategy, t.ticker_interval t.strategy, t.ticker_interval
) )
for t in Trade.query.all()], for t in Trade.get_trades().all()],
columns=columns) columns=columns)
return trades return trades

View File

@ -11,6 +11,7 @@ from sqlalchemy import (Boolean, Column, DateTime, Float, Integer, String,
create_engine, desc, func, inspect) create_engine, desc, func, inspect)
from sqlalchemy.exc import NoSuchModuleError from sqlalchemy.exc import NoSuchModuleError
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Query
from sqlalchemy.orm.scoping import scoped_session from sqlalchemy.orm.scoping import scoped_session
from sqlalchemy.orm.session import sessionmaker from sqlalchemy.orm.session import sessionmaker
from sqlalchemy.pool import StaticPool from sqlalchemy.pool import StaticPool
@ -391,12 +392,33 @@ class Trade(_DECL_BASE):
profit_percent = (close_trade_price / open_trade_price) - 1 profit_percent = (close_trade_price / open_trade_price) - 1
return float(f"{profit_percent:.8f}") return float(f"{profit_percent:.8f}")
@staticmethod
def get_trades(trade_filter=None) -> Query:
"""
Helper function to query Trades using filter.
:param trade_filter: Filter to apply to trades
:return: Query object
"""
if trade_filter is not None:
if not isinstance(trade_filter, list):
trade_filter = [trade_filter]
return Trade.query.filter(*trade_filter)
else:
return Trade.query
@staticmethod
def get_open_trades() -> List[Any]:
"""
Query trades from persistence layer
"""
return Trade.get_trades(Trade.is_open.is_(True)).all()
@staticmethod @staticmethod
def get_open_order_trades(): def get_open_order_trades():
""" """
Returns all open trades Returns all open trades
""" """
return Trade.query.filter(Trade.open_order_id.isnot(None)).all() return Trade.get_trades(Trade.open_order_id.isnot(None)).all()
@staticmethod @staticmethod
def total_open_trades_stakes() -> float: def total_open_trades_stakes() -> float:
@ -443,13 +465,6 @@ class Trade(_DECL_BASE):
.order_by(desc('profit_sum')).first() .order_by(desc('profit_sum')).first()
return best_pair return best_pair
@staticmethod
def get_open_trades() -> List[Any]:
"""
Query trades from persistence layer
"""
return Trade.query.filter(Trade.is_open.is_(True)).all()
@staticmethod @staticmethod
def stoploss_reinitialization(desired_stoploss): def stoploss_reinitialization(desired_stoploss):
""" """