In this tutorial we will create an OpenID Connect and JWT Authentication system with Django, DRF and SimpleJWT and RSA algorithm.
First, let’s install the required packages:
pip install django
pip install djangorestframework
pip install djangorestframework-simplejwt
pip install pyjwt[crypto]
Next, let’s create a Django app called authentication
:
python manage.py startapp authentication
We’ll define the models, serializers, views, and URLs in the authentication
app.
In models.py
, we’ll define a custom user model that extends the AbstractBaseUser
class:
from django.contrib.auth.models import AbstractBaseUser, BaseUserManager from django.db import models class CustomUserManager(BaseUserManager): def create_user(self, email, password=None, **extra_fields): if not email: raise ValueError('The Email field must be set') email = self.normalize_email(email) user = self.model(email=email, **extra_fields) user.set_password(password) user.save(using=self._db) return user def create_superuser(self, email, password, **extra_fields): extra_fields.setdefault('is_staff', True) extra_fields.setdefault('is_superuser', True) return self.create_user(email, password, **extra_fields) class CustomUser(AbstractBaseUser): email = models.EmailField(unique=True) first_name = models.CharField(max_length=50, blank=True) last_name = models.CharField(max_length=50, blank=True) is_active = models.BooleanField(default=True) is_staff = models.BooleanField(default=False) is_superuser = models.BooleanField(default=False) USERNAME_FIELD = 'email' objects = CustomUserManager() def __str__(self): return self.email def has_perm(self, perm, obj=None): return True def has_module_perms(self, app_label): return True @property def token_payload(self): return { 'email': self.email, 'first_name': self.first_name, 'last_name': self.last_name, }
In serializers.py
, we’ll define the serializers for the authentication endpoints:
from rest_framework import serializers class LoginSerializer(serializers.Serializer): email = serializers.EmailField() password = serializers.CharField(max_length=128, write_only=True) class RefreshTokenSerializer(serializers.Serializer): refresh_token = serializers.CharField()
In views.py
, we’ll define the views for the authentication endpoints:
import jwt import uuid import requests from django.conf import settings from django.http import HttpResponse,Redirect from django.urls import reverse from rest_framework import generics, status from rest_framework.permissions import AllowAny from rest_framework.response import Response from rest_framework_simplejwt.tokens import RefreshToken, AccessToken from .serializers import LoginSerializer, RefreshTokenSerializer from .models import User class LoginView(generics.GenericAPIView): serializer_class = LoginSerializer permission_classes = (AllowAny,) def post(self, request): serializer = self.serializer_class(data=request.data) serializer.is_valid(raise_exception=True) email = serializer.validated_data['email'] try: user = User.objects.get(email=email) except User.DoesNotExist: return Response({'error': 'Invalid email or password'}, status=status.HTTP_401_UNAUTHORIZED) if not user.check_password(serializer.validated_data['password']): return Response({'error': 'Invalid email or password'}, status=status.HTTP_401_UNAUTHORIZED) access_token = AccessToken.for_user(user) refresh_token = RefreshToken.for_user(user) return Response({ 'access_token': str(access_token), 'refresh_token': str(refresh_token), }) class RefreshTokenView(generics.GenericAPIView): serializer_class = RefreshTokenSerializer permission_classes = (AllowAny,) def post(self, request): serializer = self.serializer_class(data=request.data) serializer.is_valid(raise_exception=True) refresh_token = serializer.validated_data['refresh_token'] try: token = RefreshToken(refresh_token) user = User.objects.get(email=token['email']) access_token = AccessToken.for_user(user) new_refresh_token = RefreshToken.for_user(user) return Response({ 'access_token': str(access_token), 'refresh_token': str(new_refresh_token), }) except Exception as e: return Response({'error': str(e)}, status=status.HTTP_401_UNAUTHORIZED) def oidc_login(request): """ Redirects the user to the OpenID Connect provider's login page """ client_id = settings.OIDC_CLIENT_ID redirect_uri = request.build_absolute_uri(reverse('authentication:oidc_callback')) state = jwt.encode({'redirect_uri': redirect_uri}, settings.SECRET_KEY, algorithm='HS256') authorization_endpoint = settings.OIDC_AUTHORIZATION_ENDPOINT scope = 'openid email profile' response_type = 'code' nonce = str(uuid.uuid4()) prompt = 'select_account' url = f'{authorization_endpoint}?client_id={client_id}&redirect_uri={redirect_uri}&response_type={response_type}&scope={scope}&state={state}&nonce={nonce}&prompt={prompt}' return HttpResponseRedirect(url) def oidc_callback(request): """ Handles the callback from the OpenID Connect provider after the user logs in """ code = request.GET.get('code') state = request.GET.get('state') try: payload = jwt.decode(state, settings.SECRET_KEY, algorithms=['HS256']) except jwt.InvalidTokenError: return Response({'error': 'Invalid state'}, status=status.HTTP_400_BAD_REQUEST) redirect_uri = payload['redirect_uri'] token_endpoint = settings.OIDC_TOKEN_ENDPOINT client_id = settings.OIDC_CLIENT_ID client_secret = settings.OIDC_CLIENT_SECRET grant_type = 'authorization_code' response = requests.post(token_endpoint, data={ 'code': code, 'client_id': client_id, 'client_secret': client_secret,}) if response.status_code != 200: return Response({'error': 'Unable to retrieve access token'}, status=status.HTTP_400_BAD_REQUEST) id_token = response.json()['id_token'] decoded_id_token = jwt.decode(id_token, verify=False) email = decoded_id_token['email'] try: user = User.objects.get(email=email) except User.DoesNotExist: # Create a new user user = User(email=email, is_active=True) user.save() access_token = AccessToken.for_user(user) refresh_token = RefreshToken.for_user(user) return HttpResponseRedirect(f'{redirect_uri}?access_token={str(access_token)}&refresh_token={str(refresh_token)}')
Finally, in urls.py
, we’ll define the URLs for the authentication endpoints:
from django.urls import path from rest_framework_simplejwt.views import TokenRefreshView from .views import OIDCLoginView, OIDCCallbackView, HelloView urlpatterns = [ path('oidc/login/', OIDCLoginView.as_view(), name='oidc_login'), path('oidc/callback/', OIDCCallbackView.as_view(), name='oidc_callback'), path('hello/', HelloView.as_view(), name='hello'), path('token/refresh/', TokenRefreshView.as_view(), name='token_refresh'), ]
In this example, we’ve defined endpoints for the OpenID Connect login flow (/oidc/login/
and /oidc/callback/
), a simple hello world endpoint (/hello/
), and the token refresh endpoint provided by Simple JWT (/token/refresh/
).
Note that we’ve also imported the TokenRefreshView
from Simple JWT and included it in our urlpatterns. This view provides a default implementation for refreshing access tokens, and requires a valid refresh token to be provided in the request body. You can customize the behavior of this view by subclassing it and overriding its methods as needed.
In Summary when using DRF SimpleJWT you can use the following to configure SJWT to authenticate from an Authserver
import requests from cryptography.x509 import load_pem_x509_certificate from cryptography.hazmat.backends import default_backend from django.conf import settings from rest_framework_simplejwt import authentication, exceptions def get_public_key(): # Retrieve the public key from the SSO auth server's JWKS endpoint jwks_uri = settings.SSO_AUTH_SERVER_JWKS_URI response = requests.get(jwks_uri) if response.status_code != 200: raise exceptions.AuthenticationFailed('Failed to retrieve public key') jwks = response.json()['keys'] for jwk in jwks: if jwk['kid'] == settings.SSO_AUTH_SERVER_KEY_ID: certificate = load_pem_x509_certificate(jwk['x5c'][0].encode(), default_backend()) return certificate.public_key() raise exceptions.AuthenticationFailed('Failed to retrieve public key') class JWTAuthentication(authentication.JWTAuthentication): def get_verified_payload(self, raw_token): # Override the parent class method to use the public key from the SSO auth server to verify the JWT token public_key = get_public_key() options = {'verify_signature': True, 'verify_aud': True, 'verify_iss': True} return authentication.utils.jwt_decode(raw_token, public_key, options=options) # Configure DRF SimpleJWT to use our custom JWT authentication class REST_FRAMEWORK = { 'DEFAULT_AUTHENTICATION_CLASSES': [ 'path.to.JWTAuthentication', ], 'DEFAULT_PERMISSION_CLASSES': [ 'rest_framework.permissions.IsAuthenticated', ], 'DEFAULT_RENDERER_CLASSES': [ 'rest_framework.renderers.JSONRenderer', ], } # Add the SSO auth server's configuration to the settings SSO_AUTH_SERVER_JWKS_URI = 'https://sso.authserver.com/.well-known/jwks.json' SSO_AUTH_SERVER_KEY_ID = 'my_key_id'
In the example above, we define a custom JWTAuthentication
class that extends the rest_framework_simplejwt.authentication.JWTAuthentication
class. We override the get_verified_payload
method to use the public key retrieved from the SSO auth server’s JWKS endpoint to verify the JWT token.
We also define a get_public_key
function that retrieves the public key from the SSO auth server’s JWKS endpoint. This function searches for the key with the matching kid
value, and returns the public key as a cryptography.hazmat.primitives.asymmetric.rsa.RSAPublicKey
object.
Finally, we configure DRF to use our custom authentication class as the default authentication class, and define the SSO auth server’s JWKS endpoint URL and key ID in the settings file.
You can convert the above to a simple class as below
import requests from rest_framework_simplejwt.authentication import JWTAuthentication from jwt import decode, InvalidSignatureError class ExternalJWTAuthentication(JWTAuthentication): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.jwks_url = "https://example.com/.well-known/jwks.json" def get_public_key(self, kid): jwks = requests.get(self.jwks_url).json() for key in jwks["keys"]: if key["kid"] == kid: return key return None def decode_jwt(self, token, options, *args, **kwargs): try: header = decode(token, options={"verify_signature": False, "verify_exp": False, "verify_nbf": False}) kid = header.get("kid", None) public_key = self.get_public_key(kid) if public_key: decoded = decode(token, public_key, algorithms=["RS256"], options=options) return decoded else: raise InvalidSignatureError("Public key not found for kid: {}".format(kid)) except Exception as e: raise InvalidSignatureError("Error verifying signature: {}".format(e))