mirror of
https://github.com/freqtrade/freqtrade.git
synced 2024-11-10 10:21:59 +00:00
fix: swap tensor dimension to play nicely with pandas
This commit is contained in:
parent
c94c667fb1
commit
72dc65cb6a
|
@ -27,6 +27,12 @@ class PyTorchTransformerRegressor(BasePyTorchRegressor):
|
||||||
...
|
...
|
||||||
"freqai": {
|
"freqai": {
|
||||||
...
|
...
|
||||||
|
"conv_width": 30, // PyTorchTransformer is based on windowing
|
||||||
|
"feature_parameters": {
|
||||||
|
...
|
||||||
|
"include_shifted_candles": 0, // which removes the need for shifted candles
|
||||||
|
...
|
||||||
|
},
|
||||||
"model_training_parameters" : {
|
"model_training_parameters" : {
|
||||||
"learning_rate": 3e-4,
|
"learning_rate": 3e-4,
|
||||||
"trainer_kwargs": {
|
"trainer_kwargs": {
|
||||||
|
@ -120,16 +126,16 @@ class PyTorchTransformerRegressor(BasePyTorchRegressor):
|
||||||
# create empty torch tensor
|
# create empty torch tensor
|
||||||
self.model.model.eval()
|
self.model.model.eval()
|
||||||
yb = torch.empty(0).to(self.device)
|
yb = torch.empty(0).to(self.device)
|
||||||
if x.shape[1] > 1:
|
if x.shape[1] > self.window_size:
|
||||||
ws = self.window_size
|
ws = self.window_size
|
||||||
for i in range(0, x.shape[1] - ws):
|
for i in range(0, x.shape[1] - ws):
|
||||||
xb = x[:, i:i + ws, :].to(self.device)
|
xb = x[:, i:i + ws, :].to(self.device)
|
||||||
y = self.model.model(xb)
|
y = self.model.model(xb)
|
||||||
yb = torch.cat((yb, y), dim=0)
|
yb = torch.cat((yb, y), dim=1)
|
||||||
else:
|
else:
|
||||||
yb = self.model.model(x)
|
yb = self.model.model(x)
|
||||||
|
|
||||||
yb = yb.cpu().squeeze()
|
yb = yb.cpu().squeeze(0)
|
||||||
pred_df = pd.DataFrame(yb.detach().numpy(), columns=dk.label_list)
|
pred_df = pd.DataFrame(yb.detach().numpy(), columns=dk.label_list)
|
||||||
pred_df, _, _ = dk.label_pipeline.inverse_transform(pred_df)
|
pred_df, _, _ = dk.label_pipeline.inverse_transform(pred_df)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user