Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions api/ErrorHandler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import logging

from fastapi import Response

from api.ErrorResponseContentProvider import ErrorResponseContentProvider
from api.errors import ApiError


class ErrorHandler:

def __init__(self, response_provider: ErrorResponseContentProvider):
self.response_provider = response_provider
self.logger = logging.getLogger(__name__)

def handle_error(self, error: type[ApiError]) -> Response:
"""
Handles an ApiError and creates an according Response using a ErrorResponseContentProvider.
"""
error.log_error()
return Response(
self.response_provider.provide_error_response_content(error.message),
status_code=error.status_code,
media_type="application/xml; charset=utf-8",
)
60 changes: 60 additions & 0 deletions api/ErrorHandlerTest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import logging
import unittest

from api.ErrorHandler import ErrorHandler
from api.ErrorResponseContentProvider import ErrorResponseContentProvider
from api.errors.ApiError import ApiError
from api.errors.InternalServerError import InternalServerError
from api.ojp1.ErrorResponseContentProviderOjp1 import ErrorResponseContentProviderOjp1
from api.ojp2.ErrorResponseContentProviderOjp2 import ErrorResponseContentProviderOjp2


class ErrorHandlerTest(unittest.TestCase):

logger = logging.getLogger(__name__)

def test_ojp1_WHEN_internal_server_error_EXPECT_error_response(self):
self._catch_internal_server_error(ErrorResponseContentProviderOjp1())

def test_ojp2_WHEN_internal_server_EXPECT_error_response(self):
self._catch_internal_server_error(ErrorResponseContentProviderOjp2())

def test_ojp1_WHEN_api_error_EXPECT_error_response(self):
self._catch_api_error(ErrorResponseContentProviderOjp1())

def _catch_internal_server_error(self, error_response_provider: ErrorResponseContentProvider):

# prepare test case
error_handler = ErrorHandler(error_response_provider)
self.assertIsNotNone(error_handler)
message = "Oups, terrible failure ;-)"
try:
raise InternalServerError(message)

# run test case
except InternalServerError as e:
self.logger.info("caught InternalServerError ...")
response = error_handler.handle_error(e)

# assert expectations
self.assertIsNotNone(response)

def _catch_api_error(self, error_response_provider: ErrorResponseContentProvider):

# prepare test case
error_handler = ErrorHandler(error_response_provider)
self.assertIsNotNone(error_handler)
message = "Oups, terrible failure ;-)"
try:
raise InternalServerError(message)

# run test case
except ApiError as e:
self.logger.info("caught ApiError ...")
response = error_handler.handle_error(e)

# assert expectations
self.assertIsNotNone(response)

if __name__ == '__main__':
unittest.main()
10 changes: 10 additions & 0 deletions api/ErrorResponseContentProvider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from abc import abstractmethod, ABC


class ErrorResponseContentProvider(ABC):
@abstractmethod
def provide_error_response_content(self, message: str) -> str:
"""
Provides the error response content for the given error message.
"""
return message
17 changes: 17 additions & 0 deletions api/OjpFareService.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/usr/bin/env python3
from abc import abstractmethod, ABC

from fastapi import Response


class OjpFareService(ABC):

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

@abstractmethod
def handle_request(self, body: bytes) -> Response:
"""
Handles an ojp request.
"""
return Response(body)
46 changes: 46 additions & 0 deletions api/OjpVersionParser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from xml.etree import ElementTree

from api.errors.InvalidOjpRequestError import InvalidOjpRequestError


class OjpVersionParser:
chunk_size = 8192 # ausreichend gross, um alle Root-Attribute zu erwischen
root_element_name = "OJP"

def parse_version(self, xml_content: str) -> str:
"""
Reads until first start tag and gets the value of the 'version' attribute
of the OJP root element.

Raises:
InvalidOjpRequestError upon invalid XML, unexcepted root element or missing version attribute.
"""

parser = ElementTree.XMLPullParser(events=('start',))
pos = 0

while True:
if pos >= len(xml_content):
# end of the document reached without having seen the root element
try:
parser.close()
except ElementTree.ParseError as e:
raise InvalidOjpRequestError(f"Invalid XML: {e}") from e
raise InvalidOjpRequestError("Root element not found.")

parser.feed(xml_content[pos:pos + self.chunk_size])
pos += self.chunk_size

for event, elem in parser.read_events():
# first start event is the root
tag = elem.tag
local = tag.split('}', 1)[1] if tag.startswith('{') else tag

if local != self.root_element_name:
raise InvalidOjpRequestError(f"Root-Element ist '{local}', erwartet '{self.root_element_name}'")

ver = elem.get("version")
if ver is None or ver.strip() == "":
raise InvalidOjpRequestError(f"'{self.root_element_name}'-Element hat kein 'version'-Attribut.")

return ver.strip()
43 changes: 43 additions & 0 deletions api/OjpVersionParserTest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import logging
import unittest

from api.OjpVersionParser import OjpVersionParser
from api.errors.InvalidOjpRequestError import InvalidOjpRequestError


class OjpVersionParserTest(unittest.TestCase):

logger = logging.getLogger(__name__)

def test_ojp1_WHEN_parse_EXPECT_1(self):
version_parser = OjpVersionParser()
payload = '<OJP version="1.0"></OJP>'
version = version_parser.parse_version(payload)
self.assertEqual(version, '1.0')

def test_ojp2_WHEN_parse_EXPECT_2(self):
version_parser = OjpVersionParser()
payload = '<OJP version="2.0"></OJP>'
version = version_parser.parse_version(payload)
self.assertEqual(version, '2.0')

def test_missing_version_WHEN_parse_EXPECT_failure(self):
version_parser = OjpVersionParser()
payload = '<OJP></OJP>'
with self.assertRaises(InvalidOjpRequestError):
version_parser.parse_version(payload)

def test_missing_ojp_element_WHEN_parse_EXPECT_failure(self):
version_parser = OjpVersionParser()
payload = '<TEST></TEST>'
with self.assertRaises(InvalidOjpRequestError):
version_parser.parse_version(payload)

def test_no_xml_WHEN_parse_EXPECT_failure(self):
version_parser = OjpVersionParser()
payload = 'Test'
with self.assertRaises(InvalidOjpRequestError):
version_parser.parse_version(payload)

if __name__ == '__main__':
unittest.main()
7 changes: 7 additions & 0 deletions api/SerializerUtil.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from xsdata.formats.dataclass.serializers import XmlSerializer
from xsdata.formats.dataclass.serializers.config import SerializerConfig

class SerializerUtil:
ns_map = {"": "http://www.siri.org.uk/siri", "ojp": "http://www.vdv.de/ojp"}
_serializer_config = SerializerConfig(ignore_default_attributes=True, pretty_print=True)
serializer = XmlSerializer(config=_serializer_config)
Empty file added api/__init__.py
Empty file.
13 changes: 13 additions & 0 deletions api/errors/ApiError.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import logging
from abc import abstractmethod, ABC

class ApiError(Exception, ABC):
def __init__(self, message="An unspecific error occurred.", status_code=500):
self.message = message
self.status_code = status_code
self.logger = logging.getLogger(__name__)
super().__init__(self.message)

@abstractmethod
def log_error(self):
self.logger.warning(self.message)
9 changes: 9 additions & 0 deletions api/errors/InternalServerError.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from api.errors.ApiError import ApiError


class InternalServerError(ApiError):
def __init__(self, message="An internal server error occurred."):
super().__init__(message)

def log_error(self):
self.logger.error(self.message)
9 changes: 9 additions & 0 deletions api/errors/InvalidNovaResponseError.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from api.errors.ApiError import ApiError


class InvalidNovaResponseError(ApiError):
def __init__(self, message="There was no valid NOVA response."):
super().__init__(message,400)

def log_error(self):
self.logger.warning(self.message)
9 changes: 9 additions & 0 deletions api/errors/InvalidOjpRequestError.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from api.errors.ApiError import ApiError


class InvalidOjpRequestError(ApiError):
def __init__(self, message="There was no (valid) OJP request."):
super().__init__(message, 400)

def log_error(self):
self.logger.warning(self.message)
9 changes: 9 additions & 0 deletions api/errors/NoNovaResponseError.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from api.errors.ApiError import ApiError


class NoNovaResponseError(ApiError):
def __init__(self, message="There was no NOVA response."):
super().__init__(message)

def log_error(self):
self.logger.warning(self.message)
10 changes: 10 additions & 0 deletions api/errors/OjpRequestParseError.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from api.errors.ApiError import ApiError


class OjpRequestParseError(ApiError):
def __init__(self, cause: Exception = None, message="Failed to parse OJP request."):
self.cause = cause
super().__init__(message, 400)

def log_error(self):
self.logger.warning(self.message + ": " + str(self.cause))
Empty file added api/errors/__init__.py
Empty file.
24 changes: 24 additions & 0 deletions api/ojp1/ErrorResponseContentProviderOjp1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import datetime

from xsdata.models.datatype import XmlDateTime

from api.ErrorResponseContentProvider import ErrorResponseContentProvider
from api.SerializerUtil import SerializerUtil
from ojp import Ojp, Ojpresponse, ServiceDelivery, ServiceDeliveryStructure, OtherError


class ErrorResponseContentProviderOjp1(ErrorResponseContentProvider):
ns_map = SerializerUtil.ns_map
serializer = SerializerUtil.serializer

def provide_error_response_content(self, message: str) -> str:
"""
Provides the ojp1 error response content for the given error message.
"""
ojp = Ojp(ojpresponse=Ojpresponse(service_delivery=
ServiceDelivery(response_timestamp=XmlDateTime.from_datetime(
datetime.datetime.now(datetime.timezone.utc)),
producer_ref="OJP2NOVA",
error_condition=ServiceDeliveryStructure.ErrorCondition(
other_error=OtherError(message)))))
return self.serializer.render(ojp, ns_map=self.ns_map)
76 changes: 76 additions & 0 deletions api/ojp1/OjpFareServiceOjp1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import logging

from fastapi import Response

from api.ErrorHandler import ErrorHandler
from api.OjpFareService import OjpFareService
from api.SerializerUtil import SerializerUtil
from api.errors.ApiError import ApiError
from api.errors.InvalidOjpRequestError import InvalidOjpRequestError
from api.errors.OjpRequestParseError import OjpRequestParseError
from api.ojp1.ErrorResponseContentProviderOjp1 import ErrorResponseContentProviderOjp1
from map_nova_to_ojp import map_nova_reply_to_ojp_fare_delivery
from map_ojp_to_ojp import parse_ojp
from ojp import OjpfareDelivery, Ojpresponse, ServiceDelivery, Ojp
from test_network_flow import test_nova_request_reply, call_ojp_2000


class FareServiceOjp1(OjpFareService):
def __init__(self, *args, **kwargs) -> None:
self.logger = logging.getLogger(__name__)
super().__init__(*args, **kwargs)

def handle_request(self, body: bytes) -> Response:
"""
Handles an ojp1 request.
"""
error_handler = ErrorHandler(ErrorResponseContentProviderOjp1())
try:
ojp_fare_request = _parse_request(body)
_validate_request(ojp_fare_request)
self.logger.debug("Request passed validation: " + str(ojp_fare_request))

if ojp_fare_request.ojprequest.service_request.ojpfare_request:
self.logger.debug("Fare request - about to query NOVA: " + str(ojp_fare_request))
nova_response = test_nova_request_reply(ojp_fare_request)
ojp_fare_delivery = map_nova_reply_to_ojp_fare_delivery(nova_response)
self.logger.debug("Workable NOVA response put into OJP: " + str(ojp_fare_delivery))
return _create_response(ojp_fare_delivery)
else:
self.logger.debug("OJP request - returning the call to the OJP server:" + str(body.decode("utf-8")))
s, r = call_ojp_2000(body.decode("utf-8"))
return Response(r, media_type="application/xml; charset=utf-8", status_code=s)

except ApiError as error:
return error_handler.handle_error(error)


def _validate_request(ojp_fare_request: Ojp):
if ojp_fare_request.ojprequest is None:
raise InvalidOjpRequestError(message="missing Element OJPRequest.");

if ojp_fare_request.ojprequest.service_request.ojpfare_request is None:
raise InvalidOjpRequestError()


def _parse_request(body: bytes) -> Ojp:
try:
return parse_ojp(body.decode("utf-8"))
except Exception as e:
raise OjpRequestParseError(cause=e)


def _create_response(ojp_fare_delivery: OjpfareDelivery) -> Response:
xml = SerializerUtil.serializer.render(
Ojp(
ojpresponse=Ojpresponse(
service_delivery=ServiceDelivery(
response_timestamp=ojp_fare_delivery.response_timestamp,
producer_ref="OJP2NOVA",
ojpfare_delivery=[ojp_fare_delivery],
)
)
),
ns_map=SerializerUtil.ns_map,
)
return Response(xml, media_type="application/xml; charset=utf-8")
Empty file added api/ojp1/__init__.py
Empty file.
Loading