Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 10 additions & 22 deletions scripts/windows_service_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
import win32serviceutil
import win32service
import servicemanager
import subprocess
import sys
import os
import argparse
import shlex
import win32con
import win32api
import pytest
from getpass import getpass


Expand Down Expand Up @@ -54,32 +54,20 @@ def SvcDoRun(self):
)
code_location = os.environ["CODE_LOCATION"]
pytest_args = os.environ.get("PYTEST_ARGS", None)
log_file_name = os.path.join(code_location, "test.log")

args = ["pytest", os.path.join(code_location, "test")]
# We need to disable xdist as it runs each test in a Python
# subprocess, which results in the tests not running as a
# service as we want.
args = [os.path.join(code_location, "test"), "--numprocesses=0"]

if pytest_args:
args.extend(shlex.split(pytest_args, posix=False))
with open(log_file_name, mode="w") as f:
sys.stdout = f
sys.stderr = f

logging.basicConfig(
filename=os.path.join(code_location, "test.log"),
encoding="utf-8",
level=logging.INFO,
filemode="w",
)
process = subprocess.Popen(
args,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
cwd=code_location,
)

while True:
output = process.stdout.readline()
if not output and process.poll() is not None:
break

logger.info(output.strip())
pytest.main(args)

servicemanager.LogMsg(
servicemanager.EVENTLOG_INFORMATION_TYPE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,6 @@ def signal_process(pgid: int):
if not kernel32.GenerateConsoleCtrlEvent(CTRL_BREAK_EVENT, pgid):
raise ctypes.WinError()

if not kernel32.FreeConsole():
raise ctypes.WinError()
if not kernel32.AttachConsole(ATTACH_PARENT_PROCESS):
raise ctypes.WinError()


if __name__ == "__main__":
signal_process(int(sys.argv[1]))
26 changes: 15 additions & 11 deletions src/openjd/sessions/_subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from queue import Queue, Empty
from subprocess import DEVNULL, PIPE, STDOUT, Popen, list2cmdline, run
from threading import Event, Thread
from typing import Any
from typing import Callable, Literal, Optional, Sequence, cast
from typing import Callable, Literal, Optional, Sequence, cast, Any

from ._linux._capabilities import try_use_cap_kill
from ._linux._sudo import find_sudo_child_process_group_id
Expand Down Expand Up @@ -623,17 +622,22 @@ def _windows_notify_subprocess(self) -> None:
str(WINDOWS_SIGNAL_SUBPROC_SCRIPT_PATH),
str(self._process.pid),
]
result = run(
cmd,
stdout=PIPE,
stderr=STDOUT,
stdin=DEVNULL,
creationflags=CREATE_NEW_PROCESS_GROUP | CREATE_NO_WINDOW,
process = LoggingSubprocess(
logger=self._logger,
args=cmd,
encoding=self._encoding,
user=self._user,
os_env_vars=self._os_env_vars,
working_dir=self._working_dir,
creation_flags=CREATE_NO_WINDOW,
)
if result.returncode != 0:

# Blocking call
process.run()

if process.exit_code != 0:
self._logger.warning(
f"Failed to send signal 'CTRL_BREAK_EVENT' to subprocess {self._process.pid}: %s",
result.stdout.decode("utf-8"),
f"Failed to send signal 'CTRL_BREAK_EVENT' to subprocess {self._process.pid}",
extra=LogExtraInfo(
openjd_log_content=LogContent.PROCESS_CONTROL | LogContent.EXCEPTION_INFO
),
Expand Down
34 changes: 34 additions & 0 deletions src/openjd/sessions/_win32/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,17 @@
# Constants
LOGON32_LOGON_INTERACTIVE,
LOGON32_PROVIDER_DEFAULT,
PI_NOUI,
PROFILEINFO,
# Functions
CloseHandle,
CreateEnvironmentBlock,
DestroyEnvironmentBlock,
GetCurrentProcessId,
LogonUserW,
ProcessIdToSessionId,
LoadUserProfileW,
UnloadUserProfile,
)


Expand Down Expand Up @@ -166,3 +170,33 @@ def environment_block_from_dict(env: dict[str, str]) -> c_wchar_p:
env_block_str = null_delimited + "\0"

return c_wchar_p(env_block_str)


def load_user_profile(token: HANDLE, username: str) -> PROFILEINFO:
"""
Load the user profile for the given logon token and user name

NOTE: The caller *MUST* call unload_user_profile when finished with the user profile

See: https://learn.microsoft.com/en-us/windows/win32/api/userenv/nf-userenv-loaduserprofilew
"""
profile_info = PROFILEINFO()
profile_info.dwSize = sizeof(PROFILEINFO)
profile_info.lpUserName = username
profile_info.dwFlags = PI_NOUI
profile_info.lpProfilePath = None

if not LoadUserProfileW(token, byref(profile_info)):
raise WinError()

return profile_info


def unload_user_profile(token: HANDLE, profile_info: PROFILEINFO) -> None:
"""
Unload the user profile for the given token and profile.

See: https://learn.microsoft.com/en-us/windows/win32/api/userenv/nf-userenv-unloaduserprofile
"""
if not UnloadUserProfile(token, profile_info.hProfile):
raise WinError()
14 changes: 12 additions & 2 deletions test/openjd/sessions/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from hashlib import sha256
from unittest.mock import MagicMock
import pytest
import sys

from openjd.sessions import PosixSessionUser, WindowsSessionUser, BadCredentialsException
from openjd.sessions._os_checker import is_posix, is_windows
Expand All @@ -21,6 +22,8 @@
from openjd.sessions._win32._helpers import ( # type: ignore
get_current_process_session_id,
logon_user_context,
load_user_profile,
unload_user_profile,
)

TEST_RUNNING_IN_WINDOWS_SESSION_0 = 0 == get_current_process_session_id()
Expand Down Expand Up @@ -251,10 +254,10 @@ def windows_user() -> Generator[WindowsSessionUser, None, None]:

if TEST_RUNNING_IN_WINDOWS_SESSION_0:
try:
# Note: We don't load the user profile; it's currently not needed by our tests,
# and we're getting a mysterious crash when unloading it.
with logon_user_context(user, password) as logon_token:
profile_info = load_user_profile(logon_token, user)
yield WindowsSessionUser(user, logon_token=logon_token)
unload_user_profile(logon_token, profile_info)
except OSError as e:
raise Exception(
f"Could not logon as {user}. Check the password that was provided in {WIN_PASS_ENV_VAR}."
Expand Down Expand Up @@ -282,3 +285,10 @@ def queue_handler(message_queue: SimpleQueue) -> QueueHandler:
@pytest.fixture(scope="function")
def session_id() -> str:
return "some Id"


@pytest.fixture(scope="function")
def python_exe() -> str:
if is_windows() and TEST_RUNNING_IN_WINDOWS_SESSION_0:
return sys.executable.lower().replace("pythonservice.exe", "python.exe")
return sys.executable
Loading