diff --git a/.dns/dns_api.py b/.dns/dns_api.py index a53a469da..7fee31ec9 100644 --- a/.dns/dns_api.py +++ b/.dns/dns_api.py @@ -7,7 +7,7 @@ import logging import os import re -import subprocess +import subprocess # noqa: S404 from collections import defaultdict from dataclasses import dataclass from enum import StrEnum @@ -189,7 +189,7 @@ class BindDNSServerManager: """Bind9 DNS server manager.""" @staticmethod - def _get_zone_obj_by_zone_name(zone_name) -> dns.zone.Zone: + def _get_zone_obj_by_zone_name(zone_name: str) -> dns.zone.Zone: """Get DNS zone object by zone name. Algorithm: @@ -582,7 +582,7 @@ def restart(self) -> None: Algorithm: 1. Call rndc reconfig. """ - subprocess.run( # noqa: S603 + subprocess.run( [ "/usr/sbin/rndc", "reconfig", @@ -846,8 +846,8 @@ def update_record( self, old_record: DNSRecord, new_record: DNSRecord, - record_type, - zone_name, + record_type: DNSRecordType, + zone_name: str, ) -> None: """Update a record in a zone (value or TTL). @@ -998,7 +998,11 @@ def get_server_settings() -> list[DNSServerParam]: async def get_dns_manager() -> type[BindDNSServerManager]: - """Get DNS server manager client.""" + """Get DNS server manager client. + + Returns: + BindDNSServerManager: dns manager. + """ return BindDNSServerManager() @@ -1043,7 +1047,11 @@ def delete_zone( async def get_all_records_by_zone( dns_manager: Annotated[BindDNSServerManager, Depends(get_dns_manager)], ) -> list[DNSZone]: - """Get all DNS records grouped by zone.""" + """Get all DNS records grouped by zone. + + Returns: + list[DNSZone]: List of DNSZone objects with records. + """ return dns_manager.get_all_records() @@ -1051,7 +1059,11 @@ async def get_all_records_by_zone( async def get_forward_zones( dns_manager: Annotated[BindDNSServerManager, Depends(get_dns_manager)], ) -> list[DNSForwardZone]: - """Get all forward DNS zones.""" + """Get all forward DNS zones. + + Returns: + list[DNSForwardZone]: List of DNSForwardZone objects. + """ return await dns_manager.get_forward_zones() @@ -1141,7 +1153,11 @@ def update_dns_server_settings( async def get_server_settings( dns_manager: Annotated[BindDNSServerManager, Depends(get_dns_manager)], ) -> list[DNSServerParam]: - """Get list of modifiable server settings.""" + """Get list of modifiable server settings. + + Returns: + list[DNSServerParam]: List of server parameters. + """ return dns_manager.get_server_settings() @@ -1155,7 +1171,11 @@ def setup_server( def create_app() -> FastAPI: - """Create FastAPI app.""" + """Create FastAPI app. + + Returns: + FastAPI: FastAPI application instance. + """ app = FastAPI( name="DNSServerManager", title="DNSServerManager", diff --git a/.docker/lint.Dockerfile b/.docker/lint.Dockerfile index fab799df3..11405d13e 100644 --- a/.docker/lint.Dockerfile +++ b/.docker/lint.Dockerfile @@ -21,7 +21,7 @@ RUN --mount=type=cache,target=$POETRY_CACHE_DIR poetry install --with linters -- # The runtime image, used to just run the code provided its virtual environment FROM python:3.12.6-slim-bookworm AS runtime -WORKDIR /app +WORKDIR /md RUN set -eux; ENV VIRTUAL_ENV=/venvs/.venv \ @@ -31,5 +31,7 @@ ENV VIRTUAL_ENV=/venvs/.venv \ COPY --from=builder ${VIRTUAL_ENV} ${VIRTUAL_ENV} -COPY app /app -COPY pyproject.toml ./ \ No newline at end of file +COPY app ./app +COPY tests ./tests +COPY .kerberos ./.kerberos +COPY pyproject.toml ./ diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 01b92ef8b..f796e9727 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -26,7 +26,7 @@ jobs: - name: Run linters env: NEW_TAG: linter - run: docker run $NEW_TAG ruff check --output-format=github . + run: docker run $NEW_TAG ruff check --output-format=github . --preview ruff_format: runs-on: ubuntu-latest @@ -46,7 +46,7 @@ jobs: - name: Run linters env: NEW_TAG: linter - run: docker run $NEW_TAG ruff format --check + run: docker run $NEW_TAG ruff format --check --preview mypy: runs-on: ubuntu-latest @@ -86,4 +86,4 @@ jobs: - name: Run tests env: TAG: tests - run: docker compose -f docker-compose.remote.test.yml up --no-log-prefix --attach md-test --exit-code-from md-test \ No newline at end of file + run: docker compose -f docker-compose.remote.test.yml up --no-log-prefix --attach md-test --exit-code-from md-test diff --git a/.kerberos/config_server.py b/.kerberos/config_server.py index 28ab8c66d..da906fe9b 100644 --- a/.kerberos/config_server.py +++ b/.kerberos/config_server.py @@ -99,70 +99,57 @@ async def add_princ( ) -> None: """Create principal. - :param str name: principal - :param str | None password: if empty - uses randkey. + Args: + name (str): principal name + password (str | None): password, if empty - uses randkey. + **dbargs: database arguments """ @abstractmethod async def get_princ(self, name: str) -> Principal | None: - """Get principal. - - :param str name: principal - :return kadmin.Principal: Principal - """ + """Get principal.""" @abstractmethod async def change_password(self, name: str, new_password: str) -> None: - """Chanage principal's password. - - :param str name: principal - :param str new_password: ... - """ + """Change principal's password.""" @abstractmethod - async def create_or_update_princ_pw(self, name: str, new_password) -> None: + async def create_or_update_princ_pw( + self, + name: str, + new_password: str, + ) -> None: """Create new principal or update password. - :param str name: principal - :param _type_ new_password: pw + Args: + name (str): principal name + new_password (str): password """ @abstractmethod async def del_princ(self, name: str) -> None: - """Delete principal by name. - - :param str name: principal - """ + """Delete principal by name.""" @abstractmethod async def rename_princ(self, name: str, new_name: str) -> None: - """Rename principal. - - :param str name: original name - :param str new_name: new name - """ + """Rename principal.""" @abstractmethod async def ktadd(self, names: list[str], fn: str) -> None: """Create or write to keytab. - :param str name: principal - :param str fn: filename + Args: + names (list[str]): principal names + fn (str): file name """ @abstractmethod async def lock_princ(self, name: str, **dbargs) -> None: - """Lock principal. - - :param str name: principal - """ + """Lock principal.""" @abstractmethod async def force_pw_principal(self, name: str, **dbargs) -> None: - """Lock principal. - - :param str name: principal - """ + """Force password principal.""" class KAdminLocalManager(AbstractKRBManager): @@ -175,7 +162,11 @@ def __init__(self, loop: asyncio.AbstractEventLoop | None = None) -> None: self.loop = loop or asyncio.get_running_loop() async def connect(self) -> Self: - """Create threadpool for kadmin client.""" + """Create threadpool for kadmin client. + + Returns: + KAdminLocalManager: + """ self.pool = ThreadPoolExecutor(max_workers=500).__enter__() self.client = await asyncio.wait_for(self._init_client(), 40) return self @@ -195,11 +186,21 @@ async def __aexit__( exc: BaseException | None, tb: TracebackType | None, ) -> None: - """Destroy threadpool.""" + """Destroy threadpool. + + Args: + exc_type (type[BaseException] | None): exception type + exc (BaseException | None): exception + tb (TracebackType | None): traceback + """ await self.disconnect() async def _init_client(self) -> KAdminProtocol: - """Init kadmin local connection.""" + """Init kadmin local connection. + + Returns: + KAdminProtocol: client of kadmin.KAdmin + """ return await self.loop.run_in_executor(self.pool, kadmv.local) async def add_princ( @@ -210,8 +211,10 @@ async def add_princ( ) -> None: """Create principal. - :param str name: principal - :param str | None password: if empty - uses randkey. + Args: + name (str): principal name + password (str): password, if empty - uses randkey. + **dbargs: database arguments """ await self.loop.run_in_executor( self.pool, @@ -243,18 +246,14 @@ async def _get_raw_principal(self, name: str) -> PrincipalProtocol: async def get_princ(self, name: str) -> Principal: """Get principal. - :param str name: principal - :return kadmin.Principal: Principal + Returns: + Principal: Principal kadmin object """ principal = await self._get_raw_principal(name) return Principal.model_validate(principal, from_attributes=True) async def change_password(self, name: str, new_password: str) -> None: - """Chanage principal's password. - - :param str name: principal - :param str new_password: ... - """ + """Chanage principal's password.""" princ = await self._get_raw_principal(name) await self.loop.run_in_executor( self.pool, @@ -262,11 +261,16 @@ async def change_password(self, name: str, new_password: str) -> None: new_password, ) - async def create_or_update_princ_pw(self, name: str, new_password) -> None: - """Create new principal or update password. + async def create_or_update_princ_pw( + self, + name: str, + new_password: str, + ) -> None: + """Create new or update password principal. - :param str name: principal - :param _type_ new_password: ... + Args: + name (str): principal name + new_password (str): password """ try: await self.change_password(name, new_password) @@ -274,18 +278,11 @@ async def create_or_update_princ_pw(self, name: str, new_password) -> None: await self.add_princ(name, new_password) async def del_princ(self, name: str) -> None: - """Delete principal by name. - - :param str name: principal - """ + """Delete principal by name.""" await self.loop.run_in_executor(self.pool, self.client.delprinc, name) async def rename_princ(self, name: str, new_name: str) -> None: - """Rename principal. - - :param str name: original name - :param str new_name: new name - """ + """Rename principal.""" await self.loop.run_in_executor( self.pool, self.client.rename_principal, @@ -296,9 +293,12 @@ async def rename_princ(self, name: str, new_name: str) -> None: async def ktadd(self, names: list[str], fn: str) -> None: """Create or write to keytab. - :param str name: principal - :param str fn: filename - :raises self.PrincipalNotFoundError: on not found princ + Args: + names (list[str]): principal names + fn (str): file name + + Raises: + PrincipalNotFoundError: Principal not found """ principals = [await self._get_raw_principal(name) for name in names] if not all(principals): @@ -308,19 +308,13 @@ async def ktadd(self, names: list[str], fn: str) -> None: await self.loop.run_in_executor(self.pool, princ.ktadd, fn) async def lock_princ(self, name: str, **dbargs) -> None: - """Lock princ. - - :param str name: upn - """ + """Lock princ.""" princ = await self._get_raw_principal(name) princ.expire = "Now" await self.loop.run_in_executor(self.pool, princ.commit) async def force_pw_principal(self, name: str, **dbargs) -> None: - """Lock princ. - - :param str name: upn - """ + """Force password principal.""" princ = await self._get_raw_principal(name) princ.pwexpire = "Now" await self.loop.run_in_executor(self.pool, princ.commit) @@ -328,7 +322,11 @@ async def force_pw_principal(self, name: str, **dbargs) -> None: @asynccontextmanager async def kadmin_lifespan(app: FastAPI) -> AsyncIterator[None]: - """Create kadmin instance.""" + """Create kadmin instance. + + Yields: + AsyncIterator[None]: Async iterator + """ loop = asyncio.get_running_loop() async def try_set_kadmin(app: FastAPI) -> None: @@ -343,7 +341,7 @@ async def try_set_kadmin(app: FastAPI) -> None: logging.info("Successfully connected to kadmin local") return - loop.create_task(try_set_kadmin(app)) + await loop.create_task(try_set_kadmin(app)) yield if kadmind := getattr(app.state, "kadmind", None): await kadmind.disconnect() @@ -352,12 +350,24 @@ async def try_set_kadmin(app: FastAPI) -> None: def get_kadmin() -> KAdminLocalManager: - """Stub.""" + """Stub. + + Raises: + NotImplementedError: NotImplementedError + """ raise NotImplementedError def handle_db_error(request: Request, exc: BaseException): # noqa: ARG001 - """Handle duplicate.""" + """Handle duplicate. + + Args: + request (Request): request + exc (BaseException): exception + + Raises: + HTTPException: Database Error + """ raise HTTPException( status.HTTP_424_FAILED_DEPENDENCY, detail="Database Error", @@ -365,7 +375,15 @@ def handle_db_error(request: Request, exc: BaseException): # noqa: ARG001 def handle_duplicate(request: Request, exc: BaseException): # noqa: ARG001 - """Handle duplicate.""" + """Handle duplicate. + + Args: + request (Request): request + exc (BaseException): exception + + Raises: + HTTPException: Principal already exists + """ raise HTTPException( status.HTTP_409_CONFLICT, detail="Principal already exists", @@ -373,7 +391,15 @@ def handle_duplicate(request: Request, exc: BaseException): # noqa: ARG001 def handle_not_found(request: Request, exc: BaseException): # noqa: ARG001 - """Handle duplicate.""" + """Handle duplicate. + + Args: + request (Request): request + exc (BaseException): exception + + Raises: + HTTPException: Principal does not exist + """ raise HTTPException( status.HTTP_404_NOT_FOUND, detail="Principal does not exist", @@ -391,8 +417,9 @@ def write_configs( ) -> None: """Write two config files, strings are: hex bytes. - :param Annotated[str, Body krb5_config: krb5 hex bytes format config - :param Annotated[str, Body kdc_config: kdc hex bytes format config + Args: + krb5_config (str): krb5 hex bytes format config. + kdc_config (str): kdc hex bytes format config. """ with open("/etc/krb5.conf", "wb") as f: f.write(bytes.fromhex(krb5_config)) @@ -403,7 +430,14 @@ def write_configs( @setup_router.post("/stash", status_code=201) async def run_setup_stash(schema: ConfigSchema) -> None: - """Set up stash file.""" + """Set up stash file. + + Args: + schema (ConfigSchema): Configuration schema for stash setup. + + Raises: + HTTPException: Failed stash + """ proc = await asyncio.create_subprocess_exec( "kdb5_ldap_util", "-D", @@ -437,8 +471,11 @@ async def run_setup_stash(schema: ConfigSchema) -> None: async def run_setup_subtree(schema: ConfigSchema) -> None: """Set up subtree in ldap. - :param ConfigSchema schema: _description_ - :raises HTTPException: _description_ + Args: + schema (ConfigSchema): Configuration schema for subtree setup. + + Raises: + HTTPException: If setup fails. """ create_proc = await asyncio.create_subprocess_exec( "kdb5_ldap_util", @@ -492,9 +529,10 @@ async def add_princ( ) -> None: """Add principal. - :param Annotated[AbstractKRBManager, Depends kadmin: kadmin abstract - :param Annotated[str, Body name: principal name - :param Annotated[str, Body password: principal password + Args: + kadmin (AbstractKRBManager): Kadmin abstract manager. + name (str): Principal name. + password (str | None): Principal password. """ await kadmin.add_princ(name, password) @@ -504,11 +542,14 @@ async def get_princ( kadmin: Annotated[AbstractKRBManager, Depends(get_kadmin)], name: str, ) -> Principal: - """Add principal. + """Get principal. + + Args: + kadmin (AbstractKRBManager): Kadmin abstract manager. + name (str): Principal name. - :param Annotated[AbstractKRBManager, Depends kadmin: kadmin abstract - :param Annotated[str, Body name: principal name - :param Annotated[str, Body password: principal password + Returns: + Principal: Principal object. """ return await kadmin.get_princ(name) @@ -518,11 +559,11 @@ async def del_princ( kadmin: Annotated[AbstractKRBManager, Depends(get_kadmin)], name: str, ) -> None: - """Add principal. + """Delete principal. - :param Annotated[AbstractKRBManager, Depends kadmin: kadmin abstract - :param Annotated[str, Body name: principal name - :param Annotated[str, Body password: principal password + Args: + kadmin (AbstractKRBManager): Kadmin abstract manager. + name (str): Principal name. """ await kadmin.del_princ(name) @@ -533,11 +574,12 @@ async def change_princ_password( name: Annotated[str, Body()], password: Annotated[str, Body()], ) -> None: - """Change princ pw principal. + """Change principal password. - :param Annotated[AbstractKRBManager, Depends kadmin: kadmin abstract - :param Annotated[str, Body name: principal name - :param Annotated[str, Body password: principal password + Args: + kadmin (AbstractKRBManager): Kadmin abstract manager. + name (str): Principal name. + password (str): Principal password. """ await kadmin.change_password(name, password) @@ -552,11 +594,12 @@ async def create_or_update_princ_password( name: Annotated[str, Body()], password: Annotated[str, Body()], ) -> None: - """Change princ pw principal or create with new. + """Change principal password or create with new. - :param Annotated[AbstractKRBManager, Depends kadmin: kadmin abstract - :param Annotated[str, Body name: principal name - :param Annotated[str, Body password: principal password + Args: + kadmin (AbstractKRBManager): Kadmin abstract manager. + name (str): Principal name. + password (str): Principal password. """ await kadmin.create_or_update_princ_pw(name, password) @@ -573,11 +616,11 @@ async def rename_princ( ) -> None: """Rename principal. - :param Annotated[AbstractKRBManager, Depends kadmin: kadmin abstract - :param Annotated[str, Body name: principal name - :param Annotated[str, Body new_name: principal new name + Args: + kadmin (AbstractKRBManager): Kadmin abstract manager. + name (str): Principal name. + new_name (str): Principal new name. """ - """""" await kadmin.rename_princ(name, new_name) @@ -588,9 +631,12 @@ async def ktadd( ) -> FileResponse: """Ktadd principal. - :param Annotated[AbstractKRBManager, Depends kadmin: kadmin abstract - :param Annotated[str, Body name: principal name - :param Annotated[str, Body password: principal password + Args: + kadmin (AbstractKRBManager): Kadmin abstract manager. + names (list[str]): List of principal names. + + Returns: + FileResponse: Keytab file response. """ filename = os.path.join(gettempdir(), str(uuid.uuid1())) await kadmin.ktadd(names, filename) @@ -608,8 +654,9 @@ async def lock_princ( ) -> None: """Lock principal. - :param Annotated[AbstractKRBManager, Depends kadmin: kadmin abstract - :param Annotated[str, Body name: principal name + Args: + kadmin (AbstractKRBManager): Kadmin abstract manager. + name (str): Principal name. """ await kadmin.lock_princ(name) @@ -619,10 +666,11 @@ async def force_pw_reset_principal( kadmin: Annotated[AbstractKRBManager, Depends(get_kadmin)], name: Annotated[str, Body(embed=True)], ) -> None: - """Mark princ as pw expired. + """Mark principal as password expired. - :param Annotated[AbstractKRBManager, Depends kadmin: kadmin abstract - :param Annotated[str, Body name: principal name + Args: + kadmin (AbstractKRBManager): Kadmin abstract manager. + name (str): Principal name. """ await kadmin.force_pw_principal(name) @@ -633,6 +681,9 @@ def get_status(request: Request) -> bool: true - is ready false - not set + + Returns: + bool: True if kadmin is ready, False otherwise. """ kadmind = getattr(request.app.state, "kadmind", None) @@ -640,7 +691,11 @@ def get_status(request: Request) -> bool: def create_app() -> FastAPI: - """Create FastAPI app.""" + """Create FastAPI app. + + Returns: + FastAPI: web app + """ app = FastAPI( name="KadminMultiDirectory", title="KadminMultiDirectory", diff --git a/Makefile b/Makefile index 134485aac..b105d9e9e 100644 --- a/Makefile +++ b/Makefile @@ -3,8 +3,8 @@ help: ## show help message @awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[36m\033[0m\n"} /^[$$()% a-zA-Z_-]+:.*?##/ { printf " \033[36m%-15s\033[0m %s\n", $$1, $$2 } /^##@/ { printf "\n\033[1m%s\033[0m\n", substr($$0, 5) } ' $(MAKEFILE_LIST) before_pr: - ruff format . - ruff check . --fix --unsafe-fixes + ruff check . --preview --fix --unsafe-fixes + ruff format . --preview mypy . build: ## build app and manually generate self-signed cert diff --git a/app/alembic/env.py b/app/alembic/env.py index 07f1b6613..e8036ec41 100644 --- a/app/alembic/env.py +++ b/app/alembic/env.py @@ -4,7 +4,7 @@ from logging.config import fileConfig from alembic import context -from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.ext.asyncio import AsyncConnection, create_async_engine from config import Settings from models import Base @@ -21,7 +21,7 @@ target_metadata = Base.metadata -def run_sync_migrations(connection): +def run_sync_migrations(connection: AsyncConnection): """Run sync migrations.""" context.configure( connection=connection, @@ -34,7 +34,7 @@ def run_sync_migrations(connection): context.run_migrations() -async def run_async_migrations(settings): +async def run_async_migrations(settings: Settings): """Run async migrations.""" engine = create_async_engine(str(settings.POSTGRES_URI)) diff --git a/app/alembic/versions/275222846605_initial_ldap_schema.py b/app/alembic/versions/275222846605_initial_ldap_schema.py index 80e466a0b..2416af494 100644 --- a/app/alembic/versions/275222846605_initial_ldap_schema.py +++ b/app/alembic/versions/275222846605_initial_ldap_schema.py @@ -11,7 +11,7 @@ import sqlalchemy as sa from alembic import op from ldap3.protocol.schemas.ad2012R2 import ad_2012_r2_schema -from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession from sqlalchemy.orm import Session from extra.alembic_utils import temporary_stub_entity_type_name @@ -180,7 +180,7 @@ def upgrade() -> None: session.commit() # NOTE: Load objectClasses into the database - async def _create_object_classes(connection): + async def _create_object_classes(connection: AsyncConnection): session = AsyncSession(bind=connection) await session.begin() @@ -240,7 +240,7 @@ async def _create_object_classes(connection): op.run_async(_create_object_classes) - async def _create_attribute_types(connection): + async def _create_attribute_types(connection: AsyncConnection): session = AsyncSession(bind=connection) await session.begin() @@ -262,7 +262,7 @@ async def _create_attribute_types(connection): op.run_async(_create_attribute_types) - async def _modify_object_classes(connection): + async def _modify_object_classes(connection: AsyncConnection): session = AsyncSession(bind=connection) await session.begin() diff --git a/app/alembic/versions/fafc3d0b11ec_.py b/app/alembic/versions/fafc3d0b11ec_.py index 1e4484875..282c572c6 100644 --- a/app/alembic/versions/fafc3d0b11ec_.py +++ b/app/alembic/versions/fafc3d0b11ec_.py @@ -9,7 +9,7 @@ from alembic import op from sqlalchemy import delete, exists, select from sqlalchemy.exc import DBAPIError, IntegrityError -from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession from extra.alembic_utils import temporary_stub_entity_type_name from ldap_protocol.policies.access_policy import create_access_policy @@ -31,7 +31,9 @@ def upgrade() -> None: """Upgrade.""" - async def _create_readonly_grp_and_plcy(connection) -> None: + async def _create_readonly_grp_and_plcy( + connection: AsyncConnection, + ) -> None: session = AsyncSession(bind=connection) await session.begin() base_dn_list = await get_base_directories(session) @@ -84,7 +86,9 @@ async def _create_readonly_grp_and_plcy(connection) -> None: def downgrade() -> None: """Downgrade.""" - async def _delete_readonly_grp_and_plcy(connection) -> None: + async def _delete_readonly_grp_and_plcy( + connection: AsyncConnection, + ) -> None: session = AsyncSession(bind=connection) await session.begin() base_dn_list = await get_base_directories(session) diff --git a/app/api/__init__.py b/app/api/__init__.py index 8dfd02557..919b0c275 100644 --- a/app/api/__init__.py +++ b/app/api/__init__.py @@ -17,16 +17,15 @@ from .shadow.router import shadow_router __all__ = [ - "auth_router", - "session_router", - "network_router", - "mfa_router", - "pwd_router", "access_policy_router", - "ldap_schema_router", + "auth_router", "dns_router", - "krb5_router", "entry_router", + "krb5_router", + "ldap_schema_router", + "mfa_router", "network_router", + "pwd_router", + "session_router", "shadow_router", ] diff --git a/app/api/auth/oauth2.py b/app/api/auth/oauth2.py index 4f2e9671a..3a60ab3eb 100644 --- a/app/api/auth/oauth2.py +++ b/app/api/auth/oauth2.py @@ -38,10 +38,13 @@ async def authenticate_user( ) -> User | None: """Get user and verify password. - :param AsyncSession session: sa session - :param str username: any str - :param str password: any str - :return User | None: User model (pydantic) + Args: + session (AsyncSession): sa session + username (str): any str + password (str): any str + + Returns: + User | None: User model (pydantic) """ user = await get_user(session, username) @@ -68,15 +71,21 @@ async def get_current_user( request's cookies, verifies the session, and returns the user schema. Makes a rekey of the session if necessary. - :param FromDishka[Settings] settings: settings - :param FromDishka[AsyncSession] session: db session - :param FromDishka[SessionStorage] session_storage: session storage - :param Request request: request - :param Response response: response - :param Annotated[IPv4Address | IPv6Address] ip: ip address - :param Annotated[str] user_agent: user agent - :return UserSchema: user schema - """ + Args: + settings (FromDishka[Settings]): settings + session (FromDishka[AsyncSession]): db session + session_storage (FromDishka[SessionStorage]): session storage + request (Request): request + response (Response): response + ip (Annotated[IPv4Address | IPv6Address]): ip address + user_agent (Annotated[str]): user agent + + Returns: + UserSchema: user schema + + Raises: + _CREDENTIALS_EXCEPTION: creds not valid + """ # noqa: DOC502 session_key = request.cookies.get("id", "") try: user_id = await session_storage.get_user_id( diff --git a/app/api/auth/router.py b/app/api/auth/router.py index 986d29484..8cd173942 100644 --- a/app/api/auth/router.py +++ b/app/api/auth/router.py @@ -63,25 +63,28 @@ async def login( """Create session to cookies and storage. - **username**: username formats: - `DN`, `userPrincipalName`, `saMAccountName` + `DN`, `userPrincipalName`, `saMAccountName` - **password**: password \f - :param Annotated[OAuth2Form, Depends form: login form - :param FromDishka[AsyncSession] session: db - :param FromDishka[Settings] settings: app settings - :param FromDishka[MultifactorAPI] mfa: mfa api wrapper - :param FromDishka[SessionStorage] storage: session storage - :param Response response: FastAPI response - :param Annotated[IPv4Address | IPv6Address, Depends ip: client ip - :raises HTTPException: 401 if incorrect username or password - :raises HTTPException: 403 if user not part of domain admins - :raises HTTPException: 403 if user account is disabled - :raises HTTPException: 403 if user account is expired - :raises HTTPException: 403 if ip is not provided - :raises HTTPException: 403 if user not part of network policy - :raises HTTPException: 426 if mfa required - :return None: None + Args: + form (OAuth2Form): Login form with username and password. + session (FromDishka[AsyncSession]): Database session. + settings (FromDishka[Settings]): Application settings. + mfa (FromDishka[MultifactorAPI]): MFA API wrapper. + storage (FromDishka[SessionStorage]): Session storage. + response (Response): FastAPI response object. + ip (IPv4Address | IPv6Address): Client IP address. + user_agent (str): Client user agent string. + + Raises: + HTTPException: 401 if incorrect username or password + HTTPException: 403 if user not part of domain admins + HTTPException: 403 if user account is disabled + HTTPException: 403 if user account is expired + HTTPException: 403 if ip is not provided + HTTPException: 403 if user not part of network policy + HTTPException: 426 if mfa required """ user = await authenticate_user(session, form.username, form.password) @@ -150,7 +153,14 @@ async def login( async def users_me( user: Annotated[UserSchema, Depends(get_current_user)], ) -> UserSchema: - """Get current logged in user data.""" + """Get current logged in user data. + + Args: + user (UserSchema): Current user schema from dependency. + + Returns: + UserSchema: Current user data. + """ return user @@ -160,7 +170,13 @@ async def logout( storage: FromDishka[SessionStorage], user: Annotated[UserSchema, Depends(get_current_user)], ) -> None: - """Delete token cookies.""" + """Delete token cookies. + + Args: + response (Response): FastAPI response object. + storage (FromDishka[SessionStorage]): Session storage. + user (UserSchema): Current user schema from dependency. + """ response.delete_cookie("id", httponly=True) await storage.delete_user_session(user.session_id) @@ -182,14 +198,16 @@ async def password_reset( `userPrincipalName`, `saMAccountName` or `DN` - **new_password**: password to set \f - :param FromDishka[AsyncSession] session: db - :param FromDishka[AbstractKadmin] kadmin: kadmin api - :param Annotated[str, Body identity: reset target user - :param Annotated[str, Body new_password: new password for user - :raises HTTPException: 404 if user not found - :raises HTTPException: 422 if password not valid - :raises HTTPException: 424 if kerberos password update failed - :return None: None + Args: + identity (str): Reset target user identity. + new_password (str): New password for user. + session (FromDishka[AsyncSession]): Database session. + kadmin (FromDishka[AbstractKadmin]): Kadmin API instance. + + Raises: + HTTPException: 404 if user not found + HTTPException: 422 if password not valid + HTTPException: 424 if kerberos password update failed """ user = await get_user(session, identity) @@ -226,7 +244,8 @@ async def password_reset( async def check_setup(session: FromDishka[AsyncSession]) -> bool: """Check if initial setup needed. - True if setup already complete, False if setup is needed. + Returns: + bool: True if setup already complete, False if setup is needed. """ query = select(exists(Directory).where(Directory.parent_id.is_(None))) retval = await session.scalars(query) @@ -242,7 +261,17 @@ async def first_setup( request: SetupRequest, session: FromDishka[AsyncSession], ) -> None: - """Perform initial setup.""" + """Perform initial setup. + + Args: + request (SetupRequest): Setup request containing domain and user data. + session (FromDishka[AsyncSession]): Database session. + + Raises: + HTTPException: 422 if password policy validation fails + HTTPException: 423 if setup already performed + HTTPException: 424 if integrity error occurs during setup. + """ setup_already_performed = await session.scalar( select(Directory) .filter(Directory.parent_id.is_(None)) diff --git a/app/api/auth/router_mfa.py b/app/api/auth/router_mfa.py index 72ce39632..71efdcf6e 100644 --- a/app/api/auth/router_mfa.py +++ b/app/api/auth/router_mfa.py @@ -64,9 +64,12 @@ async def setup_mfa( """Set mfa credentials, rewrites if exists. \f - :param MFACreateRequest mfa: MuliFactor credentials - :param FromDishka[AsyncSession] session: db - :return bool: status + Args: + mfa (MFACreateRequest): MuliFactor credentials + session (FromDishka[AsyncSession]): db + + Returns: + bool: status """ async with session.begin_nested(): await session.execute( @@ -96,7 +99,12 @@ async def remove_mfa( session: FromDishka[AsyncSession], scope: Literal["ldap", "http"], ) -> None: - """Remove mfa credentials.""" + """Remove mfa credentials. + + Args: + session (FromDishka[AsyncSession]): Database session. + scope (Literal["ldap", "http"]): Scope of the credentials. + """ if scope == "http": keys = ["mfa_key", "mfa_secret"] else: @@ -117,7 +125,12 @@ async def get_mfa( """Get MFA creds. \f - :return MFAGetResponse: response. + Args: + mfa_creds (FromDishka[MFA_HTTP_Creds]): creds for http app. + mfa_creds_ldap (FromDishka[MFA_LDAP_Creds]): creds for ldap app. + + Returns: + MFAGetResponse: response. """ if not mfa_creds: mfa_creds = MFA_HTTP_Creds(Creds(None, None)) @@ -149,15 +162,20 @@ async def callback_mfa( Callback endpoint for MFA. \f - :param FromDishka[AsyncSession] session: db - :param FromDishka[SessionStorage] storage: session storage - :param FromDishka[Settings] settings: app settings - :param FromDishka[MFA_HTTP_Creds] mfa_creds: - creds for multifactor (http app) - :param Annotated[IPv4Address | IPv6Address, Depends ip: client ip - :param Annotated[str, Form access_token: token from multifactor callback - :raises HTTPException: if mfa not set up - :return RedirectResponse: on bypass or success + Args: + access_token (str): Token from multifactor callback. + session (FromDishka[AsyncSession]): db session. + storage (FromDishka[SessionStorage]): session storage. + settings (FromDishka[Settings]): app settings. + mfa_creds (FromDishka[MFA_HTTP_Creds]): creds for http app. + ip (IPv4Address | IPv6Address): Client IP address. + user_agent (str): Client user agent string. + + Raises: + HTTPException: if mfa not set up + + Returns: + RedirectResponse: on bypass or success """ if not mfa_creds: raise HTTPException(status.HTTP_404_NOT_FOUND) @@ -206,20 +224,26 @@ async def two_factor_protocol( """Initiate two factor protocol with app. \f - :param Annotated[OAuth2Form, Depends form: password form - :param Request request: FastAPI request - :param FromDishka[AsyncSession] session: db - :param FromDishka[MultifactorAPI] api: wrapper for MFA DAO - :param FromDishka[Settings] settings: app settings - :param FromDishka[SessionStorage] storage: redis storage - :param Response response: FastAPI response - :param Annotated[IPv4Address | IPv6Address, Depends ip: client ip - :raises HTTPException: Missing API credentials - :raises HTTPException: Invalid credentials - :raises HTTPException: network policy violation - :raises HTTPException: Multifactor error - :return MFAChallengeResponse: - {'status': 'pending', 'message': https://example.com}. + Args: + form (Annotated[OAuth2Form, Depends]): Password form containing\ + username and password. + request (Request): FastAPI request. + session (FromDishka[AsyncSession]): Database session. + api (FromDishka[MultifactorAPI]): Wrapper for MFA DAO. + settings (FromDishka[Settings]): App settings. + storage (FromDishka[SessionStorage]): Redis storage. + response (Response): FastAPI response. + ip (Annotated[IPv4Address | IPv6Address, Depends]): Client IP address. + user_agent (Annotated[str, Depends]): Client user agent string. + + Raises: + HTTPException: Missing API credentials + HTTPException: Invalid credentials + HTTPException: network policy violation + HTTPException: Multifactor error + + Returns: + MFAChallengeResponse: Response containing status and message. """ if not api: raise HTTPException( diff --git a/app/api/auth/router_pwd_policy.py b/app/api/auth/router_pwd_policy.py index 3fcbc70c6..d03db21bb 100644 --- a/app/api/auth/router_pwd_policy.py +++ b/app/api/auth/router_pwd_policy.py @@ -25,7 +25,15 @@ async def create_policy( policy: PasswordPolicySchema, session: FromDishka[AsyncSession], ) -> PasswordPolicySchema: - """Create current policy setting.""" + """Create current policy setting. + + Args: + policy (PasswordPolicySchema): Password policy schema to create. + session (AsyncSession): Database session. + + Returns: + PasswordPolicySchema: Created password policy schema. + """ return await policy.create_policy_settings(session) @@ -33,7 +41,11 @@ async def create_policy( async def get_policy( session: FromDishka[AsyncSession], ) -> PasswordPolicySchema: - """Get current policy setting.""" + """Get current policy setting. + + Returns: + PasswordPolicySchema: Current password policy schema. + """ return await PasswordPolicySchema.get_policy_settings(session) @@ -42,7 +54,15 @@ async def update_policy( policy: PasswordPolicySchema, session: FromDishka[AsyncSession], ) -> PasswordPolicySchema: - """Update current policy setting.""" + """Update current policy setting. + + Args: + policy (PasswordPolicySchema): Password policy schema to update. + session (AsyncSession): Database session. + + Returns: + PasswordPolicySchema: Updated password policy schema. + """ await policy.update_policy_settings(session) return policy @@ -51,5 +71,9 @@ async def update_policy( async def reset_policy( session: FromDishka[AsyncSession], ) -> PasswordPolicySchema: - """Reset current policy setting.""" + """Reset current policy setting. + + Returns: + PasswordPolicySchema: Reset password policy schema. + """ return await PasswordPolicySchema.delete_policy_settings(session) diff --git a/app/api/auth/schema.py b/app/api/auth/schema.py index e8b2f3321..6aed7961c 100644 --- a/app/api/auth/schema.py +++ b/app/api/auth/schema.py @@ -42,7 +42,12 @@ def __init__( username: str = Form(), password: str = Form(), ): - """Initialize form.""" + """Initialize form. + + Args: + username (str): username + password (str): password + """ self.username = username self.password = password @@ -66,7 +71,16 @@ class SetupRequest(BaseModel): password: str @field_validator("domain") - def validate_domain(cls, v: str) -> str: # noqa + @classmethod + def validate_domain(cls, v: str) -> str: + """Validate domain. + + Returns: + str: Validated domain string. + + Raises: + ValueError: If the domain is invalid. + """ if re.match(_domain_re, v) is None: raise ValueError("Invalid domain value") return v.lower() @@ -82,6 +96,11 @@ class MFACreateRequest(BaseModel): @computed_field # type: ignore @property def key_name(self) -> str: + """Get key name. + + Returns: + str: key name + """ if self.is_ldap_scope: return "mfa_key_ldap" @@ -90,6 +109,11 @@ def key_name(self) -> str: @computed_field # type: ignore @property def secret_name(self) -> str: + """Get secret name. + + Returns: + str: secret name + """ if self.is_ldap_scope: return "mfa_secret_ldap" diff --git a/app/api/auth/session_router.py b/app/api/auth/session_router.py index f6c0d4403..07c999522 100644 --- a/app/api/auth/session_router.py +++ b/app/api/auth/session_router.py @@ -26,7 +26,19 @@ async def get_user_session( storage: FromDishka[SessionStorage], session: FromDishka[AsyncSession], ) -> dict[str, SessionContentSchema]: - """Get user (upn, san or dn) data.""" + """Get user session data by UPN, SAN, or DN. + + Args: + upn (str): User principal name, SAN, or DN. + storage (SessionStorage): Session storage dependency. + session (AsyncSession): Database session. + + Returns: + dict[str, SessionContentSchema]: Dictionary of session data for user. + + Raises: + HTTPException: If user is not found. + """ user = await get_user(session, upn) if not user: raise HTTPException(status.HTTP_404_NOT_FOUND, "User not found.") @@ -39,7 +51,16 @@ async def delete_user_sessions( storage: FromDishka[SessionStorage], session: FromDishka[AsyncSession], ) -> None: - """Delete user (upn, san or dn) data.""" + """Delete all sessions for a user by UPN, SAN, or DN. + + Args: + upn (str): User principal name, SAN, or DN. + storage (SessionStorage): Session storage dependency. + session (AsyncSession): Database session. + + Raises: + HTTPException: If user is not found. + """ user = await get_user(session, upn) if not user: raise HTTPException(status.HTTP_404_NOT_FOUND, "User not found.") @@ -54,5 +75,10 @@ async def delete_session( session_id: str, storage: FromDishka[SessionStorage], ) -> None: - """Delete current logged in user data.""" + """Delete a specific user session by session ID. + + Args: + session_id (str): Session identifier. + storage (SessionStorage): Session storage dependency. + """ await storage.delete_user_session(session_id) diff --git a/app/api/auth/utils.py b/app/api/auth/utils.py index dd78b6400..19375e1f1 100644 --- a/app/api/auth/utils.py +++ b/app/api/auth/utils.py @@ -18,8 +18,14 @@ def get_ip_from_request(request: Request) -> IPv4Address | IPv6Address: """Get IP address from request. - :param Request request: The incoming request object. - :return IPv4Address | None: The IP address or None. + Args: + request (Request): The incoming request object. + + Raises: + HTTPException: If the request client is None. + + Returns: + IPv4Address | IPv6Address: The IP address or None. """ forwarded_for = request.headers.get("X-Forwarded-For") if forwarded_for: @@ -35,8 +41,8 @@ def get_ip_from_request(request: Request) -> IPv4Address | IPv6Address: def get_user_agent_from_request(request: Request) -> str: """Get user agent from request. - :param Request request: The incoming request object. - :return str: The user agent header. + Returns: + str: The user agent header. """ user_agent_header = request.headers.get("User-Agent") return user_agent_header if user_agent_header else "" @@ -56,10 +62,14 @@ async def create_and_set_session_key( Update the user's last logon time and set the appropriate cookies in the response. - :param User user: db user - :param AsyncSession session: db session - :param Settings settings: app settings - :param Response response: fastapi response object + Args: + user (User): db user + session (AsyncSession): db session + settings (Settings): app settings + response (Response): fastapi response object + storage (SessionStorage): session storage backend + ip (IPv4Address | IPv6Address): IP address of the client + user_agent (str): user agent string of the client """ await set_last_logon_user(user, session, settings.TIMEZONE) diff --git a/app/api/exception_handlers.py b/app/api/exception_handlers.py index 17e329df9..441c38108 100644 --- a/app/api/exception_handlers.py +++ b/app/api/exception_handlers.py @@ -10,7 +10,15 @@ def handle_db_connect_error( request: Request, # noqa: ARG001 exc: Exception, ) -> NoReturn: - """Handle duplicate.""" + """Handle database connection errors. + + Args: + request (Request): FastAPI request object. + exc (Exception): Exception instance. + + Raises: + HTTPException: If connection pool is exceeded or backend error occurs. + """ if "QueuePool limit of size" in str(exc): logger.critical("POOL EXCEEDED {}", exc) @@ -28,7 +36,15 @@ async def handle_dns_error( request: Request, # noqa: ARG001 exc: Exception, ) -> NoReturn: - """Handle EmptyLabel exception.""" + """Handle DNS-related errors. + + Args: + request (Request): FastAPI request object. + exc (Exception): Exception instance. + + Raises: + HTTPException: Always raised for DNS errors. + """ logger.critical("DNS manager error: {}", exc) raise HTTPException(status.HTTP_503_SERVICE_UNAVAILABLE) @@ -37,7 +53,15 @@ async def handle_instance_not_found_error( request: Request, # noqa: ARG001 exc: Exception, # noqa: ARG001 ) -> NoReturn: - """Handle Instance Not Found error.""" + """Handle Instance Not Found error. + + Args: + request (Request): request + exc (Exception): exc. + + Raises: + HTTPException: Instance not found. + """ raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Instance not found.", @@ -48,7 +72,15 @@ async def handle_instance_cant_modify_error( request: Request, # noqa: ARG001 exc: Exception, # noqa: ARG001 ) -> NoReturn: - """Handle Instance Cant Modify error.""" + """Handle Instance Cant Modify error. + + Args: + request (Request): request + exc (Exception): exc. + + Raises: + HTTPException: System Instance cannot be modified. + """ raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="System Instance cannot be modified.", @@ -59,7 +91,11 @@ async def handle_not_implemented_error( request: Request, # noqa: ARG001 exc: Exception, # noqa: ARG001 ) -> NoReturn: - """Handle Not Implemented error.""" + """Handle Not Implemented error. + + Raises: + HTTPException: This feature is supported with selfhosted DNS server. + """ raise HTTPException( status_code=status.HTTP_501_NOT_IMPLEMENTED, detail="This feature is supported with selfhosted DNS server.", diff --git a/app/api/ldap_schema/attribute_type_router.py b/app/api/ldap_schema/attribute_type_router.py index e3e12039b..5996a5596 100644 --- a/app/api/ldap_schema/attribute_type_router.py +++ b/app/api/ldap_schema/attribute_type_router.py @@ -35,12 +35,10 @@ async def create_one_attribute_type( ) -> None: """Create a new Attribute Type. - \f - :param AttributeTypeSchema request_data: Data for creating Attribute Type. - :param FromDishka[AttributeTypeDAO] attribute_type_dao: Attribute Type\ - manager. - :param FromDishka[AsyncSession] session: Database session. - :return None. + Args: + request_data (AttributeTypeSchema): Data for creating attribute type. + session (AsyncSession): Database session. + attribute_type_dao (AttributeTypeDAO): Database session. """ await attribute_type_dao.create_one( oid=request_data.oid, @@ -62,12 +60,14 @@ async def get_one_attribute_type( attribute_type_name: str, attribute_type_dao: FromDishka[AttributeTypeDAO], ) -> AttributeTypeSchema: - """Retrieve a one Attribute Type. + """Retrieve a single attribute type by name. - \f - :param str attribute_type_name: name of the Attribute Type. - :param FromDishka[AttributeTypeDAO] attribute_type_dao: Attribute Type dao. - :return AttributeTypeSchema: Attribute Type Schema. + Args: + attribute_type_name (str): Name of the attribute type. + attribute_type_dao (AttributeTypeDAO): Attribute Type dao. + + Returns: + AttributeTypeSchema: Attribute type schema. """ attribute_type = await attribute_type_dao.get_one_by_name( attribute_type_name @@ -84,12 +84,14 @@ async def get_list_attribute_types_with_pagination( attribute_type_dao: FromDishka[AttributeTypeDAO], params: Annotated[PaginationParams, Query()], ) -> AttributeTypePaginationSchema: - """Retrieve a chunk of Attribute Types with pagination. + """Retrieve a paginated list of attribute types. + + Args: + attribute_type_dao (AttributeTypeDAO): Attribute Type dao. + params (PaginationParams): Pagination parameters. - \f - :param FromDishka[AttributeTypeDAO] attribute_type_dao: Attribute Type dao. - :param PaginationParams params: Pagination parameters. - :return AttributeTypePaginationSchema: Paginator Schema. + Returns: + AttributeTypePaginationSchema: Paginated attribute types. """ pagination_result = await attribute_type_dao.get_paginator(params=params) @@ -112,14 +114,13 @@ async def modify_one_attribute_type( session: FromDishka[AsyncSession], attribute_type_dao: FromDishka[AttributeTypeDAO], ) -> None: - """Modify an Attribute Type. - - \f - :param str attribute_type_name: name of the Attribute Type for modifying. - :param AttributeTypeUpdateSchema request_data: Changed data. - :param FromDishka[AsyncSession] session: Database session. - :param FromDishka[AttributeTypeDAO] attribute_type_dao: Attribute Type dao. - :return None. + """Modify an attribute type. + + Args: + attribute_type_name (str): Name of the attribute type to modify. + request_data (AttributeTypeUpdateSchema): Data to update. + session (AsyncSession): Database session. + attribute_type_dao (AttributeTypeDAO): Attribute Type dao. """ attribute_type = await attribute_type_dao.get_one_by_name( attribute_type_name @@ -145,11 +146,10 @@ async def delete_bulk_attribute_types( ) -> None: """Delete Attribute Types by their names. - \f - :param LimitedListType attribute_types_names: List of Attribute Type names - :param FromDishka[AsyncSession] session: Database session. - :param FromDishka[AttributeTypeDAO] attribute_type_dao: Attribute type dao. - :return None: None + Args: + attribute_types_names (LimitedListType): List of attribute type names. + session (AsyncSession): Database session. + attribute_type_dao (AttributeTypeDAO): Attribute Type dao. """ await attribute_type_dao.delete_all_by_names(attribute_types_names) await session.commit() diff --git a/app/api/ldap_schema/entity_type_router.py b/app/api/ldap_schema/entity_type_router.py index bf725fbf4..3be4d3ea5 100644 --- a/app/api/ldap_schema/entity_type_router.py +++ b/app/api/ldap_schema/entity_type_router.py @@ -37,11 +37,11 @@ async def create_one_entity_type( """Create a new Entity Type. \f - :param EntityTypeSchema request_data: Data for creating Entity Type. - :param FromDishka[EntityTypeDAO] entity_type_dao: Entity Type DAO. - :param FromDishka[ObjectClassDAO] object_class_dao: Object Class DAO. - :param FromDishka[AsyncSession] session: Database session. - :return None. + Args: + request_data (EntityTypeSchema): Data for creating Entity Type. + entity_type_dao (EntityTypeDAO): Entity Type DAO. + object_class_dao (ObjectClassDAO): Object Class DAO. + session (AsyncSession): Database session. """ await object_class_dao.is_all_object_classes_exists( request_data.object_class_names @@ -67,9 +67,12 @@ async def get_one_entity_type( """Retrieve a one Entity Type. \f - :param str entity_type_name: name of the Entity Type. - :param FromDishka[EntityTypeDAO] entity_type_dao: Entity Type DAO. - :return EntityTypeSchema: Entity Type Schema. + Args: + entity_type_name (str): name of the Entity Type. + entity_type_dao (EntityTypeDAO): Entity Type DAO. + + Returns: + EntityTypeSchema: Entity Type Schema. """ entity_type = await entity_type_dao.get_one_by_name(entity_type_name) return EntityTypeSchema.from_db(entity_type) @@ -87,9 +90,12 @@ async def get_list_entity_types_with_pagination( """Retrieve a chunk of Entity Types with pagination. \f - :param FromDishka[EntityTypeDAO] entity_type_dao: Entity Type DAO. - :param PaginationParams params: Pagination parameters. - :return EntityTypePaginationSchema: Paginator Schema. + Args: + entity_type_dao (EntityTypeDAO): Entity Type DAO. + params (PaginationParams): Pagination parameters. + + Returns: + EntityTypePaginationSchema: Paginator Schema. """ pagination_result = await entity_type_dao.get_paginator(params=params) @@ -116,12 +122,12 @@ async def modify_one_entity_type( """Modify an Entity Type. \f - :param str entity_type_name: Name of the Entity Type for modifying. - :param EntityTypeUpdateSchema request_data: Changed data. - :param FromDishka[EntityTypeDAO] entity_type_dao: Entity Type DAO. - :param FromDishka[ObjectClassDAO] object_class_dao: Object Class DAO. - :param FromDishka[AsyncSession] session: Database session. - :return None. + Args: + entity_type_name (str): Name of the Entity Type for modifying. + request_data (EntityTypeUpdateSchema): Changed data. + entity_type_dao (EntityTypeDAO): Entity Type DAO. + object_class_dao (ObjectClassDAO): Object Class DAO. + session (AsyncSession): Database session. """ entity_type = await entity_type_dao.get_one_by_name(entity_type_name) @@ -145,10 +151,10 @@ async def delete_bulk_entity_types( """Delete Entity Types by their names. \f - :param LimitedListType entity_type_names: List of Entity Type names. - :param FromDishka[EntityTypeDAO] entity_type_dao: Entity Type DAO. - :param FromDishka[AsyncSession] session: Database session. - :return None: None + Args: + entity_type_names (LimitedListType): List of Entity Type names. + entity_type_dao (EntityTypeDAO): Entity Type DAO. + session (AsyncSession): Database session. """ await entity_type_dao.delete_all_by_names(entity_type_names) await session.commit() diff --git a/app/api/ldap_schema/object_class_router.py b/app/api/ldap_schema/object_class_router.py index af1ceedb2..9fa0dddf0 100644 --- a/app/api/ldap_schema/object_class_router.py +++ b/app/api/ldap_schema/object_class_router.py @@ -32,13 +32,12 @@ async def create_one_object_class( object_class_dao: FromDishka[ObjectClassDAO], session: FromDishka[AsyncSession], ) -> None: - """Create a new Object Class. + """Create a new object class. - \f - :param ObjectClassSchema request_data: Data for creating Object Class. - :param FromDishka[ObjectClassDAO] object_class_dao: Object Class DAO. - :param FromDishka[AsyncSession] session: Database session. - :return None. + Args: + request_data (ObjectClassSchema): Data for creating object class. + object_class_dao (ObjectClassDAO): Object Class DAO. + session (AsyncSession): Database session. """ await object_class_dao.create_one( oid=request_data.oid, @@ -61,12 +60,14 @@ async def get_one_object_class( object_class_name: str, object_class_dao: FromDishka[ObjectClassDAO], ) -> ObjectClassSchema: - """Retrieve a one object class. + """Retrieve a single object class by name. - \f - :param str object_class_name: name of the Object Class. - :param FromDishka[ObjectClassDAO] object_class_dao: Object Class DAO. - :return ObjectClassSchema: One Object Class Schemas. + Args: + object_class_name (str): Name of the object class. + object_class_dao (ObjectClassDAO): Object Class DAO. + + Returns: + ObjectClassSchema: Object class schema. """ object_class = await object_class_dao.get_one_by_name(object_class_name) @@ -82,12 +83,14 @@ async def get_list_object_classes_with_pagination( object_class_dao: FromDishka[ObjectClassDAO], params: Annotated[PaginationParams, Query()], ) -> ObjectClassPaginationSchema: - """Retrieve a list of all object classes with paginate. + """Retrieve a paginated list of object classes. + + Args: + object_class_dao (ObjectClassDAO): Object Class DAO. + params (PaginationParams): Pagination parameters. - \f - :param FromDishka[ObjectClassDAO] object_class_dao: Object Class DAO. - :param PaginationParams params: Pagination parameters. - :return ObjectClassPaginationSchema: Paginator. + Returns: + ObjectClassPaginationSchema: Paginated object classes. """ pagination_result = await object_class_dao.get_paginator(params=params) @@ -110,14 +113,13 @@ async def modify_one_object_class( object_class_dao: FromDishka[ObjectClassDAO], session: FromDishka[AsyncSession], ) -> None: - """Modify an Object Class. - - \f - :param str object_class_name: Name of the Object Class for modifying. - :param ObjectClassUpdateSchema request_data: Changed data. - :param FromDishka[ObjectClassDAO] object_class_dao: Object Class DAO. - :param FromDishka[AsyncSession] session: Database session. - :return None. + """Modify an object class. + + Args: + object_class_name (str): Name of the object class to modify. + request_data (ObjectClassUpdateSchema): Data to update. + object_class_dao (ObjectClassDAO): Object Class DAO. + session (AsyncSession): Database session. """ object_class = await object_class_dao.get_one_by_name(object_class_name) @@ -137,13 +139,12 @@ async def delete_bulk_object_classes( object_class_dao: FromDishka[ObjectClassDAO], session: FromDishka[AsyncSession], ) -> None: - """Delete Object Classes by their names. + """Delete object classes by their names. - \f - :param LimitedListType object_classes_names: List of Object Classes names. - :param FromDishka[ObjectClassDAO] object_class_dao: Object Class DAO. - :param FromDishka[AsyncSession] session: Database session. - :return None: None + Args: + object_classes_names (LimitedListType): List of object class names. + object_class_dao (ObjectClassDAO): Object Class DAO. + session (AsyncSession): Database session. """ await object_class_dao.delete_all_by_names(object_classes_names) await session.commit() diff --git a/app/api/main/ap_router.py b/app/api/main/ap_router.py index e315a3477..38eeedd8f 100644 --- a/app/api/main/ap_router.py +++ b/app/api/main/ap_router.py @@ -27,8 +27,9 @@ async def get_access_policies( """Get APs. \f - :param AccessPolicySchema policy: ap - :param FromDishka[AsyncSession] session: db. + + Returns: + list[MaterialAccessPolicySchema]: list of access policies. """ return [ MaterialAccessPolicySchema( diff --git a/app/api/main/dns_router.py b/app/api/main/dns_router.py index e669f4be9..0c98b39d4 100644 --- a/app/api/main/dns_router.py +++ b/app/api/main/dns_router.py @@ -49,7 +49,12 @@ async def create_record( data: DNSServiceRecordCreateRequest, dns_manager: FromDishka[AbstractDNSManager], ) -> None: - """Create DNS record with given params.""" + """Create a DNS record with the given parameters. + + Args: + data (DNSServiceRecordCreateRequest): DNS record creation request data. + dns_manager (AbstractDNSManager): DNS manager dependency. + """ await dns_manager.create_record( data.record_name, data.record_value, @@ -64,7 +69,12 @@ async def delete_single_record( data: DNSServiceRecordDeleteRequest, dns_manager: FromDishka[AbstractDNSManager], ) -> None: - """Delete DNS record with given params.""" + """Delete a DNS record with the given parameters. + + Args: + data (DNSServiceRecordDeleteRequest): DNS record deletion request data. + dns_manager (AbstractDNSManager): DNS manager dependency. + """ await dns_manager.delete_record( data.record_name, data.record_value, @@ -78,7 +88,12 @@ async def update_record( data: DNSServiceRecordUpdateRequest, dns_manager: FromDishka[AbstractDNSManager], ) -> None: - """Update DNS record with given params.""" + """Update a DNS record with the given parameters. + + Args: + data (DNSServiceRecordUpdateRequest): DNS record update request data. + dns_manager (AbstractDNSManager): DNS manager dependency. + """ await dns_manager.update_record( data.record_name, data.record_value, @@ -92,7 +107,14 @@ async def update_record( async def get_all_records( dns_manager: FromDishka[AbstractDNSManager], ) -> list[DNSRecords]: - """Get all DNS records of current zone.""" + """Get all DNS records of the current zone. + + Args: + dns_manager (AbstractDNSManager): DNS manager dependency. + + Returns: + list[DNSRecords]: List of DNS records. + """ return await dns_manager.get_all_records() @@ -101,7 +123,15 @@ async def get_dns_status( session: FromDishka[AsyncSession], dns_settings: FromDishka[DNSManagerSettings], ) -> dict[str, str | None]: - """Get DNS service status.""" + """Get DNS service status. + + Args: + session (AsyncSession): Database session. + dns_settings (DNSManagerSettings): DNS manager settings. + + Returns: + dict[str, str | None]: DNS status, zone name, and DNS server IP. + """ state = await get_dns_state(session) return { "dns_status": state, @@ -119,7 +149,16 @@ async def setup_dns( ) -> None: """Set up DNS service. - Create zone file, get TSIG key, reload DNS server if selfhosted. + Creates zone file, gets TSIG key, reloads DNS server if self-hosted. + + Args: + data (DNSServiceSetupRequest): DNS setup request data. + dns_manager (AbstractDNSManager): DNS manager dependency. + session (AsyncSession): Database session. + settings (Settings): Application settings. + + Raises: + HTTPException: If DNS setup fails. """ dns_ip_address = data.dns_ip_address or settings.DNS_BIND_HOST @@ -142,7 +181,11 @@ async def setup_dns( async def get_dns_zone( dns_manager: FromDishka[AbstractDNSManager], ) -> list[DNSZone]: - """Get all DNS records of all zones.""" + """Get all DNS records of all zones. + + Returns: + list[DNSZone]: List of DNSZone objects with records. + """ return await dns_manager.get_all_zones_records() @@ -150,7 +193,11 @@ async def get_dns_zone( async def get_forward_dns_zones( dns_manager: FromDishka[AbstractDNSManager], ) -> list[DNSForwardZone]: - """Get list of DNS forward zones with forwarders.""" + """Get list of DNS forward zones with forwarders. + + Returns: + list[DNSForwardZone]: List of DNSForwardZone objects. + """ return await dns_manager.get_forward_zones() @@ -194,7 +241,12 @@ async def check_dns_forward_zone( data: DNSServiceForwardZoneCheckRequest, dns_manager: FromDishka[AbstractDNSManager], ) -> list[DNSForwardServerStatus]: - """Check given DNS forward zone for availability.""" + """Check given DNS forward zone for availability. + + Returns: + list[DNSForwardServerStatus]: List of DNSForwardServerStatus objects + indicating the status of each DNS server. + """ return [ await dns_manager.check_forward_dns_server(dns_server_ip) for dns_server_ip in data.dns_server_ips @@ -223,7 +275,11 @@ async def update_server_options( async def get_server_options( dns_manager: FromDishka[AbstractDNSManager], ) -> list[DNSServerParam]: - """Get list of modifiable DNS server params.""" + """Get list of modifiable DNS server params. + + Returns: + list[DNSServerParam]: List of DNSServerParam objects. + """ return await dns_manager.get_server_options() diff --git a/app/api/main/krb5_router.py b/app/api/main/krb5_router.py index 128caba88..b81a1e9f5 100644 --- a/app/api/main/krb5_router.py +++ b/app/api/main/krb5_router.py @@ -66,12 +66,18 @@ async def setup_krb_catalogue( kadmin: FromDishka[AbstractKadmin], entity_type_dao: FromDishka[EntityTypeDAO], ) -> None: - """Generate tree for kdc/kadmin. - - :param Annotated[AsyncSession, Depends session: db - :param Annotated[EmailStr, Body mail: krbadmin email - :param Annotated[SecretStr, Body krbadmin_password: pw - :raises HTTPException: on conflict + """Generate tree for KDC/Kadmin. + + Args: + session (AsyncSession): Database session. + mail (EmailStr): Kerberos admin email. + krbadmin_password (SecretStr): Kerberos admin password. + ldap_session (LDAPSession): LDAP session. + kadmin (AbstractKadmin): Kadmin manager. + entity_type_dao (EntityTypeDAO): Entity Type DAO. + + Raises: + HTTPException: On conflict or failed creation. """ base_dn_list = await get_base_directories(session) base_dn = base_dn_list[0].path_dn @@ -162,20 +168,26 @@ async def setup_kdc( settings: FromDishka[Settings], kadmin: FromDishka[AbstractKadmin], request: Request, -) -> None: +) -> Response: """Set up KDC server. - Create data structure in catalogue, generate config files, trigger commands + Creates data structure in catalogue, generates config files, + and triggers commands. - - **mail**: krbadmin mail - - **password**: krbadmin password + Args: + data (KerberosSetupRequest): Kerberos setup request data. + user (UserSchema): Current user. + session (AsyncSession): Database session. + settings (Settings): Application settings. + kadmin (AbstractKadmin): Kadmin manager. + request (Request): FastAPI request. - \f - :param Annotated[EmailStr, Body mail: json, defaults to 'admin')] - :param Annotated[str, Body password: json, defaults to 'password')] - :param Annotated[AsyncSession, Depends session: db - :param Annotated[LDAPSession, Depends ldap_session: ldap session - """ + Returns: + Response: Background task response. + + Raises: + HTTPException: On authentication or KDC setup failure. + """ # noqa: DOC501 base_dn_list = await get_base_directories(session) base_dn = base_dn_list[0].path_dn domain: str = base_dn_list[0].name @@ -262,7 +274,7 @@ async def setup_kdc( data.admin_password.get_secret_value(), ) - return Response(background=task) # type: ignore + return Response(background=task) finally: await session.commit() @@ -281,8 +293,15 @@ async def ktadd( ) -> StreamingResponse: """Create keytab from kadmin server. - :param Annotated[LDAPSession, Depends ldap_session: ldap - :return bytes: file + Args: + kadmin (AbstractKadmin): Kadmin manager. + names (list[str]): List of principal names. + + Returns: + StreamingResponse: Keytab file as a streaming response. + + Raises: + HTTPException: If principal not found. """ try: response = await kadmin.ktadd(names) @@ -302,11 +321,17 @@ async def get_krb_status( session: FromDishka[AsyncSession], kadmin: FromDishka[AbstractKadmin], ) -> KerberosState: - """Get server status. + """Get Kerberos server status. - :param Annotated[AsyncSession, Depends session: db - :param Annotated[LDAPSession, Depends ldap_session: ldap - :return KerberosState: state + Args: + session (AsyncSession): Database session. + kadmin (AbstractKadmin): Kadmin manager. + + Returns: + KerberosState: Current Kerberos server state. + + Raises: + HTTPException: If unable to get server status. """ db_state = await get_krb_server_state(session) try: @@ -326,12 +351,15 @@ async def add_principal( instance: Annotated[LIMITED_STR, Body()], kadmin: FromDishka[AbstractKadmin], ) -> None: - """Create principal in kerberos with given name. + """Create principal in Kerberos with given name. + + Args: + primary (str): Principal primary name. + instance (str): Principal instance. + kadmin (AbstractKadmin): Kadmin manager. - \f - :param Annotated[str, Body principal_name: upn - :param Annotated[LDAPSession, Depends ldap_session: ldap - :raises HTTPException: on failed kamin request. + Raises: + HTTPException: On failed kadmin request. """ try: await kadmin.add_principal(f"{primary}/{instance}", None) @@ -348,13 +376,15 @@ async def rename_principal( principal_new_name: Annotated[LIMITED_STR, Body()], kadmin: FromDishka[AbstractKadmin], ) -> None: - """Rename principal in kerberos with given name. + """Rename principal in Kerberos. - \f - :param Annotated[str, Body principal_name: upn - :param Annotated[LIMITED_STR, Body principal_new_name: _description_ - :param Annotated[LDAPSession, Depends ldap_session: ldap - :raises HTTPException: on failed kamin request. + Args: + principal_name (str): Current principal name. + principal_new_name (str): New principal name. + kadmin (AbstractKadmin): Kadmin manager. + + Raises: + HTTPException: On failed kadmin request. """ try: await kadmin.rename_princ(principal_name, principal_new_name) @@ -371,13 +401,15 @@ async def reset_principal_pw( new_password: Annotated[LIMITED_STR, Body()], kadmin: FromDishka[AbstractKadmin], ) -> None: - """Reset principal password in kerberos with given name. + """Reset principal password in Kerberos. + + Args: + principal_name (str): Principal name. + new_password (str): New password. + kadmin (AbstractKadmin): Kadmin manager. - \f - :param Annotated[str, Body principal_name: upn - :param Annotated[LIMITED_STR, Body new_password: _description_ - :param Annotated[LDAPSession, Depends ldap_session: ldap - :raises HTTPException: on failed kamin request. + Raises: + HTTPException: On failed kadmin request. """ try: await kadmin.change_principal_password(principal_name, new_password) @@ -393,12 +425,14 @@ async def delete_principal( principal_name: Annotated[LIMITED_STR, Body(embed=True)], kadmin: FromDishka[AbstractKadmin], ) -> None: - """Delete principal in kerberos with given name. + """Delete principal in Kerberos. + + Args: + principal_name (str): Principal name. + kadmin (AbstractKadmin): Kadmin manager. - \f - :param Annotated[str, Body principal_name: upn - :param FromDishka[AbstractKadmin] kadmin: _description_ - :raises HTTPException: on failed kamin request + Raises: + HTTPException: On failed kadmin request. """ try: await kadmin.del_principal(principal_name) diff --git a/app/api/main/router.py b/app/api/main/router.py index 776839490..f3b796786 100644 --- a/app/api/main/router.py +++ b/app/api/main/router.py @@ -32,7 +32,15 @@ async def search( request: SearchRequest, req: Request, ) -> SearchResponse: - """LDAP SEARCH entry request.""" + """Handle LDAP SEARCH entry request. + + Args: + request (SearchRequest): object containing search parameters. + req (Request): object for accessing application state. + + Returns: + SearchResponse: Response containing search results and metadata. + """ responses = await request.handle_api(req.state.dishka_container) metadata: SearchResultDone = responses.pop(-1) # type: ignore @@ -51,7 +59,15 @@ async def add( request: AddRequest, req: Request, ) -> LDAPResult: - """LDAP ADD entry request.""" + """Handle LDAP ADD entry request. + + Args: + request (AddRequest): object containing entry data to add. + req (Request): object for accessing application state. + + Returns: + LDAPResult: Result of the add operation. + """ return await request.handle_api(req.state.dishka_container) @@ -60,7 +76,15 @@ async def modify( request: ModifyRequest, req: Request, ) -> LDAPResult: - """LDAP MODIFY entry request.""" + """Handle LDAP MODIFY entry request. + + Args: + request (ModifyRequest): object containing modification data. + req (Request): object for accessing application state. + + Returns: + LDAPResult: Result of the modify operation. + """ return await request.handle_api(req.state.dishka_container) @@ -69,7 +93,16 @@ async def modify_many( requests: list[ModifyRequest], req: Request, ) -> list[LDAPResult]: - """Bulk LDAP MODIFY entry request.""" + """Handle bulk LDAP MODIFY entry requests. + + Args: + requests (list[ModifyRequest]): List of ModifyRequest objects\ + containing modification data. + req (Request): object for accessing application state. + + Returns: + list[LDAPResult]: List of results for each modify operation. + """ results = [] for request in requests: results.append(await request.handle_api(req.state.dishka_container)) @@ -81,7 +114,15 @@ async def modify_dn( request: ModifyDNRequest, req: Request, ) -> LDAPResult: - """LDAP MODIFY entry DN request.""" + """Handle LDAP MODIFY entry DN request. + + Args: + request (ModifyDNRequest): object containing DN modification data. + req (Request): object for accessing application state. + + Returns: + LDAPResult: Result of the DN modify operation. + """ return await request.handle_api(req.state.dishka_container) @@ -90,5 +131,13 @@ async def delete( request: DeleteRequest, req: Request, ) -> LDAPResult: - """LDAP DELETE entry request.""" + """Handle LDAP DELETE entry request. + + Args: + request (DeleteRequest): object containing entry to delete. + req (Request): object for accessing application state. + + Returns: + LDAPResult: Result of the delete operation. + """ return await request.handle_api(req.state.dishka_container) diff --git a/app/api/main/schema.py b/app/api/main/schema.py index b74261fd9..ad7cd740b 100644 --- a/app/api/main/schema.py +++ b/app/api/main/schema.py @@ -23,7 +23,11 @@ class SearchRequest(LDAPSearchRequest): filter: str = Field(..., examples=["(objectClass=*)"]) # type: ignore def cast_filter(self) -> UnaryExpression | ColumnElement: - """Cast str filter to sa sql.""" + """Cast str filter to sa sql. + + Returns: + UnaryExpression | ColumnElement: SQL expression for the filter. + """ filter_ = self.filter.lower().replace("objectcategory", "objectclass") return cast_str_filter2sql(Filter.parse(filter_).simplify()) @@ -32,7 +36,15 @@ async def handle_api( # type: ignore self, container: AsyncContainer, ) -> list[SearchResultEntry | SearchResultDone]: - """Get all responses.""" + """Get all responses. + + Args: + container (AsyncContainer): Async container with dependencies. + + Returns: + list[SearchResultEntry | SearchResultDone]: List of LDAP search\ + result entries or done responses. + """ return await self._handle_api(container) # type: ignore @@ -51,6 +63,8 @@ class KerberosSetupRequest(BaseModel): class _PolicyFields: + """Policy fields.""" + name: str can_read: bool can_add: bool @@ -60,6 +74,8 @@ class _PolicyFields: class _MaterialFields: + """Material fields.""" + id: int diff --git a/app/api/main/utils.py b/app/api/main/utils.py index 0f98ad933..8fcc9f101 100644 --- a/app/api/main/utils.py +++ b/app/api/main/utils.py @@ -15,6 +15,14 @@ async def get_ldap_session( ldap_session: FromDishka[LDAPSession], user: Annotated[UserSchema, Depends(get_current_user)], ) -> LDAPSession: - """Create LDAP session.""" + """Create LDAP session. + + Args: + ldap_session (FromDishka[LDAPSession]): LDAP session. + user (UserSchema): Current user. + + Returns: + LDAPSession: LDAP session with user set. + """ await ldap_session.set_user(user) return ldap_session diff --git a/app/api/network/router.py b/app/api/network/router.py index 661bbe5cf..c437dc7d5 100644 --- a/app/api/network/router.py +++ b/app/api/network/router.py @@ -44,10 +44,16 @@ async def add_network_policy( """Add policy. \f - :param Policy policy: policy to add - :raises HTTPException: 422 invalid group DN - :raises HTTPException: 422 Entry already exists - :return PolicyResponse: Ready policy + Args: + policy (Policy): policy to add + session (AsyncSession): Database session + + Raises: + HTTPException: 422 invalid group DN + HTTPException: 422 Entry already exists + + Returns: + PolicyResponse: Ready policy """ new_policy = NetworkPolicy( name=policy.name, @@ -110,7 +116,9 @@ async def get_list_network_policies( """Get network. \f - :return list[PolicyResponse]: all policies + + Returns: + list[PolicyResponse]: List of policies with their details. """ groups = selectinload(NetworkPolicy.groups).selectinload(Group.directory) mfa_groups = selectinload(NetworkPolicy.mfa_groups).selectinload( @@ -157,12 +165,18 @@ async def delete_network_policy( """Delete policy. \f - :param int policy_id: id - :param User user: requires login - :raises HTTPException: 404 - :raises HTTPException: 422 On last active policy, - at least 1 should be in database. - :return bool: status of delete + Args: + policy_id (int): id + request (Request): http request + session (AsyncSession): Database session + + Raises: + HTTPException: 404 + HTTPException: 422 On last active policy, at least 1 should be + in database. + + Returns: + bool: status of delete """ policy = await session.get(NetworkPolicy, policy_id, with_for_update=True) @@ -199,12 +213,17 @@ async def switch_network_policy( - **policy_id**: int, policy to switch \f - :param int policy_id: id - :param User user: requires login - :raises HTTPException: 404 - :raises HTTPException: 422 On last active policy, - at least 1 should be active - :return bool: status of update + Args: + policy_id (int): id + session (FromDishka[AsyncSession]): async db session + + Raises: + HTTPException: 404 + HTTPException: 422 On last active policy, at least 1 should be + active + + Returns: + bool: status of update """ policy = await session.get(NetworkPolicy, policy_id, with_for_update=True) @@ -227,11 +246,17 @@ async def update_network_policy( """Update network policy. \f - :param PolicyUpdate policy: update request - :raises HTTPException: 404 policy not found - :raises HTTPException: 422 Invalid group DN - :raises HTTPException: 422 Entry already exists - :return PolicyResponse: Policy from database + Args: + request (PolicyUpdate): update request + session (FromDishka[AsyncSession]): async db session + + Returns: + PolicyResponse: Policy from database + + Raises: + HTTPException: 404 policy not found + HTTPException: 422 Invalid group DN + HTTPException: 422 Entry already exists """ selected_policy = await session.get( NetworkPolicy, @@ -311,10 +336,15 @@ async def swap_network_policy( - **first_policy_id**: policy to swap - **second_policy_id**: policy to swap \f - :param int first_policy_id: policy to swap - :param int second_policy_id: policy to swap - :raises HTTPException: 404 - :return SwapResponse: policy new priorities + Args: + swap (SwapRequest): http request + session (FromDishka[AsyncSession]): async db session + + Raises: + HTTPException: 404 + + Returns: + SwapResponse: policy new priorities """ policy1 = await session.get( NetworkPolicy, diff --git a/app/api/network/schema.py b/app/api/network/schema.py index 5c62ae8b7..37eb3d4fc 100644 --- a/app/api/network/schema.py +++ b/app/api/network/schema.py @@ -40,7 +40,11 @@ class NetmasksMixin: @computed_field # type: ignore @property def complete_netmasks(self) -> list[IPv4Address | IPv4Network]: - """Validate range or return networks range.""" + """Validate range or return networks range. + + Returns: + list[IPv4Address | IPv4Network]: complete netmasks + """ values = [] for item in self.netmasks: if isinstance(item, IPRange): @@ -54,6 +58,14 @@ def complete_netmasks(self) -> list[IPv4Address | IPv4Network]: @field_validator("groups") @classmethod def validate_group(cls, groups: list[str]) -> list[str]: + """Validate groups. + + Returns: + list[str]: groups + + Raises: + ValueError: Invalid DN + """ if not groups: return groups if all(validate_entry(group) for group in groups): @@ -64,6 +76,14 @@ def validate_group(cls, groups: list[str]) -> list[str]: @field_validator("mfa_groups") @classmethod def validate_mfa_group(cls, mfa_groups: list[str]) -> list[str]: + """Validate mfa groups. + + Returns: + list[str]: mfa groups + + Raises: + ValueError: Invalid DN + """ if not mfa_groups: return mfa_groups if all(validate_entry(group) for group in mfa_groups): @@ -79,8 +99,12 @@ def netmasks_serialize( ) -> list[str | dict]: """Serialize netmasks to list. - :param IPv4IntefaceListType netmasks: ip masks - :return list[str | dict]: ready to json serialized + Args: + netmasks(IPv4IntefaceListType): ip masks + netmasks: IPv4IntefaceListType: + + Returns: + list[str | dict]: ready to json serialized """ values: list[str | dict] = [] @@ -159,7 +183,14 @@ class PolicyUpdate(BaseModel, NetmasksMixin): @model_validator(mode="after") def check_passwords_match(self) -> Self: - """Validate if all fields are empty.""" + """Validate if all fields are empty. + + Returns: + PolicyUpdate: + + Raises: + ValueError: Name, netmasks and group cannot be empty + """ if not self.name and not self.netmasks and not self.groups: raise ValueError("Name, netmasks and group cannot be empty") diff --git a/app/api/network/utils.py b/app/api/network/utils.py index c5a96370e..a804bb997 100644 --- a/app/api/network/utils.py +++ b/app/api/network/utils.py @@ -14,8 +14,8 @@ async def check_policy_count(session: AsyncSession) -> None: """Check if policy count euqals 1. - :param AsyncSession session: db - :raises HTTPException: 422 + Raises: + HTTPException: 422 """ count = await session.scalars( ( diff --git a/app/api/shadow/router.py b/app/api/shadow/router.py index b45386950..caf2d6c95 100644 --- a/app/api/shadow/router.py +++ b/app/api/shadow/router.py @@ -33,7 +33,19 @@ async def proxy_request( mfa: FromDishka[LDAPMultiFactorAPI], session: FromDishka[AsyncSession], ) -> None: - """Proxy request to mfa.""" + """Proxy request to mfa. + + Args: + principal (str): user principal name + ip (IPv4Address): user ip address + mfa (FromDishka[LDAPMultiFactorAPI]): mfa api + session (FromDishka[AsyncSession]): db session + + Raises: + HTTPException: 401 if mfa is required but not passed or failed + HTTPException: 403 if user is not allowed to use kerberos + HTTPException: 422 if user not found + """ user = await get_user(session, principal) if not user: @@ -87,13 +99,14 @@ async def sync_password( - **principal**: user upn - **new_password**: password to set \f - :param FromDishka[AsyncSession] session: db - :param FromDishka[AbstractKadmin] kadmin: kadmin api - :param Annotated[str, Body principal: reset target user - :param Annotated[str, Body new_password: new password for user - :raises HTTPException: 404 if user not found - :raises HTTPException: 422 if password not valid - :return None: None + Args: + principal (Annotated[str, Body]): user principal name + new_password (Annotated[str, Body]): new password for user + session (FromDishka[AsyncSession]): db + + Raises: + HTTPException: 422 if password not valid + HTTPException: 404 if user not found """ user = await get_user(session, principal) diff --git a/app/config.py b/app/config.py index 95f597ba3..d7ee66554 100644 --- a/app/config.py +++ b/app/config.py @@ -24,12 +24,17 @@ def _get_vendor_version() -> str: + """Get vendor version. + + Returns: + str: vendor version + """ with open("/pyproject.toml", "rb") as f: return tomllib.load(f)["tool"]["poetry"]["version"] class Settings(BaseModel): - """Settigns with database dsn.""" + """Settings with database dsn.""" DOMAIN: str @@ -69,7 +74,11 @@ class Settings(BaseModel): @computed_field # type: ignore @cached_property def POSTGRES_URI(self) -> PostgresDsn: # noqa - """Build postgres DSN.""" + """Build postgres DSN. + + Returns: + PostgresDsn: postgres DSN + """ return PostgresDsn( f"{self.POSTGRES_SCHEMA}://" f"{self.POSTGRES_USER}:" @@ -118,8 +127,19 @@ def POSTGRES_URI(self) -> PostgresDsn: # noqa GSSAPI_MAX_OUTPUT_TOKEN_SIZE: int = 1024 @field_validator("TIMEZONE", mode="before") - def create_tz(cls, tz: str) -> ZoneInfo: # noqa: N805 - """Get timezone from a string.""" + @classmethod + def create_tz(cls, tz: str) -> ZoneInfo: + """Get timezone from a string. + + Args: + tz (str): string timezone + + Returns: + ZoneInfo: + + Raises: + ValueError: timezone info not found + """ try: value = ZoneInfo(tz) except ZoneInfoNotFoundError as err: @@ -134,14 +154,19 @@ def create_tz(cls, tz: str) -> ZoneInfo: # noqa: N805 def MFA_API_URI(self) -> str: # noqa: N802 """Multifactor API url. - :return str: url + Returns: + str: url """ if self.MFA_API_SOURCE == "dev": return "https://api.multifactor.dev" return "https://api.multifactor.ru" def get_copy_4_tls(self) -> "Settings": - """Create a copy for TLS bind.""" + """Create a copy for TLS bind. + + Returns: + Settings: + """ from copy import copy tls_settings = copy(self) @@ -150,10 +175,18 @@ def get_copy_4_tls(self) -> "Settings": return tls_settings def check_certs_exist(self) -> bool: - """Check if certs exist.""" + """Check if certs exist. + + Returns: + bool + """ return os.path.exists(self.SSL_CERT) and os.path.exists(self.SSL_KEY) @classmethod def from_os(cls) -> "Settings": - """Get cls from environ.""" + """Get cls from environ. + + Returns: + Settings: + """ return Settings(**os.environ) diff --git a/app/extra/alembic_utils.py b/app/extra/alembic_utils.py index 71e9a151d..801bbad0a 100644 --- a/app/extra/alembic_utils.py +++ b/app/extra/alembic_utils.py @@ -21,8 +21,8 @@ def temporary_stub_entity_type_name(func: Callable) -> Callable: that precede the 'ba78cef9700a_initial_entity_type.py' migration and include working with the Directory. - :param Callable func: any function - :return Callable: any function + Returns: + Callable: any function """ def wrapper(*args, **kwargs): diff --git a/app/extra/dump_acme_certs.py b/app/extra/dump_acme_certs.py index 9c9442b78..d6c23cff8 100644 --- a/app/extra/dump_acme_certs.py +++ b/app/extra/dump_acme_certs.py @@ -17,6 +17,12 @@ def dump_acme_cert(resolver: str = "md-resolver") -> None: acme file can be generated long enough to exit the script, try read until file contents is generated. + + Args: + resolver (str): (Default value = "md-resolver") + + Raises: + SystemExit: If there is an error loading the TLS certificate. """ if os.path.exists("/certs/cert.pem") and os.path.exists( "/certs/privkey.pem" diff --git a/app/extra/scripts/check_ldap_principal.py b/app/extra/scripts/check_ldap_principal.py index cc13bc4dc..e97d7e740 100644 --- a/app/extra/scripts/check_ldap_principal.py +++ b/app/extra/scripts/check_ldap_principal.py @@ -25,9 +25,10 @@ async def check_ldap_principal( ) -> None: """Check ldap principal and keytab existence. - :param AbstractKadmin kadmin: kadmin - :param AsyncSession session: db - :param Settings settings: settings + Args: + kadmin (AbstractKadmin): kadmin + session (AsyncSession): db + settings (Settings): settings """ logger.info("Checking ldap principal and keytab existence.") diff --git a/app/extra/scripts/principal_block_user_sync.py b/app/extra/scripts/principal_block_user_sync.py index e42ce8473..70b1653e9 100644 --- a/app/extra/scripts/principal_block_user_sync.py +++ b/app/extra/scripts/principal_block_user_sync.py @@ -26,7 +26,12 @@ async def principal_block_sync( session: AsyncSession, settings: Settings, ) -> None: - """Synchronize principal and user account blocking.""" + """Synchronize principal and user account blocking. + + Args: + session (AsyncSession): Database session. + settings (Settings): Settings. + """ for user in await session.scalars(select(User)): uac_check = await get_check_uac(session, user.directory_id) if uac_check(UserAccountControlFlag.ACCOUNTDISABLE): @@ -91,9 +96,9 @@ async def principal_block_sync( def _find_krb_exp_attr(directory: Directory) -> Attribute | None: """Find krbprincipalexpiration attribute in directory. - :param Directory directory: the directory object - :return Atrribute | None: the attribute with - the name 'krbprincipalexpiration', or None if not found. + Returns: + Attribute | None: the attribute with the name + 'krbprincipalexpiration', or None if not found. """ for attr in directory.attributes: if attr.name == "krbprincipalexpiration": diff --git a/app/extra/scripts/uac_sync.py b/app/extra/scripts/uac_sync.py index d5ccf8454..9936c4137 100644 --- a/app/extra/scripts/uac_sync.py +++ b/app/extra/scripts/uac_sync.py @@ -22,7 +22,10 @@ async def disable_accounts( ) -> None: """Update userAccountControl attr. - :param AsyncSession session: db + Args: + session (AsyncSession): Database session. + kadmin (AbstractKadmin): Kadmin interface for locking principals. + settings (Settings): Application settings. Original query: update "Attributes" a diff --git a/app/extra/scripts/update_krb5_config.py b/app/extra/scripts/update_krb5_config.py index 45c4f570d..85db92691 100644 --- a/app/extra/scripts/update_krb5_config.py +++ b/app/extra/scripts/update_krb5_config.py @@ -17,7 +17,13 @@ async def update_krb5_config( session: AsyncSession, settings: Settings, ) -> None: - """Update kerberos config.""" + """Update kerberos config. + + Args: + kadmin (AbstractKadmin): Kerberos client. + session (AsyncSession): Database session. + settings (Settings): Settings. + """ if not (await kadmin.get_status(wait_for_positive=True)): logger.error("kadmin_api is not running") return diff --git a/app/extra/setup_dev.py b/app/extra/setup_dev.py index 490470d48..d97c0d51f 100644 --- a/app/extra/setup_dev.py +++ b/app/extra/setup_dev.py @@ -9,7 +9,7 @@ CN=User 4 OU="2FA" CN=Service Accounts - CN=User 5 + CN=User 5 Copyright (c) 2024 MultiFactor License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE @@ -54,7 +54,14 @@ async def _create_dir( domain: Directory, parent: Directory | None = None, ) -> None: - """Create data recursively.""" + """Create data recursively. + + Args: + data (dict): data + session (AsyncSession): Database session + domain (Directory): domain + parent (Directory | None): parent + """ dir_ = Directory( object_class=data["object_class"], name=data["name"], @@ -163,7 +170,16 @@ async def setup_enviroment( data: list, dn: str = "multifactor.dev", ) -> None: - """Create directories and users for enviroment.""" + """Create directories and users for enviroment. + + Args: + session (AsyncSession): Database async session + data (list): data + dn (str): domain name (Default value = 'multifactor.dev') + + Raises: + Exception: Failed to setup environment + """ cat_result = await session.execute(select(Directory)) if cat_result.scalar_one_or_none(): logger.warning("dev data already set up") @@ -232,4 +248,4 @@ async def setup_enviroment( import traceback logger.error(traceback.format_exc()) - raise + raise Exception("Failed to setup environment") diff --git a/app/ioc.py b/app/ioc.py index ea7f598a5..c2733eaad 100644 --- a/app/ioc.py +++ b/app/ioc.py @@ -54,7 +54,11 @@ class MainProvider(Provider): @provide(scope=Scope.APP) def get_engine(self, settings: Settings) -> AsyncEngine: - """Get async engine.""" + """Get async engine. + + Returns: + AsyncEngine: + """ return create_async_engine( str(settings.POSTGRES_URI), pool_size=settings.INSTANCE_DB_POOL_SIZE, @@ -71,7 +75,11 @@ def get_session_factory( self, engine: AsyncEngine, ) -> async_sessionmaker[AsyncSession]: - """Create session factory.""" + """Create session factory. + + Returns: + async_sessionmaker[AsyncSession]: + """ return async_sessionmaker(engine, expire_on_commit=False) @provide(scope=Scope.REQUEST) @@ -79,7 +87,11 @@ async def create_session( self, async_session: async_sessionmaker[AsyncSession], ) -> AsyncIterator[AsyncSession]: - """Create session for request.""" + """Create session for request. + + Yields: + AsyncIterator[AsyncSession] + """ async with async_session() as session: yield session await session.commit() @@ -89,7 +101,11 @@ async def get_krb_class( self, session_maker: async_sessionmaker[AsyncSession], ) -> type[AbstractKadmin]: - """Get kerberos type.""" + """Get kerberos type. + + Returns: + type[AbstractKadmin]: kerberos class + """ async with session_maker() as session: return await get_kerberos_class(session) @@ -100,10 +116,11 @@ async def get_kadmin_http( ) -> AsyncIterator[KadminHTTPClient]: """Get kadmin class, inherits from AbstractKadmin. - :param Settings settings: app settings - :param AsyncSessionMaker session_maker: session maker - :return AsyncIterator[AbstractKadmin]: kadmin with client - :yield Iterator[AsyncIterator[AbstractKadmin]]: kadmin + Args: + settings (Settings): app settings + + Yields: + AsyncIterator[AbstractKadmin]: kadmin with client """ limits = httpx.Limits( max_connections=settings.KRB5_SERVER_MAX_CONN, @@ -125,10 +142,12 @@ async def get_kadmin( ) -> AbstractKadmin: """Get kadmin class, inherits from AbstractKadmin. - :param Settings settings: app settings - :param AsyncSessionMaker session_maker: session maker - :return AsyncIterator[AbstractKadmin]: kadmin with client - :yield Iterator[AsyncIterator[AbstractKadmin]]: kadmin + Args: + client (KadminHTTPClient): app settings + kadmin_class (type[AbstractKadmin]): session maker + + Returns: + AbstractKadmin: kadmin with client """ return kadmin_class(client) @@ -137,7 +156,11 @@ async def get_dns_mngr_class( self, session_maker: async_sessionmaker[AsyncSession], ) -> type[AbstractDNSManager]: - """Get DNS manager type.""" + """Get DNS manager type. + + Returns: + type[AbstractDNSManager]: DNS manager class + """ async with session_maker() as session: return await get_dns_manager_class(session) @@ -147,7 +170,15 @@ async def get_dns_mngr_settings( session_maker: async_sessionmaker[AsyncSession], settings: Settings, ) -> DNSManagerSettings: - """Get DNS manager's settings.""" + """Get DNS manager's settings. + + Args: + session_maker (async_sessionmaker[AsyncSession]): session maker + settings (Settings): app settings + + Returns: + DNSManagerSettings: DNS manager settings + """ resolve_coro = resolve_dns_server_ip(settings.DNS_BIND_HOST) async with session_maker() as session: return await get_dns_manager_settings(session, resolve_coro) @@ -157,7 +188,11 @@ async def get_dns_http_client( self, settings: Settings, ) -> AsyncIterator[DNSManagerHTTPClient]: - """Get async client for DNS manager.""" + """Get async client for DNS manager. + + Yields: + AsyncIterator[DNSManagerHTTPClient] + """ async with httpx.AsyncClient( base_url=f"http://{settings.DNS_BIND_HOST}:8000", ) as client: @@ -170,7 +205,16 @@ async def get_dns_mngr( dns_manager_class: type[AbstractDNSManager], http_client: DNSManagerHTTPClient, ) -> AsyncIterator[AbstractDNSManager]: - """Get DNSManager class.""" + """Get DNSManager class. + + Args: + settings (DNSManagerSettings): DNS Manager settings + dns_manager_class (type[AbstractDNSManager]): manager class + http_client (DNSManagerHTTPClient): HTTP client for DNS manager + + Yields: + AsyncIterator[AbstractDNSManager] + """ yield dns_manager_class(settings=settings, http_client=http_client) @provide(scope=Scope.REQUEST) @@ -178,7 +222,11 @@ async def get_entity_type_dao( self, session: AsyncSession, ) -> EntityTypeDAO: - """Get Entity Type DAO.""" + """Get Entity Type DAO. + + Returns: + EntityTypeDAO: Entity Type DAO + """ return EntityTypeDAO(session) @provide(scope=Scope.APP) @@ -186,7 +234,17 @@ async def get_redis_for_sessions( self, settings: Settings, ) -> AsyncIterator[SessionStorageClient]: - """Get redis connection.""" + """Get redis connection. + + Args: + settings: Settings with database dsn. + + Yields: + AsyncIterator[SessionStorageClient] + + Raises: + SystemError: Redis is not available + """ client = redis.Redis.from_url(str(settings.SESSION_STORAGE_URL)) if not await client.ping(): @@ -201,7 +259,15 @@ async def get_session_storage( client: SessionStorageClient, settings: Settings, ) -> SessionStorage: - """Get session storage.""" + """Get session storage. + + Args: + client (SessionStorageClient): session storage client + settings (Settings): app settings + + Returns: + SessionStorage: session storage + """ return RedisSessionStorage( client, settings.SESSION_KEY_LENGTH, @@ -216,7 +282,11 @@ class HTTPProvider(Provider): @provide(provides=LDAPSession) async def get_session(self) -> LDAPSession: - """Create ldap session.""" + """Create ldap session. + + Returns: + LDAPSession: ldap session + """ return LDAPSession() @provide(provides=AttributeTypeDAO) @@ -224,7 +294,11 @@ def get_attribute_type_dao( self, session: AsyncSession, ) -> AttributeTypeDAO: - """Get Attribute Type DAO.""" + """Get Attribute Type DAO. + + Returns: + AttributeTypeDAO: Attribute Type DAO. + """ return AttributeTypeDAO(session) @provide(provides=ObjectClassDAO) @@ -232,7 +306,11 @@ def get_object_class_dao( self, session: AsyncSession, ) -> ObjectClassDAO: - """Get Object Class DAO.""" + """Get Object Class DAO. + + Returns: + ObjectClassDAO: Object Class DAO. + """ attribute_type_dao = AttributeTypeDAO(session) return ObjectClassDAO( attribute_type_dao=attribute_type_dao, @@ -244,7 +322,11 @@ def get_entity_type_dao( self, session: AsyncSession, ) -> EntityTypeDAO: - """Get Entity Type DAO.""" + """Get Entity Type DAO. + + Returns: + EntityTypeDAO: Entity Type DAO. + """ return EntityTypeDAO(session) @@ -255,7 +337,11 @@ class LDAPServerProvider(Provider): @provide(scope=Scope.SESSION, provides=LDAPSession) async def get_session(self, storage: SessionStorage) -> LDAPSession: - """Create ldap session.""" + """Create ldap session. + + Returns: + LDAPSession: ldap session + """ return LDAPSession(storage=storage) @@ -268,8 +354,8 @@ class MFACredsProvider(Provider): async def get_auth(self, session: AsyncSession) -> Creds | None: """Admin creds get. - :param Annotated[AsyncSession, Depends session: session - :return MFA_HTTP_Creds: optional creds + Returns: + MFA_HTTP_Creds: optional creds """ return await get_creds(session, "mfa_key", "mfa_secret") @@ -277,8 +363,8 @@ async def get_auth(self, session: AsyncSession) -> Creds | None: async def get_auth_ldap(self, session: AsyncSession) -> Creds | None: """Admin creds get. - :param AsyncSession session: db - :return MFA_LDAP_Creds: optional creds + Returns: + MFA_LDAP_Creds: optional creds """ return await get_creds(session, "mfa_key_ldap", "mfa_secret_ldap") @@ -293,7 +379,11 @@ async def get_client( self, settings: Settings, ) -> AsyncIterator[MFAHTTPClient]: - """Get async client for DI.""" + """Get async client for DI. + + Yields: + AsyncIterator[MFAHTTPClient]. + """ async with httpx.AsyncClient( timeout=settings.MFA_CONNECT_TIMEOUT_SECONDS, limits=httpx.Limits( @@ -312,9 +402,13 @@ async def get_http_mfa( ) -> MultifactorAPI | None: """Get api from DI. - :param httpx.AsyncClient client: httpx client - :param Creds credentials: creds - :return MultifactorAPI: mfa integration + Args: + credentials (MFA_HTTP_Creds): http creds + client (MFAHTTPClient): https client + settings (Settings): settings + + Returns: + MultifactorAPI: mfa integration """ if not credentials or not credentials.key or not credentials.secret: return None @@ -334,9 +428,13 @@ async def get_ldap_mfa( ) -> LDAPMultiFactorAPI | None: """Get api from DI. - :param httpx.AsyncClient client: httpx client - :param Creds credentials: creds - :return MultifactorAPI: mfa integration + Args: + credentials (MFA_LDAP_Creds): ldap creds + client (MFAHTTPClient): https client + settings (Settings): settings + + Returns: + MultifactorAPI: mfa integration """ if not credentials or not credentials.key or not credentials.secret: return None diff --git a/app/ldap_protocol/asn1parser.py b/app/ldap_protocol/asn1parser.py index b1a40bd7e..44e497f4c 100644 --- a/app/ldap_protocol/asn1parser.py +++ b/app/ldap_protocol/asn1parser.py @@ -7,7 +7,7 @@ from contextlib import suppress from dataclasses import dataclass from enum import IntEnum -from typing import Annotated, Generic, TypeVar +from typing import Annotated, TypeVar from asn1 import Classes, Decoder, Encoder, Numbers, Tag, Types from pydantic import AfterValidator @@ -65,7 +65,7 @@ class SubstringTag(IntEnum): @dataclass -class ASN1Row(Generic[T]): +class ASN1Row[T: "ASN1Row | list[ASN1Row] | str | bytes | int | float"]: """Row with metadata.""" class_id: int @@ -74,11 +74,26 @@ class ASN1Row(Generic[T]): @classmethod def from_tag(cls, tag: Tag, value: T) -> "ASN1Row": - """Create row from tag.""" + """Create row from tag. + + Args: + tag (Tag): instance of Tag + value (T): any value + + Returns: + ASN1Row + """ return cls(tag.cls, tag.nr, value) def _handle_extensible_match(self) -> str: - """Handle extensible match filters.""" + """Handle extensible match filters. + + Returns: + str: match + + Raises: + TypeError: If value isnt a list + """ oid = attribute = value = None dn_attributes = False @@ -121,7 +136,14 @@ def _handle_extensible_match(self) -> str: return f"({match})" def _handle_substring(self) -> str: - """Process and format substring operations for LDAP.""" + """Process and format substring operations for LDAP. + + Returns: + str: substring + + Raises: + ValueError: Invalid tag_id + """ value = ( self.value.decode(errors="replace") if isinstance(self.value, bytes) @@ -145,6 +167,16 @@ def serialize(self, obj: "ASN1Row | T | None" = None) -> str: # noqa: C901 Recursively processes ASN.1 structures to construct a valid LDAP filter string based on LDAP operations such as AND, OR, and substring matches. + + Args: + obj (ASN1Row | T | None): (Default value = None) + + Returns: + str: result string + + Raises: + ValueError: Invalid tag_id + TypeError: cant serialize """ if obj is None: obj = self @@ -218,7 +250,7 @@ def serialize(self, obj: "ASN1Row | T | None" = None) -> str: # noqa: C901 elif isinstance(obj, str): return obj - elif isinstance(obj, int) or isinstance(obj, float): + elif isinstance(obj, int | float): return str(obj) else: @@ -229,6 +261,9 @@ def to_ldap_filter(self) -> str: The method recursively serializes ASN.1 rows into the LDAP filter format based on tag IDs and class IDs. + + Returns: + str: """ return self.serialize() @@ -237,7 +272,15 @@ def value_to_string( tag: Tag, value: str | bytes | int | bool, ) -> bytes | str | int: - """Convert value to string.""" + """Convert value to string. + + Args: + tag (Tag): instance of Tag + value (str | bytes | int | bool): value + + Returns: + bytes | str | int: + """ if tag.nr == Numbers.Integer: with suppress(ValueError): return int(value) @@ -255,7 +298,11 @@ def value_to_string( def asn1todict(decoder: Decoder) -> list[ASN1Row]: - """Recursively collect ASN.1 data to list of ASNRows.""" + """Recursively collect ASN.1 data to list of ASNRows. + + Returns: + list[ASN1Row]: + """ out = [] while not decoder.eof(): tag = decoder.peek() @@ -277,7 +324,14 @@ def asn1todict(decoder: Decoder) -> list[ASN1Row]: def _validate_oid(oid: str) -> str: - """Validate ldap oid with regex.""" + """Validate ldap oid with regex. + + Returns: + str: + + Raises: + ValueError: Invalid LDAPOID + """ if not Encoder._re_oid.match(oid): raise ValueError("Invalid LDAPOID") return oid diff --git a/app/ldap_protocol/dependency.py b/app/ldap_protocol/dependency.py index 8ec0e8448..8f55b905b 100644 --- a/app/ldap_protocol/dependency.py +++ b/app/ldap_protocol/dependency.py @@ -12,12 +12,15 @@ T = TypeVar("T", bound=Callable) -async def resolve_deps(func: T, container: AsyncContainer) -> T: +async def resolve_deps[T: Callable](func: T, container: AsyncContainer) -> T: """Provide async dependencies. - :param T func: Awaitable - :param AsyncContainer container: IoC container - :return T: Awaitable + Args: + func (T): Awaitable + container (AsyncContainer): IoC container + + Returns: + T: Awaitable """ hints = get_type_hints(func) del hints["return"] diff --git a/app/ldap_protocol/dialogue.py b/app/ldap_protocol/dialogue.py index fa2e5c818..d95bb7e52 100644 --- a/app/ldap_protocol/dialogue.py +++ b/app/ldap_protocol/dialogue.py @@ -50,7 +50,15 @@ async def from_db( user: User, session_id: str, ) -> UserSchema: - """Create model from db model.""" + """Create model from db model. + + Args: + user (User): instance of User + session_id (str): session id + + Returns: + UserSchema: instance of UserSchema + """ return cls( id=user.id, session_id=session_id.split(".")[0], @@ -85,7 +93,12 @@ def __init__( user: UserSchema | None = None, storage: SessionStorage | None = None, ) -> None: - """Set lock.""" + """Set lock. + + Args: + user (UserSchema | None): instance of UserSchema + storage (SessionStorage | None): instance of SessionStorage + """ self._lock = asyncio.Lock() self._user: UserSchema | None = user self.queue: asyncio.Queue[LDAPRequestMessage] = asyncio.Queue() @@ -93,16 +106,29 @@ def __init__( self.storage = storage def __str__(self) -> str: - """Session with id.""" + """Session with id. + + Returns: + str: session with id + """ return f"LDAPSession({self.id})" @property def user(self) -> UserSchema | None: - """User getter, not implemented.""" + """User getter, not implemented. + + Returns: + UserSchema | None: instance of UserSchema + """ return self._user @user.setter def user(self, user: User) -> None: + """User setter. + + Raises: + NotImplementedError: Cannot manually set user + """ raise NotImplementedError( "Cannot manually set user, use `set_user()` instead", ) @@ -124,13 +150,21 @@ async def delete_user(self) -> None: self._user = None async def get_user(self) -> UserSchema | None: - """Get user from session concurrently save.""" + """Get user from session concurrently save. + + Returns: + UserSchema | None: instance of UserSchema + """ async with self._lock: return self._user @asynccontextmanager async def lock(self) -> AsyncIterator[UserSchema | None]: - """Lock session, user cannot be deleted or get while lock is set.""" + """Lock session, user cannot be deleted or get while lock is set. + + Yields: + AsyncIterator[UserSchema | None]: instance of UserSchema + """ async with self._lock: yield self._user @@ -147,7 +181,15 @@ async def validate_conn( ip: IPv4Address | IPv6Address, session: AsyncSession, ) -> None: - """Validate network policies.""" + """Validate network policies. + + Args: + ip (IPv4Address | IPv6Address): IP + session (AsyncSession): async session + + Raises: + PermissionError: NetworkPolicy is None + """ policy = await self._get_policy(ip, session) # type: ignore if policy is not None: self.policy = policy @@ -158,10 +200,19 @@ async def validate_conn( @property def key(self) -> str: - """Get key.""" + """Get key. + + Returns: + str: key + """ return f"ldap:{self.id}" def _bound_ip(self) -> bool: + """Check if ip is bound. + + Returns: + bool: True if ip is bound, False otherwise + """ return hasattr(self, "ip") async def bind_session(self) -> None: @@ -186,6 +237,10 @@ async def ensure_session_exists(self) -> NoReturn: """Ensure session exists in storage. Does nothing if anonymous, wait 30s and if user bound, check it. + + Raises: + AttributeError: Storage is not set + ConnectionAbortedError: Session missing in storage """ if self.storage is None: raise AttributeError("Storage is not set") diff --git a/app/ldap_protocol/dns/__init__.py b/app/ldap_protocol/dns/__init__.py index 92683ab38..7668e490d 100644 --- a/app/ldap_protocol/dns/__init__.py +++ b/app/ldap_protocol/dns/__init__.py @@ -1,3 +1,9 @@ +"""DNS API module. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + from sqlalchemy.ext.asyncio import AsyncSession from .base import ( @@ -33,7 +39,11 @@ async def get_dns_manager_class( session: AsyncSession, ) -> type[AbstractDNSManager]: - """Get DNS manager class.""" + """Get DNS manager class. + + Returns: + AbstractDNSManager: Class of the DNS manager based on the current DNS. + """ dns_state = await get_dns_state(session) if dns_state == DNSManagerState.SELFHOSTED: return SelfHostedDNSManager @@ -43,28 +53,28 @@ async def get_dns_manager_class( __all__ = [ - "get_dns_manager_class", + "DNS_MANAGER_IP_ADDRESS_NAME", + "DNS_MANAGER_STATE_NAME", + "DNS_MANAGER_ZONE_NAME", "AbstractDNSManager", - "RemoteDNSManager", - "SelfHostedDNSManager", - "StubDNSManager", - "get_dns_state", - "set_dns_manager_state", - "get_dns_manager_settings", - "resolve_dns_server_ip", + "DNSConnectionError", "DNSForwardServerStatus", "DNSForwardZone", "DNSManagerSettings", + "DNSNotImplementedError", "DNSRecords", "DNSServerParam", + "DNSServerParamName", "DNSZone", "DNSZoneParam", - "DNSZoneType", - "DNSServerParamName", "DNSZoneParamName", - "DNSConnectionError", - "DNS_MANAGER_IP_ADDRESS_NAME", - "DNS_MANAGER_ZONE_NAME", - "DNS_MANAGER_STATE_NAME", - "DNSNotImplementedError", + "DNSZoneType", + "RemoteDNSManager", + "SelfHostedDNSManager", + "StubDNSManager", + "get_dns_manager_class", + "get_dns_manager_settings", + "get_dns_state", + "resolve_dns_server_ip", + "set_dns_manager_state", ] diff --git a/app/ldap_protocol/dns/base.py b/app/ldap_protocol/dns/base.py index d3106c99e..cd49469b5 100644 --- a/app/ldap_protocol/dns/base.py +++ b/app/ldap_protocol/dns/base.py @@ -221,22 +221,18 @@ async def setup( await session.execute( update(CatalogueSetting) .where(CatalogueSetting.name.in_(new_settings.keys())) - .values( - { - "value": case( - *settings, - else_=CatalogueSetting.value, - ) - } - ) + .values({ + "value": case( + *settings, + else_=CatalogueSetting.value, + ) + }) ) else: - session.add_all( - [ - CatalogueSetting(name=name, value=value) - for name, value in new_settings.items() - ] - ) + session.add_all([ + CatalogueSetting(name=name, value=value) + for name, value in new_settings.items() + ]) @abstractmethod async def create_record( @@ -246,7 +242,8 @@ async def create_record( record_type: str, ttl: int | None, zone_name: str | None = None, - ) -> None: ... + ) -> None: + """Create DNS record.""" @abstractmethod async def update_record( @@ -256,7 +253,8 @@ async def update_record( record_type: str, ttl: int | None, zone_name: str | None = None, - ) -> None: ... + ) -> None: + """Update DNS record.""" @abstractmethod async def delete_record( @@ -265,17 +263,43 @@ async def delete_record( ip: str, record_type: str, zone_name: str | None = None, - ) -> None: ... + ) -> None: + """Delete DNS record.""" @abstractmethod - async def get_all_records(self) -> list[DNSRecords]: ... + async def get_all_records(self) -> list[DNSRecords]: + """Get all DNS records of all zones. + + Raises: + DNSNotImplementedError: If the method is not implemented. + + Returns: + list[DNSRecords]: List of DNSRecords objects with records. + """ + raise DNSNotImplementedError @abstractmethod async def get_all_zones_records(self) -> list[DNSZone]: + """Get all DNS records grouped by zone. + + Raises: + DNSNotImplementedError: If the method is not implemented. + + Returns: + list[DNSZone]: List of DNSZone objects with records. + """ raise DNSNotImplementedError @abstractmethod async def get_forward_zones(self) -> list[DNSForwardZone]: + """Get all forward zones. + + Raises: + DNSNotImplementedError: If the method is not implemented. + + Returns: + list[DNSForwardZone]: List of DNSForwardZone objects. + """ raise DNSNotImplementedError @abstractmethod @@ -286,6 +310,17 @@ async def create_zone( nameserver: str | None, params: list[DNSZoneParam], ) -> None: + """Create DNS zone. + + Args: + zone_name (str): Name of the zone. + zone_type (DNSZoneType): Type of the zone (master or forward). + nameserver (str | None): Nameserver for the zone, if applicable. + params (list[DNSZoneParam]): List of parameters for the zone. + + Raises: + DNSNotImplementedError: If the method is not implemented. + """ raise DNSNotImplementedError @abstractmethod @@ -294,13 +329,27 @@ async def update_zone( zone_name: str, params: list[DNSZoneParam] | None, ) -> None: + """Update DNS zone. + + Args: + zone_name (str): Name of the zone to update. + params (list[DNSZoneParam] | None): List of parameters to update. + + Raises: + DNSNotImplementedError: If the method is not implemented. + """ raise DNSNotImplementedError @abstractmethod - async def delete_zone( - self, - zone_names: list[str], - ) -> None: + async def delete_zone(self, zone_names: list[str]) -> None: + """Delete DNS zone. + + Args: + zone_names (list[str]): List of zone names to delete. + + Raises: + DNSNotImplementedError: If the method is not implemented. + """ raise DNSNotImplementedError @abstractmethod @@ -308,6 +357,17 @@ async def check_forward_dns_server( self, dns_server_ip: IPv4Address | IPv6Address, ) -> DNSForwardServerStatus: + """Check if the given DNS server is reachable and valid. + + Args: + dns_server_ip (IPv4Address | IPv6Address): IP address of DNS server + + Returns: + DNSForwardServerStatus: Status of the DNS server. + + Raises: + DNSNotImplementedError: If the method is not implemented. + """ raise DNSNotImplementedError @abstractmethod @@ -315,20 +375,45 @@ async def update_server_options( self, params: list[DNSServerParam], ) -> None: + """Update DNS server options. + + Args: + params (list[DNSServerParam]): List of server parameters to update. + + Raises: + DNSNotImplementedError: If the method is not implemented. + """ raise DNSNotImplementedError @abstractmethod - async def get_server_options(self) -> list[DNSServerParam]: ... + async def get_server_options(self) -> list[DNSServerParam]: + """Get list of modifiable DNS server params. + + Raises: + DNSNotImplementedError: If the method is not implemented. + + Returns: + list[DNSServerParam]: List of DNSServerParam objects. + """ + raise DNSNotImplementedError @abstractmethod - async def restart_server( - self, - ) -> None: + async def restart_server(self) -> None: + """Restart DNS server. + + Raises: + DNSNotImplementedError: If the method is not implemented. + """ raise DNSNotImplementedError @abstractmethod - async def reload_zone( - self, - zone_name: str, - ) -> None: + async def reload_zone(self, zone_name: str) -> None: + """Reload DNS zone. + + Args: + zone_name (str): Name of the zone to reload. + + Raises: + DNSNotImplementedError: If the method is not implemented. + """ raise DNSNotImplementedError diff --git a/app/ldap_protocol/dns/remote.py b/app/ldap_protocol/dns/remote.py index 7f1e349bd..75a1204bc 100644 --- a/app/ldap_protocol/dns/remote.py +++ b/app/ldap_protocol/dns/remote.py @@ -23,7 +23,14 @@ class RemoteDNSManager(AbstractDNSManager): """DNS server manager.""" async def _send(self, action: Message) -> None: - """Send request to DNS server.""" + """Send request to DNS server. + + Args: + action (Message): DNS action to perform. + + Raises: + DNSConnectionError: If the DNS server IP is not set. + """ if self._dns_settings.tsig_key is not None: action.use_tsig( keyring=TsigKey("zone.", self._dns_settings.tsig_key), @@ -52,7 +59,14 @@ async def create_record( @logger_wraps() async def get_all_records(self) -> list[DNSRecords]: - """Get all DNS records.""" + """Get all DNS records. + + Returns: + list[DNSRecords]: List of DNS records grouped by type. + + Raises: + DNSConnectionError: If the DNS server IP or zone name is not set. + """ if ( self._dns_settings.dns_server_ip is None or self._dns_settings.zone_name is None diff --git a/app/ldap_protocol/dns/selfhosted.py b/app/ldap_protocol/dns/selfhosted.py index c29e07d5a..c83d7735a 100644 --- a/app/ldap_protocol/dns/selfhosted.py +++ b/app/ldap_protocol/dns/selfhosted.py @@ -56,6 +56,7 @@ async def update_record( ttl: int | None, zone_name: str | None = None, ) -> None: + """Update DNS record.""" await self._http_client.patch( "/record", json={ @@ -75,6 +76,7 @@ async def delete_record( record_type: str, zone_name: str | None = None, ) -> None: + """Delete DNS record.""" await self._http_client.request( "delete", "/record", @@ -88,6 +90,11 @@ async def delete_record( @logger_wraps() async def get_all_records(self) -> list[DNSRecords]: + """Get all DNS records. + + Returns: + list[DNSRecords]: List of DNS records grouped by type. + """ response = await self._http_client.get("/zone") response_data = response.json() @@ -103,12 +110,22 @@ async def get_all_records(self) -> list[DNSRecords]: @logger_wraps() async def get_all_zones_records(self) -> list[DNSZone]: + """Get all DNS zones with their records. + + Returns: + list[DNSZone]: List of DNS zones with their records. + """ response = await self._http_client.get("/zone") return response.json() @logger_wraps() async def get_forward_zones(self) -> list[DNSForwardZone]: + """Get all forward zones. + + Returns: + list[DNSForwardZone]: List of forward zones. + """ response = await self._http_client.get("/zone/forward") return response.json() @@ -121,6 +138,7 @@ async def create_zone( nameserver: str | None, params: list[DNSZoneParam], ) -> None: + """Create DNS zone.""" await self._http_client.post( "/zone", json={ @@ -137,6 +155,7 @@ async def update_zone( zone_name: str, params: list[DNSZoneParam], ) -> None: + """Update DNS zone.""" await self._http_client.patch( "/zone", json={ @@ -146,10 +165,8 @@ async def update_zone( ) @logger_wraps() - async def delete_zone( - self, - zone_names: list[str], - ) -> None: + async def delete_zone(self, zone_names: list[str]) -> None: + """Delete DNS zone.""" for zone_name in zone_names: await self._http_client.request( "delete", @@ -162,6 +179,11 @@ async def check_forward_dns_server( self, dns_server_ip: IPv4Address | IPv6Address, ) -> DNSForwardServerStatus: + """Check if the forward DNS server is reachable and return its FQDN. + + Returns: + DNSForwardServerStatus: Status of the forward DNS server. + """ str_dns_server_ip = str(dns_server_ip) try: hostname, _, _ = socket.gethostbyaddr(str_dns_server_ip) @@ -183,6 +205,7 @@ async def update_server_options( self, params: list[DNSServerParam], ) -> None: + """Update DNS server options.""" await self._http_client.patch( "/server/settings", json=[asdict(param) for param in params], @@ -190,19 +213,21 @@ async def update_server_options( @logger_wraps() async def get_server_options(self) -> list[DNSServerParam]: + """Get list of modifiable DNS server params. + + Returns: + list[DNSServerParam]: List of DNSServerParam objects. + """ response = await self._http_client.get("/server/settings") return response.json() @logger_wraps() - async def restart_server( - self, - ) -> None: + async def restart_server(self) -> None: + """Restart DNS server.""" await self._http_client.get("/server/restart") @logger_wraps() - async def reload_zone( - self, - zone_name: str, - ) -> None: + async def reload_zone(self, zone_name: str) -> None: + """Reload DNS zone.""" await self._http_client.get(f"/zone/{zone_name}") diff --git a/app/ldap_protocol/dns/stub.py b/app/ldap_protocol/dns/stub.py index 836a98a62..e66e91f80 100644 --- a/app/ldap_protocol/dns/stub.py +++ b/app/ldap_protocol/dns/stub.py @@ -26,7 +26,8 @@ async def create_record( record_type: str, ttl: int | None, zone_name: str | None = None, - ) -> None: ... + ) -> None: + """Stub DNS manager create record.""" @logger_wraps(is_stub=True) async def update_record( @@ -36,7 +37,8 @@ async def update_record( record_type: str, ttl: int, zone_name: str | None = None, - ) -> None: ... + ) -> None: + """Stub DNS manager update record.""" @logger_wraps(is_stub=True) async def delete_record( @@ -45,13 +47,20 @@ async def delete_record( ip: str, record_type: str, zone_name: str | None = None, - ) -> None: ... + ) -> None: + """Stub DNS manager delete record.""" @logger_wraps(is_stub=True) - async def get_all_zones_records(self) -> None: ... + async def get_all_zones_records(self) -> None: + """Stub DNS manager get all zones records.""" @logger_wraps(is_stub=True) async def get_forward_zones(self) -> list[DNSForwardZone]: + """Stub DNS manager get forward zones. + + Returns: + list[DNSForwardZone]: List of DNSForwardZone objects. + """ return [] @logger_wraps(is_stub=True) @@ -61,49 +70,54 @@ async def create_zone( zone_type: DNSZoneType, nameserver: str | None, params: list[DNSZoneParam], - ) -> None: ... + ) -> None: + """Stub DNS manager create zone.""" @logger_wraps(is_stub=True) async def update_zone( self, zone_name: str, params: list[DNSZoneParam] | None, - ) -> None: ... + ) -> None: + """Stub DNS manager update zone.""" @logger_wraps(is_stub=True) - async def delete_zone( - self, - zone_names: list[str], - ) -> None: ... + async def delete_zone(self, zone_names: list[str]) -> None: + """Stub DNS manager delete zone.""" @logger_wraps(is_stub=True) - async def check_forward_dns_server( - self, - dns_server_ip: str, - ) -> None: ... + async def check_forward_dns_server(self, dns_server_ip: str) -> None: + """Stub DNS manager check forward DNS server.""" @logger_wraps(is_stub=True) async def update_server_options( self, params: list[DNSServerParam], - ) -> None: ... + ) -> None: + """Stub DNS manager update server options.""" @logger_wraps(is_stub=True) async def get_server_options(self) -> list[DNSServerParam]: + """Stub DNS manager get server options. + + Returns: + list[DNSServerParam]: List of DNSServerParam objects. + """ return [] @logger_wraps(is_stub=True) - async def restart_server( - self, - ) -> None: ... + async def restart_server(self) -> None: + """Stub DNS manager restart server.""" @logger_wraps(is_stub=True) - async def reload_zone( - self, - zone_name: str, - ) -> None: ... + async def reload_zone(self, zone_name: str) -> None: + """Stub DNS manager reload zone.""" @logger_wraps(is_stub=True) async def get_all_records(self) -> list[DNSRecords]: - """Stub DNS manager get all records.""" + """Stub DNS manager get all records. + + Returns: + list[DNSRecords]: List of DNSRecords objects. + """ return [] diff --git a/app/ldap_protocol/dns/utils.py b/app/ldap_protocol/dns/utils.py index 8c579b5a3..1cd581a50 100644 --- a/app/ldap_protocol/dns/utils.py +++ b/app/ldap_protocol/dns/utils.py @@ -26,7 +26,11 @@ def logger_wraps(is_stub: bool = False) -> Callable: - """Log DNSManager calls.""" + """Log DNSManager calls. + + Returns: + Callable: Decorator for logging DNSManager calls. + """ def wrapper(func: Callable) -> Callable: name = func.__name__ @@ -53,10 +57,15 @@ async def wrapped(*args: str, **kwargs: str) -> Any: return wrapper -async def get_dns_state( - session: AsyncSession, -) -> "DNSManagerState": - """Get or create DNS manager state.""" +async def get_dns_state(session: AsyncSession) -> "DNSManagerState": + """Get or create DNS manager state. + + Args: + session (AsyncSession): Database session. + + Returns: + DNSManagerState: Current state of the DNS manager. + """ state = await session.scalar( select(CatalogueSetting) .filter(CatalogueSetting.name == DNS_MANAGER_STATE_NAME) @@ -88,7 +97,14 @@ async def set_dns_manager_state( async def resolve_dns_server_ip(host: str) -> str: - """Get DNS server IP from Docker network.""" + """Get DNS server IP from Docker network. + + Returns: + str: IP address of the DNS server. + + Raises: + DNSConnectionError: If the DNS server IP cannot be resolved. + """ async_resolver = AsyncResolver() dns_server_ip_resolve = await async_resolver.resolve(host) if dns_server_ip_resolve is None or dns_server_ip_resolve.rrset is None: @@ -100,7 +116,15 @@ async def get_dns_manager_settings( session: AsyncSession, resolve_coro: Awaitable[str], ) -> "DNSManagerSettings": - """Get DNS manager's settings.""" + """Get DNS manager's settings. + + Args: + session (AsyncSession): Database session. + resolve_coro (Awaitable[str]): Coroutine to resolve DNS server IP. + + Returns: + DNSManagerSettings: DNS manager settings. + """ settings_dict = {} for setting in await session.scalars( select(CatalogueSetting).filter( diff --git a/app/ldap_protocol/filter_interpreter.py b/app/ldap_protocol/filter_interpreter.py index 07aefe734..12259dea1 100644 --- a/app/ldap_protocol/filter_interpreter.py +++ b/app/ldap_protocol/filter_interpreter.py @@ -44,6 +44,17 @@ def _from_filter( attr: str, right: ASN1Row, ) -> UnaryExpression: + """Get filter from item. + + Args: + model (type): Any Model + item (ASN1Row): Row with metadata + attr (str): Attribute name + right (ASN1Row): Row with metadata + + Returns: + UnaryExpression + """ is_substring = item.tag_id == TagNumbers.SUBSTRING col = getattr(model, attr) @@ -60,7 +71,14 @@ def _from_filter( def _filter_memberof(dn: str) -> UnaryExpression: - """Retrieve query conditions with the memberOF attribute.""" + """Retrieve query conditions with the memberOF attribute. + + Args: + dn (str): any DN, dn syntax + + Returns: + UnaryExpression + """ group_id_subquery = ( select(Group.id) .join(Group.directory) @@ -78,7 +96,14 @@ def _filter_memberof(dn: str) -> UnaryExpression: def _filter_member(dn: str) -> UnaryExpression: - """Retrieve query conditions with the member attribute.""" + """Retrieve query conditions with the member attribute. + + Args: + dn (str): any DN, dn syntax + + Returns: + UnaryExpression + """ user_id_subquery = ( select(User.id) .join(User.directory) @@ -96,14 +121,31 @@ def _filter_member(dn: str) -> UnaryExpression: def _recursive_filter_memberof(dn: str) -> UnaryExpression: - """Retrieve query conditions with the memberOF attribute(recursive).""" + """Retrieve query conditions with the memberOF attribute(recursive). + + Args: + dn (str): any DN, dn syntax + + Returns: + UnaryExpression + """ cte = find_members_recursive_cte(dn) return Directory.id.in_(select(cte.c.directory_id).offset(1)) # type: ignore def _get_filter_function(column: str) -> Callable[..., UnaryExpression]: - """Retrieve the appropriate filter function based on the attribute.""" + """Retrieve the appropriate filter function based on the attribute. + + Args: + column (str): column name + + Returns: + Callable[..., UnaryExpression]: + + Raises: + ValueError: Incorrect attribute specified + """ if len(column.split(":")) == 1: attribute = column oid = "" @@ -127,7 +169,16 @@ def _ldap_filter_by_attribute( attr: ASN1Row, search_value: ASN1Row, ) -> UnaryExpression: - """Retrieve query conditions based on the specified LDAP attribute.""" + """Retrieve query conditions based on the specified LDAP attribute. + + Args: + oid: ASN1Row | None: + attr: ASN1Row: + search_value: ASN1Row: + + Returns: + UnaryExpression + """ if oid is None: attribute = attr.value.lower() else: @@ -139,6 +190,14 @@ def _ldap_filter_by_attribute( def _cast_item(item: ASN1Row) -> UnaryExpression | ColumnElement: + """Cast item to sqlalchemy condition. + + Args: + item (ASN1Row): Row with metadata + + Returns: + UnaryExpression | ColumnElement + """ # present, for e.g. `attibuteName=*`, `(attibuteName)` if item.tag_id == 7: attr = item.value.lower().replace("objectcategory", "objectclass") @@ -182,7 +241,14 @@ def _cast_item(item: ASN1Row) -> UnaryExpression | ColumnElement: def cast_filter2sql(expr: ASN1Row) -> UnaryExpression | ColumnElement: - """Recursively cast Filter to SQLAlchemy conditions.""" + """Recursively cast Filter to SQLAlchemy conditions. + + Args: + expr: ASN1Row: + + Returns: + UnaryExpression | ColumnElement + """ if expr.tag_id in range(3): conditions = [] for item in expr.value: @@ -212,7 +278,14 @@ def _from_str_filter( def _api_filter(item: Filter) -> UnaryExpression: - """Retrieve query conditions based on the specified LDAP attribute.""" + """Retrieve query conditions based on the specified LDAP attribute. + + Args: + item (Filter): LDAP filter + + Returns: + UnaryExpression + """ filter_func = _get_filter_function(item.attr) return filter_func(item.val) @@ -247,7 +320,14 @@ def _cast_filt_item(item: Filter) -> UnaryExpression | ColumnElement: def cast_str_filter2sql(expr: Filter) -> UnaryExpression | ColumnElement: - """Cast ldap filter to sa query.""" + """Cast ldap filter to sa query. + + Args: + expr (Filter): LDAP Base filter + + Returns: + UnaryExpression | ColumnElement: + """ if expr.type == "group": conditions = [] for item in expr.filters: diff --git a/app/ldap_protocol/kerberos/__init__.py b/app/ldap_protocol/kerberos/__init__.py index c685d6c8c..ebb32d4d6 100644 --- a/app/ldap_protocol/kerberos/__init__.py +++ b/app/ldap_protocol/kerberos/__init__.py @@ -1,3 +1,9 @@ +"""Kerberos API module. + +Copyright (c) 2024 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + from sqlalchemy.ext.asyncio import AsyncSession from .base import ( @@ -14,8 +20,8 @@ async def get_kerberos_class(session: AsyncSession) -> type[AbstractKadmin]: """Get kerberos server state. - :param AsyncSession session: db - :return type[KerberosMDAPIClient] | type[StubKadminMDADPIClient]: api + Returns: + type[KerberosMDAPIClient] | type[StubKadminMDADPIClient]: api """ if await get_krb_server_state(session) == KerberosState.READY: return KerberosMDAPIClient @@ -23,13 +29,13 @@ async def get_kerberos_class(session: AsyncSession) -> type[AbstractKadmin]: __all__ = [ - "get_kerberos_class", - "KerberosMDAPIClient", - "StubKadminMDADPIClient", + "KERBEROS_STATE_NAME", "AbstractKadmin", - "KerberosState", "KRBAPIError", - "unlock_principal", - "KERBEROS_STATE_NAME", + "KerberosMDAPIClient", + "KerberosState", + "StubKadminMDADPIClient", + "get_kerberos_class", "set_state", + "unlock_principal", ] diff --git a/app/ldap_protocol/kerberos/base.py b/app/ldap_protocol/kerberos/base.py index c9c2e5e5c..cdb682adc 100644 --- a/app/ldap_protocol/kerberos/base.py +++ b/app/ldap_protocol/kerberos/base.py @@ -37,10 +37,7 @@ class AbstractKadmin(ABC): client: httpx.AsyncClient def __init__(self, client: httpx.AsyncClient) -> None: - """Set client. - - :param httpx.AsyncClient client: httpx - """ + """Set client.""" self.client = client async def setup_configs( @@ -48,7 +45,15 @@ async def setup_configs( krb5_config: str, kdc_config: str, ) -> None: - """Request Setup.""" + """Request Setup. + + Args: + krb5_config (str): config + kdc_config (str): config + + Raises: + KRBAPIError: not correct + """ log.info("Setting up configs") response = await self.client.post( "/setup/configs", @@ -71,7 +76,20 @@ async def setup_stash( admin_password: str, stash_password: str, ) -> None: - """Set up stash.""" + """Set up stash. + + Args: + domain (str): domain + admin_dn (str): admin_dn + services_dn (str): services_dn + krbadmin_dn (str): krbadmin_dn + krbadmin_password (str): krbadmin_password + admin_password (str): admin_password + stash_password (str): stash_password + + Raises: + KRBAPIError: not correct + """ log.info("Setting up stash") response = await self.client.post( "/setup/stash", @@ -99,7 +117,20 @@ async def setup_subtree( admin_password: str, stash_password: str, ) -> None: - """Set up subtree.""" + """Set up subtree. + + Args: + domain (str): domain + admin_dn (str): admin_dn + services_dn (str): services_dn + krbadmin_dn (str): krbadmin_dn + krbadmin_password (str): krbadmin_password + admin_password (str): admin_password + stash_password (str): stash_password. + + Raises: + KRBAPIError: not correct + """ log.info("Setting up subtree") response = await self.client.post( "/setup/subtree", @@ -135,7 +166,20 @@ async def setup( kdc_config: str, ldap_keytab_path: str, ) -> None: - """Request Setup.""" + """Request Setup. + + Args: + domain (str): domain + admin_dn (str): admin_dn + services_dn (str): services_dn + krbadmin_dn (str): krbadmin_dn + krbadmin_password (str): krbadmin_password + admin_password (str): admin_password + stash_password (str): stash_password + krb5_config (str): krb5_config + kdc_config (str): kdc_config + ldap_keytab_path (str): ldap keytab path + """ await self.setup_configs(krb5_config, kdc_config) await self.setup_stash( domain, @@ -164,35 +208,35 @@ async def setup( ) @abstractmethod - async def add_principal( + async def add_principal( # noqa: D102 self, name: str, password: str | None, - timeout: int | float = 1, + timeout: float = 1, ) -> None: ... @abstractmethod - async def get_principal(self, name: str) -> dict: ... + async def get_principal(self, name: str) -> dict: ... # noqa: D102 @abstractmethod - async def del_principal(self, name: str) -> None: ... + async def del_principal(self, name: str) -> None: ... # noqa: D102 @abstractmethod - async def change_principal_password( + async def change_principal_password( # noqa: D102 self, name: str, password: str, ) -> None: ... @abstractmethod - async def create_or_update_principal_pw( + async def create_or_update_principal_pw( # noqa: D102 self, name: str, password: str, ) -> None: ... @abstractmethod - async def rename_princ(self, name: str, new_name: str) -> None: ... + async def rename_princ(self, name: str, new_name: str) -> None: ... # noqa: D102 @backoff.on_exception( backoff.constant, @@ -209,8 +253,15 @@ async def rename_princ(self, name: str, new_name: str) -> None: ... async def get_status(self, wait_for_positive: bool = False) -> bool | None: """Get status of setup. - :param bool wait_for_positive: wait for positive status - :return bool | None: status or None if max tries achieved + Args: + wait_for_positive (bool): wait for positive status\ + (Default value = False) + + Returns: + bool | None: status or None if max tries achieved + + Raises: + ValueError: not status """ response = await self.client.get("/setup/status") status = response.json() @@ -219,19 +270,20 @@ async def get_status(self, wait_for_positive: bool = False) -> bool | None: return status @abstractmethod - async def ktadd(self, names: list[str]) -> httpx.Response: ... + async def ktadd(self, names: list[str]) -> httpx.Response: ... # noqa: D102 @abstractmethod - async def lock_principal(self, name: str) -> None: ... + async def lock_principal(self, name: str) -> None: ... # noqa: D102 @abstractmethod - async def force_princ_pw_change(self, name: str) -> None: ... + async def force_princ_pw_change(self, name: str) -> None: ... # noqa: D102 async def ldap_principal_setup(self, name: str, path: str) -> None: """LDAP principal setup. - :param str ldap_principal_name: ldap principal name - :param str ldap_keytab_path: ldap keytab path + Args: + name (str): ldap principal name + path (str): ldap keytab path """ response = await self.client.get("/principal", params={"name": name}) if response.status_code == 200: diff --git a/app/ldap_protocol/kerberos/client.py b/app/ldap_protocol/kerberos/client.py index d4b3a5b04..b1e099826 100644 --- a/app/ldap_protocol/kerberos/client.py +++ b/app/ldap_protocol/kerberos/client.py @@ -10,7 +10,7 @@ class KerberosMDAPIClient(AbstractKadmin): """KRB server integration.""" @logger_wraps(is_stub=True) - async def setup(*_, **__) -> None: # type: ignore + async def setup(*args, **kwargs) -> None: # type: ignore """Stub method, setup is not needed.""" @logger_wraps() @@ -20,7 +20,16 @@ async def add_principal( password: str | None, timeout: int = 1, ) -> None: - """Add request.""" + """Add principal. + + Args: + name (str): principal name + password (str | None): password + timeout (int): timeout + + Raises: + KRBAPIError: API error + """ response = await self.client.post( "principal", json={"name": name, "password": password}, @@ -32,7 +41,14 @@ async def add_principal( @logger_wraps() async def get_principal(self, name: str) -> dict: - """Get request.""" + """Get principal. + + Returns: + dict + + Raises: + KRBAPIError: API error + """ response = await self.client.get("principal", params={"name": name}) if response.status_code != 200: raise KRBAPIError(response.text) @@ -41,7 +57,11 @@ async def get_principal(self, name: str) -> dict: @logger_wraps() async def del_principal(self, name: str) -> None: - """Delete principal.""" + """Delete principal. + + Raises: + KRBAPIError: API error + """ response = await self.client.delete("principal", params={"name": name}) if response.status_code != 200: raise KRBAPIError(response.text) @@ -52,7 +72,15 @@ async def change_principal_password( name: str, password: str, ) -> None: - """Change password request.""" + """Change principal password. + + Args: + name (str): principal name + password: password + + Raises: + KRBAPIError: API error + """ response = await self.client.patch( "principal", json={"name": name, "password": password}, @@ -66,7 +94,15 @@ async def create_or_update_principal_pw( name: str, password: str, ) -> None: - """Change password request.""" + """Create or update principal password. + + Args: + name (str): principal name + password: password. + + Raises: + KRBAPIError: API error + """ response = await self.client.post( "/principal/create_or_update", json={"name": name, "password": password}, @@ -76,7 +112,15 @@ async def create_or_update_principal_pw( @logger_wraps() async def rename_princ(self, name: str, new_name: str) -> None: - """Rename request.""" + """Rename principal. + + Args: + name (str): current principal name + new_name: (str): new principal name + + Raises: + KRBAPIError: API error + """ response = await self.client.put( "principal", json={"name": name, "new_name": new_name}, @@ -87,8 +131,14 @@ async def rename_princ(self, name: str, new_name: str) -> None: async def ktadd(self, names: list[str]) -> httpx.Response: """Ktadd build request for stream and return response. - :param list[str] names: principals - :return httpx.Response: stream + Args: + names (list[str]): principal names + + Returns: + httpx.Response: stream + + Raises: + KRBAPIError: principal not found """ request = self.client.build_request( "POST", @@ -104,10 +154,10 @@ async def ktadd(self, names: list[str]) -> httpx.Response: @logger_wraps() async def lock_principal(self, name: str) -> None: - """Lock princ. + """Lock principal. - :param str name: upn - :raises KRBAPIError: on error + Raises: + KRBAPIError: API error """ response = await self.client.post( "principal/lock", @@ -120,8 +170,8 @@ async def lock_principal(self, name: str) -> None: async def force_princ_pw_change(self, name: str) -> None: """Force mark password change for principal. - :param str name: pw - :raises KRBAPIError: err + Raises: + KRBAPIError: API error """ response = await self.client.post( "principal/force_reset", diff --git a/app/ldap_protocol/kerberos/stub.py b/app/ldap_protocol/kerberos/stub.py index ac8ab4240..1c093927f 100644 --- a/app/ldap_protocol/kerberos/stub.py +++ b/app/ldap_protocol/kerberos/stub.py @@ -15,7 +15,7 @@ async def setup(self, *args, **kwargs) -> None: # type: ignore await super().setup(*args, **kwargs) @logger_wraps(is_stub=True) - async def add_principal( + async def add_principal( # noqa: D102 self, name: str, password: str | None, @@ -23,34 +23,34 @@ async def add_principal( ) -> None: ... @logger_wraps(is_stub=True) - async def get_principal(self, name: str) -> None: ... + async def get_principal(self, name: str) -> None: ... # noqa: D102 @logger_wraps(is_stub=True) - async def del_principal(self, name: str) -> None: ... + async def del_principal(self, name: str) -> None: ... # noqa: D102 @logger_wraps(is_stub=True) - async def change_principal_password( + async def change_principal_password( # noqa: D102 self, name: str, password: str, ) -> None: ... @logger_wraps(is_stub=True) - async def create_or_update_principal_pw( + async def create_or_update_principal_pw( # noqa: D102 self, name: str, password: str, ) -> None: ... @logger_wraps(is_stub=True) - async def rename_princ(self, name: str, new_name: str) -> None: ... + async def rename_princ(self, name: str, new_name: str) -> None: ... # noqa: D102 @logger_wraps(is_stub=True) - async def ktadd(self, names: list[str]) -> NoReturn: # noqa: ARG002 + async def ktadd(self, names: list[str]) -> NoReturn: # noqa: ARG002 D102 raise KRBAPIError @logger_wraps(is_stub=True) - async def lock_principal(self, name: str) -> None: ... + async def lock_principal(self, name: str) -> None: ... # noqa: D102 @logger_wraps(is_stub=True) - async def force_princ_pw_change(self, name: str) -> None: ... + async def force_princ_pw_change(self, name: str) -> None: ... # noqa: D102 diff --git a/app/ldap_protocol/kerberos/utils.py b/app/ldap_protocol/kerberos/utils.py index f08f4bb02..845aaa5de 100644 --- a/app/ldap_protocol/kerberos/utils.py +++ b/app/ldap_protocol/kerberos/utils.py @@ -1,7 +1,7 @@ """Utils for kadmin.""" from functools import wraps -from typing import Any, Callable +from typing import Callable import httpx from sqlalchemy import delete, select, update @@ -15,16 +15,21 @@ def logger_wraps(is_stub: bool = False) -> Callable: """Log kadmin calls. - :param bool is_stub: flag to change logs, defaults to False - :return Callable: any method + Returns: + Callable: any method """ def wrapper(func: Callable) -> Callable: + """Wrap kadmin calls. + + Returns: + Callable: wrapped function + """ name = func.__name__ bus_type = " stub " if is_stub else " " @wraps(func) - async def wrapped(*args: str, **kwargs: str) -> Any: + async def wrapped(*args: str, **kwargs: str) -> object: logger = log.opt(depth=1) try: principal = args[1] @@ -58,6 +63,10 @@ async def set_state(session: AsyncSession, state: "KerberosState") -> None: This function updates the server state in the database by either adding a new entry, updating an existing entry, or deleting and re-adding the entry if there are multiple entries found. + + Args: + session (AsyncSession): Database session + state (KerberosState): Kerberos server state """ results = await session.execute( select(CatalogueSetting) @@ -77,7 +86,11 @@ async def set_state(session: AsyncSession, state: "KerberosState") -> None: async def get_krb_server_state(session: AsyncSession) -> "KerberosState": - """Get kerberos server state.""" + """Get kerberos server state. + + Returns: + KerberosState: The current kerberos server state. + """ state = await session.scalar( select(CatalogueSetting) .filter(CatalogueSetting.name == KERBEROS_STATE_NAME) @@ -91,8 +104,9 @@ async def get_krb_server_state(session: AsyncSession) -> "KerberosState": async def unlock_principal(name: str, session: AsyncSession) -> None: """Unlock principal. - :param str name: upn - :param AsyncSession session: db + Args: + name (str): upn + session (AsyncSession): db """ subquery = ( select(Directory.id) diff --git a/app/ldap_protocol/ldap_requests/__init__.py b/app/ldap_protocol/ldap_requests/__init__.py index 90ff4cdd8..246e3baf0 100644 --- a/app/ldap_protocol/ldap_requests/__init__.py +++ b/app/ldap_protocol/ldap_requests/__init__.py @@ -32,4 +32,4 @@ } -__all__ = ["protocol_id_map", "BaseRequest"] +__all__ = ["BaseRequest", "protocol_id_map"] diff --git a/app/ldap_protocol/ldap_requests/abandon.py b/app/ldap_protocol/ldap_requests/abandon.py index 9c616f106..ae8d3b025 100644 --- a/app/ldap_protocol/ldap_requests/abandon.py +++ b/app/ldap_protocol/ldap_requests/abandon.py @@ -20,11 +20,19 @@ class AbandonRequest(BaseRequest): @classmethod def from_data(cls, data: dict[str, list[ASN1Row]]) -> "AbandonRequest": # noqa: ARG003 - """Create structure from ASN1Row dataclass list.""" + """Create structure from ASN1Row dataclass list. + + Returns: + AbandonRequest: Instance of AbandonRequest. + """ return cls(message_id=1) async def handle(self) -> AsyncGenerator: - """Handle message with current user.""" + """Handle message with current user. + + Yields: + AsyncGenerator: Async generator. + """ await asyncio.sleep(0) return yield # type: ignore diff --git a/app/ldap_protocol/ldap_requests/add.py b/app/ldap_protocol/ldap_requests/add.py index 12e7afc1d..cae3028f4 100644 --- a/app/ldap_protocol/ldap_requests/add.py +++ b/app/ldap_protocol/ldap_requests/add.py @@ -72,15 +72,29 @@ class AddRequest(BaseRequest): @property def attr_names(self) -> dict[str, list[str | bytes]]: + """Get attribute names. + + Returns: + dict[str, list[str | bytes]]: attribute names + """ return {attr.l_name: attr.vals for attr in self.attributes} @property def attributes_dict(self) -> dict[str, list[str | bytes]]: + """Get attributes dictionary. + + Returns: + dict[str, list[str | bytes]]: attributes dictionary + """ return {attr.type: attr.vals for attr in self.attributes} @classmethod def from_data(cls, data: ASN1Row) -> "AddRequest": - """Deserialize.""" + """Deserialize. + + Returns: + AddRequest + """ entry, attributes = data # type: ignore attributes = [ PartialAttribute( @@ -98,7 +112,20 @@ async def handle( # noqa: C901 kadmin: AbstractKadmin, entity_type_dao: EntityTypeDAO, ) -> AsyncGenerator[AddResponse, None]: - """Add request handler.""" + """Add request handler. + + Args: + session (AsyncSession): Async DB session + ldap_session (LDAPSession): LDAP session + kadmin (AbstractKadmin): Abstract Kerberos Admin + entity_type_dao (EntityTypeDAO): Entity Type DAO. + + Yields: + AsyncGenerator[AddResponse, None] + + Raises: + TypeError: not valid attribute type + """ if not ldap_session.user: yield AddResponse(**INVALID_ACCESS_RESPONSE) return @@ -396,9 +423,13 @@ def from_dict( ) -> "AddRequest": """Create AddRequest from dict. - :param str entry: entry - :param dict[str, list[str]] attributes: dict of attrs - :return AddRequest: instance + Args: + entry (str): entry + attributes (dict[str, list[str]]): attributes + password (str | None): (Default value = None) + + Returns: + AddRequest: instance """ return AddRequest( entry=entry, diff --git a/app/ldap_protocol/ldap_requests/base.py b/app/ldap_protocol/ldap_requests/base.py index f48f37534..0d801d4b0 100644 --- a/app/ldap_protocol/ldap_requests/base.py +++ b/app/ldap_protocol/ldap_requests/base.py @@ -27,8 +27,8 @@ colorize=False, ) -type handler = Callable[..., AsyncGenerator[BaseResponse, None]] -type serializer = Callable[..., "BaseRequest"] +type Handler = Callable[..., AsyncGenerator[BaseResponse, None]] +type Serializer = Callable[..., BaseRequest] if TYPE_CHECKING: @@ -40,6 +40,7 @@ async def _handle_api( self, container: AsyncContainer, ) -> list[BaseResponse] | BaseResponse: ... + else: class _APIProtocol: ... @@ -48,8 +49,8 @@ class _APIProtocol: ... class BaseRequest(ABC, _APIProtocol, BaseModel): """Base request builder.""" - handle: ClassVar[handler] - from_data: ClassVar[serializer] + handle: ClassVar[Handler] + from_data: ClassVar[Serializer] @property @abstractmethod @@ -60,11 +61,13 @@ async def _handle_api( self, container: AsyncContainer, ) -> list[BaseResponse]: - """Hanlde response with api user. + """Handle response with api user. + + Args: + container (AsyncContainer): Dependency injection container. - :param DBUser user: user from db - :param AsyncSession session: db session - :return list[BaseResponse]: list of handled responses + Returns: + list[BaseResponse]: list of handled responses """ handler = await resolve_deps(func=self.handle, container=container) ldap_session = await container.get(LDAPSession) @@ -93,5 +96,12 @@ async def _handle_api( return responses async def handle_api(self, container: AsyncContainer) -> LDAPResult: - """Get single response.""" + """Get single response. + + Args: + container (AsyncContainer): Dependency injection container. + + Returns: + LDAPResult: The first response from the handled API responses. + """ return (await self._handle_api(container))[0] # type: ignore diff --git a/app/ldap_protocol/ldap_requests/bind.py b/app/ldap_protocol/ldap_requests/bind.py index 4d12470e6..6e758d369 100644 --- a/app/ldap_protocol/ldap_requests/bind.py +++ b/app/ldap_protocol/ldap_requests/bind.py @@ -59,7 +59,14 @@ class BindRequest(BaseRequest): @classmethod def from_data(cls, data: list[ASN1Row]) -> "BindRequest": - """Get bind from data dict.""" + """Get bind from data dict. + + Returns: + BindRequest + + Raises: + ValueError: Auth version not supported + """ auth = data[2].tag_id otpassword: str | None @@ -99,7 +106,16 @@ async def is_user_group_valid( ldap_session: LDAPSession, session: AsyncSession, ) -> bool: - """Test compability.""" + """Test compability. + + Args: + user (User): db user + ldap_session (LDAPSession): ldap session + session (AsyncSession): async session + + Returns: + bool + """ return await is_user_group_valid(user, ldap_session.policy, session) @staticmethod @@ -111,10 +127,14 @@ async def check_mfa( ) -> bool: """Check mfa api. - :param User user: db user - :param LDAPSession ldap_session: ldap session - :param AsyncSession session: db session - :return bool: response + Args: + api (MultifactorAPI | None): MultiFactor API + identity (str): username + otp (str | None): password + policy (NetworkPolicy): network policy + + Returns: + bool: response """ if api is None: return False @@ -136,7 +156,18 @@ async def handle( settings: Settings, mfa: LDAPMultiFactorAPI, ) -> AsyncGenerator[BindResponse, None]: - """Handle bind request, check user and password.""" + """Handle bind request, check user and password. + + Args: + session (AsyncSession): async session + ldap_session (LDAPSession): ldap session + kadmin (AbstractKadmin): kadmin user + settings (Settings): settings + mfa (LDAPMultiFactorAPI): api + + Yields: + AsyncGenerator[BindResponse, None] + """ if not self.name and self.authentication_choice.is_anonymous(): yield BindResponse(result_code=LDAPCodes.SUCCESS) return @@ -232,14 +263,22 @@ class UnbindRequest(BaseRequest): @classmethod def from_data(cls, data: dict[str, list[ASN1Row]]) -> "UnbindRequest": # noqa: ARG003 - """Unbind request has no body.""" + """Unbind request has no body. + + Returns: + UnbindRequest + """ return cls() async def handle( self, ldap_session: LDAPSession, ) -> AsyncGenerator[BaseResponse, None]: - """Handle unbind request, no need to send response.""" + """Handle unbind request, no need to send response. + + Yields: + AsyncGenerator[BaseResponse, None] + """ await ldap_session.delete_user() return # declare empty async generator and exit yield # type: ignore diff --git a/app/ldap_protocol/ldap_requests/bind_methods/__init__.py b/app/ldap_protocol/ldap_requests/bind_methods/__init__.py index 88eb06bbc..a2683cbad 100644 --- a/app/ldap_protocol/ldap_requests/bind_methods/__init__.py +++ b/app/ldap_protocol/ldap_requests/bind_methods/__init__.py @@ -25,15 +25,15 @@ } __all__ = [ - "get_bad_response", - "sasl_mechanism_map", + "GSSAPISL", "AbstractLDAPAuth", + "GSSAPIAuthStatus", + "LDAPBindErrors", "SASLMethod", "SaslAuthentication", "SaslGSSAPIAuthentication", "SaslPLAINAuthentication", "SimpleAuthentication", - "GSSAPIAuthStatus", - "GSSAPISL", - "LDAPBindErrors", + "get_bad_response", + "sasl_mechanism_map", ] diff --git a/app/ldap_protocol/ldap_requests/bind_methods/base.py b/app/ldap_protocol/ldap_requests/bind_methods/base.py index 756cdf61f..9c2f3c871 100644 --- a/app/ldap_protocol/ldap_requests/bind_methods/base.py +++ b/app/ldap_protocol/ldap_requests/bind_methods/base.py @@ -48,7 +48,11 @@ class LDAPBindErrors(StrEnum): ACCOUNT_LOCKED_OUT = "775" def __str__(self) -> str: - """Return the error message as a string.""" + """Return the error message as a string. + + Returns: + str: Error message + """ return ( "80090308: LdapErr: DSID-0C09030B, " "comment: AcceptSecurityContext error, " @@ -59,11 +63,14 @@ def __str__(self) -> str: def get_bad_response(error_message: LDAPBindErrors) -> BindResponse: """Generate BindResponse object with an invalid credentials error. - :param LDAPBindErrors error_message: Error message to include in the - response - :return BindResponse: A response object with the result code set to - INVALID_CREDENTIALS, an empty matchedDN, and the - provided error message + Args: + error_message (LDAPBindErrors): Error message to include in the\ + response + + Returns: + BindResponse: A response object with the result code set to\ + INVALID_CREDENTIALS, an empty matchedDN, and the provided error\ + message """ return BindResponse( result_code=LDAPCodes.INVALID_CREDENTIALS, @@ -85,15 +92,31 @@ def METHOD_ID(self) -> int: # noqa: N802 @abstractmethod def is_valid(self, user: User) -> bool: - """Validate state.""" + """Validate state. + + Returns: + bool: True if valid, False otherwise + """ @abstractmethod def is_anonymous(self) -> bool: - """Return true if anonymous.""" + """Check if anonymous. + + Returns: + bool: True if anonymous, False otherwise + """ @abstractmethod async def get_user(self, session: AsyncSession, username: str) -> User: - """Get user.""" + """Get user. + + Args: + session (AsyncSession): async db session. + username (str): user name. + + Returns: + User: instance of User. + """ class SaslAuthentication(AbstractLDAPAuth): @@ -105,4 +128,8 @@ class SaslAuthentication(AbstractLDAPAuth): @classmethod @abstractmethod def from_data(cls, data: list[ASN1Row]) -> "SaslAuthentication": - """Get auth from data.""" + """Get auth from data. + + Returns: + SaslAuthentication: Sasl auth form. + """ diff --git a/app/ldap_protocol/ldap_requests/bind_methods/sasl_gssapi.py b/app/ldap_protocol/ldap_requests/bind_methods/sasl_gssapi.py index 40c18a530..2009c5a17 100644 --- a/app/ldap_protocol/ldap_requests/bind_methods/sasl_gssapi.py +++ b/app/ldap_protocol/ldap_requests/bind_methods/sasl_gssapi.py @@ -82,15 +82,16 @@ class SaslGSSAPIAuthentication(SaslAuthentication): def is_valid(self, user: User | None) -> bool: # noqa: ARG002 """Check if GSSAPI token is valid. - :param User | None user: indb user - :return bool: status + Returns: + bool: status """ return True def is_anonymous(self) -> bool: """Check if auth is anonymous. - :return bool: status + Returns: + bool: status """ return False @@ -98,8 +99,12 @@ def is_anonymous(self) -> bool: def from_data(cls, data: list[ASN1Row]) -> "SaslGSSAPIAuthentication": """Get auth from data. - :param list[ASN1Row] data: data - :return SaslGSSAPIAuthentication + Args: + data(list[ASN1Row]): data + data: list[ASN1Row]: + + Returns: + SaslGSSAPIAuthentication """ return cls( ticket=data[1].value if len(data) > 1 else b"", @@ -112,8 +117,9 @@ async def _init_security_context( ) -> None: """Init security context. - :param AsyncSession session: db session - :param Settings settings: settings + Args: + session (AsyncSession): db session + settings (Settings): settings """ base_dn_list = await get_base_directories(session) base_dn = base_dn_list[0].name @@ -140,8 +146,12 @@ def _handle_ticket( ) -> GSSAPIAuthStatus: """Handle the ticket and make gssapi step. - :param gssapi.SecurityContext server_ctx: GSSAPI security context - :return GSSAPIAuthStatus: status + Args: + server_ctx(gssapi.SecurityContext): GSSAPI security context + server_ctx: gssapi.SecurityContext: + + Returns: + GSSAPIAuthStatus: status """ try: out_token = server_ctx.step(self.ticket) @@ -151,12 +161,6 @@ def _handle_ticket( return GSSAPIAuthStatus.ERROR def _validate_security_layer(self, client_layer: GSSAPISL) -> bool: - """Validate security layer. - - :param int client_layer: client security layer - :param Settings settings: settings - :return bool: validate result - """ supported = GSSAPISL.SUPPORTED_SECURITY_LAYERS return (client_layer & supported) == client_layer @@ -166,9 +170,11 @@ def _handle_final_client_message( ) -> GSSAPIAuthStatus: """Handle final client message. - :param gssapi.SecurityContext server_ctx: GSSAPI security context - :param Settings settings: settings - :return GSSAPIAuthStatus: status + Args: + server_ctx (gssapi.SecurityContext): GSSAPI security context + + Returns: + GSSAPIAuthStatus: status """ try: unwrap_message = server_ctx.unwrap(self.ticket) @@ -195,9 +201,12 @@ def _generate_final_message( ) -> bytes: """Generate final wrap message. - :param gssapi.SecurityContext server_ctx: gssapi context - :param Settings settings: settings - :return bytes: message + Args: + server_ctx (gssapi.SecurityContext): gssapi context + settings (Settings): settings + + Returns: + bytes: message """ max_size = settings.GSSAPI_MAX_OUTPUT_TOKEN_SIZE if GSSAPISL.SUPPORTED_SECURITY_LAYERS == GSSAPISL.NO_SECURITY: @@ -219,9 +228,13 @@ async def step( ) -> BindResponse | None: """GSSAPI step. - :param AsyncSession session: db session - :param LDAPSession ldap_session: ldap session - :param Settings settings: settings + Args: + session (AsyncSession): db session + ldap_session (LDAPSession): ldap session + settings (Settings): settings + + Returns: + BindResponse | None """ self._ldap_session = ldap_session @@ -265,8 +278,12 @@ async def get_user( # type: ignore ) -> User | None: """Get user. - :param gssapi.SecurityContext ctx: gssapi context - :param AsyncSession session: db session + Args: + session (AsyncSession): db session + username (str): user name + + Returns: + User | None """ ctx = self._ldap_session.gssapi_security_context if not ctx: diff --git a/app/ldap_protocol/ldap_requests/bind_methods/sasl_plain.py b/app/ldap_protocol/ldap_requests/bind_methods/sasl_plain.py index e8c42cd87..4171bec7c 100644 --- a/app/ldap_protocol/ldap_requests/bind_methods/sasl_plain.py +++ b/app/ldap_protocol/ldap_requests/bind_methods/sasl_plain.py @@ -26,8 +26,8 @@ class SaslPLAINAuthentication(SaslAuthentication): def is_valid(self, user: User | None) -> bool: """Check if pwd is valid for user. - :param User | None user: indb user - :return bool: status + Returns: + bool: True if password is valid, False otherwise. """ password = getattr(user, "password", None) if password is not None: @@ -40,13 +40,18 @@ def is_valid(self, user: User | None) -> bool: def is_anonymous(self) -> bool: """Check if auth is anonymous. - :return bool: status + Returns: + bool: True if anonymous, False otherwise. """ return False @classmethod def from_data(cls, data: list[ASN1Row]) -> "SaslPLAINAuthentication": - """Get auth from data.""" + """Get auth from data. + + Returns: + SaslPLAINAuthentication + """ _, username, password = data[1].value.split("\\x00") return cls( credentials=data[1].value, @@ -55,5 +60,13 @@ def from_data(cls, data: list[ASN1Row]) -> "SaslPLAINAuthentication": ) async def get_user(self, session: AsyncSession, _: str) -> User: - """Get user.""" + """Get user. + + Args: + session (AsyncSession): async db session + _ (str): unused arg + + Returns: + User: user + """ return await get_user(session, self.username) # type: ignore diff --git a/app/ldap_protocol/ldap_requests/bind_methods/simple.py b/app/ldap_protocol/ldap_requests/bind_methods/simple.py index 97cef77e6..594f465b6 100644 --- a/app/ldap_protocol/ldap_requests/bind_methods/simple.py +++ b/app/ldap_protocol/ldap_requests/bind_methods/simple.py @@ -23,8 +23,8 @@ class SimpleAuthentication(AbstractLDAPAuth): def is_valid(self, user: User | None) -> bool: """Check if pwd is valid for user. - :param User | None user: indb user - :return bool: status + Returns: + bool: status """ password = getattr(user, "password", None) if password is not None: @@ -34,10 +34,19 @@ def is_valid(self, user: User | None) -> bool: def is_anonymous(self) -> bool: """Check if auth is anonymous. - :return bool: status + Returns: + bool: True if password is empty, False otherwise. """ return not self.password async def get_user(self, session: AsyncSession, username: str) -> User: - """Get user.""" + """Get user. + + Args: + session (AsyncSession): Database session. + username (str): Username to search for. + + Returns: + User: User object if found, raises exception otherwise. + """ return await get_user(session, username) # type: ignore diff --git a/app/ldap_protocol/ldap_requests/delete.py b/app/ldap_protocol/ldap_requests/delete.py index c6c0fc904..cc0145f4d 100644 --- a/app/ldap_protocol/ldap_requests/delete.py +++ b/app/ldap_protocol/ldap_requests/delete.py @@ -44,6 +44,11 @@ class DeleteRequest(BaseRequest): @classmethod def from_data(cls, data: ASN1Row) -> "DeleteRequest": + """Get delete request from data. + + Returns: + DeleteRequest: Instance of DeleteRequest with the entry set. + """ return cls(entry=data) async def handle( @@ -53,7 +58,17 @@ async def handle( kadmin: AbstractKadmin, session_storage: SessionStorage, ) -> AsyncGenerator[DeleteResponse, None]: - """Delete request handler.""" + """Delete request handler. + + Args: + session (AsyncSession): The database session. + ldap_session (LDAPSession): The LDAP session. + kadmin (AbstractKadmin): The Kerberos administration interface. + session_storage (SessionStorage): Session storage for user sessions + + Yields: + DeleteResponse: The response to the delete request. + """ if not ldap_session.user: yield DeleteResponse(**INVALID_ACCESS_RESPONSE) return diff --git a/app/ldap_protocol/ldap_requests/extended.py b/app/ldap_protocol/ldap_requests/extended.py index bb445f0e3..67e1f95f6 100644 --- a/app/ldap_protocol/ldap_requests/extended.py +++ b/app/ldap_protocol/ldap_requests/extended.py @@ -41,7 +41,11 @@ class BaseExtendedValue(ABC, BaseModel): @classmethod @abstractmethod def from_data(cls, data: ASN1Row) -> "BaseExtendedValue": - """Create model from data, decoded from responseValue bytes.""" + """Create model from data, decoded from responseValue bytes. + + Returns: + BaseExtendedValue: instance of BaseExtendedValue. + """ @abstractmethod async def handle( @@ -51,10 +55,25 @@ async def handle( kadmin: AbstractKadmin, settings: Settings, ) -> BaseExtendedResponseValue: - """Generate specific extended resoponse.""" + """Generate specific extended resoponse. + + Args: + ldap_session (LDAPSession): LDAP session + session (AsyncSession): Database session + kadmin (AbstractKadmin): Kerberos client + settings (Settings): Settings + + Returns: + BaseExtendedResponseValue + """ @staticmethod def _decode_value(data: ASN1Row) -> ASN1Row: + """Decode value. + + Returns: + ASN1Row: Decoded row with metadata + """ dec = Decoder() dec.start(data[1].value) # type: ignore output = asn1todict(dec) @@ -79,7 +98,11 @@ class WhoAmIResponse(BaseExtendedResponseValue): authz_id: str def get_value(self) -> str | None: - """Get authz id.""" + """Get authz id. + + Returns: + str | None + """ return self.authz_id @@ -94,7 +117,11 @@ class WhoAmIRequestValue(BaseExtendedValue): @classmethod def from_data(cls, data: ASN1Row) -> "WhoAmIRequestValue": # noqa: ARG003 - """Create model from data, WhoAmIRequestValue data is empty.""" + """Create model from data, WhoAmIRequestValue data is empty. + + Returns: + WhoAmIRequestValue + """ return cls() async def handle( @@ -104,7 +131,17 @@ async def handle( kadmin: AbstractKadmin, # noqa: ARG002 settings: Settings, # noqa: ARG002 ) -> "WhoAmIResponse": - """Return user from session.""" + """Return user from session. + + Args: + ldap_session (LDAPSession): LDAP session + _ (AsyncSession): Database session + kadmin (AbstractKadmin): Kerberos client + settings (Settings): Settings + + Returns: + WhoAmIResponse + """ un = ( f"u:{ldap_session.user.user_principal_name}" if ldap_session.user @@ -118,7 +155,11 @@ class StartTLSResponse(BaseExtendedResponseValue): """Start tls response.""" def get_value(self) -> str | None: - """Get response value.""" + """Get response value. + + Returns: + str | None + """ return "" @@ -134,7 +175,20 @@ async def handle( kadmin: AbstractKadmin, # noqa: ARG002 settings: Settings, ) -> StartTLSResponse: - """Update password of current or selected user.""" + """Update password of current or selected user. + + Args: + ldap_session: LDAPSession + session: AsyncSession + kadmin: AbstractKadmin + settings: Settings + + Returns: + StartTLSResponse + + Raises: + PermissionError: No TLS + """ if settings.USE_CORE_TLS: return StartTLSResponse() @@ -142,7 +196,11 @@ async def handle( @classmethod def from_data(cls, data: ASN1Row) -> "StartTLSRequestValue": # noqa: ARG003 - """Create model from data, decoded from responseValue bytes.""" + """Create model from data, decoded from responseValue bytes. + + Returns: + StartTLSRequestValue + """ return cls() @@ -156,7 +214,11 @@ class PasswdModifyResponse(BaseExtendedResponseValue): gen_passwd: str = "" def get_value(self) -> str | None: - """Return gen password.""" + """Get response value. + + Returns: + str | None + """ return self.gen_passwd @@ -188,7 +250,20 @@ async def handle( kadmin: AbstractKadmin, settings: Settings, ) -> PasswdModifyResponse: - """Update password of current or selected user.""" + """Update password of current or selected user. + + Args: + ldap_session: LDAPSession + session: AsyncSession + kadmin: AbstractKadmin + settings: Settings + + Returns: + PasswdModifyResponse + + Raises: + PermissionError: user not authorized + """ if not settings.USE_CORE_TLS: raise PermissionError("TLS required") @@ -246,7 +321,11 @@ async def handle( @classmethod def from_data(cls, data: ASN1Row) -> "PasswdModifyRequestValue": - """Create model from data, decoded from responseValue bytes.""" + """Create model from data, decoded from responseValue bytes. + + Returns: + PasswdModifyRequestValue + """ d: list = cls._decode_value(data) # type: ignore if len(d) == 3: return cls( @@ -289,7 +368,17 @@ async def handle( kadmin: AbstractKadmin, settings: Settings, ) -> AsyncGenerator[ExtendedResponse, None]: - """Call proxy handler.""" + """Call proxy handler. + + Args: + ldap_session (LDAPSession): LDAP session + session (AsyncSession): Async db session + kadmin (AbstractKadmin): Stub client for non set up dirs. + settings (Settings): Settings with database dsn + + Yields: + AsyncGenerator[ExtendedResponse, None]: + """ try: response = await self.request_value.handle( ldap_session, @@ -315,8 +404,12 @@ async def handle( def from_data(cls, data: list[ASN1Row]) -> "ExtendedRequest": """Create extended request from asn.1 decoded string. - :param ASN1Row data: any data - :return ExtendedRequest: universal request + Args: + data(ASN1Row): any data + data: list[ASN1Row]: + + Returns: + ExtendedRequest: universal request """ oid = data[0].value ext_request = EXTENDED_REQUEST_OID_MAP[oid] diff --git a/app/ldap_protocol/ldap_requests/modify.py b/app/ldap_protocol/ldap_requests/modify.py index aba9f70ad..82a6b05d0 100644 --- a/app/ldap_protocol/ldap_requests/modify.py +++ b/app/ldap_protocol/ldap_requests/modify.py @@ -4,7 +4,7 @@ License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta from enum import IntEnum from typing import AsyncGenerator, ClassVar @@ -68,7 +68,11 @@ class Changes(BaseModel): modification: PartialAttribute def get_name(self) -> str: - """Get mod name.""" + """Get mod name. + + Returns: + str: mod name + """ return self.modification.type.lower() @@ -106,6 +110,11 @@ class ModifyRequest(BaseRequest): @classmethod def from_data(cls, data: list[ASN1Row]) -> "ModifyRequest": + """Get modify request from data. + + Returns: + ModifyRequest: modify request + """ entry, proto_changes = data changes = [] @@ -129,7 +138,12 @@ async def _update_password_expiration( change: Changes, session: AsyncSession, ) -> None: - """Update password expiration if policy allows.""" + """Update password expiration if policy allows. + + Args: + change (Changes): Change + session (AsyncSession): Database session + """ if not ( change.modification.type == "krbpasswordexpiration" and change.modification.vals[0] == "19700101000000Z" @@ -141,7 +155,7 @@ async def _update_password_expiration( if policy.maximum_password_age_days == 0: return - now = datetime.now(timezone.utc) + now = datetime.now(UTC) now += timedelta(days=policy.maximum_password_age_days) change.modification.vals[0] = now.strftime("%Y%m%d%H%M%SZ") @@ -154,7 +168,19 @@ async def handle( settings: Settings, entity_type_dao: EntityTypeDAO, ) -> AsyncGenerator[ModifyResponse, None]: - """Change request handler.""" + """Change request handler. + + Args: + ldap_session (LDAPSession): LDAP session + session (AsyncSession): Database session + session_storage (SessionStorage): Session storage + kadmin (AbstractKadmin): Kadmin + settings (Settings): Settings + entity_type_dao (EntityTypeDAO): Entity Type DAO. + + Yields: + AsyncGenerator[ModifyResponse, None] + """ if not ldap_session.user: yield ModifyResponse( result_code=LDAPCodes.INSUFFICIENT_ACCESS_RIGHTS, @@ -244,6 +270,14 @@ async def handle( yield ModifyResponse(result_code=LDAPCodes.SUCCESS) def _match_bad_response(self, err: BaseException) -> tuple[LDAPCodes, str]: + """Match bad response. + + Returns: + tuple[LDAPCodes, str]: result code and message + + Raises: + Exception: if can`t match exception and LDAP code. + """ match err: case ValueError(): logger.error(f"Invalid value: {err}") @@ -262,9 +296,9 @@ def _match_bad_response(self, err: BaseException) -> tuple[LDAPCodes, str]: return LDAPCodes.STRONGER_AUTH_REQUIRED, "" case _: - raise err + raise Exception - def _get_dir_query(self) -> Select: + def _get_dir_query(self) -> Select[tuple[Directory]]: return ( select(Directory) .join(Directory.attributes) @@ -281,6 +315,16 @@ def _check_password_change_requested( directory: Directory, user_dir_id: int, ) -> bool: + """Check if password change is requested. + + Args: + names (set[str]): attr names + directory (Directory): directory + user_dir_id (int): user id + + Returns: + bool: True if password change is requested, False otherwise + """ return ( ("userpassword" in names or "unicodepwd" in names) and len(names) == 1 diff --git a/app/ldap_protocol/ldap_requests/modify_dn.py b/app/ldap_protocol/ldap_requests/modify_dn.py index 54c648664..06990f24d 100644 --- a/app/ldap_protocol/ldap_requests/modify_dn.py +++ b/app/ldap_protocol/ldap_requests/modify_dn.py @@ -65,6 +65,7 @@ class ModifyDNRequest(BaseRequest): deleteoldrdn=true new_superior='ou=users,dc=multifactor,dc=dev' + Args: >>> cn = main2, ou = users, dc = multifactor, dc = dev """ @@ -77,7 +78,15 @@ class ModifyDNRequest(BaseRequest): @classmethod def from_data(cls, data: list[ASN1Row]) -> "ModifyDNRequest": - """Create structure from ASN1Row dataclass list.""" + """Create structure from ASN1Row dataclass list. + + Args: + data (list[ASN1Row]): List of ASN1Row objects containing\ + the request data. + + Returns: + ModifyDNRequest: Instance of ModifyDNRequest with parsed data. + """ return cls( entry=data[0].value, newrdn=data[1].value, @@ -90,7 +99,15 @@ async def handle( ldap_session: LDAPSession, session: AsyncSession, ) -> AsyncGenerator[ModifyDNResponse, None]: - """Handle message with current user.""" + """Handle message with current user. + + Args: + ldap_session (LDAPSession): Current LDAP session. + session (AsyncSession): Database session. + + Yields: + ModifyDNResponse: Response to the Modify DN request. + """ if not ldap_session.user: yield ModifyDNResponse(**INVALID_ACCESS_RESPONSE) return diff --git a/app/ldap_protocol/ldap_requests/search.py b/app/ldap_protocol/ldap_requests/search.py index 4939849a8..71b18cf8c 100644 --- a/app/ldap_protocol/ldap_requests/search.py +++ b/app/ldap_protocol/ldap_requests/search.py @@ -99,14 +99,24 @@ class Config: @field_serializer("filter") def serialize_filter(self, val: ASN1Row | None, _info: Any) -> str | None: - """Serialize filter field.""" + """Serialize filter field. + + Args: + val (ASN1Row | None): instance of ASN1Row + _info (Any): not used + + Returns: + str | None + """ return val.to_ldap_filter() if isinstance(val, ASN1Row) else None @classmethod - def from_data( - cls, - data: dict[str, list[ASN1Row]], - ) -> "SearchRequest": + def from_data(cls, data: dict[str, list[ASN1Row]]) -> "SearchRequest": + """Get search request from data. + + Returns: + SearchRequest: LDAP search request + """ ( base_object, scope, @@ -131,6 +141,11 @@ def from_data( @cached_property def requested_attrs(self) -> list[str]: + """Get requested attributes. + + Returns: + list[str]: requested attributes + """ return [attr.lower() for attr in self.attributes] async def _get_subschema(self, session: AsyncSession) -> SearchResultEntry: @@ -167,7 +182,12 @@ async def get_root_dse( ) -> defaultdict[str, list[str]]: """Get RootDSE. - :return defaultdict[str, list[str]]: queried attrs + Args: + session (AsyncSession): Database session + settings (Settings): Settings + + Returns: + defaultdict[str, list[str]]: queried attrs """ data = defaultdict(list) domain_query = ( @@ -222,9 +242,8 @@ async def get_root_dse( def cast_filter(self) -> UnaryExpression | ColumnElement: """Convert asn1 row filter_ to sqlalchemy obj. - :param ASN1Row filter_: requested filter_ - :param AsyncSession session: sa session - :return UnaryExpression: condition + Returns: + UnaryExpression | ColumnElement """ return cast_filter2sql(self.filter) @@ -241,6 +260,15 @@ async def handle( Provides following responses: Entry -> Reference (optional) -> Done + + Args: + session (AsyncSession): Database session + ldap_session (LDAPSession): LDAP session + settings (Settings): Settings + + Yields: + AsyncGenerator[SearchResultDone | SearchResultReference |\ + SearchResultEntry, None] """ async with ldap_session.lock() as user: async for response in self.get_result(user, session, settings): @@ -254,9 +282,13 @@ async def get_result( ) -> AsyncGenerator[SearchResultEntry | SearchResultDone, None]: """Create response. - :param bool user_logged: is user in session - :param AsyncSession session: sa session - :yield SearchResult: search result + Args: + user (UserSchema | None): schema of user + session (AsyncSession): async session. + settings (Settings): settings. + + Yields: + AsyncGenerator[SearchResultEntry | SearchResultDone, None]: """ is_root_dse = self.scope == Scope.BASE_OBJECT and not self.base_object is_schema = self.base_object.lower() == "cn=schema" @@ -303,18 +335,38 @@ async def get_result( @cached_property def member_of(self) -> bool: + """Check if member of is requested. + + Returns: + bool: True if member of is requested, False otherwise + """ return "memberof" in self.requested_attrs or self.all_attrs @cached_property def member(self) -> bool: + """Check if member is requested. + + Returns: + bool: True if member is requested, False otherwise + """ return "member" in self.requested_attrs or self.all_attrs @cached_property def token_groups(self) -> bool: + """Check if token groups is requested. + + Returns: + bool: True if token groups is requested, False otherwise + """ return "tokengroups" in self.requested_attrs @cached_property def all_attrs(self) -> bool: + """Check if all attributes are requested. + + Returns: + bool: True if all attributes are requested, False otherwise + """ return "*" in self.requested_attrs or not self.requested_attrs def build_query( @@ -322,7 +374,15 @@ def build_query( base_directories: list[Directory], user: UserSchema, ) -> Select: - """Build tree query.""" + """Build tree query. + + Args: + base_directories (list[Directory]): instances of Directory + user (UserSchema): serialized user + + Returns: + Select + """ query = ( select(Directory) .join(User, isouter=True) @@ -397,9 +457,12 @@ async def paginate_query( ) -> tuple[Select, int, int]: """Paginate query. - :param _type_ query: _description_ - :param _type_ session: _description_ - :return tuple[select, int, int]: query, pages_total, count + Args: + query (Select): SQLAlchemy select query + session (AsyncSession): async session + + Returns: + tuple[Select, int, int]: select query, pages_total, count """ if self.page_number is None: return query, 0, 0 @@ -412,14 +475,22 @@ async def paginate_query( end = start + self.size_limit query = query.offset(start).limit(end) - return query, int(ceil(count / float(self.size_limit))), count + return query, ceil(count / float(self.size_limit)), count async def tree_view( # noqa: C901 self, query: Select, session: AsyncSession, ) -> AsyncGenerator[SearchResultEntry, None]: - """Yield all resulted directories.""" + """Yield all resulted directories. + + Args: + query (Select): SQLAlchemy select query + session (AsyncSession): async session + + Yields: + AsyncGenerator[SearchResultEntry, None]: yielded directories + """ directories = await session.stream_scalars(query) # logger.debug(query.compile(compile_kwargs={"literal_binds": True})) # noqa @@ -462,10 +533,8 @@ async def tree_view( # noqa: C901 attrs["authTimestamp"].append(directory.user.last_logon) if ( - self.member_of - and "group" in obj_classes - or "user" in obj_classes - ): + self.member_of and "group" in obj_classes + ) or "user" in obj_classes: for group in directory.groups: attrs["memberOf"].append(group.directory.path_dn) diff --git a/app/ldap_protocol/ldap_responses.py b/app/ldap_protocol/ldap_responses.py index e8e8f3642..9fcdceefe 100644 --- a/app/ldap_protocol/ldap_responses.py +++ b/app/ldap_protocol/ldap_responses.py @@ -38,15 +38,18 @@ class Config: populate_by_name = True arbitrary_types_allowed = True - json_encoders = { - bytes: lambda value: value.hex(), - } + json_encoders: ClassVar[dict] = {bytes: lambda value: value.hex()} class BaseEncoder(BaseModel): """Class with encoder methods.""" def _get_asn1_fields(self) -> dict: + """Get ASN1 fields. + + Returns: + dict: ASN1 fields + """ fields = self.model_dump() fields.pop("PROTOCOL_OP", None) return fields @@ -99,26 +102,38 @@ class PartialAttribute(BaseModel): @property def l_name(self) -> str: - """Get lower case name.""" + """Get lower case name. + + Returns: + str: lower case name + """ return self.type.lower() @field_validator("type", mode="before") @classmethod def validate_type(cls, v: str | bytes | int) -> str: + """Validate type. + + Returns: + str: value + """ return str(v) @field_validator("vals", mode="before") @classmethod def validate_vals(cls, vals: list[str | int | bytes]) -> list[str | bytes]: + """Validate vals. + + Returns: + list[str | bytes]: values + """ return [v if isinstance(v, bytes) else str(v) for v in vals] class Config: """Allow class to use property.""" arbitrary_types_allowed = True - json_encoders = { - bytes: lambda value: value.hex(), - } + json_encoders: ClassVar[dict] = {bytes: lambda value: value.hex()} class SearchResultEntry(BaseResponse): @@ -169,6 +184,11 @@ class SearchResultDone(LDAPResult, BaseResponse): total_objects: int = 0 def _get_asn1_fields(self) -> dict: + """Get ASN1 fields. + + Returns: + dict: ASN1 fields + """ fields = super()._get_asn1_fields() fields.pop("total_pages") fields.pop("total_objects") diff --git a/app/ldap_protocol/ldap_schema/attribute_type_dao.py b/app/ldap_protocol/ldap_schema/attribute_type_dao.py index b371638f8..bf6fd7d5a 100644 --- a/app/ldap_protocol/ldap_schema/attribute_type_dao.py +++ b/app/ldap_protocol/ldap_schema/attribute_type_dao.py @@ -33,7 +33,11 @@ class AttributeTypeSchema(BaseSchemaModel): @classmethod def from_db(cls, attribute_type: AttributeType) -> "AttributeTypeSchema": - """Create an instance of Attribute Type Schema from SQLA object.""" + """Create an instance from database. + + Returns: + AttributeTypeSchema: serialized AttributeType. + """ return cls( oid=attribute_type.oid, name=attribute_type.name, @@ -62,8 +66,6 @@ class AttributeTypeDAO: """Attribute Type DAO.""" _session: AsyncSession - AttributeTypeNotFoundError = InstanceNotFoundError - AttributeTypeCantModifyError = InstanceCantModifyError def __init__(self, session: AsyncSession) -> None: """Initialize Attribute Type DAO with session.""" @@ -73,10 +75,13 @@ async def get_paginator( self, params: PaginationParams, ) -> PaginationResult: - """Retrieve paginated Attribute Types. + """Retrieve paginated attribute_types. + + Args: + params (PaginationParams): parameters for pagination. - :param PaginationParams params: page_size and page_number. - :return PaginationResult: Chunk of Attribute Types and metadata. + Returns: + PaginationResult: Chunk of attribute_types and metadata. """ return await PaginationResult[AttributeType].get( params=params, @@ -96,13 +101,13 @@ async def create_one( ) -> None: """Create a new Attribute Type. - :param str oid: OID. - :param str name: Name. - :param str syntax: Syntax. - :param bool single_value: Single value. - :param bool no_user_modification: User can't modify it. - :param bool is_system: Attribute Type is system. - :return None. + Args: + oid (str): OID. + name (str): Name. + syntax (str): Syntax. + single_value (bool): Single value. + no_user_modification (bool): User can't modify it. + is_system (bool): Attribute Type is system. """ attribute_type = AttributeType( oid=oid, @@ -120,9 +125,11 @@ async def get_one_by_name( ) -> AttributeType: """Get single Attribute Type by name. - :param str attribute_type_name: Attribute Type name. - :raise AttributeTypeNotFoundError: If Attribute Type not found. - :return AttributeType: Instance of Attribute Type. + Returns: + AttributeType: Attribute Type. + + Raises: + InstanceNotFoundError: Attribute Type not found. """ attribute_type = await self._session.scalar( select(AttributeType) @@ -130,7 +137,7 @@ async def get_one_by_name( ) # fmt: skip if not attribute_type: - raise self.AttributeTypeNotFoundError( + raise InstanceNotFoundError( f"Attribute Type with name '{attribute_type_name}' not found." ) @@ -142,8 +149,8 @@ async def get_all_by_names( ) -> list[AttributeType]: """Get list of Attribute Types by names. - :param list[str] attribute_type_names: Attribute Type names. - :return list[AttributeType]: List of Attribute Types. + Returns: + list[AttributeType]: List of Attribute Types. """ if not attribute_type_names: return [] @@ -161,14 +168,16 @@ async def modify_one( ) -> None: """Modify Attribute Type. - :param AttributeType attribute_type: Attribute Type. - :param AttributeTypeUpdateSchema new_statement: Attribute Type Schema. - :raise AttributeTypeCantModifyError: If Attribute Type is system,\ - it cannot be changed. - :return None. + Args: + attribute_type (AttributeType): Attribute Type. + new_statement (AttributeTypeUpdateSchema): Attribute Type + Schema. + + Raises: + InstanceCantModifyError: System Attribute Type cannot be modified. """ if attribute_type.is_system: - raise self.AttributeTypeCantModifyError( + raise InstanceCantModifyError( "System Attribute Type cannot be modified." ) @@ -184,12 +193,11 @@ async def delete_all_by_names( ) -> None: """Delete not system Attribute Types by names. - :param list[str] attribute_type_names: List of Attribute Types names. - :param AsyncSession session: Database session. - :return None: None. + Args: + attribute_type_names (list[str]): List of Attribute Types OIDs. """ if not attribute_type_names: - return None + return await self._session.execute( delete(AttributeType) diff --git a/app/ldap_protocol/ldap_schema/entity_type_dao.py b/app/ldap_protocol/ldap_schema/entity_type_dao.py index 60fdba633..88aa6fbe1 100644 --- a/app/ldap_protocol/ldap_schema/entity_type_dao.py +++ b/app/ldap_protocol/ldap_schema/entity_type_dao.py @@ -30,7 +30,11 @@ class EntityTypeSchema(BaseModel): @classmethod def from_db(cls, entity_type: EntityType) -> "EntityTypeSchema": - """Create an instance of Entity Type Schema from SQLA object.""" + """Create an instance of Entity Type Schema from SQLA object. + + Returns: + EntityTypeSchema: Instance of Entity Type Schema. + """ return cls( name=entity_type.name, is_system=entity_type.is_system, @@ -55,7 +59,6 @@ class EntityTypeDAO: """Entity Type DAO.""" _session: AsyncSession - EntityTypeNotFoundError = InstanceNotFoundError def __init__(self, session: AsyncSession) -> None: """Initialize Entity Type DAO with a database session.""" @@ -67,8 +70,8 @@ async def get_paginator( ) -> PaginationResult: """Retrieve paginated Entity Types. - :param PaginationParams params: page_size and page_number. - :return PaginationResult: Chunk of Entity Types and metadata. + Returns: + PaginationResult: Chunk of Entity Types and metadata. """ return await PaginationResult[EntityType].get( params=params, @@ -85,10 +88,10 @@ async def create_one( ) -> None: """Create a new Entity Type instance. - :param str name: Name. - :param Iterable[str] object_class_names: Object Class names. - :param bool is_system: Is system. - :return None. + Args: + name (str): Name. + object_class_names (Iterable[str]): Object Class names. + is_system (bool): Is system. """ entity_type = EntityType( name=name, @@ -103,9 +106,11 @@ async def get_one_by_name( ) -> EntityType: """Get single Entity Type by name. - :param str entity_type_name: Entity Type name. - :raise EntityTypeNotFoundError: If Entity Type not found. - :return EntityType: Instance of Entity Type. + Returns: + EntityType: Instance of Entity Type. + + Raises: + InstanceNotFoundError: If Entity Type not found. """ entity_type = await self._session.scalar( select(EntityType) @@ -113,7 +118,7 @@ async def get_one_by_name( ) # fmt: skip if not entity_type: - raise self.EntityTypeNotFoundError( + raise InstanceNotFoundError( f"Entity Type with name '{entity_type_name}' not found." ) @@ -125,8 +130,8 @@ async def get_entity_type_by_object_class_names( ) -> EntityType | None: """Get single Entity Type by object class names. - :param Iterable[str] object_class_names: object class names. - :return EntityType | None: Instance of Entity Type or None. + Returns: + EntityType | None: Instance of Entity Type or None. """ result = await self._session.execute( select(EntityType) @@ -146,10 +151,11 @@ async def modify_one( ) -> None: """Modify Entity Type. - :param EntityType entity_type: Entity Type. - :param EntityTypeUpdateSchema new_statement: New statement\ - of Entity Type. - :return None. + Args: + entity_type (EntityType): Entity Type. + new_statement (EntityTypeUpdateSchema): New statement\ + of Entity Type. + object_class_dao (ObjectClassDAO): Object Class DAO. """ await object_class_dao.is_all_object_classes_exists( new_statement.object_class_names @@ -191,8 +197,8 @@ async def delete_all_by_names( ) -> None: """Delete not system and not used Entity Type by their names. - :param list[str] entity_type_names: Entity Type names. - :return None. + Args: + entity_type_names (list[str]): Entity Type names. """ await self._session.execute( delete(EntityType) @@ -207,10 +213,7 @@ async def delete_all_by_names( ) # fmt: skip async def attach_entity_type_to_directories(self) -> None: - """Find all Directories without an Entity Type and attach it to them. - - :return None. - """ + """Find all Directories without Entity Type and attach it to them.""" result = await self._session.execute( select(Directory) .where(Directory.entity_type_name.is_(None)) @@ -226,7 +229,7 @@ async def attach_entity_type_to_directories(self) -> None: is_system_entity_type=False, ) - return None + return async def attach_entity_type_to_directory( self, @@ -235,9 +238,9 @@ async def attach_entity_type_to_directory( ) -> None: """Try to find the Entity Type, attach it to the Directory. - :param Directory directory: Directory to attach Entity Type. - :param bool is_system_entity_type: Is system Entity Type. - :return None. + Args: + directory (Directory): Directory to attach Entity Type. + is_system_entity_type (bool): Is system Entity Type. """ object_class_names = directory.object_class_names_set diff --git a/app/ldap_protocol/ldap_schema/object_class_dao.py b/app/ldap_protocol/ldap_schema/object_class_dao.py index 3693377f2..65b76264e 100644 --- a/app/ldap_protocol/ldap_schema/object_class_dao.py +++ b/app/ldap_protocol/ldap_schema/object_class_dao.py @@ -38,7 +38,11 @@ class ObjectClassSchema(BaseSchemaModel): @classmethod def from_db(cls, object_class: ObjectClass) -> "ObjectClassSchema": - """Create an instance of Object Class Schema from SQLA object.""" + """Create an instance of Object Class Schema from SQLA object. + + Returns: + ObjectClassSchema: instance of ObjectClassSchema. + """ return cls( oid=object_class.oid, name=object_class.name, @@ -69,15 +73,17 @@ class ObjectClassDAO: _session: AsyncSession _attribute_type_dao: AttributeTypeDAO - ObjectClassNotFoundError = InstanceNotFoundError - ObjectClassCantModifyError = InstanceCantModifyError - def __init__( self, session: AsyncSession, attribute_type_dao: AttributeTypeDAO, ) -> None: - """Initialize Object Class DAO with session.""" + """Initialize Object Class DAO with session. + + Args: + session (AsyncSession): async db session. + attribute_type_dao (AttributeTypeDAO): Attribute Type DAO. + """ self._session = session self._attribute_type_dao = attribute_type_dao @@ -87,8 +93,8 @@ async def get_paginator( ) -> PaginationResult: """Retrieve paginated Object Classes. - :param PaginationParams params: page_size and page_number. - :return PaginationResult: Chunk of Object Classes and metadata. + Returns: + PaginationResult: Chunk of object_classes and metadata. """ return await PaginationResult[ObjectClass].get( params=params, @@ -109,15 +115,17 @@ async def create_one( ) -> None: """Create a new Object Class. - :param str oid: OID. - :param str name: Name. - :param str | None superior_name: Parent Object Class. - :param KindType kind: Kind. - :param bool is_system: Object Class is system. - :param list[str] attribute_type_names_must: Attribute Types must. - :param list[str] attribute_type_names_may: Attribute Types may. - :raise ObjectClassNotFoundError: If superior Object Class not found. - :return None. + Args: + oid (str): OID. + name (str): Name. + superior_name (str | None): Parent Object Class. + kind (KindType): Kind. + is_system (bool): Object Class is system. + attribute_type_names_must (list[str]): Attribute Types must. + attribute_type_names_may (list[str]): Attribute Types may. + + Raises: + InstanceNotFoundError: Superior (parent) Object class not found. """ superior = ( await self.get_one_by_name(superior_name) @@ -125,7 +133,7 @@ async def create_one( else None ) if superior_name and not superior: - raise self.ObjectClassNotFoundError( + raise InstanceNotFoundError( f"Superior (parent) Object class {superior_name} not found\ in schema." ) @@ -160,8 +168,8 @@ async def count_exists_object_class_by_names( ) -> int: """Count exists Object Class by names. - :param list[str] object_class_names: Object Class names. - :return int. + Returns: + int: count of object classes """ count_query = ( select(func.count()) @@ -177,16 +185,18 @@ async def is_all_object_classes_exists( ) -> Literal[True]: """Check if all Object Classes exist. - :param list[str] object_class_names: Object Class names. - :raise ObjectClassNotFoundError: If Object Class not found. - :return bool. + Returns: + Literal[True]: True if all object classes found. + + Raises: + InstanceNotFoundError: Object class not found. """ count_ = await self.count_exists_object_class_by_names( object_class_names ) if count_ != len(object_class_names): - raise self.ObjectClassNotFoundError( + raise InstanceNotFoundError( f"Not all Object Classes\ with names {object_class_names} found." ) @@ -199,9 +209,11 @@ async def get_one_by_name( ) -> ObjectClass: """Get single Object Class by name. - :param str object_class_name: Object Class name. - :raise ObjectClassNotFoundError: If Object Class not found. - :return ObjectClass: Instance of Object Class. + Returns: + ObjectClass: Object Class. + + Raises: + InstanceNotFoundError: Object class not found. """ object_class = await self._session.scalar( select(ObjectClass) @@ -209,7 +221,7 @@ async def get_one_by_name( ) # fmt: skip if not object_class: - raise self.ObjectClassNotFoundError( + raise InstanceNotFoundError( f"Object Class with name '{object_class_name}' not found." ) @@ -221,8 +233,8 @@ async def get_all_by_names( ) -> list[ObjectClass]: """Get list of Object Classes by names. - :param list[str] object_class_names: Object Classes names. - :return list[ObjectClass]: List of Object Classes. + Returns: + list[ObjectClass]: List of Object Classes. """ query = await self._session.scalars( select(ObjectClass) @@ -241,14 +253,17 @@ async def modify_one( ) -> None: """Modify Object Class. - :param ObjectClass object_class: Object Class. - :param ObjectClassUpdateSchema new_statement: New statement ObjectClass - :raise ObjectClassCantModifyError: If Object Class is system,\ - it cannot be changed. - :return None. + Args: + object_class (ObjectClass): Object Class. + new_statement (ObjectClassUpdateSchema): New statement of object + class + + Raises: + InstanceCantModifyError: If Object Class is system,\ + it cannot be changed. """ if object_class.is_system: - raise self.ObjectClassCantModifyError( + raise InstanceCantModifyError( "System Object Class cannot be modified." ) @@ -277,8 +292,8 @@ async def delete_all_by_names( ) -> None: """Delete not system Object Classes by Names. - :param list[str] object_classes_names: Object Classes names. - :return None. + Args: + object_classes_names (list[str]): Object classes names. """ await self._session.execute( delete(ObjectClass) diff --git a/app/ldap_protocol/messages.py b/app/ldap_protocol/messages.py index 78e2f6afa..317140b37 100644 --- a/app/ldap_protocol/messages.py +++ b/app/ldap_protocol/messages.py @@ -36,7 +36,11 @@ class LDAPMessage(ABC, BaseModel): @property def name(self) -> str: - """Message name.""" + """Message name. + + Returns: + str: message name + """ return get_class_name(self.context) @@ -46,7 +50,11 @@ class LDAPResponseMessage(LDAPMessage): context: SerializeAsAny[BaseResponse] def encode(self) -> bytes: - """Encode message to asn1.""" + """Encode message to asn1. + + Returns: + bytes + """ enc = Encoder() enc.start() enc.enter(Numbers.Sequence) @@ -76,7 +84,17 @@ class LDAPRequestMessage(LDAPMessage): @classmethod def from_bytes(cls, source: bytes) -> "LDAPRequestMessage": - """Create message from bytes.""" + """Create message from bytes. + + Args: + source: bytes + + Returns: + LDAPRequestMessage + + Raises: + ValueError: incorrect schema + """ dec = Decoder() dec.start(source) output = asn1todict(dec) @@ -118,10 +136,12 @@ def from_bytes(cls, source: bytes) -> "LDAPRequestMessage": def from_err(cls, source: bytes, err: Exception) -> LDAPResponseMessage: """Create error response message. - :param bytes source: source data - :param Exception err: any error - :raises ValueError: on invalid schema - :return LDAPResponseMessage: response with err code + Args: + source (bytes): source data + err (Exception): any error + + Returns: + LDAPResponseMessage: response with err code """ output = asn1todict(source) message_id = 0 @@ -153,7 +173,12 @@ async def create_response( ) -> AsyncGenerator[LDAPResponseMessage, None]: """Call unique context handler. - :yield LDAPResponseMessage: create response for context. + Args: + handler (Callable[..., AsyncGenerator[BaseResponse, None]]):\ + handler + + Yields: + LDAPResponseMessage: create response for context. """ async for response in handler(): yield LDAPResponseMessage( diff --git a/app/ldap_protocol/multifactor.py b/app/ldap_protocol/multifactor.py index 155e8af15..25d2a4e4d 100644 --- a/app/ldap_protocol/multifactor.py +++ b/app/ldap_protocol/multifactor.py @@ -58,7 +58,13 @@ async def get_creds( ) -> Creds | None: """Get API creds. - :return tuple[str, str]: api key and secret + Args: + session (AsyncSession): session + key_name (str): key name + secret_name (str): secret name + + Returns: + tuple[str, str]: api key and secret """ query = ( select(CatalogueSetting) @@ -84,23 +90,14 @@ class MultifactorAPI: Methods: - `__init__(key, secret, client, settings)`: Initializes the object with - the required credentials and bound HTTP client from di. + the required credentials and bound HTTP client from di. - `ldap_validate_mfa(username, password)`: Validates MFA for a user. If the - password is not provided, sends a push notification and waits for user - approval with a timeout of 60 seconds. + password is not provided, sends a push notification and waits for user + approval with a timeout of 60 seconds. - `get_create_mfa(username)`: Retrieves or creates an MFA token for the - specified user. + specified user. - `refresh_token()`: Refreshes the authentication token using the refresh - endpoint. - - Attributes: - - `MultifactorError`: Exception class for MFA-related errors. - - `AUTH_URL_USERS`: Endpoint URL for user authentication requests. - - `AUTH_URL_ADMIN`: Endpoint URL for admin authentication requests. - - `REFRESH_URL`: Endpoint URL for token refresh. - - `client`: Asynchronous HTTP client for making requests. - - `settings`: Configuration settings for the MFA service. - + endpoint. """ MultifactorError = _MultifactorError @@ -123,10 +120,12 @@ def __init__( ): """Set creds and web client. - :param str key: mfa key - :param str secret: mfa secret - :param httpx.AsyncClient client: client for making queries (activated) - :param Settings settings: app settings + Args: + key (str): mfa key + secret (str): mfa secret + client (httpx.AsyncClient): client for making queries + (activated) + settings (Settings): app settings """ self.client = client self.settings = settings @@ -134,6 +133,11 @@ def __init__( @staticmethod def _generate_trace_id_header() -> dict[str, str]: + """Generate trace id header. + + Returns: + dict[str, str] + """ return {"mf-trace-id": f"md:{uuid.uuid4()}"} @log_mfa.catch(reraise=True) @@ -149,14 +153,18 @@ async def ldap_validate_mfa( timeout is 60 seconds. "m" key-character is used to mark push request in multifactor API. - :param str username: un - :param str password: pwd - :param NetworkPolicy policy: policy - :raises MultifactorError: connect timeout - :raises MultifactorError: invalid json - :raises MultifactorError: Invalid status - :return bool: status - """ + Args: + username (str): un + password (str): pwd + + Returns: + bool: status + + Raises: + MFAConnectError: API Timeout + MFAMissconfiguredError: API Key or Secret is invalid + MultifactorError: status error + """ # noqa: DOC502 passcode = password or "m" log_mfa.debug(f"LDAP MFA request: {username}, {password}") try: @@ -211,14 +219,19 @@ async def get_create_mfa( ) -> str: """Create mfa link. - :param str username: un - :param str callback_url: callback uri to send token - :param int uid: user id - :raises httpx.TimeoutException: on timeout - :raises self.MultifactorError: on invalid json, Key or error status - code - :return str: url to open in new page - """ + Args: + username (str): un + callback_url (str): callback uri to send token + uid (int): user id + + Returns: + str: url to open in new page + + Raises: + MFAConnectError: API Timeout + MFAMissconfiguredError: API Key or Secret is invalid + MultifactorError: Incorrect resource + """ # noqa: DOC502 data = { "identity": username, "claims": { @@ -264,10 +277,15 @@ async def get_create_mfa( async def refresh_token(self, token: str) -> str: """Refresh mfa token. - :param str token: str jwt token - :raises self.MultifactorError: on api err - :return str: new token - """ + Args: + token (str): str jwt token + + Returns: + str: new token + + Raises: + MultifactorError: on api err + """ # noqa: DOC502 try: response = await self.client.post( self.settings.MFA_API_URI + self.REFRESH_URL, diff --git a/app/ldap_protocol/policies/access_policy.py b/app/ldap_protocol/policies/access_policy.py index 0ffc5dd4a..9014b2a50 100644 --- a/app/ldap_protocol/policies/access_policy.py +++ b/app/ldap_protocol/policies/access_policy.py @@ -21,14 +21,14 @@ from models import AccessPolicy, Directory, Group T = TypeVar("T", bound=Select) -__all__ = ["get_policies", "create_access_policy", "mutate_ap"] +__all__ = ["create_access_policy", "get_policies", "mutate_ap"] async def get_policies(session: AsyncSession) -> list[AccessPolicy]: """Get policies. - :param AsyncSession session: db - :return list[AccessPolicy]: result + Returns: + list[AccessPolicy]: result """ query = select(AccessPolicy).options( selectinload(AccessPolicy.groups).selectinload(Group.directory), @@ -48,7 +48,18 @@ async def create_access_policy( groups: list[GRANT_DN_STRING], session: AsyncSession, ) -> None: - """Get policies.""" + """Get policies. + + Args: + name (str): access policy name + can_read (bool): can read + can_add (bool): can add + can_modify (bool): can modify + can_delete (bool): can delete + grant_dn (GRANT_DN_STRING): main dn + groups (list[GRANT_DN_STRING]): list of groups + session (AsyncSession): session + """ path = get_search_path(grant_dn) dir_filter = get_path_filter( column=Directory.path[1 : len(path)], @@ -71,16 +82,21 @@ async def create_access_policy( await session.flush() -def mutate_ap( +def mutate_ap[T: Select]( query: T, user: UserSchema, action: Literal["add", "read", "modify", "del"] = "read", ) -> T: """Modify query with read rule filter, joins acess policies. - :param T query: select(Directory) - :param UserSchema user: user data - :return T: select(Directory).join(Directory.access_policies) + Args: + query (T): select(Directory) + user (UserSchema): serialized user + action (Literal["add", "read", "modify", "del"]): + (Default value = "read") + + Returns: + T: select(Directory).join(Directory.access_policies) """ whitelist = AccessPolicy.id.in_(user.access_policies_ids) diff --git a/app/ldap_protocol/policies/network_policy.py b/app/ldap_protocol/policies/network_policy.py index c39277b7e..ed935a681 100644 --- a/app/ldap_protocol/policies/network_policy.py +++ b/app/ldap_protocol/policies/network_policy.py @@ -22,11 +22,14 @@ def build_policy_query( ) -> Select: """Build a base query for network policies with optional group filtering. - :param IPv4Address ip: IP address to filter - :param Literal["is_http", "is_ldap", "is_kerberos"] protocol_field_name - protocol: Protocol to filter - :param list[int] | None user_group_ids: List of user group IDs, optional - :return: Select query + Args: + ip (IPv4Address | IPv6Address): IP address to filter + protocol_field_name (Literal["is_http", "is_ldap", "is_kerberos"]):\ + Protocol to filter + user_group_ids (list[int] | None): List of user group IDs, optional + + Returns: + Select: SQLAlchemy query """ protocol_field = getattr(NetworkPolicy, protocol_field_name) query = ( @@ -62,10 +65,13 @@ async def check_mfa_group( ) -> bool: """Check if user is in a group with MFA policy. - :param NetworkPolicy policy: policy object - :param User user: user object - :param AsyncSession session: db session - :return bool: status + Args: + policy (NetworkPolicy): policy object + user (User): user object + session (AsyncSession): db session + + Returns: + bool: status """ return await session.scalar( select( @@ -84,9 +90,13 @@ async def get_user_network_policy( ) -> NetworkPolicy | None: """Get the highest priority network policy for user, ip and protocol. - :param User user: user object - :param AsyncSession session: db session - :return NetworkPolicy | None: a NetworkPolicy object + Args: + ip (IPv4Address | IPv6Address): IP address to filter + user (User): user object + session (AsyncSession): db session + + Returns: + NetworkPolicy | None: a NetworkPolicy object """ user_group_ids = [group.id for group in user.groups] @@ -102,10 +112,13 @@ async def is_user_group_valid( ) -> bool: """Validate user groups, is it including to policy. - :param User user: db user - :param NetworkPolicy policy: db policy - :param AsyncSession session: db - :return bool: status + Args: + user (User): db user + policy (NetworkPolicy): db policy + session (AsyncSession): db + + Returns: + bool: status """ if user is None or policy is None: return False diff --git a/app/ldap_protocol/policies/password_policy.py b/app/ldap_protocol/policies/password_policy.py index 25c789697..79c833ac3 100644 --- a/app/ldap_protocol/policies/password_policy.py +++ b/app/ldap_protocol/policies/password_policy.py @@ -29,8 +29,9 @@ async def post_save_password_actions( ) -> None: """Post save actions for password update. - :param User user: user from db - :param AsyncSession session: db + Args: + user (User): user from db + session (AsyncSession): db """ await session.execute( # update bind reject attribute update(Attribute) @@ -81,8 +82,11 @@ def _validate_minimum_pwd_age(self) -> "PasswordPolicySchema": async def create_policy_settings(self, session: AsyncSession) -> Self: """Create policies settings. - :param AsyncSession session: db session - :return PasswordPolicySchema: password policy. + Returns: + Self: Serialized password policy. + + Raises: + PermissionError: Policy already exists. """ existing_policy = await session.scalar(select(exists(PasswordPolicy))) if existing_policy: @@ -98,8 +102,8 @@ async def get_policy_settings( ) -> "PasswordPolicySchema": """Get policy settings. - :param AsyncSession session: db - :return PasswordPolicySchema: policy + Returns: + PasswordPolicySchema: policy """ policy = await session.scalar(select(PasswordPolicy)) if not policy: @@ -107,10 +111,7 @@ async def get_policy_settings( return cls.model_validate(policy, from_attributes=True) async def update_policy_settings(self, session: AsyncSession) -> None: - """Update policy. - - :param AsyncSession session: db - """ + """Update policy.""" await session.execute( (update(PasswordPolicy).values(self.model_dump(mode="json"))), ) @@ -123,8 +124,8 @@ async def delete_policy_settings( ) -> "PasswordPolicySchema": """Reset (delete) default policy. - :param AsyncSession session: db - :return PasswordPolicySchema: schema policy + Returns: + PasswordPolicySchema: schema policy """ default_policy = cls() await default_policy.update_policy_settings(session) @@ -134,8 +135,8 @@ async def delete_policy_settings( def _count_password_exists_days(last_pwd_set: Attribute) -> int: """Get number of days, pwd exists. - :param Attribute last_pwd_set: pwdLastSet - :return int: days + Returns: + int: count of days """ tz = ZoneInfo("UTC") now = datetime.now(tz=tz) @@ -155,9 +156,12 @@ async def get_pwd_last_set( ) -> Attribute: """Get pwdLastSet. - :param AsyncSession session: db - :param int directory_id: id - :return Attribute: pwdLastSet + Args: + session (AsyncSession): db + directory_id (int): id + + Returns: + Attribute: pwdLastSet """ plset = await session.scalar( select(Attribute) @@ -181,12 +185,9 @@ async def get_pwd_last_set( def validate_min_age(self, last_pwd_set: Attribute) -> bool: """Validate min password change age. - :param Attribute last_pwd_set: last pwd set - :return bool: can change pwd - True - not valid, can not change - False - valid, can change - - on minimum_password_age_days can always change. + Returns: + bool: can change pwd True - not valid, can not change False + - valid, can change on minimum_password_age_days can always change. """ if self.minimum_password_age_days == 0: return False @@ -198,12 +199,9 @@ def validate_min_age(self, last_pwd_set: Attribute) -> bool: def validate_max_age(self, last_pwd_set: Attribute) -> bool: """Validate max password change age. - :param Attribute last_pwd_set: last pwd set - :return bool: is pwd expired - True - not valid, expired - False - valid, not expired - - on maximum_password_age_days always valid. + Returns: + bool: is pwd expired True - not valid, expired False - + valid, not expired on maximum_password_age_days always valid. """ if self.maximum_password_age_days == 0: return False @@ -219,10 +217,12 @@ async def validate_password_with_policy( ) -> list[str]: """Validate password with chosen policy. - :param str password: new raw password - :param User user: db user - :param AsyncSession session: db - :return bool: status + Args: + password (str): new raw password + user (User): db user + + Returns: + bool: status """ errors = [] history: Iterable = [] @@ -242,9 +242,9 @@ async def validate_password_with_policy( errors.append("password minimum length violation") regex = ( - re.search("[A-ZА-Я]", password) is not None, - re.search("[a-zа-я]", password) is not None, - re.search("[0-9]", password) is not None, + re.search(r"[A-ZА-Я]", password) is not None, # noqa: RUF001 + re.search(r"[a-zа-я]", password) is not None, # noqa: RUF001 + re.search(r"[0-9]", password) is not None, password.lower() not in _COMMON_PASSWORDS, ) diff --git a/app/ldap_protocol/server.py b/app/ldap_protocol/server.py index 44f8ebad2..f94923ecd 100644 --- a/app/ldap_protocol/server.py +++ b/app/ldap_protocol/server.py @@ -57,7 +57,12 @@ class PoolClientHandler: ssl_context: ssl.SSLContext | None = None def __init__(self, settings: Settings, container: AsyncContainer): - """Set workers number for single client concurrent handling.""" + """Set workers number for single client concurrent handling. + + Args: + settings (Settings): settings + container (AsyncContainer): container + """ self.container = container self.settings = settings @@ -77,7 +82,12 @@ async def __call__( reader: asyncio.StreamReader, writer: asyncio.StreamWriter, ) -> None: - """Create session, queue and start message handlers concurrently.""" + """Create session, queue and start message handlers concurrently. + + Args: + reader (asyncio.StreamReader): reader + writer (asyncio.StreamWriter): writer + """ async with self.container(scope=Scope.SESSION) as session_scope: ldap_session = await session_scope.get(LDAPSession) addr, first_chunk = await self.recieve( @@ -124,7 +134,11 @@ async def __call__( await writer.wait_closed() def _load_ssl_context(self) -> None: - """Load SSL context for LDAPS.""" + """Load SSL context for LDAPS. + + Raises: + SystemExit: Certs not found + """ if self.settings.USE_CORE_TLS and self.settings.LDAP_LOAD_SSL_CERT: if not self.settings.check_certs_exist(): log.critical("Certs not found, exiting...") @@ -143,8 +157,15 @@ def _extract_proxy_protocol_address( ) -> tuple[IPv4Address | IPv6Address, bytes]: """Get ip from proxy protocol header. - :param bytes data: data - :return tuple: ip, data + Args: + data (bytes): data + writer (asyncio.StreamWriter): writer + + Returns: + tuple: ip, data + + Raises: + ValueError: Invalid source address """ peername = ":".join(map(str, writer.get_extra_info("peername"))) peer_addr = ip_address(peername.split(":")[0]) @@ -184,12 +205,17 @@ async def recieve( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, - return_addr: Literal[True, False] | bool = False, + return_addr: bool = False, ) -> tuple[IPv4Address | IPv6Address, bytes] | bytes: """Read N packets by 1kB. - :param asyncio.StreamReader reader: reader - :return tuple: ip, data + Args: + reader (asyncio.StreamReader): reader + writer (asyncio.StreamWriter): writer + return_addr (bool): address (Default value = "read") + + Returns: + tuple[IPv4Address | IPv6Address, bytes] | bytes: """ buffer = BytesIO() addr = None @@ -229,8 +255,11 @@ def _compute_ldap_message_size(data: bytes) -> int: source: https://github.com/cannatag/ldap3/blob/dev/ldap3/strategy/base.py#L455 - :param bytes data: body - :return int: actual size + Args: + data (bytes): body + + Returns: + int: actual size """ if len(data) > 2: if data[1] <= 127: # short @@ -255,12 +284,15 @@ async def _handle_request( ) -> None: """Create request object and send it to queue. - :param bytes data: initial data - :param asyncio.StreamReader reader: reader - :param asyncio.StreamWriter writer: writer - :param AsyncContainer container: container - :raises ConnectionAbortedError: if client sends empty request (b'') - :raises RuntimeError: reraises on unexpected exc + Args: + data (bytes): initial data + reader (asyncio.StreamReader): reader + writer (asyncio.StreamWriter): writer + container (AsyncContainer): container + + Raises: + ConnectionAbortedError: if client sends empty request (b'') + RuntimeError: reraises on unexpected exc """ ldap_session: LDAPSession = await container.get(LDAPSession) while True: @@ -294,9 +326,16 @@ async def _unwrap_request( ) -> bytes: """Unwrap request with GSSAPI security layer if needed. - :param bytes data: request data - :param LDAPSession ldap_session: session - :return bytes: unwrapped data + Args: + data (bytes): request data + ldap_session (LDAPSession): session + + Returns: + bytes: unwrapped data + + Raises: + ConnectionAbortedError: SASL buffer length mismatch or\ + GSSAPI security context not found """ if ldap_session.gssapi_security_layer in ( GSSAPISL.INTEGRITY_PROTECTION, @@ -326,6 +365,7 @@ async def _unwrap_request( @staticmethod def _req_log_full(addr: str, msg: LDAPRequestMessage) -> None: + """Request full log.""" log.debug( f"\nFrom: {addr!r}\n{msg.name}[{msg.message_id}]: " f"{msg.model_dump_json()}\n", @@ -333,6 +373,7 @@ def _req_log_full(addr: str, msg: LDAPRequestMessage) -> None: @staticmethod def _resp_log_full(addr: str, msg: LDAPResponseMessage) -> None: + """Response full log.""" log.debug( f"\nTo: {addr!r}\n{msg.name}[{msg.message_id}]: " f"{msg.model_dump_json()}"[:3000], @@ -340,6 +381,7 @@ def _resp_log_full(addr: str, msg: LDAPResponseMessage) -> None: @staticmethod def _log_short(addr: str, msg: LDAPMessage) -> None: + """Short log.""" log.info(f"\n{addr!r}: {msg.name}[{msg.message_id}]\n") async def _handle_single_response( @@ -347,7 +389,15 @@ async def _handle_single_response( writer: asyncio.StreamWriter, container: AsyncContainer, ) -> None: - """Get message from queue and handle it.""" + """Get message from queue and handle it. + + Args: + writer (asyncio.StreamWriter): writer + container (AsyncContainer): container + + Raises: + RuntimeError: any error + """ ldap_session: LDAPSession = await container.get(LDAPSession) addr = str(ldap_session.ip) @@ -387,10 +437,13 @@ async def _wrap_response( ) -> bytes: """Wrap response with GSSAPI security layer if needed. - :param bytes data: response data - :param LDAPSession ldap_session: session - :param int protocol_op: protocol operation - :return bytes: wrapped data + Args: + data (bytes): response data + ldap_session (LDAPSession): session + protocol_op (int): protocol operation + + Returns: + bytes: wrapped data """ if ( ldap_session.gssapi_authenticated @@ -423,6 +476,10 @@ async def _handle_responses( Spawns (default 5) workers, then every task awaits for queue object, cycle locks until pool completes at least 1 task. + + Args: + writer (asyncio.StreamWriter): writer + container (AsyncContainer): container """ tasks = [ self._handle_single_response(writer, container) @@ -432,7 +489,11 @@ async def _handle_responses( await asyncio.gather(*tasks) async def _get_server(self) -> asyncio.base_events.Server: - """Get async server.""" + """Get async server. + + Returns: + asyncio.base_events.Server: async server + """ return await asyncio.start_server( self, str(self.settings.HOST), @@ -449,6 +510,7 @@ async def _run_server(server: asyncio.base_events.Server) -> None: @staticmethod def log_addrs(server: asyncio.base_events.Server) -> None: + """Log server addresses.""" addrs = ", ".join(str(sock.getsockname()) for sock in server.sockets) log.info(f"Server on {addrs}") diff --git a/app/ldap_protocol/session_storage.py b/app/ldap_protocol/session_storage.py index 1183892c1..86afaf0a1 100644 --- a/app/ldap_protocol/session_storage.py +++ b/app/ldap_protocol/session_storage.py @@ -7,7 +7,7 @@ import json from abc import ABC, abstractmethod from collections import defaultdict -from datetime import datetime, timezone +from datetime import UTC, datetime from secrets import token_hex from typing import Iterable, Literal, Self @@ -32,26 +32,21 @@ class SessionStorage(ABC): async def get(self, key: str) -> dict: """Retrieve data associated with the given key from storage. - :param str key: The key to look up in the storage. - :return dict: The data associated with the key, - or an empty dictionary if the key is not found. + Args: + key (str): The key to look up in the storage. + + Returns: + dict: The data associated with the key, or an empty + dictionary if the key is not found. """ @abstractmethod async def _get_session_keys_by_uid(self, uid: int) -> set[str]: - """Get session keys by user id. - - :param int uid: user id - :return set[str]: session keys - """ + """Get session keys by user id.""" @abstractmethod async def _get_session_keys_by_ip(self, ip: str) -> set[str]: - """Get session keys by ip. - - :param str ip: ip - :return set[str]: session keys - """ + """Get session keys by ip.""" @abstractmethod async def get_user_sessions( @@ -61,9 +56,14 @@ async def get_user_sessions( ) -> dict: """Get sessions by user id. - :param int uid: user id - :param ProtocolType | None protocol: protocol - :return dict: user sessions contents + Args: + uid (int): user id + protocol (ProtocolType | None): The protocol type to filter\ + sessions by (e.g., "http" or "ldap"). If None,\ + sessions for all protocols are returned. + + Returns: + dict: user sessions contents """ @abstractmethod @@ -74,29 +74,35 @@ async def get_ip_sessions( ) -> dict: """Get sessions data by ip. - :param str ip: ip - :param ProtocolType | None protocol: protocol - :return dict: user sessions contents + Args: + ip (str): ip + protocol (ProtocolType | None): The protocol type to filter\ + sessions by (e.g., "http" or "ldap"). If None,\ + sessions for all protocols are returned. + + Returns: + dict: user sessions contents """ @abstractmethod async def clear_user_sessions(self, uid: int) -> None: - """Clear user sessions. - - :param int uid: user id - :return None: - """ + """Clear user sessions.""" @abstractmethod async def delete_user_session(self, session_id: str) -> None: - """Delete user session. - - :param str session_id: session id - :return None: - """ + """Delete user session.""" @staticmethod def _sign(session_id: str, settings: Settings) -> str: + """Sign session id. + + Args: + session_id (str): Session id + settings (Settings): Settings with database dsn. + + Returns: + str: The HMAC signature for the session_id using provided settings. + """ return hmac.new( settings.SECRET_KEY.encode(), session_id.encode(), @@ -104,13 +110,35 @@ def _sign(session_id: str, settings: Settings) -> str: ).hexdigest() def get_user_agent_hash(self, user_agent: str) -> str: - """Get user agent hash.""" + """Get user agent hash. + + Returns: + str: The hash of the user agent. + """ return hashlib.blake2b(user_agent.encode(), digest_size=6).hexdigest() def _get_ip_session_key(self, ip: str, protocol: ProtocolType) -> str: + """Get ip session key. + + Args: + ip (str): IP + protocol (ProtocolType): Type of Protocol + + Returns: + str: The session key for the given IP and protocol. + """ return f"ip:{protocol}:{ip}" def _get_user_session_key(self, uid: int, protocol: ProtocolType) -> str: + """Get user session key. + + Args: + uid (int): uid + protocol (ProtocolType): Type of Protocol + + Returns: + str: The session key for the given user and protocol. + """ return f"keys:{protocol}:{uid}" def _get_protocol(self, session_id: str) -> ProtocolType: @@ -119,16 +147,12 @@ def _get_protocol(self, session_id: str) -> ProtocolType: def _generate_key(self) -> str: """Generate a new key for storing data in the storage. - :return str: A new key. + Returns: + str: New key. """ return f"http:{token_hex(self.key_length)}" def _get_lock_key(self, session_id: str) -> str: - """Get lock key. - - :param str session_id: session id - :return str: lock key - """ return f"lock:{session_id}" @abstractmethod @@ -141,10 +165,14 @@ async def create_session( ) -> str: """Create session. - :param int uid: user id - :param Settings settings: app settings - :param dict | None extra_data: data, defaults to None - :return str: session id + Args: + uid (int): user id + settings (Settings): app settings + extra_data (dict | None): Additional data to include\ + in the session, defaults to None. + + Returns: + str: session id """ async def get_user_id( @@ -156,11 +184,17 @@ async def get_user_id( ) -> int: """Get user from storage. - :param Settings settings: app settings - :param str session_key: session key - :param str user_agent: user agent - :param str ip: ip address - :return int: user id + Args: + settings (Settings): app settings + session_key (str): session key + user_agent (str): user agent + ip (str): ip address + + Returns: + int: user id. + + Raises: + KeyError: key error. """ try: session_id, signature = session_key.split(".") @@ -195,7 +229,17 @@ def _generate_session_data( settings: Settings, extra_data: dict | None, ) -> tuple[str, str, dict]: - """Set data.""" + """Set data. + + Args: + uid (int): uid + settings (Settings): Settings with database dsn. + extra_data (dict | None): additional data + + Returns: + tuple[str, str, dict]: A tuple containing the session_id,\ + signature, and session data dictionary. + """ if extra_data is None: extra_data = {} @@ -203,15 +247,15 @@ def _generate_session_data( signature = self._sign(session_id, settings) data = {"id": uid, "sign": signature} | extra_data - data["issued"] = datetime.now(timezone.utc).isoformat() + data["issued"] = datetime.now(UTC).isoformat() return session_id, signature, data @abstractmethod async def check_session(self, session_id: str) -> bool: """Check session. - :param str session_id: session id - :return bool: True if session exists + Returns: + bool: True if session exists """ @abstractmethod @@ -223,26 +267,34 @@ async def create_ldap_session( ) -> None: """Create ldap session. - :param int uid: user id - :param dict data: data, defaults to None + Args: + uid (int): user id + key (str): key + data (dict): data, defaults to None """ @abstractmethod async def check_rekey(self, session_id: str, rekey_interval: int) -> bool: """Check rekey. - :param str session_id: session id - :param int rekey_interval: rekey interval in seconds - :return bool: True if rekey is needed + Args: + session_id (str): session id + rekey_interval (int): rekey interval in seconds + + Returns: + bool: True if rekey is needed """ @abstractmethod async def rekey_session(self, session_id: str, settings: Settings) -> str: """Rekey session. - :param str session_id: session id - :param Settings settings: app settings - :return str: jwt token + Args: + session_id (str): session id + settings (Settings): app settings + + Returns: + str: jwt token """ @@ -299,10 +351,11 @@ class RedisSessionStorage(SessionStorage): def __init__(self, storage: Redis, key_length: int, key_ttl: int) -> None: """Initialize the storage. - :param Redis storage: - The Redis/DragonflyDB instance to use for storage. - :param int key_length: The length of the keys to generate. - :param int key_ttl: The time-to-live for keys in seconds. + Args: + storage (Redis): The Redis/DragonflyDB instance to use for + storage. + key_length (int): The length of the keys to generate. + key_ttl (int): The time-to-live for keys in seconds. """ self._storage = storage self.key_length = key_length @@ -311,9 +364,12 @@ def __init__(self, storage: Redis, key_length: int, key_ttl: int) -> None: async def _get_lock(self, name: str, blocking_timeout: int = 5) -> Lock: """Get lock. - :param str name: lock name - :param int blocking_timeout: blocking timeout, defaults to 5 - :return Lock: lock object + Args: + name (str): lock name + blocking_timeout (int): blocking timeout, defaults to 5 + + Returns: + Lock: lock object """ return self._storage.lock( name=self._get_lock_key(name), @@ -323,9 +379,15 @@ async def _get_lock(self, name: str, blocking_timeout: int = 5) -> Lock: async def get(self, key: str) -> dict: """Retrieve data associated with the given key from storage. - :param str key: The key to look up in the storage. - :return dict: The data associated with the key, - or an empty dictionary if the key is not found. + Args: + key (str): The key to look up in the storage. + + Returns: + dict: The data associated with the key, or an empty + dictionary if the key is not found. + + Raises: + KeyError: If the key is not found in the storage. """ data = await self._storage.get(key) if data is None: @@ -333,16 +395,14 @@ async def get(self, key: str) -> dict: return json.loads(data) async def delete(self, keys: Iterable[str]) -> None: - """Delete data associated with the given key from storage. - - :param str key: The key to delete from the storage. - """ + """Delete data associated with the given key from storage.""" await self._storage.delete(*keys) async def _fetch_keys(self, key: str) -> set[str]: """Fetch keys. - :param str key: key + Returns: + set[str]: A set of decoded keys from the storage. """ encoded_keys = await self._storage.smembers(key) # type: ignore return {k.decode() for k in encoded_keys} @@ -358,9 +418,12 @@ async def _get_session_keys_by_ip( specific protocol is provided, only sessions for that protocol are returned. - :param str ip: ip - :param ProtocolType | None protocol: protocol - :return set[str]: session keys + Args: + ip (str): ip + protocol (ProtocolType | None): protocol + + Returns: + set[str]: session keys """ if protocol: return await self._fetch_keys( @@ -382,9 +445,12 @@ async def _get_session_keys_by_uid( specific protocol is provided, only sessions for that protocol are returned. - :param int uid: user id - :param ProtocolType | None protocol: protocol - :return set[str]: session keys + Args: + uid (int): user id + protocol (ProtocolType | None): protocol + + Returns: + set[str]: session keys """ if protocol: return await self._fetch_keys( @@ -416,9 +482,12 @@ async def _get_sessions(self, keys: set[str], id_value: str | int) -> dict: 4. Remove expired session keys from the sets that track user ID or IP sessions. - :param set[str] keys: session keys - :param str | int id_value: user id or ip - :return dict: user sessions contents + Args: + keys (set[str]): session keys + id_value (str | int): user id or ip + + Returns: + dict: user sessions contents """ if not keys: return {} @@ -458,9 +527,12 @@ async def get_user_sessions( ) -> dict: """Get sessions by user id. - :param int uid: user id - :param ProtocolType | None protocol: protocol - :return dict: user sessions contents + Args: + uid (int): user id + protocol (ProtocolType | None): protocol + + Returns: + dict: user sessions contents """ keys = await self._get_session_keys_by_uid(uid, protocol) return await self._get_sessions(keys, uid) @@ -472,9 +544,12 @@ async def get_ip_sessions( ) -> dict: """Get sessions data by ip. - :param str ip: ip - :param ProtocolType | None protocol: protocol - :return dict: user sessions contents + Args: + ip (str): ip + protocol (ProtocolType | None): protocol + + Returns: + dict: user sessions contents """ keys = await self._get_session_keys_by_ip(ip, protocol) return await self._get_sessions(keys, ip) @@ -494,7 +569,8 @@ async def clear_user_sessions(self, uid: int) -> None: 5. Identify and remove session references stored under UID-based keys. 6. Delete all user session keys from storage. - :param int uid: user id + Args: + uid (int): user id """ keys = await self._get_session_keys_by_uid(uid) if not keys: @@ -544,7 +620,11 @@ async def delete_user_session(self, session_id: str) -> None: 9. Delete the session data from storage. 10. Release the lock. - :param str session_id: session id + Args: + session_id (str): session id + + Raises: + KeyError: key error. """ try: data = await self.get(session_id) @@ -591,12 +671,13 @@ async def _add_session( Adds a session to the storage and updates the session tracking keys for both user ID and IP address. - :param str session_id: session id - :param dict data: session data - :param int uid: user id - :param str ip_session_key: ip session key - :param str sessions_key: sessions key - :param int | None ttl: time to live, defaults to None + Args: + session_id (str): session id + data (dict): session data + uid (int): user id + ip_session_key (str): ip session key + sessions_key (str): sessions key + ttl (int | None): time to live, defaults to None """ zset_key = ( self.ZSET_HTTP_SESSIONS @@ -634,14 +715,15 @@ async def create_session( 3. Link the session to the user's session tracking key (`keys:http:`). 4. If an IP address is provided in `extra_data`, also link the session - to the IP-based session tracking key (`ip:http:`). + to the IP-based session tracking key (`ip:http:`). + + Args: + uid (int): user id + settings (Settings): settings + extra_data (dict): extra data - :param int uid: user id - :param dict data: data dict - :param str secret: secret key - :param int expires_minutes: exire time in minutes - :param Literal[refresh, access] grant_type: grant type flag - :return str: jwt token + Returns: + str: jwt token """ session_id, signature, data = self._generate_session_data( uid=uid, @@ -666,7 +748,11 @@ async def create_session( return f"{session_id}.{signature}" async def check_session(self, session_id: str) -> bool: - """Check session.""" + """Check session. + + Returns: + bool: True if exists. + """ return await self._storage.exists(session_id) async def create_ldap_session( @@ -689,11 +775,13 @@ async def create_ldap_session( 4. If an IP address is provided in `extra_data`, also link the session to the IP-based session tracking key (`ip:ldap:`). - :param int uid: user id - :param str key: session key - :param dict data: any data + Args: + uid (int): user id + key (str): The session key to use for storing the LDAP session. + This is the unique identifier for the LDAP session in storage. + data (dict): any data """ - data["issued"] = datetime.now(timezone.utc).isoformat() + data["issued"] = datetime.now(UTC).isoformat() ldap_sessions_key = self._get_user_session_key(uid, "ldap") ip_sessions_key = None @@ -711,9 +799,12 @@ async def create_ldap_session( async def check_rekey(self, session_id: str, rekey_interval: int) -> bool: """Check rekey. - :param str session_id: session id - :param int rekey_interval: rekey interval in seconds - :return bool: True if rekey is needed + Args: + session_id (str): session id + rekey_interval (int): rekey interval in seconds + + Returns: + bool: True if rekey is needed """ lock = await self._get_lock(session_id) @@ -723,7 +814,7 @@ async def check_rekey(self, session_id: str, rekey_interval: int) -> bool: data = await self.get(session_id) issued = datetime.fromisoformat(data.get("issued")) # type: ignore - return (datetime.now(timezone.utc) - issued).seconds > rekey_interval + return (datetime.now(UTC) - issued).seconds > rekey_interval async def _rekey_session(self, session_id: str, settings: Settings) -> str: """Rekey session. @@ -745,9 +836,15 @@ async def _rekey_session(self, session_id: str, settings: Settings) -> str: - The IP-based session tracking key (`ip:http:`) 8. Delete the old session. - :param str session_id: session id - :param Settings settings: app settings - :return str: jwt token + Args: + session_id (str): session id + settings (Settings): app settings + + Returns: + str: jwt token + + Raises: + KeyError: key error. """ data = await self.get(session_id) @@ -785,9 +882,12 @@ async def _rekey_session(self, session_id: str, settings: Settings) -> str: async def rekey_session(self, session_id: str, settings: Settings) -> str: """Rekey session. - :param str session_id: session id - :param Settings settings: app settings - :return str: jwt token + Args: + session_id (str): session id + settings (Settings): app settings + + Returns: + str: jwt token """ lock = await self._get_lock(session_id) diff --git a/app/ldap_protocol/user_account_control.py b/app/ldap_protocol/user_account_control.py index 6aab592a0..23616119d 100644 --- a/app/ldap_protocol/user_account_control.py +++ b/app/ldap_protocol/user_account_control.py @@ -72,8 +72,12 @@ class UserAccountControlFlag(IntFlag): def is_value_valid(cls, uac_value: str | int) -> bool: """Check all flags set in the userAccountControl value. - :param int uac_value: userAccountControl attribute value - :return: True if the value is valid (only known flags), False otherwise + Args: + uac_value(int): userAccountControl attribute value + uac_value: str | int: + + Returns: + bool: True if all flags are set correctly, False otherwise """ if isinstance(uac_value, int): pass @@ -94,9 +98,12 @@ async def get_check_uac( ) -> Callable[[UserAccountControlFlag], bool]: """Get userAccountControl attribute and check binary flags in it. - :param AsyncSession session: SA async session - :param int directory_id: id - :return Callable: function to check given flag in current + Args: + session (AsyncSession): SA async session + directory_id (int): id + + Returns: + Callable: function to check given flag in current userAccountControl attribute """ query = ( @@ -112,8 +119,11 @@ async def get_check_uac( def is_flag_true(flag: UserAccountControlFlag) -> bool: """Check given flag in current userAccountControl attribute. - :param userAccountControlFlag flag: flag - :return bool: result + Args: + flag (UserAccountControlFlag): flag + + Returns: + bool: True if flag is set, False otherwise """ return bool(int(value) & flag) diff --git a/app/ldap_protocol/utils/const.py b/app/ldap_protocol/utils/const.py index 7f9ce7bc5..76307a6c1 100644 --- a/app/ldap_protocol/utils/const.py +++ b/app/ldap_protocol/utils/const.py @@ -13,6 +13,17 @@ def _type_validate_entry(entry: str) -> str: + """Validate entry name. + + Args: + entry (str): entry name + + Returns: + str: entry name + + Raises: + ValueError: Invalid entry name + """ if validate_entry(entry): return entry raise ValueError(f"Invalid entry name {entry}") @@ -24,9 +35,17 @@ def _type_validate_entry(entry: str) -> str: def _type_validate_email(email: str) -> str: + """Validate email. + + Returns: + str: email address + + Raises: + ValueError: Invalid email + """ if EMAIL_RE.fullmatch(email): return email - raise ValueError(f"Invalid entry name {email}") + raise ValueError(f"Invalid email {email}") GRANT_DN_STRING = Annotated[str, AfterValidator(_type_validate_entry)] diff --git a/app/ldap_protocol/utils/cte.py b/app/ldap_protocol/utils/cte.py index 98b4e48a1..628956142 100644 --- a/app/ldap_protocol/utils/cte.py +++ b/app/ldap_protocol/utils/cte.py @@ -26,26 +26,26 @@ def find_members_recursive_cte(dn: str) -> CTE: ------------------ 1. **Base Query (Initial Part of the CTE)**: - The function begins by defining the initial part of the CTE, named - `directory_hierarchy`. This query selects the `directory_id` and - `group_id` from the `Directory` and `Groups` tables, filtering based - on the distinguished name (DN) provided by the `dn` argument. + !The function begins by defining the initial part of the CTE, named + !`directory_hierarchy`. This query selects the `directory_id` and + !`group_id` from the `Directory` and `Groups` tables, filtering based + !on the distinguished name (DN) provided by the `dn` argument. 2. **Recursive Part of the CTE**: - The second part of the CTE is recursive. It joins the results of - `directory_hierarchy` with the `DirectoryMemberships` table to find - all groups that are members of other groups, iterating through - all nested memberships. + !The second part of the CTE is recursive. It joins the results of + !`directory_hierarchy` with the `DirectoryMemberships` table to find + !all groups that are members of other groups, iterating through + !all nested memberships. 3. **Combining Results**: - The CTE combines the initial and recursive parts using `union_all` - effectively creating a recursive query that gathers all directorie - and their associated groups, both directly and indirectly related. + !The CTE combines the initial and recursive parts using `union_all` + !effectively creating a recursive query that gathers all directorie + !and their associated groups, both directly and indirectly related. 4. **Final Query**: - The final query applies the method (typically a comparison operation - to the results of the CTE, returning the desired condition for furthe - use in the main query. + !The final query applies the method (typically a comparison operation + !to the results of the CTE, returning the desired condition for furthe + !use in the main query. The query translates to the following SQL: @@ -76,6 +76,11 @@ def find_members_recursive_cte(dn: str) -> CTE: In the case of a recursive search through the specified group1, the search result will be as follows: user1, user2, group2, user3, group3, user4. + Args: + dn (str): domain name + + Returns: + CTE: Common Table Expression """ directory_hierarchy = ( select(Directory.id.label("directory_id"), Group.id.label("group_id")) @@ -102,7 +107,7 @@ def find_members_recursive_cte(dn: str) -> CTE: return directory_hierarchy.union_all(recursive_part) -def find_root_group_recursive_cte(dn_list: list) -> CTE: +def find_root_group_recursive_cte(dn_list: list[str]) -> CTE: """Create CTE to filter directory root group. The query translates to the following SQL: @@ -135,6 +140,11 @@ def find_root_group_recursive_cte(dn_list: list) -> CTE: result will be as follows: group1, group2, group3, user4. + Args: + dn_list (list[str]): domain names + + Returns: + CTE: Common Table Expression """ directory_hierarchy = ( select( @@ -177,6 +187,15 @@ async def get_members_root_group( result will be as follows: group1, user1, user2, group2, user3, group3, user4. + Args: + dn (str): domain name + session (AsyncSession): async session + + Returns: + list[Directory]: list of directories + + Raises: + RuntimeError: not found directory """ cte = find_root_group_recursive_cte([dn]) result = await session.scalars(select(cte.c.directory_id)) @@ -206,7 +225,7 @@ async def get_members_root_group( select(Directory) .where( or_( - *[Directory.id == dir_id for dir_id in dir_ids], + *[Directory.id == dir_id for dir_id in dir_ids] ) ) ) # fmt: skip @@ -222,9 +241,12 @@ async def get_all_parent_group_directories( ) -> AsyncScalarResult | None: """Get all parent groups directory. - :param list[Group] groups: directory groups - :param AsyncSession session: session - :return set[Directory]: all groups and their parent group directories + Args: + groups (list[Group]): directory groups + session (AsyncSession): session + + Returns: + AsyncScalarResult | None: all groups and their parent group directories """ dn_list = [group.directory.path_dn for group in groups] diff --git a/app/ldap_protocol/utils/helpers.py b/app/ldap_protocol/utils/helpers.py index 0b0553680..40a31338f 100644 --- a/app/ldap_protocol/utils/helpers.py +++ b/app/ldap_protocol/utils/helpers.py @@ -148,8 +148,12 @@ def validate_entry(entry: str) -> bool: cn=first,dc=example,dc=com -> valid cn=first,dc=example,dc=com -> valid - :param str entry: any str - :return bool: result + + Args: + entry (str): entry path + + Returns: + bool: entry path is correct """ return all( re.match(r"^[a-zA-Z\-]+$", part.split("=")[0]) @@ -159,22 +163,46 @@ def validate_entry(entry: str) -> bool: def is_dn_in_base_directory(base_directory: Directory, entry: str) -> bool: - """Check if an entry in a base dn.""" + """Check if an entry in a base dn. + + Args: + base_directory (Directory): instance of Directory + entry (str): entry path + + Returns: + bool: True if the entry is in the base directory, False otherwise + """ return entry.lower().endswith(base_directory.path_dn.lower()) def dn_is_base_directory(base_directory: Directory, entry: str) -> bool: - """Check if an entry is a base dn.""" + """Check if an entry is a base dn. + + Args: + base_directory (Directory): base Directory instance + entry (str): entry path + + Returns: + bool: True if the entry is a base dn, False otherwise + """ return base_directory.path_dn.lower() == entry.lower() def get_generalized_now(tz: ZoneInfo) -> str: - """Get generalized time (formated) with tz.""" + """Get generalized time (formated) with tz. + + Returns: + str: generalized time + """ return datetime.now(tz).strftime("%Y%m%d%H%M%S.%f%z") def _get_domain(name: str) -> str: - """Get domain from name.""" + """Get domain from name. + + Returns: + str: domain + """ return ".".join( [ item[3:].lower() @@ -187,15 +215,22 @@ def _get_domain(name: str) -> str: def create_integer_hash(text: str, size: int = 9) -> int: """Create integer hash from text. - :param str text: any string - :param int size: fixed size of hash, defaults to 15 - :return int: hash + Args: + text (str): any string + size (int): fixed size of hash, defaults to 9 + + Returns: + int: hash """ return int(hashlib.sha256(text.encode("utf-8")).hexdigest(), 16) % 10**size def get_windows_timestamp(value: datetime) -> int: - """Get the Windows timestamp from the value.""" + """Get the Windows timestamp from the value. + + Returns: + int: Windows timestamp + """ return (int(value.timestamp()) + 11644473600) * 10000000 @@ -207,6 +242,12 @@ def dt_to_ft(dt: datetime) -> int: """Convert a datetime to a Windows filetime. If the object is time zone-naive, it is forced to UTC before conversion. + + Args: + dt (datetime): date and time + + Returns: + int: Windows filetime """ if dt.tzinfo is None or dt.tzinfo.utcoffset(dt) != 0: dt = dt.astimezone(ZoneInfo("UTC")) @@ -221,6 +262,12 @@ def ft_to_dt(filetime: int) -> datetime: The new datetime object is timezone-naive but is equivalent to tzinfo=utc. 1) Get seconds and remainder in terms of Unix epoch 2) Convert to datetime object, with remainder as microseconds. + + Args: + filetime (int): Windows file time number + + Returns: + datetime: Python datetime """ s, ns100 = divmod(filetime - _EPOCH_AS_FILETIME, _HUNDREDS_OF_NS) return datetime.fromtimestamp(s, tz=ZoneInfo("UTC")).replace( @@ -229,7 +276,11 @@ def ft_to_dt(filetime: int) -> datetime: def ft_now() -> str: - """Get now filetime timestamp.""" + """Get now filetime timestamp. + + Returns: + str: now filetime timestamp + """ return str(dt_to_ft(datetime.now(tz=ZoneInfo("UTC")))) @@ -245,8 +296,11 @@ def string_to_sid(sid_string: str) -> bytes: - The identifier authority is packed as a 6-byte sequence. - Each sub-authority is packed as a 4-byte sequence. - :param sid_string: The string representation of the SID - :return bytes: The binary representation of the SID + Args: + sid_string (str): The string representation of the SID + + Returns: + bytes: The binary representation of the SID """ parts = sid_string.split("-") @@ -274,19 +328,25 @@ def create_object_sid( ) -> str: """Generate the objectSid attribute for an object. - :param domain: domain directory - :param int rid: relative identifier - :param bool reserved: A flag indicating whether the RID is reserved. - If `True`, the given RID is used directly. If - `False`, 1000 is added to the given RID to generate - the final RID - :return str: the complete objectSid as a string + Args: + domain (Directory): domain directory + rid (int): relative identifier + reserved (bool): A flag indicating whether the RID is reserved. + If `True`, the given RID is used directly. If `False`, 1000 + is added to the given RID to generate the final RID + + Returns: + str: the complete objectSid as a string """ return domain.object_sid + f"-{rid if reserved else 1000 + rid}" def generate_domain_sid() -> str: - """Generate domain objectSid attr.""" + """Generate domain objectSid attr. + + Returns: + str: domain objectSid attr + """ sub_authorities = [ random.randint(1000000000, (1 << 32) - 1), random.randint(1000000000, (1 << 32) - 1), @@ -299,6 +359,9 @@ def create_user_name(directory_id: int) -> str: """Create username by directory id. NOTE: keycloak + + Returns: + str: username """ return blake2b(str(directory_id).encode(), digest_size=8).hexdigest() diff --git a/app/ldap_protocol/utils/pagination.py b/app/ldap_protocol/utils/pagination.py index 3e20d06a6..e85b97020 100644 --- a/app/ldap_protocol/utils/pagination.py +++ b/app/ldap_protocol/utils/pagination.py @@ -67,7 +67,14 @@ class BaseSchemaModel[S: Base](BaseModel): @classmethod @abstractmethod def from_db(cls, sqla_instance: S) -> "BaseSchemaModel[S]": - """Create an instance of Schema from instance of SQLA model.""" + """Create an instance of Schema from instance of SQLA model. + + Args: + sqla_instance (S): instance of SQLAlchemy Model + + Returns: + BaseSchemaModel[S]: instance of Schema + """ @dataclass @@ -88,7 +95,20 @@ async def get( sqla_model: type[S], session: AsyncSession, ) -> "PaginationResult[S]": - """Get paginator.""" + """Get paginator. + + Args: + query (Select[tuple[S]]): SQLAlchemy query to execute. + params (PaginationParams): Pagination parameters. + sqla_model (type[S]): SQLAlchemy model class to paginate. + session (AsyncSession): SQLAlchemy async session. + + Raises: + ValueError: If the query does not have an order_by clause. + + Returns: + PaginationResult[S]: Paginator with metadata and items. + """ if query._order_by_clause is None or len(query._order_by_clause) == 0: raise ValueError("Select query must have an order_by clause.") diff --git a/app/ldap_protocol/utils/queries.py b/app/ldap_protocol/utils/queries.py index d7561b683..9076ce08d 100644 --- a/app/ldap_protocol/utils/queries.py +++ b/app/ldap_protocol/utils/queries.py @@ -28,7 +28,11 @@ @cache async def get_base_directories(session: AsyncSession) -> list[Directory]: - """Get base domain directories.""" + """Get base domain directories. + + Returns: + list[Directory]: base domain directories + """ result = await session.execute( select(Directory) .filter(Directory.parent_id.is_(None)) @@ -39,9 +43,12 @@ async def get_base_directories(session: AsyncSession) -> list[Directory]: async def get_user(session: AsyncSession, name: str) -> User | None: """Get user with username. - :param AsyncSession session: sqlalchemy session - :param str name: any name: dn, email or upn - :return User | None: user from db + Args: + session (AsyncSession): sqlalchemy session + name (str): any name: dn, email or upn + + Returns: + User | None: user from db """ policies = selectinload(User.groups).selectinload(Group.access_policies) @@ -65,7 +72,15 @@ async def get_directories( dn_list: list[GRANT_DN_STRING], session: AsyncSession, ) -> list[Directory]: - """Get directories by dn list.""" + """Get directories by dn list. + + Args: + dn_list (list[ENTRY_TYPE]): dn list + session (AsyncSession): sqlalchemy session + + Returns: + list[Directory]: directories + """ paths = [] for dn in dn_list: @@ -90,7 +105,15 @@ async def get_directories( async def get_groups(dn_list: list[str], session: AsyncSession) -> list[Group]: - """Get dirs with groups by dn list.""" + """Get dirs with groups by dn list. + + Args: + dn_list (list[str]): dn list + session (AsyncSession): sqlalchemy session + + Returns: + list[Group]: groups + """ return [ directory.group for directory in await get_directories(dn_list, session) @@ -99,14 +122,20 @@ async def get_groups(dn_list: list[str], session: AsyncSession) -> list[Group]: async def get_group( - dn: str | GRANT_DN_STRING, session: AsyncSession + dn: str | GRANT_DN_STRING, + session: AsyncSession, ) -> Directory: """Get dir with group by dn. - :param str dn: Distinguished Name - :param AsyncSession session: SA session - :raises AttributeError: on invalid dn - :return Directory: dir with group + Args: + dn (str| ENTRY_TYPE): Distinguished Name + session (AsyncSession): SA session + + Returns: + Directory: dir with group + + Raises: + ValueError: Cannot set memberOf with base dn or group not found """ for base_directory in await get_base_directories(session): if dn_is_base_directory(base_directory, dn): @@ -132,9 +161,12 @@ async def check_kerberos_group( ) -> bool: """Check if user in kerberos group. - :param User | None user: user (sa model) - :param AsyncSession session: db - :return bool: exists result + Args: + user (User | None): user (sa model) + session (AsyncSession): db + + Returns: + bool: exists result """ if user is None: return False @@ -157,7 +189,13 @@ async def set_last_logon_user( session: AsyncSession, tz: ZoneInfo, ) -> None: - """Update lastLogon attr.""" + """Update lastLogon attr. + + Args: + user (User): user + session (AsyncSession): sqlalchemy session + tz (ZoneInfo): timezone info + """ await session.execute( update(User) .values({"last_logon": datetime.now(tz=tz)}) @@ -169,8 +207,8 @@ async def set_last_logon_user( def get_search_path(dn: str) -> list[str]: """Get search path for dn. - :param str dn: any DN, dn syntax - :return list[str]: reversed list of dn values + Returns: + list[str]: reversed list of dn values """ search_path = [path.strip() for path in dn.lower().split(",")] search_path.reverse() @@ -184,9 +222,13 @@ def get_path_filter( ) -> ColumnElement: """Get filter condition for path equality. - :param list[str] path: dn - :param Column field: path column, defaults to Directory.path - :return ColumnElement: filter (where) element + Args: + path (list[str]): domain name + column (ColumnElement | Column | InstrumentedAttribute):\ + (Default value = Directory.path) + + Returns: + ColumnElement: filter (where) element """ return func.array_lowercase(column) == path @@ -196,7 +238,16 @@ def get_filter_from_path( *, column: Column | InstrumentedAttribute = Directory.path, ) -> ColumnElement: - """Get filter condition for path equality from dn.""" + """Get filter condition for path equality from dn. + + Args: + dn (str): any DN, dn syntax + column (Column | InstrumentedAttribute): (Default value =\ + Directory.path) + + Returns: + ColumnElement: filter (where) element + """ return get_path_filter(get_search_path(dn), column=column) @@ -205,6 +256,13 @@ async def get_dn_by_id(id_: int, session: AsyncSession) -> str: >>> await get_dn_by_id(0, session) >>> "cn=groups,dc=example,dc=com" + + Args: + id_ (int): id + session (AsyncSession): Database session + + Returns: + str: domain name """ query = select(Directory).filter(Directory.id == id_) retval = (await session.scalars(query)).one() @@ -212,7 +270,11 @@ async def get_dn_by_id(id_: int, session: AsyncSession) -> str: def get_domain_object_class(domain: Directory) -> Iterator[Attribute]: - """Get default domain attrs.""" + """Get default domain attrs. + + Yields: + Iterator[Attribute] + """ for value in ["domain", "top", "domainDNS"]: yield Attribute(name="objectClass", value=value, directory=domain) @@ -226,9 +288,13 @@ async def create_group( cn=name,cn=groups,dc=domain,dc=com - :param str name: group name - :param int sid: objectSid - :param AsyncSession session: db + Args: + name (str): group name + sid (int): objectSid + session (AsyncSession): db + + Returns: + tuple[Directory, Group] """ base_dn_list = await get_base_directories(session) @@ -282,8 +348,12 @@ async def create_group( async def is_computer(directory_id: int, session: AsyncSession) -> bool: """Determine whether the entry is a computer. - :param AsyncSession session: db - :param int directory_id: id + Args: + session (AsyncSession): db + directory_id (int): id + + Returns: + bool: True if the entry is a computer, False otherwise """ query = select( select(Attribute) @@ -304,26 +374,25 @@ async def add_lock_and_expire_attributes( ) -> None: """Add `nsAccountLock` and `shadowExpire` attributes to the directory. - :param AsyncSession session: db - :param Directory directory: directory - :param ZoneInfo tz: timezone info + Args: + session (AsyncSession): db + directory (Directory): directory + tz (ZoneInfo): timezone info """ now_with_tz = datetime.now(tz=tz) absolute_date = int(time.mktime(now_with_tz.timetuple()) / 86400) - session.add_all( - [ - Attribute( - name="nsAccountLock", - value="true", - directory=directory, - ), - Attribute( - name="shadowExpire", - value=str(absolute_date), - directory=directory, - ), - ] - ) + session.add_all([ + Attribute( + name="nsAccountLock", + value="true", + directory=directory, + ), + Attribute( + name="shadowExpire", + value=str(absolute_date), + directory=directory, + ), + ]) async def get_principal_directory( @@ -332,9 +401,12 @@ async def get_principal_directory( ) -> Directory | None: """Fetch the principal's directory by principal name. - :param AsyncSession session: db session - :param str principal_name: the principal name to search for - :return Directory | None: the principal's directory + Args: + session (AsyncSession): db session + principal_name (str): the principal name to search for + + Returns: + Directory | None: the principal's directory """ return await session.scalar( select(Directory) diff --git a/app/ldap_protocol/utils/raw_definition_parser.py b/app/ldap_protocol/utils/raw_definition_parser.py index 75954dd9e..b844f41d5 100644 --- a/app/ldap_protocol/utils/raw_definition_parser.py +++ b/app/ldap_protocol/utils/raw_definition_parser.py @@ -16,6 +16,14 @@ class RawDefinitionParser: @staticmethod def _list_to_string(data: list[str]) -> str | None: + """Convert list to string. + + Raises: + ValueError: if list has more than one element + + Returns: + str | None: single string if list has one element + """ if not data: return None if len(data) == 1: @@ -24,13 +32,23 @@ def _list_to_string(data: list[str]) -> str | None: @staticmethod def _get_attribute_type_info(raw_definition: str) -> AttributeTypeInfo: + """Get attribute type info. + + Returns: + AttributeTypeInfo: parsed attribute type info + """ tmp = AttributeTypeInfo.from_definition(definitions=[raw_definition]) - return list(tmp.values())[0] + return next(iter(tmp.values())) @staticmethod def get_object_class_info(raw_definition: str) -> ObjectClassInfo: + """Get object class info. + + Returns: + ObjectClassInfo: parsed object class info + """ tmp = ObjectClassInfo.from_definition(definitions=[raw_definition]) - return list(tmp.values())[0] + return next(iter(tmp.values())) @staticmethod async def _get_attribute_types_by_names( @@ -47,6 +65,14 @@ async def _get_attribute_types_by_names( def create_attribute_type_by_raw( raw_definition: str, ) -> AttributeType: + """Create attribute type by raw definition. + + Args: + raw_definition (str): raw definition of attribute type + + Returns: + AttributeType: created attribute type instance + """ attribute_type_info = RawDefinitionParser._get_attribute_type_info( raw_definition=raw_definition ) @@ -78,7 +104,15 @@ async def create_object_class_by_info( session: AsyncSession, object_class_info: ObjectClassInfo, ) -> ObjectClass: - """Create Object Class by ObjectClassInfo.""" + """Create Object Class by ObjectClassInfo. + + Args: + session (AsyncSession): db session + object_class_info (ObjectClassInfo): object class info + + Returns: + ObjectClass: object class instance + """ superior_name = RawDefinitionParser._list_to_string( object_class_info.superior ) diff --git a/app/models.py b/app/models.py index 0d4263262..6744f807f 100644 --- a/app/models.py +++ b/app/models.py @@ -9,7 +9,7 @@ import enum import uuid from collections import defaultdict -from datetime import datetime, timezone +from datetime import UTC, datetime from ipaddress import IPv4Address, IPv4Network from typing import Annotated, ClassVar, Literal @@ -66,7 +66,16 @@ def compile_create_uc( compiler: DDLCompiler, **kw: dict, ) -> str: - """Add NULLS NOT DISTINCT if its in args.""" + """Add NULLS NOT DISTINCT if its in args. + + Args: + create (DDLElement): The DDL element to compile. + compiler (DDLCompiler): The DDL compiler instance. + **kw (dict): Additional keyword arguments. + + Returns: + str: Compiled unique constraint statement. + """ stmt = compiler.visit_unique_constraint(create, **kw) postgresql_opts = create.dialect_options["postgresql"] # type: ignore @@ -179,12 +188,20 @@ class EntityType(Base): @property def object_class_names_set(self) -> set[str]: - """Get object class names.""" + """Get object class names. + + Returns: + set[str]: object class names + """ return set(self.object_class_names) @classmethod def generate_entity_type_name(cls, directory: Directory) -> str: - """Generate entity type name based on Directory.""" + """Generate entity type name based on Directory. + + Returns: + str: entity type name. + """ return f"{directory.name}_entity_type_{directory.id}" @@ -225,7 +242,11 @@ class Directory(Base): @property def entity_type_object_class_names_set(self) -> set[str]: - """Get object class names of entity type.""" + """Get object class names of entity type. + + Returns: + set[str]: object class names of entity type. + """ return ( self.entity_type.object_class_names_set if self.entity_type @@ -234,6 +255,11 @@ def entity_type_object_class_names_set(self) -> set[str]: @property def object_class_names_set(self) -> set[str]: + """Object class names from directory's attribute. + + Returns: + set[str]: object class names. + """ return set( self.attributes_dict.get("objectClass", []) + self.attributes_dict.get("objectclass", []) @@ -289,6 +315,12 @@ def object_class_names_set(self) -> set[str]: @property def attributes_dict(self) -> defaultdict[str, list[str]]: + """Get attributes dictionary. + + Returns: + defaultdict[str, list[str]]: Dictionary of attribute names\ + to their values. + """ attributes = defaultdict(list) for attribute in self.attributes: attributes[attribute.name].extend(attribute.values) @@ -336,14 +368,14 @@ def attributes_dict(self) -> defaultdict[str, list[str]]: ), ) - search_fields = { + search_fields: ClassVar[dict[str, str]] = { "entitytypename": "entityTypeName", "name": "name", "objectguid": "objectGUID", "objectsid": "objectSid", } - ro_fields = { + ro_fields: ClassVar[set[str]] = { "uid", "whenCreated", "lastLogon", @@ -354,29 +386,52 @@ def attributes_dict(self) -> defaultdict[str, list[str]]: } def get_dn_prefix(self) -> DistinguishedNamePrefix: - """Get distinguished name prefix.""" + """Get distinguished name prefix. + + Returns: + DistinguishedNamePrefix: Prefix for distinguished name. + """ return { "organizationalUnit": "ou", "domain": "dc", }.get(self.object_class, "cn") # type: ignore def get_dn(self, dn: str = "cn") -> str: - """Get distinguished name.""" + """Get distinguished name. + + Args: + dn (str): Distinguished name prefix (default: "cn"). + + Returns: + str: Distinguished name. + """ return f"{dn}={self.name}" @property def is_domain(self) -> bool: - """Is directory domain.""" + """Is directory domain. + + Returns: + bool: True if directory is domain, otherwise False. + """ return not self.parent_id and self.object_class == "domain" @property def host_principal(self) -> str: - """Principal computer name.""" + """Principal computer name. + + Returns: + str: Host principal name. + """ return f"host/{self.name}" @property def path_dn(self) -> str: - """Get DN from path.""" + """Get DN from path. + + Returns: + str: Distinguished name from path. + """ return ",".join(reversed(self.path)) def create_path( @@ -384,18 +439,31 @@ def create_path( parent: Directory | None = None, dn: str = "cn", ) -> None: - """Create path from a new directory.""" + """Create path from a new directory. + + Args: + parent (Directory | None): Parent directory (default: None). + dn (str): Distinguished name prefix (default: "cn"). + """ pre_path: list[str] = parent.path if parent else [] - self.path = pre_path + [self.get_dn(dn)] + self.path = [*pre_path, self.get_dn(dn)] self.depth = len(self.path) self.rdname = dn def __str__(self) -> str: - """Dir name.""" + """Dir name. + + Returns: + str: Directory name. + """ return f"Directory({self.name})" def __repr__(self) -> str: - """Dir id and name.""" + """Dir id and name. + + Returns: + str: Directory id and name. + """ return f"Directory({self.id}:{self.name})" @@ -452,7 +520,7 @@ class User(Base): DateTime(timezone=True), ) - search_fields = { + search_fields: ClassVar[dict[str, str]] = { "mail": "mail", "samaccountname": "sAMAccountName", "userprincipalname": "userPrincipalName", @@ -461,7 +529,7 @@ class User(Base): "accountexpires": "accountExpires", } - fields = { + fields: ClassVar[dict[str, str]] = { "loginshell": "loginShell", "uidnumber": "uidNumber", "homedirectory": "homeDirectory", @@ -486,24 +554,40 @@ class User(Base): ) def get_upn_prefix(self) -> str: - """Get userPrincipalName prefix.""" + """Get userPrincipalName prefix. + + Returns: + str: Prefix of userPrincipalName. + """ return self.user_principal_name.split("@")[0] def __str__(self) -> str: - """User show.""" + """User show. + + Returns: + str: User string representation. + """ return f"User({self.sam_accout_name})" def __repr__(self) -> str: - """User map with dir id.""" + """User map with dir id. + + Returns: + str: User id and directory id. + """ return f"User({self.directory_id}:{self.sam_accout_name})" def is_expired(self) -> bool: - """Check AccountExpires.""" + """Check AccountExpires. + + Returns: + bool: True if account is expired, otherwise False. + """ if self.account_exp is None: return False - now = datetime.now(tz=timezone.utc) - user_account_exp = self.account_exp.astimezone(timezone.utc) + now = datetime.now(tz=UTC) + user_account_exp = self.account_exp.astimezone(UTC) return now > user_account_exp @@ -585,11 +669,19 @@ class Group(Base): ) def __str__(self) -> str: - """Group id.""" + """Group id. + + Returns: + str: Group id. + """ return f"Group({self.id})" def __repr__(self) -> str: - """Group id and dir id.""" + """Group id and dir id. + + Returns: + str: Group id and directory id. + """ return f"Group({self.id}:{self.directory_id})" @@ -623,7 +715,11 @@ class Attribute(Base): @property def _decoded_value(self) -> str | None: - """Get attribute value.""" + """Get attribute value. + + Returns: + str | None: Decoded attribute value. + """ if self.value: return self.value if self.bvalue: @@ -632,15 +728,27 @@ def _decoded_value(self) -> str | None: @property def values(self) -> list[str]: - """Get attribute value by list.""" + """Get attribute value by list. + + Returns: + list[str]: List of attribute values. + """ return [self._decoded_value] if self._decoded_value else [] def __str__(self) -> str: - """Attribute name and value.""" + """Attribute name and value. + + Returns: + str: Attribute name and value. + """ return f"Attribute({self.name}:{self._decoded_value})" def __repr__(self) -> str: - """Attribute name and value.""" + """Attribute name and value. + + Returns: + str: Attribute name and value. + """ return f"Attribute({self.name}:{self._decoded_value})" @@ -663,7 +771,14 @@ class AttributeType(Base): is_system: Mapped[bool] # NOTE: it's not equal `NO-USER-MODIFICATION` def get_raw_definition(self) -> str: - """Format SQLAlchemy Attribute Type object to LDAP definition.""" + """Format SQLAlchemy Attribute Type object to LDAP definition. + + Returns: + str: LDAP definition string. + + Raises: + ValueError: If required fields are missing. + """ if not self.oid or not self.name or not self.syntax: err_msg = f"{self}: Fields 'oid', 'name', and 'syntax' are required for LDAP definition." # noqa: E501 raise ValueError(err_msg) @@ -683,11 +798,19 @@ def get_raw_definition(self) -> str: return " ".join(chunks) def __str__(self) -> str: - """AttributeType name.""" + """AttributeType name. + + Returns: + str: AttributeType name. + """ return f"AttributeType({self.name})" def __repr__(self) -> str: - """AttributeType oid and name.""" + """AttributeType oid and name. + + Returns: + str: AttributeType oid and name. + """ return f"AttributeType({self.oid}:{self.name})" @@ -784,7 +907,14 @@ class ObjectClass(Base): ) def get_raw_definition(self) -> str: - """Format SQLAlchemy Object Class object to LDAP definition.""" + """Format SQLAlchemy Object Class object to LDAP definition. + + Returns: + str: LDAP definition string. + + Raises: + ValueError: If required fields are missing. + """ if not self.oid or not self.name or not self.kind: err_msg = f"{self}: Fields 'oid', 'name', and 'kind' are required for LDAP definition." # noqa: E501 raise ValueError(err_msg) @@ -809,20 +939,36 @@ def get_raw_definition(self) -> str: @property def attribute_type_names_must(self) -> list[str]: - """Display attribute types must.""" + """Display attribute types must. + + Returns: + list[str]: List of must attribute type names. + """ return [attr.name for attr in self.attribute_types_must] @property def attribute_type_names_may(self) -> list[str]: - """Display attribute types may.""" + """Display attribute types may. + + Returns: + list[str]: List of may attribute type names. + """ return [attr.name for attr in self.attribute_types_may] def __str__(self) -> str: - """ObjectClass name.""" + """ObjectClass name. + + Returns: + str: ObjectClass name. + """ return f"ObjectClass({self.name})" def __repr__(self) -> str: - """ObjectClass oid and name.""" + """ObjectClass oid and name. + + Returns: + str: ObjectClass oid and name. + """ return f"ObjectClass({self.oid}:{self.name})" diff --git a/app/multidirectory.py b/app/multidirectory.py index 943aad6e8..34cb6f454 100644 --- a/app/multidirectory.py +++ b/app/multidirectory.py @@ -65,9 +65,12 @@ async def proc_time_header_middleware( ) -> Response: """Set X-Process-Time header. - :param Request request: _description_ - :param Callable call_next: _description_ - :return Response: _description_ + Args: + request (Request): _description_ + call_next (Callable): _description_ + + Returns: + Response: Response object with X-Process-Time header. """ start_time = time.perf_counter() response = await call_next(request) @@ -78,12 +81,21 @@ async def proc_time_header_middleware( @asynccontextmanager async def _lifespan(app: FastAPI) -> AsyncIterator[None]: + """Lifespan context manager. + + Yields: + AsyncIterator: async iterator + """ yield await app.state.dishka_container.close() def _create_basic_app(settings: Settings) -> FastAPI: - """Create basic FastAPI app with dependencies overrides.""" + """Create basic FastAPI app with dependencies overrides. + + Returns: + FastAPI: Configured FastAPI application. + """ app = FastAPI( name="MultiDirectory", title="MultiDirectory", @@ -134,7 +146,11 @@ def _create_basic_app(settings: Settings) -> FastAPI: def _create_shadow_app(settings: Settings) -> FastAPI: - """Create shadow FastAPI app for shadow.""" + """Create shadow FastAPI app for shadow. + + Returns: + FastAPI: Configured FastAPI application for shadow API. + """ app = FastAPI( name="Shadow API", title="Internal API", @@ -150,7 +166,15 @@ def create_prod_app( factory: Callable[[Settings], FastAPI] = _create_basic_app, settings: Settings | None = None, ) -> FastAPI: - """Create production app with container.""" + """Create production app with container. + + Args: + factory (Callable[[Settings], FastAPI]): _create_basic_app + settings (Settings | None): (Default value = None) + + Returns: + FastAPI: application. + """ settings = settings or Settings.from_os() app = factory(settings) container = make_async_container( @@ -189,6 +213,7 @@ async def _servers(settings: Settings) -> None: await asyncio.gather(*servers) def _run() -> None: + """Run ldap server.""" uvloop.run(_servers(settings), debug=settings.DEBUG) try: diff --git a/app/schedule.py b/app/schedule.py index 48a3f9ac9..27cd13412 100644 --- a/app/schedule.py +++ b/app/schedule.py @@ -15,9 +15,9 @@ from ioc import MainProvider from ldap_protocol.dependency import resolve_deps -type task_type = Callable[..., Coroutine] +type TaskType = Callable[..., Coroutine] -_TASKS: set[tuple[task_type, float]] = { +_TASKS: set[tuple[TaskType, float]] = { (disable_accounts, 600.0), (principal_block_sync, 60.0), (check_ldap_principal, -1.0), @@ -26,15 +26,16 @@ async def _schedule( - task: task_type, + task: TaskType, wait: float, container: AsyncContainer, ) -> None: """Run task periodically. - :param Awaitable task: any task - :param AsyncContainer container: container - :param float wait: time to wait after execution + Args: + task (TaskType): callable coroutine + wait (float): time to wait after execution + container (AsyncContainer): container """ logger.info("Registered: {}", task.__name__) while True: @@ -63,6 +64,7 @@ async def runner(settings: Settings) -> None: tg.create_task(_schedule(task, timeout, container)) def _run() -> None: + """Run the scheduler.""" uvloop.run(runner(settings)) try: diff --git a/app/security.py b/app/security.py index 8315c2953..9cae57260 100644 --- a/app/security.py +++ b/app/security.py @@ -12,9 +12,12 @@ def verify_password(plain_password: str, hashed_password: str) -> bool: """Validate password. - :param str plain_password: raw password - :param str hashed_password: pwd hash from db - :return bool: is password valid + Args: + plain_password (str): raw password + hashed_password (str): pwd hash from db + + Returns: + bool: is password valid """ return pwd_context.verify(plain_password, hashed_password) @@ -22,7 +25,7 @@ def verify_password(plain_password: str, hashed_password: str) -> bool: def get_password_hash(password: str) -> str: """Hash password. - :param str password: raw pwd - :return str: hash + Returns: + str: hash """ return pwd_context.hash(password, max_rounds=9) diff --git a/interface b/interface index 2b8f6556f..fccbff889 160000 --- a/interface +++ b/interface @@ -1 +1 @@ -Subproject commit 2b8f6556f80005cc3ec387ee4cf37a441111c43a +Subproject commit fccbff88901935affdee79584fde63857932db90 diff --git a/pyproject.toml b/pyproject.toml index 5cb2f5259..67ba833e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,6 +98,7 @@ output-format = "grouped" unsafe-fixes = true [tool.ruff.format] +preview = true docstring-code-format = true docstring-code-line-length = 79 line-ending = "lf" @@ -113,6 +114,7 @@ select = [ "N", # pep8-naming "A", # flake8 builtin-attribute-shadowing "D", # pydocstyle, check tool.ruff.lint.pydocstyle + "DOC", # pydoclint TODO uncomment, ruff fix and fix error "UP", # pyupgrade, check tool.ruff.lint.pyupgrade. Must have "ANN", # flake8-annotations, check tool.ruff.lint.flake8-annotations "ASYNC", # flake8-async @@ -121,7 +123,7 @@ select = [ "COM", # flake8-commas # "CPY", # flake8-copyright TODO uncomment, ruff fix and fix error "PIE", # flake8-pie - # "PYI", # flake8-pyi TODO uncomment, ruff fix and fix error + "PYI", # flake8-pyi "PT", # flake8-pytest "Q", # flake8-quotes # "RET", # flake8-return TODO uncomment, ruff fix and fix error @@ -134,50 +136,48 @@ select = [ "ERA", # eradicate # "PGH", # pygrep-hooks TODO does we need it? uncomment, ruff fix and fix error # "PL", # Pylint TODO uncomment, ruff fix and fix error - # "DOC", # pydoclint TODO uncomment, ruff fix and fix error - # "RUF", # Ruff-specific rules TODO uncomment, ruff fix and fix error - "RUF100", # Ruff100-specific rule TODO delete that and uncomment "RUF"-rule in line up. + "RUF", # Ruff-specific rules + "FAST", # FastAPI checks + # "FURB", # Refurb ] # Gradually remove all values marked 'TODO' and fix errors. ignore = [ - "D102", # TODO delete that and fix all errors - "D104", # TODO delete that and fix all errors - "D203", # this is necessary. Conflict with `D211` - "D213", # this is necessary. Conflict with `D212` - "D301", # this is necessary. - "UP017", # TODO delete that and fix all errors - "UP034", # TODO delete that and fix all errors - "UP035", # this is necessary. We allowed deprecated import - "ANN001", # TODO delete that and fix all errors - "ANN002", # this is necessary. - "ANN003", # this is necessary. - "ANN401", # TODO delete that and fix all errors - "ASYNC109", - "ASYNC230", - "S311", # this is necessary. - "B904", # this is necessary. - "COM812", # this is necessary. Cause conflicts when used with the formatter - "TC001", # this is necessary. - "TC002", # this is necessary. - "TC003", # this is necessary. - "SIM101", # analogue simplify-boolean-expressions IF100 - "B905", # this is necessary. get-attr-with-constant + "D203", # Conflict with `D211`. + "D213", # Conflict with `D212`. + "D301", # It convert `"""Description."""` to `r"""Description."""`. + "UP035", # We allowed deprecated import. + "ANN002", # Disable type annotations for *args. + "ANN003", # Disable type annotations for **kwargs. + "ASYNC109", # Allow timeout parameters into async func. + "ASYNC230", # Allow open files with blocking methods like open. + "S311", # We used `random — Generate pseudo-random numbers`. + "B904", # We don't used `raise ... FROM ..`. + "COM812", # Cause conflicts when used with the formatter. + "TC001", # First-party imports not defined in a type-checking block. + "TC002", # Third-party imports not defined in a type-checking block. + "TC003", # Standard library imports not defined in a type-checking block. + "B905", # Allow `zip` calls without an explicit `strict` parameter. + "RUF029", # 'Checks for functions declared async that do not await or otherwise use features requiring the function to be declared async.' + "ANN401", # Allow dynamically typed expressions (typing.Any). ] -extend-select = [] - fixable = ["ALL"] unfixable = [ "T20", # dont auto delete print/pprint lines in code ] +[tool.ruff.lint.pydocstyle] +convention = "google" # Google Python Style Guide - Docstrings: https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings +ignore-var-parameters = true + [tool.ruff.lint.flake8-unused-arguments] ignore-variadic-names = true [tool.ruff.lint.per-file-ignores] -"tests/*.py" = ["S101"] # Ignore `Flake8-bandit S101` rule for the `tests/` directory. -"alembic/*.py" = ["I001"] # Ignore `Flake8-isort IO01` rule for the `alembic/` directory. It works incorrect in CI ruff test. +"tests/*.py" = ["S101", "D104", "DOC501", "D417", "DOC201", "DOC402"] # Ignore rules for the `tests/` directory. +"app/alembic/*.py" = ["ANN001"] # Ignore `Flake8-isort IO01` rule for the `alembic/` directory. It works incorrect in CI ruff test. +"alembic/*.py" = ["ANN001"] # Ignore `Flake8-isort IO01` rule for the `alembic/` directory. It works incorrect in CI ruff test. [tool.ruff.lint.mccabe] # 15 Complexity level is too high, need to reduce this level or ignore it `# noqa: C901`. diff --git a/tests/conftest.py b/tests/conftest.py index b57f4454d..68a82b334 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,7 +9,14 @@ import weakref from contextlib import suppress from dataclasses import dataclass -from typing import AsyncGenerator, AsyncIterator, Generator, Iterator +from typing import ( + Any, + AsyncGenerator, + AsyncIterator, + Generator, + Iterator, + Literal, +) from unittest.mock import AsyncMock, Mock import aioldap3 @@ -214,7 +221,11 @@ def get_session_factory( self, engine: AsyncEngine, ) -> async_sessionmaker[AsyncSession]: - """Create session factory.""" + """Create session factory. + + Args: + engine (AsyncEngine): async engine + """ return async_sessionmaker( engine, expire_on_commit=False, @@ -340,8 +351,16 @@ class MutePolicyBindRequest(BindRequest): __test__ = False @staticmethod - async def is_user_group_valid(*args, **kwargs) -> bool: # type: ignore - """Stub.""" + async def is_user_group_valid(*args: Any, **kwargs: Any) -> Literal[True]: + """Stub. + + Args: + *args: arguments + **kwargs: keyword arguments + + Returns: + Literal[True]: True + """ return True @@ -380,10 +399,12 @@ async def _migrations( config.attributes["app_settings"] = settings def upgrade(conn: AsyncConnection) -> None: + """Run up migrations.""" config.attributes["connection"] = conn command.upgrade(config, "head") def downgrade(conn: AsyncConnection) -> None: + """Run down migrations.""" config.attributes["connection"] = conn command.downgrade(config, "base") @@ -483,7 +504,12 @@ def _server( event_loop: asyncio.BaseEventLoop, handler: PoolClientHandler, ) -> Generator: - """Run server in background.""" + """Run server in background. + + Args: + event_loop (asyncio.BaseEventLoop): events loop + handler (PoolClientHandler): handler + """ task = asyncio.ensure_future(handler.start(), loop=event_loop) event_loop.run_until_complete(asyncio.sleep(0.1)) yield @@ -496,7 +522,15 @@ async def ldap_client( settings: Settings, creds: TestCreds, ) -> AsyncIterator[aioldap3.LDAPConnection]: - """Get LDAP client without credentials.""" + """Get ldap clinet with creds. + + Args: + settings (Settings): Settings with database dsn. + creds (TestCreds): credentials for ldap auth + + Yields: + aioldap3.LDAPConnection: ldap async client + """ conn = aioldap3.LDAPConnection( aioldap3.Server(host=str(settings.HOST), port=settings.PORT) ) @@ -537,7 +571,8 @@ async def unbound_http_client( ) -> AsyncIterator[httpx.AsyncClient]: """Get async client for fastapi tests. - :param FastAPI app: asgi app + Args: + app (FastAPI): asgi app :yield Iterator[AsyncIterator[httpx.AsyncClient]]: yield client """ async with httpx.AsyncClient( @@ -556,10 +591,13 @@ async def http_client( ) -> httpx.AsyncClient: """Authenticate and return client with cookies. - :param httpx.AsyncClient unbound_http_client: client w/o cookies - :param TestCreds creds: creds to authn - :param None setup_session: just a fixture call - :return httpx.AsyncClient: bound client with cookies + Args: + unbound_http_client (httpx.AsyncClient): client w/o cookies + creds (TestCreds): creds to authn + setup_session (None): just a fixture call + + Returns: + httpx.AsyncClient: bound client with cookies """ response = await unbound_http_client.post( "auth/", diff --git a/tests/test_api/test_auth/test_router.py b/tests/test_api/test_auth/test_router.py index 8ce3879d3..afaf9cd12 100644 --- a/tests/test_api/test_auth/test_router.py +++ b/tests/test_api/test_auth/test_router.py @@ -28,10 +28,11 @@ async def apply_user_account_control( ) -> dict[str, Any]: """Apply userAccountControl value and return response data. - :param AsyncClient http_client: client - :param str user_dn: distinguished name of the user - :param str user_account_control_value: new value to set for the - `userAccountControl` attribute. + Args: + http_client (AsyncClient): client + user_dn (str): distinguished name of the user + user_account_control_value (str): new value to set for the + `userAccountControl` attribute. """ response = await http_client.patch( "entry/update", diff --git a/tests/test_api/test_auth/test_sessions.py b/tests/test_api/test_auth/test_sessions.py index b98e7a8af..7e52fd3b3 100644 --- a/tests/test_api/test_auth/test_sessions.py +++ b/tests/test_api/test_auth/test_sessions.py @@ -38,7 +38,7 @@ async def test_session_creation( assert sessions - key = list(sessions.keys())[0] + key = next(iter(sessions.keys())) assert sessions[key]["id"] == user.id assert sessions[key]["issued"] @@ -67,13 +67,13 @@ async def test_session_rekey( ) sessions = await storage.get_user_sessions(user.id) - old_key = list(sessions.keys())[0] + old_key = next(iter(sessions.keys())) old_session = sessions[old_key] await storage.rekey_session(old_key, settings) sessions = await storage.get_user_sessions(user.id) - new_key = list(sessions.keys())[0] + new_key = next(iter(sessions.keys())) new_session = sessions[new_key] assert len(sessions) == 1 @@ -108,7 +108,7 @@ async def test_session_creation_ldap_bind_unbind( assert sessions - key = list(sessions.keys())[0] + key = next(iter(sessions.keys())) assert sessions[key]["id"] == user.id assert sessions[key]["issued"] @@ -187,7 +187,7 @@ async def test_session_api_delete_detail( response = await http_client.get(f"sessions/{creds.un}") assert response.status_code == 200 - session_id = list(response.json().keys())[0] + session_id = next(iter(response.json().keys())) assert len(await storage.get_user_sessions(user.id)) == 1 @@ -311,36 +311,36 @@ async def test_get_sessions_by_protocol( all_sessions = await storage.get_user_sessions(uid) assert len(all_sessions) == 2 - key = list(all_sessions.keys())[0] + key = next(iter(all_sessions.keys())) assert all_sessions[key]["id"] == user.id http_sessions = await storage.get_user_sessions(uid, "http") assert len(http_sessions) == 1 - key = list(http_sessions.keys())[0] + key = next(iter(http_sessions.keys())) assert http_sessions[key]["id"] == user.id assert http_sessions[key]["ip"] == http_ip ldap_sessions = await storage.get_user_sessions(uid, "ldap") assert len(ldap_sessions) == 1 - key = list(ldap_sessions.keys())[0] + key = next(iter(ldap_sessions.keys())) assert ldap_sessions[key]["id"] == user.id assert ldap_sessions[key]["ip"] == ldap_ip ip_all_sessions = await storage.get_ip_sessions(http_ip) assert len(ip_all_sessions) == 1 - key = list(ip_all_sessions.keys())[0] + key = next(iter(ip_all_sessions.keys())) assert ip_all_sessions[key]["id"] == user.id assert ip_all_sessions[key]["ip"] == http_ip ip_http_sessions = await storage.get_ip_sessions(http_ip, "http") assert len(ip_http_sessions) == 1 - key = list(ip_http_sessions.keys())[0] + key = next(iter(ip_http_sessions.keys())) assert ip_http_sessions[key]["id"] == user.id assert ip_http_sessions[key]["ip"] == http_ip ip_ldap_sessions = await storage.get_ip_sessions(ldap_ip, "ldap") assert len(ip_ldap_sessions) == 1 - key = list(ip_ldap_sessions.keys())[0] + key = next(iter(ip_ldap_sessions.keys())) assert ip_ldap_sessions[key]["id"] == user.id assert ip_ldap_sessions[key]["ip"] == ldap_ip diff --git a/tests/test_api/test_main/test_kadmin.py b/tests/test_api/test_main/test_kadmin.py index 08eda97b8..07bc8c058 100644 --- a/tests/test_api/test_main/test_kadmin.py +++ b/tests/test_api/test_main/test_kadmin.py @@ -21,6 +21,12 @@ def _create_test_user_data( name: str, pw: str, ) -> dict[str, str | list[dict[str, str | list[str]]]]: + """Create test user data. + + Args: + name (str): user name + pw (str): user password + """ return { "entry": "cn=ktest,dc=md,dc=test", "password": pw, @@ -137,8 +143,9 @@ async def test_setup_call( ) -> None: """Test setup args. - :param AsyncClient http_client: http cl - :param LDAPSession ldap_session: ldap + Args: + http_client (AsyncClient): http cl + ldap_session (LDAPSession): ldap """ response = await http_client.post( "/kerberos/setup", @@ -180,8 +187,9 @@ async def test_status_change( ) -> None: """Test setup args. - :param AsyncClient http_client: http cl - :param LDAPSession ldap_session: ldap + Args: + http_client (AsyncClient): http cl + ldap_session (LDAPSession): ldap """ response = await http_client.get("/kerberos/status") assert response.status_code == status.HTTP_200_OK @@ -208,8 +216,9 @@ async def test_ktadd( ) -> None: """Test ktadd. - :param AsyncClient http_client: http cl - :param LDAPSession ldap_session: ldap + Args: + http_client (AsyncClient): http cl + ldap_session (LDAPSession): ldap """ names = ["test1", "test2"] response = await http_client.post("/kerberos/ktadd", json=names) @@ -234,8 +243,9 @@ async def test_ktadd_404( ) -> None: """Test ktadd failure. - :param AsyncClient http_client: http cl - :param LDAPSession ldap_session: ldap + Args: + http_client (AsyncClient): http cl + ldap_session (LDAPSession): ldap """ kadmin.ktadd.side_effect = KRBAPIError() # type: ignore @@ -253,8 +263,9 @@ async def test_ldap_add( ) -> None: """Test add calls add_principal on user creation. - :param AsyncClient http_client: http - :param TestKadminClient kadmin: kadmin + Args: + http_client (AsyncClient): http + kadmin (TestKadminClient): kadmin """ san = "ktest" pw = "Password123" @@ -385,8 +396,9 @@ async def test_add_princ( ) -> None: """Test setup args. - :param AsyncClient http_client: http cl - :param LDAPSession ldap_session: ldap + Args: + http_client (AsyncClient): http cl + ldap_session (LDAPSession): ldap """ response = await http_client.post( "/kerberos/principal/add", @@ -408,8 +420,9 @@ async def test_rename_princ( ) -> None: """Test setup args. - :param AsyncClient http_client: http cl - :param LDAPSession ldap_session: ldap + Args: + http_client (AsyncClient): http cl + ldap_session (LDAPSession): ldap """ response = await http_client.patch( "/kerberos/principal/rename", @@ -431,8 +444,9 @@ async def test_change_princ( ) -> None: """Test setup args. - :param AsyncClient http_client: http cl - :param LDAPSession ldap_session: ldap + Args: + http_client (AsyncClient): http cl + ldap_session (LDAPSession): ldap """ response = await http_client.patch( "/kerberos/principal/reset", @@ -454,8 +468,9 @@ async def test_delete_princ( ) -> None: """Test setup args. - :param AsyncClient http_client: http cl - :param LDAPSession ldap_session: ldap + Args: + http_client (AsyncClient): http cl + ldap_session (LDAPSession): ldap """ response = await http_client.request( "delete", @@ -470,11 +485,7 @@ async def test_delete_princ( @pytest.mark.usefixtures("session") @pytest.mark.usefixtures("setup_session") async def test_admin_incorrect_pw_setup(http_client: AsyncClient) -> None: - """Test setup args. - - :param AsyncClient http_client: http cl - :param LDAPSession ldap_session: ldap - """ + """Test setup args.""" response = await http_client.get("/kerberos/status") assert response.status_code == status.HTTP_200_OK assert response.json() == KerberosState.NOT_CONFIGURED diff --git a/tests/test_api/test_main/test_router/test_add.py b/tests/test_api/test_main/test_router/test_add.py index d011587d9..fb0c36f93 100644 --- a/tests/test_api/test_main/test_router/test_add.py +++ b/tests/test_api/test_main/test_router/test_add.py @@ -245,7 +245,7 @@ async def test_api_correct_add_double_member_of( assert data.get("resultCode") == LDAPCodes.SUCCESS assert data["search_result"][0]["object_name"] == user - created_groups = groups + ["cn=domain users,cn=groups,dc=md,dc=test"] + created_groups = [*groups, "cn=domain users,cn=groups,dc=md,dc=test"] for attr in data["search_result"][0]["partial_attributes"]: if attr["type"] == "memberOf": diff --git a/tests/test_api/test_shadow/conftest.py b/tests/test_api/test_shadow/conftest.py index ebae802b6..b0c641cb5 100644 --- a/tests/test_api/test_shadow/conftest.py +++ b/tests/test_api/test_shadow/conftest.py @@ -15,13 +15,7 @@ class ProxyRequestModel(BaseModel): - """Model for the proxy request. - - Attributes: - principal: Unique user identifier - ip: IP address from which the request is made - - """ + """Model for the proxy request.""" principal: str ip: str diff --git a/tests/test_api/test_shadow/test_router.py b/tests/test_api/test_shadow/test_router.py index 5e4dddafe..83ba6230a 100644 --- a/tests/test_api/test_shadow/test_router.py +++ b/tests/test_api/test_shadow/test_router.py @@ -92,10 +92,9 @@ async def test_shadow_api_whitelist_without_user_group( ) -> None: """Test shadow api whitelist without user group.""" await session.execute( - update(NetworkPolicy).values( - {NetworkPolicy.mfa_status: MFAFlags.WHITELIST} - ), - ) + update(NetworkPolicy) + .values({NetworkPolicy.mfa_status: MFAFlags.WHITELIST}), + ) # fmt: skip response = await http_client.post( "/shadow/mfa/push", @@ -114,10 +113,9 @@ async def test_shadow_api_enable_mfa( ) -> None: """Test shadow api enable mfa.""" await session.execute( - update(NetworkPolicy).values( - {NetworkPolicy.mfa_status: MFAFlags.ENABLED} - ), - ) + update(NetworkPolicy) + .values({NetworkPolicy.mfa_status: MFAFlags.ENABLED}), + ) # fmt: skip response = await http_client.post( "/shadow/mfa/push", diff --git a/tests/test_ldap/test_util/test_add.py b/tests/test_ldap/test_util/test_add.py index 00fe13a0d..f1b35d02b 100644 --- a/tests/test_ldap/test_util/test_add.py +++ b/tests/test_ldap/test_util/test_add.py @@ -37,14 +37,12 @@ async def test_ldap_root_add( search_path = get_search_path(dn) with tempfile.NamedTemporaryFile("w") as file: file.write( - ( - f"dn: {dn}\n" - "name: test\n" - "cn: test\n" - "objectClass: organization\n" - "objectClass: top\n" - "memberOf: cn=domain admins,cn=groups,dc=md,dc=test\n" - ) + f"dn: {dn}\n" + "name: test\n" + "cn: test\n" + "objectClass: organization\n" + "objectClass: top\n" + "memberOf: cn=domain admins,cn=groups,dc=md,dc=test\n" ) file.seek(0) proc = await asyncio.create_subprocess_exec( @@ -166,14 +164,12 @@ async def test_ldap_user_add_group_with_group( with tempfile.NamedTemporaryFile("w") as file: file.write( - ( - f"dn: {child_group_dn}\n" - "name: twisted\n" - "cn: twisted\n" - "objectClass: group\n" - "objectClass: top\n" - f"memberOf: {group_dn}\n" - ) + f"dn: {child_group_dn}\n" + "name: twisted\n" + "cn: twisted\n" + "objectClass: group\n" + "objectClass: top\n" + f"memberOf: {group_dn}\n" ) file.seek(0) proc = await asyncio.create_subprocess_exec( @@ -251,13 +247,11 @@ async def test_ldap_add_access_control( async def try_add() -> int: with tempfile.NamedTemporaryFile("w") as file: file.write( - ( - f"dn: {dn}\n" - "name: test\n" - "cn: test\n" - "objectClass: organization\n" - "objectClass: top\n" - ) + f"dn: {dn}\n" + "name: test\n" + "cn: test\n" + "objectClass: organization\n" + "objectClass: top\n" ) file.seek(0) proc = await asyncio.create_subprocess_exec( diff --git a/tests/test_ldap/test_util/test_delete.py b/tests/test_ldap/test_util/test_delete.py index 187dd183c..3bccc3d1a 100644 --- a/tests/test_ldap/test_util/test_delete.py +++ b/tests/test_ldap/test_util/test_delete.py @@ -30,14 +30,12 @@ async def test_ldap_delete( with tempfile.NamedTemporaryFile("w") as file: file.write( - ( - f"dn: {dn}\n" - "name: test\n" - "cn: test\n" - "objectClass: organization\n" - "objectClass: top\n" - "memberOf: cn=domain admins,cn=groups,dc=md,dc=test\n" - ) + f"dn: {dn}\n" + "name: test\n" + "cn: test\n" + "objectClass: organization\n" + "objectClass: top\n" + "memberOf: cn=domain admins,cn=groups,dc=md,dc=test\n" ) file.seek(0) proc = await asyncio.create_subprocess_exec( @@ -94,13 +92,11 @@ async def test_ldap_delete_w_access_control( with tempfile.NamedTemporaryFile("w") as file: file.write( - ( - f"dn: {dn}\n" - "name: test\n" - "cn: test\n" - "objectClass: organization\n" - "objectClass: top\n" - ) + f"dn: {dn}\n" + "name: test\n" + "cn: test\n" + "objectClass: organization\n" + "objectClass: top\n" ) file.seek(0) proc = await asyncio.create_subprocess_exec( # Add as Admin diff --git a/tests/test_ldap/test_util/test_modify.py b/tests/test_ldap/test_util/test_modify.py index 451f27699..b3af9d4fa 100644 --- a/tests/test_ldap/test_util/test_modify.py +++ b/tests/test_ldap/test_util/test_modify.py @@ -53,24 +53,22 @@ async def test_ldap_base_modify( with tempfile.NamedTemporaryFile("w") as file: file.write( - ( - f"dn: {dn}\n" - "changetype: modify\n" - "replace: mail\n" - "mail: modme@student.of.life.edu\n" - "-\n" - "add: title\n" - "title: Grand Poobah\n" - "title: Grand Poobah1\n" - "title: Grand Poobah2\n" - "title: Grand Poobah3\n" - "-\n" - "add: jpegPhoto\n" - "jpegPhoto: modme.jpeg\n" - "-\n" - "delete: posixEmail\n" - "-\n" - ) + f"dn: {dn}\n" + "changetype: modify\n" + "replace: mail\n" + "mail: modme@student.of.life.edu\n" + "-\n" + "add: title\n" + "title: Grand Poobah\n" + "title: Grand Poobah1\n" + "title: Grand Poobah2\n" + "title: Grand Poobah3\n" + "-\n" + "add: jpegPhoto\n" + "jpegPhoto: modme.jpeg\n" + "-\n" + "delete: posixEmail\n" + "-\n" ) file.seek(0) proc = await asyncio.create_subprocess_exec( @@ -141,7 +139,7 @@ async def test_ldap_membersip_user_delete( assert directory.groups with tempfile.NamedTemporaryFile("w") as file: - file.write((f"dn: {dn}\nchangetype: modify\ndelete: memberOf\n-\n")) + file.write(f"dn: {dn}\nchangetype: modify\ndelete: memberOf\n-\n") file.seek(0) proc = await asyncio.create_subprocess_exec( "ldapmodify", @@ -192,13 +190,11 @@ async def test_ldap_membersip_user_add( with tempfile.NamedTemporaryFile("w") as file: file.write( - ( - f"dn: {dn}\n" - "changetype: modify\n" - "add: memberOf\n" - "memberOf: cn=domain admins,cn=groups,dc=md,dc=test\n" - "-\n" - ) + f"dn: {dn}\n" + "changetype: modify\n" + "add: memberOf\n" + "memberOf: cn=domain admins,cn=groups,dc=md,dc=test\n" + "-\n" ) file.seek(0) proc = await asyncio.create_subprocess_exec( @@ -249,14 +245,12 @@ async def test_ldap_membersip_user_replace( # add new group with tempfile.NamedTemporaryFile("w") as file: file.write( - ( - f"dn: {new_group_dn}" - "name: twisted\n" - "cn: twisted\n" - "objectClass: group\n" - "objectClass: top\n" - "memberOf: cn=domain admins,cn=groups,dc=md,dc=test\n" - ) + f"dn: {new_group_dn}" + "name: twisted\n" + "cn: twisted\n" + "objectClass: group\n" + "objectClass: top\n" + "memberOf: cn=domain admins,cn=groups,dc=md,dc=test\n" ) file.seek(0) proc = await asyncio.create_subprocess_exec( @@ -281,13 +275,11 @@ async def test_ldap_membersip_user_replace( with tempfile.NamedTemporaryFile("w") as file: file.write( - ( - f"dn: {dn}\n" - "changetype: modify\n" - "replace: memberOf\n" - "memberOf: cn=twisted,cn=groups,dc=md,dc=test\n" - "-\n" - ) + f"dn: {dn}\n" + "changetype: modify\n" + "replace: memberOf\n" + "memberOf: cn=twisted,cn=groups,dc=md,dc=test\n" + "-\n" ) file.seek(0) proc = await asyncio.create_subprocess_exec( @@ -343,13 +335,11 @@ async def test_ldap_membersip_grp_replace( # add new group with tempfile.NamedTemporaryFile("w") as file: file.write( - ( - "dn: cn=twisted1,cn=groups,dc=md,dc=test\n" - "name: twisted\n" - "cn: twisted\n" - "objectClass: group\n" - "objectClass: top\n" - ) + "dn: cn=twisted1,cn=groups,dc=md,dc=test\n" + "name: twisted\n" + "cn: twisted\n" + "objectClass: group\n" + "objectClass: top\n" ) file.seek(0) proc = await asyncio.create_subprocess_exec( @@ -374,13 +364,11 @@ async def test_ldap_membersip_grp_replace( with tempfile.NamedTemporaryFile("w") as file: file.write( - ( - f"dn: {dn}\n" - "changetype: modify\n" - "replace: memberOf\n" - "memberOf: cn=twisted1,cn=groups,dc=md,dc=test\n" - "-\n" - ) + f"dn: {dn}\n" + "changetype: modify\n" + "replace: memberOf\n" + "memberOf: cn=twisted1,cn=groups,dc=md,dc=test\n" + "-\n" ) file.seek(0) proc = await asyncio.create_subprocess_exec( @@ -420,13 +408,11 @@ async def test_ldap_modify_dn( with tempfile.NamedTemporaryFile("w") as file: file.write( - ( - f"dn: {dn}\n" - "changetype: modrdn\n" - "newrdn: cn=user2\n" - "deleteoldrdn: 1\n" - "newsuperior: ou=users,dc=md,dc=test\n" - ) + f"dn: {dn}\n" + "changetype: modrdn\n" + "newrdn: cn=user2\n" + "deleteoldrdn: 1\n" + "newsuperior: ou=users,dc=md,dc=test\n" ) file.seek(0) proc = await asyncio.create_subprocess_exec( @@ -467,13 +453,11 @@ async def test_ldap_modify_password_change( with tempfile.NamedTemporaryFile("w") as file: file.write( - ( - f"dn: {dn}\n" - "changetype: modify\n" - "replace: userPassword\n" - f"userPassword: {new_password}\n" - "-\n" - ) + f"dn: {dn}\n" + "changetype: modify\n" + "replace: userPassword\n" + f"userPassword: {new_password}\n" + "-\n" ) file.seek(0) proc = await asyncio.create_subprocess_exec( @@ -537,24 +521,22 @@ async def test_ldap_modify_with_ap( async def try_modify() -> int: with tempfile.NamedTemporaryFile("w") as file: file.write( - ( - f"dn: {dn}\n" - "changetype: modify\n" - "replace: mail\n" - "mail: modme@student.of.life.edu\n" - "-\n" - "add: title\n" - "title: Grand Poobah\n" - "title: Grand Poobah1\n" - "title: Grand Poobah2\n" - "title: Grand Poobah3\n" - "-\n" - "add: jpegPhoto\n" - "jpegPhoto: modme.jpeg\n" - "-\n" - "delete: posixEmail\n" - "-\n" - ) + f"dn: {dn}\n" + "changetype: modify\n" + "replace: mail\n" + "mail: modme@student.of.life.edu\n" + "-\n" + "add: title\n" + "title: Grand Poobah\n" + "title: Grand Poobah1\n" + "title: Grand Poobah2\n" + "title: Grand Poobah3\n" + "-\n" + "add: jpegPhoto\n" + "jpegPhoto: modme.jpeg\n" + "-\n" + "delete: posixEmail\n" + "-\n" ) file.seek(0) proc = await asyncio.create_subprocess_exec( diff --git a/tests/test_ldap/test_util/test_search.py b/tests/test_ldap/test_util/test_search.py index f16b4f6d2..cb5e029cd 100644 --- a/tests/test_ldap/test_util/test_search.py +++ b/tests/test_ldap/test_util/test_search.py @@ -339,14 +339,12 @@ async def test_ldap_search_access_control_denied( dn_list = [d for d in data if d.startswith("dn:")] assert result == 0 - assert sorted(dn_list) == sorted( - [ - "dn: dc=md,dc=test", - "dn: ou=users,dc=md,dc=test", - "dn: cn=groups,dc=md,dc=test", - "dn: cn=domain admins,cn=groups,dc=md,dc=test", - "dn: cn=developers,cn=groups,dc=md,dc=test", - "dn: cn=domain users,cn=groups,dc=md,dc=test", - "dn: cn=user_non_admin,ou=users,dc=md,dc=test", - ] - ) + assert sorted(dn_list) == sorted([ + "dn: dc=md,dc=test", + "dn: ou=users,dc=md,dc=test", + "dn: cn=groups,dc=md,dc=test", + "dn: cn=domain admins,cn=groups,dc=md,dc=test", + "dn: cn=developers,cn=groups,dc=md,dc=test", + "dn: cn=domain users,cn=groups,dc=md,dc=test", + "dn: cn=user_non_admin,ou=users,dc=md,dc=test", + ])