Merge pull request #9378 from freqtrade/fix/transformer-dimensions

Bugfix: PyTorchTransformer
This commit is contained in:
Matthias 2023-11-04 15:51:54 +01:00 committed by GitHub
commit 8ce39a6d75
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -27,6 +27,12 @@ class PyTorchTransformerRegressor(BasePyTorchRegressor):
...
"freqai": {
...
"conv_width": 30, // PyTorchTransformer is based on windowing
"feature_parameters": {
...
"include_shifted_candles": 0, // which removes the need for shifted candles
...
},
"model_training_parameters" : {
"learning_rate": 3e-4,
"trainer_kwargs": {
@ -120,16 +126,16 @@ class PyTorchTransformerRegressor(BasePyTorchRegressor):
# create empty torch tensor
self.model.model.eval()
yb = torch.empty(0).to(self.device)
if x.shape[1] > 1:
if x.shape[1] > self.window_size:
ws = self.window_size
for i in range(0, x.shape[1] - ws):
xb = x[:, i:i + ws, :].to(self.device)
y = self.model.model(xb)
yb = torch.cat((yb, y), dim=0)
yb = torch.cat((yb, y), dim=1)
else:
yb = self.model.model(x)
yb = yb.cpu().squeeze()
yb = yb.cpu().squeeze(0)
pred_df = pd.DataFrame(yb.detach().numpy(), columns=dk.label_list)
pred_df, _, _ = dk.label_pipeline.inverse_transform(pred_df)