diff --git a/src/js/packages/@reactpy/client/src/reactpy-client.ts b/src/js/packages/@reactpy/client/src/reactpy-client.ts index b840479a0..a934308ef 100644 --- a/src/js/packages/@reactpy/client/src/reactpy-client.ts +++ b/src/js/packages/@reactpy/client/src/reactpy-client.ts @@ -1,5 +1,5 @@ -import { ReactPyModule } from "./reactpy-vdom"; import logger from "./logger"; +import { ReactPyModule } from "./reactpy-vdom"; /** * A client for communicating with a ReactPy server. @@ -108,6 +108,7 @@ export type SimpleReactPyClientProps = { connectionTimeout?: number; debugMessages?: boolean; socketLoopThrottle?: number; + pingInterval?: number; }; /** @@ -156,6 +157,7 @@ enum messageTypes { clientState = "client-state", stateUpdate = "state-update", layoutUpdate = "layout-update", + pingIntervalSet = "ping-interval-set", }; export class SimpleReactPyClient @@ -180,6 +182,8 @@ export class SimpleReactPyClient private didReconnectingCallback: boolean; private willReconnect: boolean; private socketLoopThrottle: number; + private pingPongIntervalId?: number | null; + private pingInterval: number; constructor(props: SimpleReactPyClientProps) { super(); @@ -193,6 +197,7 @@ export class SimpleReactPyClient ); this.idleDisconnectTimeMillis = (props.idleDisconnectTimeSeconds || 240) * 1000; this.connectionTimeout = props.connectionTimeout || 5000; + this.pingInterval = props.pingInterval || 0; this.lastActivityTime = Date.now() this.reconnectOptions = props.reconnectOptions this.debugMessages = props.debugMessages || false; @@ -215,8 +220,9 @@ export class SimpleReactPyClient this.updateClientState(msg.state_vars); this.invokeLayoutUpdateHandlers(msg.path, msg.model); this.willReconnect = true; // don't indicate a reconnect until at least one successful layout update - }) - + }); + this.onMessage(messageTypes.pingIntervalSet, (msg) => { this.pingInterval = msg.ping_interval; this.updatePingInterval(); }); + this.updatePingInterval() this.reconnect() const handleUserAction = (ev: any) => { @@ -350,11 +356,20 @@ export class SimpleReactPyClient } } + updatePingInterval(): void { + if (this.pingPongIntervalId) { + window.clearInterval(this.pingPongIntervalId); + } + if (this.pingInterval) { + this.pingPongIntervalId = window.setInterval(() => { this.socket.current?.readyState === WebSocket.OPEN && this.socket.current?.send("ping") }, this.pingInterval); + } + } + reconnect(onOpen?: () => void, interval: number = 750, connectionAttemptsRemaining: number = 20, lastAttempt: number = 0): void { const intervalJitter = this.reconnectOptions?.intervalJitter || 0.5; const backoffRate = this.reconnectOptions?.backoffRate || 1.2; - const maxInterval = this.reconnectOptions?.maxInterval || 20000; - const maxRetries = this.reconnectOptions?.maxRetries || 20; + const maxInterval = this.reconnectOptions?.maxInterval || 500; + const maxRetries = this.reconnectOptions?.maxRetries || 40; if (this.layoutUpdateHandlers.length == 0) { setTimeout(() => { this.reconnect(onOpen, interval, connectionAttemptsRemaining, lastAttempt); }, 10); @@ -412,6 +427,8 @@ export class SimpleReactPyClient clearInterval(this.socketLoopIntervalId); if (this.idleCheckIntervalId) clearInterval(this.idleCheckIntervalId); + if (this.pingPongIntervalId) + clearInterval(this.pingPongIntervalId); if (!this.sleeping) { const thisInterval = nextInterval(addJitter(interval, intervalJitter), backoffRate, maxInterval); const newRetriesRemaining = connectionAttemptsRemaining - 1; diff --git a/src/py/reactpy/reactpy/backend/sanic.py b/src/py/reactpy/reactpy/backend/sanic.py index e648747fa..0018c97c3 100644 --- a/src/py/reactpy/reactpy/backend/sanic.py +++ b/src/py/reactpy/reactpy/backend/sanic.py @@ -7,8 +7,8 @@ from typing import Any from urllib import parse as urllib_parse from uuid import uuid4 -import orjson +import orjson from sanic import Blueprint, Sanic, request, response from sanic.config import Config from sanic.server.websockets.connection import WebSocketConnection @@ -213,7 +213,10 @@ async def sock_send(value: Any) -> None: await socket.send(orjson.dumps(value).decode("utf-8")) async def sock_recv() -> Any: - data = await socket.recv() + while True: + data = await socket.recv() + if data != "ping": + break if data is None: raise Stop() return orjson.loads(data) diff --git a/src/py/reactpy/reactpy/core/layout.py b/src/py/reactpy/reactpy/core/layout.py index f2274c6a8..2c4ada19c 100644 --- a/src/py/reactpy/reactpy/core/layout.py +++ b/src/py/reactpy/reactpy/core/layout.py @@ -7,7 +7,6 @@ FIRST_COMPLETED, CancelledError, PriorityQueue, - Queue, Task, create_task, get_running_loop, @@ -19,9 +18,8 @@ from logging import getLogger from typing import ( Any, - Awaitable, + AsyncIterable, Callable, - Coroutine, Generic, NamedTuple, NewType, @@ -57,7 +55,6 @@ Key, LayoutEventMessage, LayoutUpdateMessage, - StateUpdateMessage, VdomChild, VdomDict, VdomJson, @@ -155,7 +152,10 @@ async def finish(self) -> None: del self._root_life_cycle_state_id del self._model_states_by_life_cycle_state_id - clear_hook_state(self._hook_state_token) + try: + clear_hook_state(self._hook_state_token) + except LookupError: + pass def start_rendering(self) -> None: self._schedule_render_task(self._root_life_cycle_state_id) @@ -188,7 +188,7 @@ async def render(self) -> LayoutUpdateMessage: else: # nocov return await self._serial_render() - async def render_until_queue_empty(self) -> None: + async def render_until_queue_empty(self) -> AsyncIterable[LayoutUpdateMessage]: model_state_id = await self._rendering_queue.get() while True: try: @@ -199,7 +199,7 @@ async def render_until_queue_empty(self) -> None: f"{model_state_id!r} - component already unmounted" ) else: - await self._create_layout_update(model_state, get_hook_state()) + yield await self._create_layout_update(model_state, get_hook_state()) # this might seem counterintuitive. What's happening is that events can get kicked off # and currently there's no (obvious) visibility on if we're waiting for them to finish # so this will wait up to 0.15 * 5 = 750 ms to see if any renders come in before diff --git a/src/py/reactpy/reactpy/core/serve.py b/src/py/reactpy/reactpy/core/serve.py index 85799c762..0485304da 100644 --- a/src/py/reactpy/reactpy/core/serve.py +++ b/src/py/reactpy/reactpy/core/serve.py @@ -4,6 +4,7 @@ import string from collections.abc import Awaitable from logging import getLogger +from os import environ from typing import Callable from warnings import warn @@ -22,12 +23,26 @@ LayoutEventMessage, LayoutType, LayoutUpdateMessage, + PingIntervalSetMessage, ReconnectingCheckMessage, RootComponentConstructor, ) logger = getLogger(__name__) +MAX_HOT_RELOADING = environ.get("REACTPY_MAX_HOT_RELOADING", "0") in ( + "1", + "true", + "True", + "yes", +) +if MAX_HOT_RELOADING: + logger.warning("Doing maximum hot reloading") + from reactpy.hot_reloading import ( + monkeypatch_jurigged_to_kill_connections_if_function_update, + ) + + monkeypatch_jurigged_to_kill_connections_if_function_update() SendCoroutine = Callable[ [ @@ -128,24 +143,35 @@ def __init__( async def handle_connection( self, connection: Connection, constructor: RootComponentConstructor ): + if MAX_HOT_RELOADING: + from reactpy.hot_reloading import active_connections + + active_connections.append(connection) layout = Layout( ConnectionContext( constructor(), value=connection, ), ) - async with layout: - await self._handshake(layout) - # salt may be set to client's old salt during handshake - if self._state_recovery_manager: - layout.set_recovery_serializer( - self._state_recovery_manager.create_serializer(self._salt) + try: + async with layout: + await self._handshake(layout) + # salt may be set to client's old salt during handshake + if self._state_recovery_manager: + layout.set_recovery_serializer( + self._state_recovery_manager.create_serializer(self._salt) + ) + await serve_layout( + layout, + self._send, + self._recv, ) - await serve_layout( - layout, - self._send, - self._recv, - ) + finally: + if MAX_HOT_RELOADING: + try: + active_connections.remove(connection) + except ValueError: + pass async def _handshake(self, layout: Layout) -> None: await self._send(ReconnectingCheckMessage(type="reconnecting-check")) @@ -172,8 +198,22 @@ async def _handshake(self, layout: Layout) -> None: await self._indicate_ready(), async def _indicate_ready(self) -> None: + if MAX_HOT_RELOADING: + await self._send( + PingIntervalSetMessage(type="ping-interval-set", ping_interval=250) + ) await self._send(IsReadyMessage(type="is-ready", salt=self._salt)) + if MAX_HOT_RELOADING: + + async def _handle_rebuild_msg(self, msg: LayoutUpdateMessage) -> None: + await self._send(msg) + + else: + + async def _handle_rebuild_msg(self, msg: LayoutUpdateMessage) -> None: + pass # do nothing + async def _do_state_rebuild_for_reconnection(self, layout: Layout) -> str: salt = self._salt await self._send(ClientStateMessage(type="client-state")) @@ -197,7 +237,8 @@ async def _do_state_rebuild_for_reconnection(self, layout: Layout) -> str: salt = client_state_msg["salt"] layout.start_rendering_for_reconnect() - await layout.render_until_queue_empty() + async for msg in layout.render_until_queue_empty(): + await self._handle_rebuild_msg(msg) except StateRecoveryFailureError: logger.warning( "State recovery failed (likely client from different version). Starting fresh" diff --git a/src/py/reactpy/reactpy/core/types.py b/src/py/reactpy/reactpy/core/types.py index a4be74f61..b2b070b24 100644 --- a/src/py/reactpy/reactpy/core/types.py +++ b/src/py/reactpy/reactpy/core/types.py @@ -263,6 +263,12 @@ class LayoutEventMessage(TypedDict): """A list of event data passed to the event handler.""" +class PingIntervalSetMessage(TypedDict): + type: Literal["ping-interval-set"] + + ping_interval: int + + class Context(Protocol[_Type]): """Returns a :class:`ContextProvider` component""" diff --git a/src/py/reactpy/reactpy/hot_reloading.py b/src/py/reactpy/reactpy/hot_reloading.py new file mode 100644 index 000000000..6aae079a9 --- /dev/null +++ b/src/py/reactpy/reactpy/hot_reloading.py @@ -0,0 +1,29 @@ +import asyncio +import logging + +logger = logging.getLogger(__name__) + +active_connections = [] + + +def monkeypatch_jurigged_to_kill_connections_if_function_update(): + import jurigged.codetools as jurigged_codetools # type: ignore + + OrigFunctionDefinition = jurigged_codetools.FunctionDefinition + + class NewFunctionDefinition(OrigFunctionDefinition): + def reevaluate(self, new_node, glb): + if active_connections: + logger.info("Killing active connections") + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + tasks = [ + connection.carrier.websocket.close() + for connection in active_connections + ] + loop.run_until_complete(asyncio.gather(*tasks)) + loop.close() + active_connections.clear() + return super().reevaluate(new_node, glb) + + jurigged_codetools.FunctionDefinition = NewFunctionDefinition