diff --git a/cross_connect_client/controllers/cross_connect.py b/cross_connect_client/controllers/cross_connect.py index 2018b843a5..c86af26008 100644 --- a/cross_connect_client/controllers/cross_connect.py +++ b/cross_connect_client/controllers/cross_connect.py @@ -22,5 +22,5 @@ def cross_connect( if not server: raise UserError(_("Server not found")) - url = server._get_cross_connect_url(request.params.get("redirect_url")) + url = server._get_cross_connect_url(**params) return request.redirect(url, local=False) diff --git a/cross_connect_client/models/cross_connect_server.py b/cross_connect_client/models/cross_connect_server.py index 9543a5616b..ed2c72488e 100644 --- a/cross_connect_client/models/cross_connect_server.py +++ b/cross_connect_client/models/cross_connect_server.py @@ -2,6 +2,8 @@ # @author Florian Mounier # License AGPL-3.0 or later (http://www.gnu.org/licenses/agpl). +from urllib.parse import urlencode + import requests from odoo import _, api, fields, models @@ -98,7 +100,7 @@ def _request(self, method, url, headers=None, data=None): response.raise_for_status() return response.json() - def _get_cross_connect_url(self, redirect_url=None): + def _get_cross_connect_url(self, **params): self.ensure_one() groups = self.env.user.groups_id & self.group_ids if not groups: @@ -115,8 +117,6 @@ def _get_cross_connect_url(self, redirect_url=None): "lang": self.env.user.lang, "groups": [group.cross_connect_server_group_id for group in groups], } - if redirect_url: - data["redirect_url"] = redirect_url response = self._request("POST", "/access", data=data) client_id = response.get("client_id") @@ -124,7 +124,11 @@ def _get_cross_connect_url(self, redirect_url=None): if not token: raise UserError(_("Missing token")) - return self._absolute_url_for(f"login/{client_id}/{token}") + url = f"login/{client_id}/{token}" + if params: + url += "?" + urlencode(params) + + return self._absolute_url_for(url) def _sync_groups(self): self.ensure_one() diff --git a/cross_connect_server/models/cross_connect_client.py b/cross_connect_server/models/cross_connect_client.py index 42a615a0b6..739aad94f3 100644 --- a/cross_connect_server/models/cross_connect_client.py +++ b/cross_connect_server/models/cross_connect_client.py @@ -1,6 +1,7 @@ # Copyright 2024 Akretion (http://www.akretion.com). # @author Florian Mounier # License AGPL-3.0 or later (http://www.gnu.org/licenses/agpl). +import re from datetime import datetime, timedelta, timezone from secrets import token_urlsafe @@ -34,6 +35,14 @@ class CrossConnectClient(models.Model): related="endpoint_id.cross_connect_allowed_group_ids", ) + bypass_user_mail_re = fields.Char( + string="Bypass Users Email Regexes", + help=( + "If set, users with an email matching one of these regex will bypass " + "the token user/login creation. The regexes are comma separated." + ), + ) + group_ids = fields.Many2many( "res.groups", string="Groups", @@ -64,6 +73,12 @@ def _compute_user_count(self): record.user_count = len(record.user_ids) def _request_access(self, access_request): + if self.bypass_user_mail_re and any( + re.search(mail_re.strip(), access_request.email) + for mail_re in self.bypass_user_mail_re.split(",") + ): + return "bypass" + # check groups groups = self.env["res.groups"].browse(access_request.groups) if groups - self.group_ids or not groups.exists(): @@ -72,6 +87,13 @@ def _request_access(self, access_request): user = self.user_ids.filtered( lambda u: u.cross_connect_client_user_id == access_request.id ) + + # Fallback to default lang if not installed + if access_request.lang not in [ + code for code, _name in self.env["res.lang"].get_installed() + ]: + access_request.lang = "en_US" + vals = { "login": f"{self.id}_{access_request.id}_{access_request.login}", "email": access_request.email, @@ -94,7 +116,6 @@ def _request_access(self, access_request): "exp": datetime.now(tz=timezone.utc) + timedelta(minutes=2), "aud": str(self.id), "id": user.id, - "redirect_url": access_request.redirect_url or "/web", }, self.endpoint_id.cross_connect_secret_key, algorithm="HS256", @@ -117,4 +138,10 @@ def _log_from_token(self, token): if not user: raise AccessDenied(_("Invalid Token")) - return user, obj["redirect_url"] + return user + + def _get_final_redirect_url(self, **params): + """Get the final redirect url after login. + Override this method to customize the local landing action. + """ + return "/web" diff --git a/cross_connect_server/routers/cross_connect.py b/cross_connect_server/routers/cross_connect.py index 94fe4ee83c..0341c9d681 100644 --- a/cross_connect_server/routers/cross_connect.py +++ b/cross_connect_server/routers/cross_connect.py @@ -4,7 +4,7 @@ from typing import Annotated -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Request from fastapi.responses import RedirectResponse from odoo import _, api @@ -49,12 +49,19 @@ async def login( client_id: int, token: str, env: Annotated[api.Environment, Depends(odoo_env)], + request: Request, ) -> RedirectResponse: """Log user and redirect to odoo index.""" cross_connect_client = env["cross.connect.client"].sudo().browse(client_id) if not cross_connect_client: raise MissingError(_("Client not found")) - user, redirect_url = cross_connect_client.sudo()._log_from_token(token) + params = request.query_params + if token == "bypass": + return RedirectResponse( + url=cross_connect_client._get_final_redirect_url(bypass=True, **params) + ) + + user = cross_connect_client.sudo()._log_from_token(token) user = user.with_user(user) user._update_last_login() env = env(user=user.id) @@ -68,7 +75,9 @@ async def login( session.session_token = user._compute_session_token(session.sid) root.session_store.save(session) # Redirect after login - response = RedirectResponse(url=redirect_url) + response = RedirectResponse( + url=cross_connect_client._get_final_redirect_url(**params) + ) response.set_cookie( "session_id", session.sid, diff --git a/cross_connect_server/schemas.py b/cross_connect_server/schemas.py index 04e59628f7..b1295e3729 100644 --- a/cross_connect_server/schemas.py +++ b/cross_connect_server/schemas.py @@ -35,7 +35,6 @@ class AccessRequest(StrictExtendableBaseModel, extra="ignore"): email: str lang: str groups: list[int] - redirect_url: str = None class AccessResponse(StrictExtendableBaseModel): diff --git a/cross_connect_server/views/fastapi_endpoint_views.xml b/cross_connect_server/views/fastapi_endpoint_views.xml index 41f2c4660d..06557c82ba 100644 --- a/cross_connect_server/views/fastapi_endpoint_views.xml +++ b/cross_connect_server/views/fastapi_endpoint_views.xml @@ -36,6 +36,7 @@ widget="many2many_tags" options="{'no_create': True}" /> +