From 6420d8f81d625c417df33c49821ab177e3760e0b Mon Sep 17 00:00:00 2001
From: "T. Franzel" <tfranzel@users.noreply.github.com>
Date: Wed, 16 Apr 2025 18:57:56 +0200
Subject: [PATCH] Add allauth's DRF token auth #1401

Co-authored-by: Ahmed Kamal <27medkamal@users.noreply.github.com>
---
 drf_spectacular/contrib/__init__.py       |  1 +
 drf_spectacular/contrib/django_allauth.py | 15 ++++++++++
 requirements/optionals.txt                |  2 +-
 tests/conftest.py                         |  6 ++++
 tests/contrib/test_django_allauth.py      | 36 +++++++++++++++++++++++
 tests/contrib/test_rest_auth.py           |  9 ++++++
 6 files changed, 68 insertions(+), 1 deletion(-)
 create mode 100644 drf_spectacular/contrib/django_allauth.py
 create mode 100644 tests/contrib/test_django_allauth.py

diff --git a/drf_spectacular/contrib/__init__.py b/drf_spectacular/contrib/__init__.py
index 4da50f0a..ae24ae4f 100644
--- a/drf_spectacular/contrib/__init__.py
+++ b/drf_spectacular/contrib/__init__.py
@@ -1,5 +1,6 @@
 __all__ = [
     'django_oauth_toolkit',
+    'django_allauth',
     'djangorestframework_camel_case',
     'rest_auth',
     'rest_framework',
diff --git a/drf_spectacular/contrib/django_allauth.py b/drf_spectacular/contrib/django_allauth.py
new file mode 100644
index 00000000..97c5742e
--- /dev/null
+++ b/drf_spectacular/contrib/django_allauth.py
@@ -0,0 +1,15 @@
+from drf_spectacular.extensions import OpenApiAuthenticationExtension
+
+
+class XSessionTokenAuthenticationScheme(OpenApiAuthenticationExtension):
+    target_class = 'allauth.headless.contrib.rest_framework.authentication.XSessionTokenAuthentication'
+    name = 'XSessionTokenAuth'
+    optional = True
+
+    def get_security_definition(self, auto_schema):
+        return {
+            "type": "apiKey",
+            "in": "header",
+            "name": "X-Session-Token",
+            "description": "X-Session-Token authentication",
+        }
diff --git a/requirements/optionals.txt b/requirements/optionals.txt
index 96dc77ad..646f15c8 100644
--- a/requirements/optionals.txt
+++ b/requirements/optionals.txt
@@ -1,4 +1,3 @@
-django-allauth<0.55.0  # breaking change breaking dj-rest-auth
 drf-jwt>=0.13.0
 dj-rest-auth>=1.0.0
 djangorestframework-simplejwt>=4.4.0
@@ -16,3 +15,4 @@ djangorestframework-dataclasses>=1.0.0; python_version >= '3.7'
 djangorestframework-gis>=1.0.0
 pydantic>=2,<3; python_version >= '3.7'
 django-rest-knox>=4.1
+django-allauth[socialaccount]
diff --git a/tests/conftest.py b/tests/conftest.py
index 0fa8463a..9c276c75 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -22,6 +22,7 @@ def pytest_configure(config):
         'dj_rest_auth.registration',
         'allauth',
         'allauth.account',
+        'allauth.socialaccount',
         'oauth2_provider',
         'django_filters',
         'knox',
@@ -30,6 +31,10 @@ def pytest_configure(config):
         # 'polymorphic',
         # 'rest_framework_jwt',
     ]
+    try:
+        from allauth import __version__ as allauth_version
+    except ImportError:
+        allauth_version = ""
 
     # only load GIS if library is installed. This is required for the GIS test to work
     if is_gis_installed():
@@ -71,6 +76,7 @@ def pytest_configure(config):
         MIDDLEWARE=(
             'django.contrib.sessions.middleware.SessionMiddleware',
             'django.middleware.common.CommonMiddleware',
+            *(['allauth.account.middleware.AccountMiddleware'] if allauth_version > "0.60.0" else []),
             'django.contrib.auth.middleware.AuthenticationMiddleware',
             'django.middleware.locale.LocaleMiddleware',
         ),
diff --git a/tests/contrib/test_django_allauth.py b/tests/contrib/test_django_allauth.py
new file mode 100644
index 00000000..bf3a9697
--- /dev/null
+++ b/tests/contrib/test_django_allauth.py
@@ -0,0 +1,36 @@
+import pytest
+from rest_framework import permissions
+from rest_framework.views import APIView
+
+from drf_spectacular.utils import extend_schema
+from tests import generate_schema
+
+try:
+    from allauth import __version__ as allauth_version
+    from allauth.headless.contrib.rest_framework.authentication import XSessionTokenAuthentication
+except ImportError:
+    XSessionTokenAuthentication = object
+    allauth_version = "0.0.0"
+
+
+@pytest.mark.contrib('django_allauth')
+@pytest.mark.skipif(allauth_version < "0.65.4", reason='')
+def test_allauth_token_auth(no_warnings):
+
+    class XAPIView(APIView):
+        authentication_classes = [XSessionTokenAuthentication]
+        permission_classes = [permissions.IsAuthenticated]
+
+        @extend_schema(responses=int)
+        def get(self, request):
+            pass  # pragma: no cover
+
+    schema = generate_schema('x', view=XAPIView)
+    assert schema['components']['securitySchemes'] == {
+        'XSessionTokenAuth': {
+            "type": "apiKey",
+            "in": "header",
+            "name": "X-Session-Token",
+            "description": "X-Session-Token authentication",
+        }
+    }
diff --git a/tests/contrib/test_rest_auth.py b/tests/contrib/test_rest_auth.py
index 77179b66..6f965514 100644
--- a/tests/contrib/test_rest_auth.py
+++ b/tests/contrib/test_rest_auth.py
@@ -9,6 +9,13 @@
 from tests import assert_schema, generate_schema
 from tests.models import SimpleModel, SimpleSerializer
 
+try:
+    from allauth import __version__ as allauth_version
+    from dj_rest_auth.__version__ import __version__ as dj_rest_auth_version
+except ImportError:
+    dj_rest_auth_version = ""
+    allauth_version = ""
+
 transforms = [
     # User model first_name differences
     lambda x: re.sub(r'(first_name:\n *type: string\n *maxLength:) 30', r'\g<1> 150', x),
@@ -17,6 +24,7 @@
 ]
 
 
+@pytest.mark.skipif(dj_rest_auth_version < "5" and allauth_version >= "0.55.0", reason='')
 @pytest.mark.contrib('dj_rest_auth', 'allauth')
 @mock.patch('drf_spectacular.settings.spectacular_settings.SCHEMA_PATH_PREFIX', '')
 def test_rest_auth(no_warnings):
@@ -30,6 +38,7 @@ def test_rest_auth(no_warnings):
     )
 
 
+@pytest.mark.skipif(dj_rest_auth_version < "5" and allauth_version >= "0.55.0", reason='')
 @pytest.mark.contrib('dj_rest_auth', 'allauth', 'rest_framework_simplejwt')
 @mock.patch('drf_spectacular.settings.spectacular_settings.SCHEMA_PATH_PREFIX', '')
 @mock.patch('dj_rest_auth.app_settings.api_settings.USE_JWT', True)
