Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion chats/consumers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,15 @@ async def connect(self):
await self.channel_layer.group_add(
EventGroupType.GENERAL_EVENTS, self.channel_name
)
await self.accept()
# Confirm selected subprotocol so browser clients finish handshake.
subprotocol = None
if (
self.scope.get("subprotocols")
and len(self.scope["subprotocols"]) >= 1
):
subprotocol = self.scope["subprotocols"][0]

await self.accept(subprotocol=subprotocol)

async def disconnect(self, close_code):
"""User disconnected from websocket"""
Expand Down
3 changes: 3 additions & 0 deletions core/auth/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""
Authentication utilities for ASGI/WebSocket middleware.
"""
36 changes: 19 additions & 17 deletions chats/middleware.py → core/auth/middleware.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from urllib.parse import parse_qs

import jwt
from channels.db import database_sync_to_async
from django.conf import settings
Expand Down Expand Up @@ -97,33 +95,37 @@ def get_user(scope):

class TokenAuthMiddleware:
"""
Custom middleware that takes a token from the query string and authenticates via Django Rest Framework authtoken.
Custom middleware that takes a token from WebSocket subprotocols and authenticates via JWT.
"""

SUBPROTOCOL_KEYWORD = "Bearer"

def __init__(self, app):
# Store the ASGI application we were passed
self.app = app

async def __call__(self, scope, receive, send):
# Look up user from query string

# TODO: (you should also do things like
# checking if it is a valid user ID, or if scope["user" ] is already
# populated).

query_string = scope["query_string"].decode()
query_dict = parse_qs(query_string)
try:
token = query_dict["token"][0]
if token is None:
raise ValueError("Token is missing from headers")
# Extract token from Sec-WebSocket-Protocol header.
token = self._extract_token_from_subprotocol(scope.get("subprotocols", []))

if token:
scope["token"] = token
scope["user"] = await get_user(scope)
except (ValueError, KeyError, IndexError):
# Token is missing from query string
else:
from django.contrib.auth.models import AnonymousUser

scope["user"] = AnonymousUser()

return await self.app(scope, receive, send)

def _extract_token_from_subprotocol(self, subprotocols: list[str]) -> str | None:
"""
Expect subprotocols in the form ["Bearer", "<JWT>"].
"""
if not subprotocols:
return None

if len(subprotocols) >= 2 and subprotocols[0] == self.SUBPROTOCOL_KEYWORD:
return subprotocols[1]

return None
6 changes: 3 additions & 3 deletions procollab/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
# Ensure Django app registry is loaded before importing project routes.
django_asgi_app = get_asgi_application()

import chats.routing # noqa: E402
from chats.middleware import TokenAuthMiddleware # noqa: E402
from core.auth.middleware import TokenAuthMiddleware # noqa: E402
from procollab.websocket_routing import websocket_urlpatterns # noqa: E402

application = ProtocolTypeRouter(
{
"http": django_asgi_app,
"websocket": TokenAuthMiddleware(URLRouter(chats.routing.websocket_urlpatterns)),
"websocket": TokenAuthMiddleware(URLRouter(websocket_urlpatterns)),
}
)
4 changes: 4 additions & 0 deletions procollab/websocket_routing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from chats.routing import websocket_urlpatterns as chat_websocket_urlpatterns

websocket_urlpatterns = []
websocket_urlpatterns += chat_websocket_urlpatterns