diff --git a/freqtrade/freqai/prediction_models/PyTorchTransformerRegressor.py b/freqtrade/freqai/prediction_models/PyTorchTransformerRegressor.py index 846d6df2e..b1f2eecc6 100644 --- a/freqtrade/freqai/prediction_models/PyTorchTransformerRegressor.py +++ b/freqtrade/freqai/prediction_models/PyTorchTransformerRegressor.py @@ -27,6 +27,12 @@ class PyTorchTransformerRegressor(BasePyTorchRegressor): ... "freqai": { ... + "conv_width": 30, // PyTorchTransformer is based on windowing + "feature_parameters": { + ... + "include_shifted_candles": 0, // which removes the need for shifted candles + ... + }, "model_training_parameters" : { "learning_rate": 3e-4, "trainer_kwargs": { @@ -120,16 +126,16 @@ class PyTorchTransformerRegressor(BasePyTorchRegressor): # create empty torch tensor self.model.model.eval() yb = torch.empty(0).to(self.device) - if x.shape[1] > 1: + if x.shape[1] > self.window_size: ws = self.window_size for i in range(0, x.shape[1] - ws): xb = x[:, i:i + ws, :].to(self.device) y = self.model.model(xb) - yb = torch.cat((yb, y), dim=0) + yb = torch.cat((yb, y), dim=1) else: yb = self.model.model(x) - yb = yb.cpu().squeeze() + yb = yb.cpu().squeeze(0) pred_df = pd.DataFrame(yb.detach().numpy(), columns=dk.label_list) pred_df, _, _ = dk.label_pipeline.inverse_transform(pred_df)