Implement token/login and token/refresh endpoints

This commit is contained in:
Matthias 2020-05-10 10:35:38 +02:00
parent b72997fc2b
commit 8139058fcc

View File

@ -2,11 +2,16 @@ import logging
import threading
from datetime import date, datetime
from ipaddress import IPv4Address
from typing import Dict, Callable, Any
from typing import Any, Callable, Dict
from arrow import Arrow
from flask import Flask, jsonify, request
from flask.json import JSONEncoder
from flask_jwt_extended import (JWTManager, create_access_token,
create_refresh_token, get_jwt_identity,
jwt_refresh_token_required,
verify_jwt_in_request_optional)
from werkzeug.security import safe_str_cmp
from werkzeug.serving import make_server
from freqtrade.__init__ import __version__
@ -38,9 +43,10 @@ class ArrowJSONEncoder(JSONEncoder):
def require_login(func: Callable[[Any, Any], Any]):
def func_wrapper(obj, *args, **kwargs):
verify_jwt_in_request_optional()
auth = request.authorization
if auth and obj.check_auth(auth.username, auth.password):
i = get_jwt_identity()
if i or auth and obj.check_auth(auth.username, auth.password):
return func(obj, *args, **kwargs)
else:
return jsonify({"error": "Unauthorized"}), 401
@ -70,8 +76,8 @@ class ApiServer(RPC):
"""
def check_auth(self, username, password):
return (username == self._config['api_server'].get('username') and
password == self._config['api_server'].get('password'))
return (safe_str_cmp(username, self._config['api_server'].get('username')) and
safe_str_cmp(password, self._config['api_server'].get('password')))
def __init__(self, freqtrade) -> None:
"""
@ -83,6 +89,11 @@ class ApiServer(RPC):
self._config = freqtrade.config
self.app = Flask(__name__)
# Setup the Flask-JWT-Extended extension
self.app.config['JWT_SECRET_KEY'] = 'super-secret' # Change this!
self.jwt = JWTManager(self.app)
self.app.json_encoder = ArrowJSONEncoder
# Register application handling
@ -148,6 +159,10 @@ class ApiServer(RPC):
self.app.register_error_handler(404, self.page_not_found)
# Actions to control the bot
self.app.add_url_rule(f'{BASE_URI}/token/login', 'login',
view_func=self._login, methods=['POST'])
self.app.add_url_rule(f'{BASE_URI}/token/refresh', 'token_refresh',
view_func=self._refresh_token, methods=['POST'])
self.app.add_url_rule(f'{BASE_URI}/start', 'start',
view_func=self._start, methods=['POST'])
self.app.add_url_rule(f'{BASE_URI}/stop', 'stop', view_func=self._stop, methods=['POST'])
@ -199,6 +214,37 @@ class ApiServer(RPC):
'code': 404
}), 404
@require_login
@rpc_catch_errors
def _login(self):
"""
Handler for /token/login
Returns a JWT token
"""
auth = request.authorization
if auth and self.check_auth(auth.username, auth.password):
keystuff = {'u': auth.username}
ret = {
'access_token': create_access_token(identity=keystuff),
'refresh_token': create_refresh_token(identity=keystuff),
}
return self.rest_dump(ret)
return jsonify({"error": "Unauthorized"}), 401
@jwt_refresh_token_required
@rpc_catch_errors
def _refresh_token(self):
"""
Handler for /token/refresh
Returns a JWT token based on a JWT refresh token
"""
current_user = get_jwt_identity()
new_token = create_access_token(identity=current_user, fresh=False)
ret = {'access_token': new_token}
return self.rest_dump(ret)
@require_login
@rpc_catch_errors
def _start(self):