Verify Cognito JWT in a Flask Application
Send Cognito JWT from React client side and verify it on the Flask Server Side
Send the Cognito JWT in the header in an API request
The JWT is stored in the browser's local storage during sign-in.
Auth.signIn(email, password)
.then(user => {
console.log('user', user)
localStorage.setItem("access_token", user.signInUserSession.accessToken.jwtToken)
window.location.href = "/"
})
Get the JWT from the local storage of the browser, and add it as the value of Authorization
parameter in the headers
of the API request.
const backend_url = `${process.env.REACT_APP_BACKEND_URL}/api/activities/home`
const res = await fetch(backend_url, {
headers: {
Authorization: `Bearer ${localStorage.getItem("access_token")}`
},
method: "GET"
});
Get the JWT from the API request on the server side
from flask import request
@app.route("/api/activities/home", methods=['GET'])
def data_home():
jwt_token = request.headers.get("Authorization")
app.logger.debug(f"Auth token: {jwt_token}")
Verify the Cognito JWT token on the server side
Use the codes in another repo Flask-AWSCognito
Did some changes to make the codes usable in this repo. The new codes are list here.
import time
import requests
from jose import jwk, jwt
from jose.exceptions import JOSEError
from jose.utils import base64url_decode
class FlaskAWSCognitoError(Exception):
pass
class TokenVerifyError(Exception):
pass
class CognitoJwtToken:
def __init__(self, user_pool_id, user_pool_client_id, region, request_client=None):
self.region = region
if not self.region:
raise FlaskAWSCognitoError("No AWS region provided")
self.user_pool_id = user_pool_id
self.user_pool_client_id = user_pool_client_id
self.claims = None
if not request_client:
self.request_client = requests.get
else:
self.request_client = request_client
self._load_jwk_keys()
@classmethod
def extract_access_token(self, auth_header):
access_token = auth_header
if auth_header and " " in auth_header:
_, access_token = auth_header.split()
return access_token
def _load_jwk_keys(self):
keys_url = f"https://cognito-idp.{self.region}.amazonaws.com/{self.user_pool_id}/.well-known/jwks.json"
try:
response = self.request_client(keys_url)
self.jwk_keys = response.json()["keys"]
except requests.exceptions.RequestException as e:
raise FlaskAWSCognitoError(str(e)) from e
@staticmethod
def _extract_headers(token):
try:
headers = jwt.get_unverified_headers(token)
return headers
except JOSEError as e:
raise TokenVerifyError(str(e)) from e
def _find_pkey(self, headers):
kid = headers["kid"]
# search for the kid in the downloaded public keys
key_index = -1
for i in range(len(self.jwk_keys)):
if kid == self.jwk_keys[i]["kid"]:
key_index = i
break
if key_index == -1:
raise TokenVerifyError("Public key not found in jwks.json")
return self.jwk_keys[key_index]
@staticmethod
def _verify_signature(token, pkey_data):
try:
# construct the public key
public_key = jwk.construct(pkey_data)
except JOSEError as e:
raise TokenVerifyError(str(e)) from e
# get the last two sections of the token,
# message and signature (encoded in base64)
message, encoded_signature = str(token).rsplit(".", 1)
# decode the signature
decoded_signature = base64url_decode(encoded_signature.encode("utf-8"))
# verify the signature
if not public_key.verify(message.encode("utf8"), decoded_signature):
raise TokenVerifyError("Signature verification failed")
@staticmethod
def _extract_claims(token):
try:
claims = jwt.get_unverified_claims(token)
return claims
except JOSEError as e:
raise TokenVerifyError(str(e)) from e
@staticmethod
def _check_expiration(claims, current_time):
if not current_time:
current_time = time.time()
if current_time > claims["exp"]:
raise TokenVerifyError("Token is expired") # probably another exception
def _check_audience(self, claims):
# and the Audience (use claims['client_id'] if verifying an access token)
audience = claims["aud"] if "aud" in claims else claims["client_id"]
if audience != self.user_pool_client_id:
raise TokenVerifyError("Token was not issued for this audience")
def verify(self, token, current_time=None):
""" https://github.com/awslabs/aws-support-tools/blob/master/Cognito/decode-verify-jwt/decode-verify-jwt.py """
if not token:
raise TokenVerifyError("No token provided")
headers = self._extract_headers(token)
pkey_data = self._find_pkey(headers)
self._verify_signature(token, pkey_data)
claims = self._extract_claims(token)
self._check_expiration(claims, current_time)
self._check_audience(claims)
self.claims = claims
return claims
Call the code to verify JWT on the server side
Need to install a library python-jose
. Add this line in requirements.txt
python-jose
Call the code to verify JWT.
# For Cognito JWT token
from lib.cognito_jwt_token import CognitoJwtToken, TokenVerifyError
jwt_service = CognitoJwtToken(
user_pool_id = os.getenv("AWS_COGNITO_USER_POOL_ID"),
user_pool_client_id = os.getenv("AWS_COGNITO_USER_POOL_CLIENT_ID"),
region = os.getenv("AWS_DEFAULT_REGION")
)
@app.route("/api/activities/home", methods=['GET'])
def data_home():
auth_header = request.headers.get("Authorization")
access_token = CognitoJwtToken.extract_access_token(auth_header)
try:
claims = jwt_service.verify(access_token)
app.logger.debug(f"Authenticated")
app.logger.debug(f"Claims: {claims}")
data = HomeActivities.run(LOGGER, claims["username"])
except TokenVerifyError as e:
app.logger.debug(f"Unauthenticated")
data = HomeActivities.run(LOGGER)
return data, 200
Only when the JWT is verified successfully, the username in the claims is passed into HomeActivities.run
. This is the way in the app to distinguish if the user is authenticated.