diff --git a/Cargo.lock b/Cargo.lock index 2cdc685..706a029 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1598,7 +1598,6 @@ dependencies = [ "matchit", "minijinja", "multer", - "once_cell", "pyo3", "pyo3-async-runtimes", "pyo3-stub-gen", diff --git a/Cargo.toml b/Cargo.toml index 1c28134..5d0ac69 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "oxapy" version = "0.7.9" -edition = "2021" +edition = "2024" authors = ["FITAHIANA Nomeniavo Joe <24nomeniavo@gmail.com>"] repository = "https://github.com/j03-dev/oxapy" @@ -43,7 +43,6 @@ tera = "1.20" ahash = "0.8.12" ctrlc = "3.5.1" glob = "0.3.3" -once_cell = "1.21.3" rand = "0.10.0-rc.5" url = "2.5.7" diff --git a/oxapy/__init__.py b/oxapy/__init__.py index 2803bc2..6e05c4f 100644 --- a/oxapy/__init__.py +++ b/oxapy/__init__.py @@ -1,8 +1,19 @@ from .oxapy import * +import os import mimetypes +def secure_join(base: str, *paths: str) -> str: + base = os.path.realpath(base) + target = os.path.realpath(os.path.join(base, *paths)) + + if target != base and not target.startswith(base + os.sep): + raise exceptions.ForbiddenError("Access denied") + + return target + + def static_file(path: str = "/static", directory: str = "./static"): r""" Create a route for serving static files. @@ -22,11 +33,8 @@ def static_file(path: str = "/static", directory: str = "./static"): @get(f"{path}/{{*path}}") def handler(_request, path: str): - file_path = f"{directory}/{path}" - try: - return send_file(file_path) - except FileNotFoundError: - return Response("File not found", Status.NOT_FOUND) + file_path = secure_join(directory, path) + return send_file(file_path) return handler @@ -40,6 +48,12 @@ def send_file(path: str) -> Response: Returns: Response: A Response with file content """ + if not os.path.exists(path): + raise exceptions.NotFoundError("Requested file not found") + + if not os.path.isfile(path): + raise exceptions.ForbiddenError("Not a file") + with open(path, "rb") as f: content = f.read() content_type, _ = mimetypes.guess_type(path) diff --git a/oxapy/serializer/__init__.pyi b/oxapy/serializer/__init__.pyi index 524d4d1..44ef726 100644 --- a/oxapy/serializer/__init__.pyi +++ b/oxapy/serializer/__init__.pyi @@ -267,9 +267,9 @@ class Serializer(Field): @raw_data.setter def raw_data(self, value: typing.Optional[builtins.str]) -> None: ... @property - def context(self) -> typing.Optional[dict]: ... + def context(self) -> dict: ... @context.setter - def context(self, value: typing.Optional[dict]) -> None: ... + def context(self, value: dict) -> None: ... @property def data(self) -> typing.Any: r""" diff --git a/src/cors.rs b/src/cors.rs index 33bd453..9862e6d 100644 --- a/src/cors.rs +++ b/src/cors.rs @@ -34,15 +34,19 @@ pub struct Cors { /// List of allowed origins, default is ["*"] (all origins) #[pyo3(get, set)] pub origins: Vec, + /// List of allowed HTTP methods, default includes common methods #[pyo3(get, set)] pub methods: Vec, + /// List of allowed HTTP headers, default includes common headers #[pyo3(get, set)] pub headers: Vec, + /// Whether to allow credentials (cookies, authorization headers), default is true #[pyo3(get, set)] pub allow_credentials: bool, + /// Maximum age of preflight requests in seconds, default is 86400 (1 day) #[pyo3(get, set)] pub max_age: u32, @@ -52,8 +56,20 @@ impl Default for Cors { fn default() -> Self { Self { origins: vec!["*".to_string()], - methods: vec!["GET, POST, PUT, DELETE, PATCH, OPTIONS".to_string()], - headers: vec!["Content-Type, Authorization, X-Requested-With, Accept".to_string()], + methods: vec![ + "DELETE".to_string(), + "GET".to_string(), + "OPTIONS".to_string(), + "PATCH".to_string(), + "POST".to_string(), + "PUT".to_string(), + ], + headers: vec![ + "Accept".to_string(), + "Authorization".to_string(), + "Content-Type".to_string(), + "X-Requested-With".to_string(), + ], allow_credentials: true, max_age: 86400, } diff --git a/src/exceptions.rs b/src/exceptions.rs index 2b5e222..f155fb9 100644 --- a/src/exceptions.rs +++ b/src/exceptions.rs @@ -1,5 +1,5 @@ use pyo3::exceptions::PyException; -use pyo3::{impl_exception_boilerplate, prelude::*}; +use pyo3::prelude::*; use pyo3_stub_gen::derive::*; pub trait IntoPyException { @@ -13,7 +13,22 @@ impl IntoPyException for Result { } macro_rules! extend_exception { + ($name:ident) => { + pyo3::impl_exception_boilerplate!($name); + + #[pyo3_stub_gen::derive::gen_stub_pymethods] + #[pyo3::prelude::pymethods] + impl $name { + #[new] + fn new(e: pyo3::Py) -> $name { + Self(e) + } + } + }; + ($name:ident, $extend:ident) => { + pyo3::impl_exception_boilerplate!($name); + #[pyo3_stub_gen::derive::gen_stub_pymethods] #[pyo3::prelude::pymethods] impl $name { @@ -36,17 +51,7 @@ macro_rules! extend_exception { #[gen_stub_pyclass] #[pyclass(subclass, extends=PyException, module="oxapy.exceptions")] pub struct ClientError(pub Py); - -impl_exception_boilerplate!(ClientError); - -#[gen_stub_pymethods] -#[pymethods] -impl ClientError { - #[new] - fn new(e: Py) -> ClientError { - Self(e) - } -} +extend_exception!(ClientError); /// HTTP 400 Bad Request error exception. /// @@ -58,8 +63,6 @@ impl ClientError { #[gen_stub_pyclass] #[pyclass(extends=ClientError, module="oxapy.exceptions")] pub struct BadRequestError; - -impl_exception_boilerplate!(BadRequestError); extend_exception!(BadRequestError, ClientError); /// HTTP 401 Unauthorized error exception. @@ -71,8 +74,6 @@ extend_exception!(BadRequestError, ClientError); #[gen_stub_pyclass] #[pyclass(extends=ClientError, module="oxapy.exceptions")] pub struct UnauthorizedError; - -impl_exception_boilerplate!(UnauthorizedError); extend_exception!(UnauthorizedError, ClientError); /// HTTP 403 Forbidden error exception. @@ -85,8 +86,6 @@ extend_exception!(UnauthorizedError, ClientError); #[gen_stub_pyclass] #[pyclass(extends=ClientError, module="oxapy.exceptions")] pub struct ForbiddenError; - -impl_exception_boilerplate!(ForbiddenError); extend_exception!(ForbiddenError, ClientError); /// HTTP 404 Not Found error exception. @@ -99,8 +98,6 @@ extend_exception!(ForbiddenError, ClientError); #[gen_stub_pyclass] #[pyclass(extends=ClientError, module="oxapy.exceptions")] pub struct NotFoundError; - -impl_exception_boilerplate!(NotFoundError); extend_exception!(NotFoundError, ClientError); /// HTTP 409 Conflict error exception. @@ -114,8 +111,6 @@ extend_exception!(NotFoundError, ClientError); #[gen_stub_pyclass] #[pyclass(extends=ClientError, module="oxapy.exceptions")] pub struct ConflictError; - -impl_exception_boilerplate!(ConflictError); extend_exception!(ConflictError, ClientError); /// HTTP 500 Internal Server Error exception. @@ -129,17 +124,7 @@ extend_exception!(ConflictError, ClientError); #[pyclass(extends=PyException, module="oxapy.exceptions")] #[repr(transparent)] pub struct InternalError(Py); - -impl_exception_boilerplate!(InternalError); - -#[gen_stub_pymethods] -#[pymethods] -impl InternalError { - #[new] - fn new(e: Py) -> InternalError { - Self(e) - } -} +extend_exception!(InternalError); pub fn exceptions(m: &Bound<'_, PyModule>) -> PyResult<()> { let exceptions = PyModule::new(m.py(), "exceptions")?; diff --git a/src/into_response.rs b/src/into_response.rs index 671dc29..d2cd092 100644 --- a/src/into_response.rs +++ b/src/into_response.rs @@ -96,7 +96,7 @@ impl From for Response { Status::INTERNAL_SERVER_ERROR } }; - let response: Response = status.into(); + let response = Response::from(status); response.set_body(format!( r#"{{"detail": "{}"}}"#, value.value(py).to_string().replace('"', "'") @@ -106,9 +106,9 @@ impl From for Response { } impl From for Response { - fn from(val: Cors) -> Self { - let mut response = Status::NO_CONTENT.into(); - val.apply_headers(&mut response); + fn from(cors: Cors) -> Self { + let mut response = Response::from(Status::NO_CONTENT); + cors.apply_headers(&mut response); response } } diff --git a/src/json.rs b/src/json.rs index 152eeae..2cb9447 100644 --- a/src/json.rs +++ b/src/json.rs @@ -1,17 +1,15 @@ -use once_cell::sync::OnceCell; -use pyo3::{prelude::*, types::PyDict}; +use pyo3::{prelude::*, sync::PyOnceLock, types::PyDict}; use serde::{Deserialize, Serialize}; -static ORJSON: OnceCell> = OnceCell::new(); +static ORJSON: PyOnceLock> = PyOnceLock::new(); #[inline] pub fn dumps(data: &Py) -> PyResult { Python::attach(|py| { - let orjson = ORJSON.get_or_init(|| PyModule::import(py, "orjson").unwrap().into()); - let serialized_data = - orjson - .call_method1(py, "dumps", (data,))? - .call_method1(py, "decode", ("utf-8",))?; + let serialized_data = ORJSON + .get_or_try_init(py, || PyModule::import(py, "orjson").map(|m| m.into()))? + .call_method1(py, "dumps", (data,))? + .call_method1(py, "decode", ("utf-8",))?; Ok(serialized_data.extract(py)?) }) } @@ -19,8 +17,9 @@ pub fn dumps(data: &Py) -> PyResult { #[inline] pub fn loads(data: &str) -> PyResult> { Python::attach(|py| { - let orjson = ORJSON.get_or_init(|| PyModule::import(py, "orjson").unwrap().into()); - let deserialized_data = orjson.call_method1(py, "loads", (data,))?; + let deserialized_data = ORJSON + .get_or_try_init(py, || PyModule::import(py, "orjson").map(|m| m.into()))? + .call_method1(py, "loads", (data,))?; Ok(deserialized_data.extract(py)?) }) } diff --git a/src/jwt.rs b/src/jwt.rs index c25083b..f809955 100644 --- a/src/jwt.rs +++ b/src/jwt.rs @@ -2,7 +2,7 @@ use jsonwebtoken::errors::ErrorKind; use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation}; use pyo3::exceptions::PyException; use pyo3::types::PyDict; -use pyo3::{exceptions::PyValueError, impl_exception_boilerplate, prelude::*}; +use pyo3::{exceptions::PyValueError, prelude::*}; use pyo3_stub_gen::derive::*; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -18,46 +18,30 @@ use crate::json::Wrap; #[pyclass(subclass, extends=PyException, module="oxapy.jwt")] #[repr(transparent)] pub struct JwtError(Py); - -impl_exception_boilerplate!(JwtError); - -#[gen_stub_pymethods] -#[pymethods] -impl JwtError { - #[new] - fn new(e: Py) -> JwtError { - Self(e) - } -} +extend_exception!(JwtError); /// Occurs when there's an error during JWT encoding. #[gen_stub_pyclass] #[pyclass(extends=JwtError, module="oxapy.jwt")] pub struct JwtEncodingError; - -impl_exception_boilerplate!(JwtEncodingError); extend_exception!(JwtEncodingError, JwtError); /// Occurs when there's an error during JWT decoding/verification. #[gen_stub_pyclass] #[pyclass(extends=JwtError, module="oxapy.jwt")] pub struct JwtDecodingError; - -impl_exception_boilerplate!(JwtDecodingError); extend_exception!(JwtDecodingError, JwtError); /// Occurs when the JWT algorithm is invalid or not supported. #[gen_stub_pyclass] #[pyclass(extends=JwtError, module="oxapy.jwt")] pub struct JwtInvalidAlgorithm; -impl_exception_boilerplate!(JwtInvalidAlgorithm); extend_exception!(JwtInvalidAlgorithm, JwtError); /// Occurs when a JWT claim is invalid (e.g., wrong format). #[gen_stub_pyclass] #[pyclass(extends=JwtError, module="oxapy.jwt")] pub struct JwtInvalidClaim; -impl_exception_boilerplate!(JwtInvalidClaim); extend_exception!(JwtInvalidClaim, JwtError); #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/src/lib.rs b/src/lib.rs index 571e96b..4d67c8e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,8 +17,19 @@ mod templating; use std::net::SocketAddr; use std::ops::Deref; -use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; + +use pyo3::exceptions::PyValueError; +use pyo3::prelude::*; +use pyo3::types::{PyDict, PyInt, PyString}; +use pyo3_async_runtimes::tokio::{future_into_py, into_future}; +use pyo3_stub_gen::derive::*; + +use ahash::HashMap; +use tokio::net::{TcpListener, TcpStream}; +use tokio::sync::Semaphore; +use tokio::sync::mpsc::{Receiver, Sender, channel}; use crate::catcher::Catcher; use crate::cors::Cors; @@ -33,27 +44,15 @@ use crate::session::{Session, SessionStore}; use crate::status::Status; use crate::templating::Template; -use pyo3::exceptions::PyValueError; -use pyo3::types::{PyDict, PyInt, PyString}; -use pyo3_async_runtimes::tokio::{future_into_py, into_future}; -use pyo3_stub_gen::derive::*; - -use ahash::HashMap; -use tokio::net::{TcpListener, TcpStream}; -use tokio::sync::mpsc::{channel, Receiver, Sender}; -use tokio::sync::Semaphore; - -use pyo3::prelude::*; - pyo3_stub_gen::define_stub_info_gatherer!(stub_info); struct ProcessRequest { - request: Arc, + catchers: Option>>>, + cors: Option>, layer: Option>, match_route: Option>, + request: Arc, tx: Sender, - cors: Option>, - catchers: Option>>>, } #[derive(Clone)] @@ -62,8 +61,8 @@ struct RequestContext { catchers: Option>>>, channel_capacity: usize, cors: Option>, - request_sender: Sender, layers: Vec>, + request_sender: Sender, session_store: Option>, template: Option>, } diff --git a/src/request.rs b/src/request.rs index a6b3d07..b26abb0 100644 --- a/src/request.rs +++ b/src/request.rs @@ -15,11 +15,10 @@ use url::form_urlencoded; use crate::routing::MatchRoute; use crate::status::Status; use crate::{ - json, + IntoPyException, ProcessRequest, RequestContext, json, multipart::File, session::{Session, SessionStore}, templating::Template, - IntoPyException, ProcessRequest, RequestContext, }; use crate::{multipart::parse_multipart, response::Body}; use crate::{response::Response, routing::Layer}; @@ -214,11 +213,11 @@ impl Request { /// ``` #[getter] pub fn session(&self) -> PyResult { - let message = "Session not available. Make sure you've configured SessionStore."; - let session = self - .session - .as_ref() - .ok_or_else(|| PyAttributeError::new_err(message))?; + let session = self.session.as_ref().ok_or_else(|| { + PyAttributeError::new_err( + "Session not available. Make sure you've configured SessionStore.", + ) + })?; Ok(session.as_ref().clone()) } @@ -331,13 +330,12 @@ impl Request { process_request: ProcessRequest, mut rx: tokio::sync::mpsc::Receiver, ) -> Result, hyper::http::Error> { - if ctx.request_sender.send(process_request).await.is_ok() { - if let Some(response) = rx.recv().await { - return response.try_into(); - } + if ctx.request_sender.send(process_request).await.is_ok() + && let Some(response) = rx.recv().await + { + return response.try_into(); } - let response: Response = Status::NOT_FOUND.into(); - response.try_into() + Response::from(Status::NOT_FOUND).try_into() } } diff --git a/src/response.rs b/src/response.rs index c1129ed..c2dcbc7 100644 --- a/src/response.rs +++ b/src/response.rs @@ -3,9 +3,9 @@ use http_body_util::combinators::BoxBody; use hyper::body::Frame; use hyper::http::HeaderValue; use hyper::{ - body::Bytes, - header::{HeaderName, CONTENT_TYPE, LOCATION}, HeaderMap, + body::Bytes, + header::{CONTENT_TYPE, HeaderName, LOCATION}, }; use futures_util::stream; @@ -17,11 +17,12 @@ use pyo3::types::{PyBytes, PyString}; use pyo3_stub_gen::derive::*; use std::convert::Infallible; +use std::fs; use std::io::Read; +use std::str::{self, FromStr}; use std::sync::Arc; -use std::{fs, str}; -use crate::{convert_to_response, json, Cors, IntoPyException, ProcessRequest, Request, Status}; +use crate::{Cors, IntoPyException, ProcessRequest, Request, Status, convert_to_response, json}; pub type Body = BoxBody; @@ -117,7 +118,7 @@ impl Response { fn body(&self) -> PyResult { match &self.body { ResponseBody::Bytes(b) => { - let s = str::from_utf8(b.as_ref()).into_py_exception()?; + let s = str::from_utf8(&b).into_py_exception()?; Ok(s.to_string()) } _ => { @@ -167,10 +168,8 @@ impl Response { /// response.insert_header("Cache-Control", "no-cache") /// ``` pub fn insert_header(&mut self, key: &str, value: String) { - self.headers.insert( - HeaderName::from_bytes(key.as_bytes()).unwrap(), - value.parse().unwrap(), - ); + self.headers + .insert(HeaderName::from_str(key).unwrap(), value.parse().unwrap()); } /// Append a header to the response. @@ -193,10 +192,8 @@ impl Response { /// response.append_header("Set-Cookie", "theme=dark") /// ``` pub fn append_header(&mut self, key: &str, value: String) { - self.headers.append( - HeaderName::from_bytes(key.as_bytes()).unwrap(), - value.parse().unwrap(), - ); + self.headers + .append(HeaderName::from_str(key).unwrap(), value.parse().unwrap()); } } @@ -231,24 +228,23 @@ impl Response { } fn from_json(obj: Bound, status: Status, content_type: HeaderValue) -> PyResult { - let json = json::dumps(&obj.into())?; Ok(Self { status, - body: ResponseBody::Bytes(Bytes::from(json.clone())), + body: ResponseBody::Bytes(Bytes::from(json::dumps(&obj.into())?)), headers: HeaderMap::from_iter([(CONTENT_TYPE, content_type)]), }) } pub(crate) fn apply_catcher(mut self, req: &ProcessRequest) -> Self { - if let Some(catchers) = &req.catchers { - if let Some(handler) = catchers.get(&self.status) { - let request: Request = req.request.as_ref().clone(); - self = Python::attach(|py| { - let result = handler.call(py, (request, self), None)?; - convert_to_response(result, py) - }) - .unwrap_or_else(Response::from); - } + if let Some(catchers) = &req.catchers + && let Some(handler) = catchers.get(&self.status) + { + let request = req.request.as_ref().clone(); + self = Python::attach(|py| { + let result = handler.call(py, (request, self), None)?; + convert_to_response(result, py) + }) + .unwrap_or_else(Response::from); } self } diff --git a/src/serializer/mod.rs b/src/serializer/mod.rs index c5e29ae..d153849 100644 --- a/src/serializer/mod.rs +++ b/src/serializer/mod.rs @@ -6,20 +6,19 @@ use std::{ }; use self::fields::*; -use crate::{exceptions::ClientError, json, IntoPyException}; +use crate::{IntoPyException, exceptions::ClientError, json}; -use once_cell::sync::{Lazy, OnceCell}; use pyo3::{ + IntoPyObjectExt, exceptions::PyException, - impl_exception_boilerplate, prelude::*, + sync::PyOnceLock, types::{PyDict, PyList, PyType}, - IntoPyObjectExt, }; use pyo3_stub_gen::derive::*; -use serde_json::{json, Value}; +use serde_json::{Value, json}; -static SQL_ALCHEMY_INSPECT: OnceCell> = OnceCell::new(); +static SQL_ALCHEMY_INSPECT: PyOnceLock> = PyOnceLock::new(); #[gen_stub_pyclass] #[pyclass(module="oxapy.serializer", subclass, extends=Field)] @@ -32,7 +31,7 @@ struct Serializer { #[pyo3(get, set)] raw_data: Option, #[pyo3(get, set)] - context: Option>, + context: Py, } #[gen_stub_pymethods] @@ -96,7 +95,7 @@ impl Serializer { validated_data: PyDict::new(py).into(), raw_data: data, instance, - context, + context: context.unwrap_or_else(|| PyDict::new(py).into()), }, Field { required, @@ -129,8 +128,8 @@ impl Serializer { /// print(schema) /// ``` #[pyo3(signature=())] - fn schema(slf: Bound<'_, Self>) -> PyResult> { - let schema_value = Self::json_schema_value(&slf.get_type(), false)?; + fn schema(slf: Bound<'_, Self>, py: Python<'_>) -> PyResult> { + let schema_value = Self::json_schema_value(&slf.get_type(), false, py)?; json::loads(&schema_value.to_string()) } @@ -189,10 +188,14 @@ impl Serializer { /// serializer.validate({"email": "user@example.com"}) /// ``` #[pyo3(signature=(attr))] - fn validate<'a>(slf: Bound<'a, Self>, attr: Bound<'a, PyDict>) -> PyResult> { + fn validate<'a>( + slf: Bound<'a, Self>, + attr: Bound<'a, PyDict>, + py: Python<'a>, + ) -> PyResult> { let json::Wrap(json_value) = attr.clone().try_into()?; - let schema_value = Self::json_schema_value(&slf.get_type(), false)?; + let schema_value = Self::json_schema_value(&slf.get_type(), false, py)?; let validator = jsonschema::options() .should_validate_formats(true) @@ -205,11 +208,11 @@ impl Serializer { for k in attr.keys() { let key = k.to_string(); - if let Ok(field) = slf.getattr(&key) { - let field = field.extract::()?; - if field.read_only { - attr.del_item(&key)?; - } + if let Ok(f) = slf.getattr(&key) + && let Ok(field) = f.extract::() + && field.read_only + { + attr.del_item(&key)?; } } @@ -394,13 +397,11 @@ impl Serializer { ) -> PyResult> { let dict = PyDict::new(py); - let inspect = SQL_ALCHEMY_INSPECT.get_or_init(|| { - let sqlalchemy = - PyModule::import(py, "sqlalchemy").expect("sqlalchemy is not installed!"); - let inspection = sqlalchemy.getattr("inspection").unwrap(); - let inspect = inspection.getattr("inspect").unwrap(); - inspect.into() - }); + let inspect = SQL_ALCHEMY_INSPECT.get_or_try_init(py, || { + let sqlalchemy = PyModule::import(py, "sqlalchemy")?; + let inspection = sqlalchemy.getattr("inspection")?; + inspection.getattr("inspect").map(|i| i.into()) + })?; let mapper = inspect.call1(py, (instance.get_type(),))?; @@ -408,10 +409,10 @@ impl Serializer { for c in columns { let col = c?.getattr("name")?.to_string(); - if let Ok(field) = slf.getattr(&col) { - if !field.extract::()?.write_only { - dict.set_item(&col, instance.getattr(&col)?)?; - } + if let Ok(field) = slf.getattr(&col) + && !field.extract::()?.write_only + { + dict.set_item(&col, instance.getattr(&col)?)?; } } @@ -422,30 +423,37 @@ impl Serializer { for r in relationships { let key = r?.getattr("key")?.to_string(); - if let Ok(field) = slf.getattr(&key) { - if !field.extract::()?.write_only { - slf.getattr("context") - .and_then(|ctx| field.setattr("context", ctx))?; - field.setattr("instance", instance.getattr(&key)?)?; - dict.set_item(key, field.getattr("data")?)?; - } + if let Ok(field) = slf.getattr(&key) + && !field.extract::()?.write_only + { + slf.getattr("context") + .and_then(|ctx| field.setattr("context", ctx))?; + field.setattr("instance", instance.getattr(&key)?)?; + dict.set_item(key, field.getattr("data")?)?; } } Ok(dict) } } -static CACHES_JSON_SCHEMA_VALUE: Lazy>>> = - Lazy::new(|| Arc::new(Mutex::new(HashMap::new()))); +static CACHE: PyOnceLock>>> = PyOnceLock::new(); + +fn cache(py: Python<'_>) -> &Arc>> { + CACHE.get_or_init(py, || Arc::new(Mutex::new(HashMap::new()))) +} impl Serializer { - fn json_schema_value(cls: &Bound<'_, PyType>, nullable: bool) -> PyResult { + fn json_schema_value( + cls: &Bound<'_, PyType>, + nullable: bool, + py: Python<'_>, + ) -> PyResult { let mut properties = serde_json::Map::with_capacity(16); let mut required_fields = Vec::with_capacity(8); let class_name = cls.name()?; - if let Some(value) = CACHES_JSON_SCHEMA_VALUE + if let Some(value) = cache(py) .lock() .into_py_exception()? .get(&class_name.to_string()) @@ -469,7 +477,7 @@ impl Serializer { .required .then(|| required_fields.push(attr_name.clone())); let nested_schema = - Self::json_schema_value(&attr_obj.get_type(), field.nullable)?; + Self::json_schema_value(&attr_obj.get_type(), field.nullable, py)?; if field.many { let mut array_schema = serde_json::Map::with_capacity(2); @@ -505,7 +513,7 @@ impl Serializer { let final_schema = json!(schema); - CACHES_JSON_SCHEMA_VALUE + cache(py) .lock() .into_py_exception()? .insert(class_name.to_string(), final_schema.clone()); @@ -522,8 +530,6 @@ impl Serializer { #[gen_stub_pyclass] #[pyclass(module = "oxapy.serializer", extends=ClientError)] pub struct ValidationException; - -impl_exception_boilerplate!(ValidationException); extend_exception!(ValidationException, ClientError); pub fn serializer_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> { diff --git a/src/session.rs b/src/session.rs index cc26afb..1b4d6f2 100644 --- a/src/session.rs +++ b/src/session.rs @@ -4,9 +4,9 @@ use std::{ }; use ahash::HashMap; -use pyo3::{prelude::*, types::PyTuple, IntoPyObjectExt}; +use pyo3::{IntoPyObjectExt, prelude::*, types::PyTuple}; use pyo3_stub_gen::derive::*; -use rand::{distr::Alphanumeric, Rng}; +use rand::{Rng, distr::Alphanumeric}; use crate::IntoPyException; @@ -395,15 +395,15 @@ impl SessionStore { pub fn get_session(&self, session_id: Option<&str>) -> PyResult { let mut sessions = self.sessions.write().into_py_exception()?; - if let Some(id) = session_id { - if let Some(session) = sessions.get(id) { - *session.last_accessed.lock().unwrap() = SystemTime::now() - .duration_since(UNIX_EPOCH) - .into_py_exception()? - .as_secs(); + if let Some(id) = session_id + && let Some(session) = sessions.get(id) + { + *session.last_accessed.lock().unwrap() = SystemTime::now() + .duration_since(UNIX_EPOCH) + .into_py_exception()? + .as_secs(); - return Ok(session.as_ref().clone()); - } + return Ok(session.as_ref().clone()); } let session = Session::new(None)?; diff --git a/src/templating/minijinja.rs b/src/templating/minijinja.rs index c446d5a..b57d7e2 100644 --- a/src/templating/minijinja.rs +++ b/src/templating/minijinja.rs @@ -5,8 +5,8 @@ use pyo3::{prelude::*, types::PyDict}; use pyo3_stub_gen::derive::*; use std::sync::Arc; -use crate::json; use crate::IntoPyException; +use crate::json; #[gen_stub_pyclass] #[pyclass(module = "oxapy.templating")] diff --git a/tests/app.py b/tests/app.py new file mode 100644 index 0000000..4bed198 --- /dev/null +++ b/tests/app.py @@ -0,0 +1,9 @@ +from oxapy import HttpServer, Router, get + + +@get("/greet/{name}") +def greet(_r, name: str): + return f"Hello, {name}!" + + +HttpServer(("0.0.0.0", 5555)).attach(Router().route(greet)).run() diff --git a/tests/test_bench.py b/tests/test_bench.py index a3ccc0e..65817a1 100644 --- a/tests/test_bench.py +++ b/tests/test_bench.py @@ -1,5 +1,5 @@ import time -from oxapy import Response, serializer +from oxapy import Response def test_response_benchmark():