In this tutorial we will create an OpenID Connect and JWT Authentication system with Django, DRF and SimpleJWT and RSA algorithm.
First, let’s install the required packages:
pip install django
pip install djangorestframework
pip install djangorestframework-simplejwt
pip install pyjwt[crypto]
Next, let’s create a Django app called authentication:
python manage.py startapp authentication
We’ll define the models, serializers, views, and URLs in the authentication app.
In models.py, we’ll define a custom user model that extends the AbstractBaseUser class:
from django.contrib.auth.models import AbstractBaseUser, BaseUserManager
from django.db import models
class CustomUserManager(BaseUserManager):
def create_user(self, email, password=None, **extra_fields):
if not email:
raise ValueError('The Email field must be set')
email = self.normalize_email(email)
user = self.model(email=email, **extra_fields)
user.set_password(password)
user.save(using=self._db)
return user
def create_superuser(self, email, password, **extra_fields):
extra_fields.setdefault('is_staff', True)
extra_fields.setdefault('is_superuser', True)
return self.create_user(email, password, **extra_fields)
class CustomUser(AbstractBaseUser):
email = models.EmailField(unique=True)
first_name = models.CharField(max_length=50, blank=True)
last_name = models.CharField(max_length=50, blank=True)
is_active = models.BooleanField(default=True)
is_staff = models.BooleanField(default=False)
is_superuser = models.BooleanField(default=False)
USERNAME_FIELD = 'email'
objects = CustomUserManager()
def __str__(self):
return self.email
def has_perm(self, perm, obj=None):
return True
def has_module_perms(self, app_label):
return True
@property
def token_payload(self):
return {
'email': self.email,
'first_name': self.first_name,
'last_name': self.last_name,
}
In serializers.py, we’ll define the serializers for the authentication endpoints:
from rest_framework import serializers
class LoginSerializer(serializers.Serializer):
email = serializers.EmailField()
password = serializers.CharField(max_length=128, write_only=True)
class RefreshTokenSerializer(serializers.Serializer):
refresh_token = serializers.CharField()
In views.py, we’ll define the views for the authentication endpoints:
import jwt
import uuid
import requests
from django.conf import settings
from django.http import HttpResponse,Redirect
from django.urls import reverse
from rest_framework import generics, status
from rest_framework.permissions import AllowAny
from rest_framework.response import Response
from rest_framework_simplejwt.tokens import RefreshToken, AccessToken
from .serializers import LoginSerializer, RefreshTokenSerializer
from .models import User
class LoginView(generics.GenericAPIView):
serializer_class = LoginSerializer
permission_classes = (AllowAny,)
def post(self, request):
serializer = self.serializer_class(data=request.data)
serializer.is_valid(raise_exception=True)
email = serializer.validated_data['email']
try:
user = User.objects.get(email=email)
except User.DoesNotExist:
return Response({'error': 'Invalid email or password'}, status=status.HTTP_401_UNAUTHORIZED)
if not user.check_password(serializer.validated_data['password']):
return Response({'error': 'Invalid email or password'}, status=status.HTTP_401_UNAUTHORIZED)
access_token = AccessToken.for_user(user)
refresh_token = RefreshToken.for_user(user)
return Response({
'access_token': str(access_token),
'refresh_token': str(refresh_token),
})
class RefreshTokenView(generics.GenericAPIView):
serializer_class = RefreshTokenSerializer
permission_classes = (AllowAny,)
def post(self, request):
serializer = self.serializer_class(data=request.data)
serializer.is_valid(raise_exception=True)
refresh_token = serializer.validated_data['refresh_token']
try:
token = RefreshToken(refresh_token)
user = User.objects.get(email=token['email'])
access_token = AccessToken.for_user(user)
new_refresh_token = RefreshToken.for_user(user)
return Response({
'access_token': str(access_token),
'refresh_token': str(new_refresh_token),
})
except Exception as e:
return Response({'error': str(e)}, status=status.HTTP_401_UNAUTHORIZED)
def oidc_login(request):
"""
Redirects the user to the OpenID Connect provider's login page
"""
client_id = settings.OIDC_CLIENT_ID
redirect_uri = request.build_absolute_uri(reverse('authentication:oidc_callback'))
state = jwt.encode({'redirect_uri': redirect_uri}, settings.SECRET_KEY, algorithm='HS256')
authorization_endpoint = settings.OIDC_AUTHORIZATION_ENDPOINT
scope = 'openid email profile'
response_type = 'code'
nonce = str(uuid.uuid4())
prompt = 'select_account'
url = f'{authorization_endpoint}?client_id={client_id}&redirect_uri={redirect_uri}&response_type={response_type}&scope={scope}&state={state}&nonce={nonce}&prompt={prompt}'
return HttpResponseRedirect(url)
def oidc_callback(request):
"""
Handles the callback from the OpenID Connect provider after the user logs in
"""
code = request.GET.get('code')
state = request.GET.get('state')
try:
payload = jwt.decode(state, settings.SECRET_KEY, algorithms=['HS256'])
except jwt.InvalidTokenError:
return Response({'error': 'Invalid state'}, status=status.HTTP_400_BAD_REQUEST)
redirect_uri = payload['redirect_uri']
token_endpoint = settings.OIDC_TOKEN_ENDPOINT
client_id = settings.OIDC_CLIENT_ID
client_secret = settings.OIDC_CLIENT_SECRET
grant_type = 'authorization_code'
response = requests.post(token_endpoint, data={
'code': code,
'client_id': client_id,
'client_secret': client_secret,})
if response.status_code != 200:
return Response({'error': 'Unable to retrieve access token'}, status=status.HTTP_400_BAD_REQUEST)
id_token = response.json()['id_token']
decoded_id_token = jwt.decode(id_token, verify=False)
email = decoded_id_token['email']
try:
user = User.objects.get(email=email)
except User.DoesNotExist:
# Create a new user
user = User(email=email, is_active=True)
user.save()
access_token = AccessToken.for_user(user)
refresh_token = RefreshToken.for_user(user)
return HttpResponseRedirect(f'{redirect_uri}?access_token={str(access_token)}&refresh_token={str(refresh_token)}')
Finally, in urls.py, we’ll define the URLs for the authentication endpoints:
from django.urls import path
from rest_framework_simplejwt.views import TokenRefreshView
from .views import OIDCLoginView, OIDCCallbackView, HelloView
urlpatterns = [
path('oidc/login/', OIDCLoginView.as_view(), name='oidc_login'),
path('oidc/callback/', OIDCCallbackView.as_view(), name='oidc_callback'),
path('hello/', HelloView.as_view(), name='hello'),
path('token/refresh/', TokenRefreshView.as_view(), name='token_refresh'),
]
In this example, we’ve defined endpoints for the OpenID Connect login flow (/oidc/login/ and /oidc/callback/), a simple hello world endpoint (/hello/), and the token refresh endpoint provided by Simple JWT (/token/refresh/).
Note that we’ve also imported the TokenRefreshView from Simple JWT and included it in our urlpatterns. This view provides a default implementation for refreshing access tokens, and requires a valid refresh token to be provided in the request body. You can customize the behavior of this view by subclassing it and overriding its methods as needed.
In Summary when using DRF SimpleJWT you can use the following to configure SJWT to authenticate from an Authserver
import requests
from cryptography.x509 import load_pem_x509_certificate
from cryptography.hazmat.backends import default_backend
from django.conf import settings
from rest_framework_simplejwt import authentication, exceptions
def get_public_key():
# Retrieve the public key from the SSO auth server's JWKS endpoint
jwks_uri = settings.SSO_AUTH_SERVER_JWKS_URI
response = requests.get(jwks_uri)
if response.status_code != 200:
raise exceptions.AuthenticationFailed('Failed to retrieve public key')
jwks = response.json()['keys']
for jwk in jwks:
if jwk['kid'] == settings.SSO_AUTH_SERVER_KEY_ID:
certificate = load_pem_x509_certificate(jwk['x5c'][0].encode(), default_backend())
return certificate.public_key()
raise exceptions.AuthenticationFailed('Failed to retrieve public key')
class JWTAuthentication(authentication.JWTAuthentication):
def get_verified_payload(self, raw_token):
# Override the parent class method to use the public key from the SSO auth server to verify the JWT token
public_key = get_public_key()
options = {'verify_signature': True, 'verify_aud': True, 'verify_iss': True}
return authentication.utils.jwt_decode(raw_token, public_key, options=options)
# Configure DRF SimpleJWT to use our custom JWT authentication class
REST_FRAMEWORK = {
'DEFAULT_AUTHENTICATION_CLASSES': [
'path.to.JWTAuthentication',
],
'DEFAULT_PERMISSION_CLASSES': [
'rest_framework.permissions.IsAuthenticated',
],
'DEFAULT_RENDERER_CLASSES': [
'rest_framework.renderers.JSONRenderer',
],
}
# Add the SSO auth server's configuration to the settings
SSO_AUTH_SERVER_JWKS_URI = 'https://sso.authserver.com/.well-known/jwks.json'
SSO_AUTH_SERVER_KEY_ID = 'my_key_id'
In the example above, we define a custom JWTAuthentication class that extends the rest_framework_simplejwt.authentication.JWTAuthentication class. We override the get_verified_payload method to use the public key retrieved from the SSO auth server’s JWKS endpoint to verify the JWT token.
We also define a get_public_key function that retrieves the public key from the SSO auth server’s JWKS endpoint. This function searches for the key with the matching kid value, and returns the public key as a cryptography.hazmat.primitives.asymmetric.rsa.RSAPublicKey object.
Finally, we configure DRF to use our custom authentication class as the default authentication class, and define the SSO auth server’s JWKS endpoint URL and key ID in the settings file.
You can convert the above to a simple class as below

import requests
from rest_framework_simplejwt.authentication import JWTAuthentication
from jwt import decode, InvalidSignatureError
class ExternalJWTAuthentication(JWTAuthentication):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.jwks_url = "https://example.com/.well-known/jwks.json"
def get_public_key(self, kid):
jwks = requests.get(self.jwks_url).json()
for key in jwks["keys"]:
if key["kid"] == kid:
return key
return None
def decode_jwt(self, token, options, *args, **kwargs):
try:
header = decode(token, options={"verify_signature": False, "verify_exp": False, "verify_nbf": False})
kid = header.get("kid", None)
public_key = self.get_public_key(kid)
if public_key:
decoded = decode(token, public_key, algorithms=["RS256"], options=options)
return decoded
else:
raise InvalidSignatureError("Public key not found for kid: {}".format(kid))
except Exception as e:
raise InvalidSignatureError("Error verifying signature: {}".format(e))