mirror of
https://github.com/freqtrade/freqtrade.git
synced 2024-09-20 09:31:12 +00:00
fix: swap tensor dimension to play nicely with pandas
This commit is contained in:
parent
c94c667fb1
commit
72dc65cb6a
|
@ -27,6 +27,12 @@ class PyTorchTransformerRegressor(BasePyTorchRegressor):
|
|||
...
|
||||
"freqai": {
|
||||
...
|
||||
"conv_width": 30, // PyTorchTransformer is based on windowing
|
||||
"feature_parameters": {
|
||||
...
|
||||
"include_shifted_candles": 0, // which removes the need for shifted candles
|
||||
...
|
||||
},
|
||||
"model_training_parameters" : {
|
||||
"learning_rate": 3e-4,
|
||||
"trainer_kwargs": {
|
||||
|
@ -120,16 +126,16 @@ class PyTorchTransformerRegressor(BasePyTorchRegressor):
|
|||
# create empty torch tensor
|
||||
self.model.model.eval()
|
||||
yb = torch.empty(0).to(self.device)
|
||||
if x.shape[1] > 1:
|
||||
if x.shape[1] > self.window_size:
|
||||
ws = self.window_size
|
||||
for i in range(0, x.shape[1] - ws):
|
||||
xb = x[:, i:i + ws, :].to(self.device)
|
||||
y = self.model.model(xb)
|
||||
yb = torch.cat((yb, y), dim=0)
|
||||
yb = torch.cat((yb, y), dim=1)
|
||||
else:
|
||||
yb = self.model.model(x)
|
||||
|
||||
yb = yb.cpu().squeeze()
|
||||
yb = yb.cpu().squeeze(0)
|
||||
pred_df = pd.DataFrame(yb.detach().numpy(), columns=dk.label_list)
|
||||
pred_df, _, _ = dk.label_pipeline.inverse_transform(pred_df)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user