Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dependencies = [
"pyjwt[crypto]",
"tomlkit",
"graypy>=2.1.0",
"httpx>=0.28.1",
]
dynamic = ["version"]
license.file = "LICENSE"
Expand All @@ -52,6 +53,7 @@ dev = [
"pyright!=1.1.407", # https://github.com/bluesky/scanspec/issues/190
"pytest-cov",
"pytest-asyncio",
"pytest-httpx>=0.35.0",
"responses",
"ruff",
"semver",
Expand Down
4 changes: 2 additions & 2 deletions src/blueapi/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ def __post_init__(self, configuration: ApplicationConfig | None):
# local reference so it's available in _update_scan_num
numtracker = self.numtracker

def _update_scan_num(md: dict[str, Any]) -> int:
scan = numtracker.create_scan(
async def _update_scan_num(md: dict[str, Any]) -> int:
scan = await numtracker.create_scan(
md["instrument_session"], md["instrument"]
)
md["data_session_directory"] = str(scan.scan.directory.path)
Expand Down
15 changes: 8 additions & 7 deletions src/blueapi/utils/numtracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pathlib import Path
from textwrap import dedent

import requests
import httpx
from pydantic import Field, HttpUrl

from blueapi.utils import BlueapiBaseModel
Expand Down Expand Up @@ -62,7 +62,7 @@ def set_headers(self, headers: Mapping[str, str]) -> None:

self._headers = headers

def create_scan(
async def create_scan(
self, instrument_session: str, instrument: str
) -> NumtrackerScanMutationResponse:
"""
Expand Down Expand Up @@ -94,11 +94,12 @@ def create_scan(
""")
}

response = requests.post(
self._url.unicode_string(),
headers=self._headers,
json=query,
)
async with httpx.AsyncClient() as client:
response = await client.post(
self._url.unicode_string(),
headers=self._headers,
json=query,
)

response.raise_for_status()
json = response.json()
Expand Down
139 changes: 6 additions & 133 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import base64
import time
from collections.abc import Iterable
from pathlib import Path
from textwrap import dedent
from typing import Any, cast
Expand All @@ -20,7 +19,6 @@
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.trace import get_tracer_provider
from responses.matchers import json_params_matcher

from blueapi.config import ApplicationConfig, OIDCConfig
from blueapi.service.model import Cache
Expand Down Expand Up @@ -335,12 +333,9 @@ def mock_jwks_fetch(json_web_keyset: JWK):
return patch("jwt.PyJWKClient.fetch_data", mock)


NOT_CONFIGURED_INSTRUMENT = "p100"


@pytest.fixture(scope="module")
def mock_numtracker_server() -> Iterable[responses.RequestsMock]:
query_working = {
@pytest.fixture
def nt_query() -> dict[str, str]:
return {
"query": dedent("""
mutation{
scan(
Expand All @@ -358,94 +353,11 @@ def mock_numtracker_server() -> Iterable[responses.RequestsMock]:
}
""")
}
query_400 = {
"query": dedent("""
mutation{
scan(
instrument: "p47",
instrumentSession: "ab123"
) {
directory{
instrumentSession
instrument
path
}
scanFile
scanNumber
}
}
""")
}
query_500 = {
"query": dedent("""
mutation{
scan(
instrument: "p48",
instrumentSession: "ab123"
) {
directory{
instrumentSession
instrument
path
}
scanFile
scanNumber
}
}
""")
}
query_key_error = {
"query": dedent("""
mutation{
scan(
instrument: "p49",
instrumentSession: "ab123"
) {
directory{
instrumentSession
instrument
path
}
scanFile
scanNumber
}
}
""")
}
query_200_with_errors = {
"query": dedent(f"""
mutation{{
scan(
instrument: "{NOT_CONFIGURED_INSTRUMENT}",
instrumentSession: "ab123"
) {{
directory{{
instrumentSession
instrument
path
}}
scanFile
scanNumber
}}
}}
""")
}

response_with_errors = {
"data": None,
"errors": [
{
"message": (
"No configuration available for instrument "
f'"{NOT_CONFIGURED_INSTRUMENT}"'
),
"locations": [{"line": 3, "column": 5}],
"path": ["scan"],
}
],
}

working_response = {
@pytest.fixture
def nt_response() -> dict[str, Any]:
return {
"data": {
"scan": {
"scanFile": "p46-11",
Expand All @@ -458,42 +370,3 @@ def mock_numtracker_server() -> Iterable[responses.RequestsMock]:
}
}
}
empty_response = {}

with responses.RequestsMock(assert_all_requests_are_fired=False) as requests_mock:
requests_mock.add(
responses.POST,
url="https://numtracker-example.com/graphql",
match=[json_params_matcher(query_working)],
status=200,
json=working_response,
)
requests_mock.add(
responses.POST,
url="https://numtracker-example.com/graphql",
match=[json_params_matcher(query_400)],
status=400,
json=empty_response,
)
requests_mock.add(
responses.POST,
url="https://numtracker-example.com/graphql",
match=[json_params_matcher(query_500)],
status=500,
json=empty_response,
)
requests_mock.add(
responses.POST,
url="https://numtracker-example.com/graphql",
match=[json_params_matcher(query_key_error)],
status=200,
json=empty_response,
)
requests_mock.add(
responses.POST,
"https://numtracker-example.com/graphql",
match=[json_params_matcher(query_200_with_errors)],
status=200,
json=response_with_errors,
)
yield requests_mock
50 changes: 35 additions & 15 deletions tests/unit_tests/service/test_interface.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import uuid
from dataclasses import dataclass
from inspect import isawaitable
from typing import Any
from unittest.mock import ANY, MagicMock, Mock, patch

Expand All @@ -15,6 +16,7 @@
)
from ophyd_async.epics.motor import Motor
from pydantic import HttpUrl
from pytest_httpx import HTTPXMock
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: explicitly import httpx_mock

from stomp.connect import StompConnection11 as Connection

from blueapi.config import (
Expand Down Expand Up @@ -513,8 +515,8 @@ def test_configure_numtracker():
assert nt._url.unicode_string() == "https://numtracker-example.com/graphql"


@patch("blueapi.utils.numtracker.requests.post")
def test_headers_are_cleared(mock_post):
@patch("blueapi.utils.numtracker.httpx.AsyncClient.post")
async def test_headers_are_cleared(mock_post):
mock_response = Mock()
mock_post.return_value = mock_response
mock_response.raise_for_status.side_effect = None
Expand Down Expand Up @@ -544,16 +546,18 @@ def test_headers_are_cleared(mock_post):
interface.begin_task(task=WorkerTask(task_id=None), pass_through_headers=headers)
ctx = interface.context()
assert ctx.run_engine.scan_id_source is not None
ctx.run_engine.scan_id_source(
scan_id = ctx.run_engine.scan_id_source(
{"instrument_session": "cm12345-1", "instrument": "p46"}
)
assert isawaitable(scan_id) and await scan_id
mock_post.assert_called_once()
assert mock_post.call_args.kwargs["headers"] == headers

interface.begin_task(task=WorkerTask(task_id=None))
ctx.run_engine.scan_id_source(
scan_id = ctx.run_engine.scan_id_source(
{"instrument_session": "cm12345-1", "instrument": "p46"}
)
assert isawaitable(scan_id) and await scan_id
assert mock_post.call_count == 2
assert mock_post.call_args.kwargs["headers"] == {}

Expand Down Expand Up @@ -632,7 +636,9 @@ def test_setup_with_numtracker_raises_if_provider_is_defined_in_device_module():


@patch("blueapi.utils.numtracker.NumtrackerClient.create_scan")
def test_numtracker_create_scan_called_with_arguments_from_metadata(mock_create_scan):
async def test_numtracker_create_scan_called_with_arguments_from_metadata(
mock_create_scan,
):
conf = ApplicationConfig(
numtracker=NumtrackerConfig(
url=HttpUrl("https://numtracker-example.com/graphql")
Expand All @@ -649,14 +655,24 @@ def test_numtracker_create_scan_called_with_arguments_from_metadata(mock_create_

ctx.numtracker.set_headers(headers)
ctx.run_engine.md["instrument_session"] = "ab123"
ctx.run_engine.scan_id_source(ctx.run_engine.md)
scan_id = ctx.run_engine.scan_id_source(ctx.run_engine.md)
assert isawaitable(scan_id) and await scan_id

mock_create_scan.assert_called_once_with("ab123", "p46")


def test_update_scan_num_side_effect_sets_data_session_directory_in_re_md(
mock_numtracker_server,
async def test_update_scan_num_side_effect_sets_data_session_directory_in_re_md(
httpx_mock,
nt_query,
nt_response,
):
httpx_mock.add_response(
method="POST",
url="https://numtracker-example.com/graphql",
match_json=nt_query,
status_code=200,
json=nt_response,
)
conf = ApplicationConfig(
env=EnvironmentConfig(metadata=MetadataConfig(instrument="p46")),
numtracker=NumtrackerConfig(
Expand All @@ -669,28 +685,32 @@ def test_update_scan_num_side_effect_sets_data_session_directory_in_re_md(
assert ctx.run_engine.scan_id_source is not None

ctx.run_engine.md["instrument_session"] = "ab123"
ctx.run_engine.scan_id_source(ctx.run_engine.md)
scan_id = ctx.run_engine.scan_id_source(ctx.run_engine.md)
assert isawaitable(scan_id) and await scan_id

assert (
ctx.run_engine.md["data_session_directory"] == "/exports/mybeamline/data/2025"
)


def test_update_scan_num_side_effect_sets_scan_file_in_re_md(
mock_numtracker_server,
async def test_update_scan_num_side_effect_sets_scan_file_in_re_md(
httpx_mock: HTTPXMock, nt_query, nt_response
):
nt_url = "https://numtracker-example.com/graphql"
httpx_mock.add_response(
method="POST", url=nt_url, match_json=nt_query, json=nt_response
)
conf = ApplicationConfig(
env=EnvironmentConfig(metadata=MetadataConfig(instrument="p46")),
numtracker=NumtrackerConfig(
url=HttpUrl("https://numtracker-example.com/graphql")
),
numtracker=NumtrackerConfig(url=HttpUrl(nt_url)),
)
interface.setup(conf)
ctx = interface.context()

assert ctx.run_engine.scan_id_source is not None

ctx.run_engine.md["instrument_session"] = "ab123"
ctx.run_engine.scan_id_source(ctx.run_engine.md)
scan_id = ctx.run_engine.scan_id_source(ctx.run_engine.md)
assert isawaitable(scan_id) and await scan_id

assert ctx.run_engine.md["scan_file"] == "p46-11"
Loading