JWT Authentication System with Django, SSO and DRF

In this tutorial we will create an OpenID Connect and JWT Authentication system with Django, DRF and SimpleJWT and RSA algorithm.

First, let’s install the required packages:

pip install django
pip install djangorestframework
pip install djangorestframework-simplejwt
pip install pyjwt[crypto]

Next, let’s create a Django app called authentication:

python manage.py startapp authentication

We’ll define the models, serializers, views, and URLs in the authentication app.

In models.py, we’ll define a custom user model that extends the AbstractBaseUser class:

from django.contrib.auth.models import AbstractBaseUser, BaseUserManager
from django.db import models


class CustomUserManager(BaseUserManager):
    def create_user(self, email, password=None, **extra_fields):
        if not email:
            raise ValueError('The Email field must be set')

        email = self.normalize_email(email)
        user = self.model(email=email, **extra_fields)
        user.set_password(password)
        user.save(using=self._db)

        return user

    def create_superuser(self, email, password, **extra_fields):
        extra_fields.setdefault('is_staff', True)
        extra_fields.setdefault('is_superuser', True)

        return self.create_user(email, password, **extra_fields)


class CustomUser(AbstractBaseUser):
    email = models.EmailField(unique=True)
    first_name = models.CharField(max_length=50, blank=True)
    last_name = models.CharField(max_length=50, blank=True)
    is_active = models.BooleanField(default=True)
    is_staff = models.BooleanField(default=False)
    is_superuser = models.BooleanField(default=False)

    USERNAME_FIELD = 'email'

    objects = CustomUserManager()

    def __str__(self):
        return self.email

    def has_perm(self, perm, obj=None):
        return True

    def has_module_perms(self, app_label):
        return True

    @property
    def token_payload(self):
        return {
            'email': self.email,
            'first_name': self.first_name,
            'last_name': self.last_name,
        }

In serializers.py, we’ll define the serializers for the authentication endpoints:

from rest_framework import serializers


class LoginSerializer(serializers.Serializer):
    email = serializers.EmailField()
    password = serializers.CharField(max_length=128, write_only=True)


class RefreshTokenSerializer(serializers.Serializer):
    refresh_token = serializers.CharField()

In views.py, we’ll define the views for the authentication endpoints:

import jwt
import uuid
import requests

from django.conf import settings
from django.http import HttpResponse,Redirect
from django.urls import reverse
from rest_framework import generics, status
from rest_framework.permissions import AllowAny
from rest_framework.response import Response
from rest_framework_simplejwt.tokens import RefreshToken, AccessToken

from .serializers import LoginSerializer, RefreshTokenSerializer
from .models import User


class LoginView(generics.GenericAPIView):
    serializer_class = LoginSerializer
    permission_classes = (AllowAny,)

    def post(self, request):
        serializer = self.serializer_class(data=request.data)
        serializer.is_valid(raise_exception=True)

        email = serializer.validated_data['email']

        try:
            user = User.objects.get(email=email)
        except User.DoesNotExist:
            return Response({'error': 'Invalid email or password'}, status=status.HTTP_401_UNAUTHORIZED)

        if not user.check_password(serializer.validated_data['password']):
            return Response({'error': 'Invalid email or password'}, status=status.HTTP_401_UNAUTHORIZED)

        access_token = AccessToken.for_user(user)
        refresh_token = RefreshToken.for_user(user)

        return Response({
            'access_token': str(access_token),
            'refresh_token': str(refresh_token),
        })


class RefreshTokenView(generics.GenericAPIView):
    serializer_class = RefreshTokenSerializer
    permission_classes = (AllowAny,)

    def post(self, request):
        serializer = self.serializer_class(data=request.data)
        serializer.is_valid(raise_exception=True)

        refresh_token = serializer.validated_data['refresh_token']

        try:
            token = RefreshToken(refresh_token)
            user = User.objects.get(email=token['email'])

            access_token = AccessToken.for_user(user)
            new_refresh_token = RefreshToken.for_user(user)

            return Response({
                'access_token': str(access_token),
                'refresh_token': str(new_refresh_token),
            })
        except Exception as e:
            return Response({'error': str(e)}, status=status.HTTP_401_UNAUTHORIZED)


def oidc_login(request):
    """
    Redirects the user to the OpenID Connect provider's login page
    """
    client_id = settings.OIDC_CLIENT_ID
    redirect_uri = request.build_absolute_uri(reverse('authentication:oidc_callback'))
    state = jwt.encode({'redirect_uri': redirect_uri}, settings.SECRET_KEY, algorithm='HS256')

    authorization_endpoint = settings.OIDC_AUTHORIZATION_ENDPOINT
    scope = 'openid email profile'
    response_type = 'code'
    nonce = str(uuid.uuid4())
    prompt = 'select_account'

    url = f'{authorization_endpoint}?client_id={client_id}&redirect_uri={redirect_uri}&response_type={response_type}&scope={scope}&state={state}&nonce={nonce}&prompt={prompt}'

    return HttpResponseRedirect(url)


def oidc_callback(request):
    """
    Handles the callback from the OpenID Connect provider after the user logs in
    """
    code = request.GET.get('code')
    state = request.GET.get('state')

    try:
        payload = jwt.decode(state, settings.SECRET_KEY, algorithms=['HS256'])
    except jwt.InvalidTokenError:
        return Response({'error': 'Invalid state'}, status=status.HTTP_400_BAD_REQUEST)

    redirect_uri = payload['redirect_uri']

    token_endpoint = settings.OIDC_TOKEN_ENDPOINT
    client_id = settings.OIDC_CLIENT_ID
    client_secret = settings.OIDC_CLIENT_SECRET
    grant_type = 'authorization_code'

    response = requests.post(token_endpoint, data={
        'code': code,
        'client_id': client_id,
        'client_secret': client_secret,})
   if response.status_code != 200:
    return Response({'error': 'Unable to retrieve access token'}, status=status.HTTP_400_BAD_REQUEST)

id_token = response.json()['id_token']
decoded_id_token = jwt.decode(id_token, verify=False)
email = decoded_id_token['email']

try:
    user = User.objects.get(email=email)
except User.DoesNotExist:
    # Create a new user
    user = User(email=email, is_active=True)
    user.save()

access_token = AccessToken.for_user(user)
refresh_token = RefreshToken.for_user(user)

return HttpResponseRedirect(f'{redirect_uri}?access_token={str(access_token)}&refresh_token={str(refresh_token)}')

Finally, in urls.py, we’ll define the URLs for the authentication endpoints:

from django.urls import path
from rest_framework_simplejwt.views import TokenRefreshView

from .views import OIDCLoginView, OIDCCallbackView, HelloView

urlpatterns = [
    path('oidc/login/', OIDCLoginView.as_view(), name='oidc_login'),
    path('oidc/callback/', OIDCCallbackView.as_view(), name='oidc_callback'),
    path('hello/', HelloView.as_view(), name='hello'),
    path('token/refresh/', TokenRefreshView.as_view(), name='token_refresh'),
]

In this example, we’ve defined endpoints for the OpenID Connect login flow (/oidc/login/ and /oidc/callback/), a simple hello world endpoint (/hello/), and the token refresh endpoint provided by Simple JWT (/token/refresh/).

Note that we’ve also imported the TokenRefreshView from Simple JWT and included it in our urlpatterns. This view provides a default implementation for refreshing access tokens, and requires a valid refresh token to be provided in the request body. You can customize the behavior of this view by subclassing it and overriding its methods as needed.

In Summary when using DRF SimpleJWT you can use the following to configure SJWT to authenticate from an Authserver

import requests
from cryptography.x509 import load_pem_x509_certificate
from cryptography.hazmat.backends import default_backend
from django.conf import settings
from rest_framework_simplejwt import authentication, exceptions


def get_public_key():
    # Retrieve the public key from the SSO auth server's JWKS endpoint
    jwks_uri = settings.SSO_AUTH_SERVER_JWKS_URI
    response = requests.get(jwks_uri)

    if response.status_code != 200:
        raise exceptions.AuthenticationFailed('Failed to retrieve public key')

    jwks = response.json()['keys']
    for jwk in jwks:
        if jwk['kid'] == settings.SSO_AUTH_SERVER_KEY_ID:
            certificate = load_pem_x509_certificate(jwk['x5c'][0].encode(), default_backend())
            return certificate.public_key()

    raise exceptions.AuthenticationFailed('Failed to retrieve public key')


class JWTAuthentication(authentication.JWTAuthentication):
    def get_verified_payload(self, raw_token):
        # Override the parent class method to use the public key from the SSO auth server to verify the JWT token
        public_key = get_public_key()
        options = {'verify_signature': True, 'verify_aud': True, 'verify_iss': True}
        return authentication.utils.jwt_decode(raw_token, public_key, options=options)


# Configure DRF SimpleJWT to use our custom JWT authentication class
REST_FRAMEWORK = {
    'DEFAULT_AUTHENTICATION_CLASSES': [
        'path.to.JWTAuthentication',
    ],
    'DEFAULT_PERMISSION_CLASSES': [
        'rest_framework.permissions.IsAuthenticated',
    ],
    'DEFAULT_RENDERER_CLASSES': [
        'rest_framework.renderers.JSONRenderer',
    ],
}

# Add the SSO auth server's configuration to the settings
SSO_AUTH_SERVER_JWKS_URI = 'https://sso.authserver.com/.well-known/jwks.json'
SSO_AUTH_SERVER_KEY_ID = 'my_key_id'

In the example above, we define a custom JWTAuthentication class that extends the rest_framework_simplejwt.authentication.JWTAuthentication class. We override the get_verified_payload method to use the public key retrieved from the SSO auth server’s JWKS endpoint to verify the JWT token.

We also define a get_public_key function that retrieves the public key from the SSO auth server’s JWKS endpoint. This function searches for the key with the matching kid value, and returns the public key as a cryptography.hazmat.primitives.asymmetric.rsa.RSAPublicKey object.

Finally, we configure DRF to use our custom authentication class as the default authentication class, and define the SSO auth server’s JWKS endpoint URL and key ID in the settings file.

You can convert the above to a simple class as below

import requests
from rest_framework_simplejwt.authentication import JWTAuthentication
from jwt import decode, InvalidSignatureError

class ExternalJWTAuthentication(JWTAuthentication):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.jwks_url = "https://example.com/.well-known/jwks.json"

    def get_public_key(self, kid):
        jwks = requests.get(self.jwks_url).json()
        for key in jwks["keys"]:
            if key["kid"] == kid:
                return key
        return None

    def decode_jwt(self, token, options, *args, **kwargs):
        try:
            header = decode(token, options={"verify_signature": False, "verify_exp": False, "verify_nbf": False})
            kid = header.get("kid", None)
            public_key = self.get_public_key(kid)
            if public_key:
                decoded = decode(token, public_key, algorithms=["RS256"], options=options)
                return decoded
            else:
                raise InvalidSignatureError("Public key not found for kid: {}".format(kid))
        except Exception as e:
            raise InvalidSignatureError("Error verifying signature: {}".format(e))

Thanks For Your Attention

Jesus Saves

Leave a Comment

Your email address will not be published. Required fields are marked *