Skip to content

Commit bbd0315

Browse files
chore: use signal.signal and add test
1 parent 869f7d4 commit bbd0315

File tree

3 files changed

+31
-7
lines changed

3 files changed

+31
-7
lines changed

crossplane/function/resource.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ def update(r: fnv1.Resource, source: dict | structpb.Struct | pydantic.BaseModel
4545
# apiVersion is set to its default value 's3.aws.upbound.io/v1beta2'
4646
# (and not explicitly provided during initialization), it will be
4747
# excluded from the serialized output.
48-
data['apiVersion'] = source.apiVersion
49-
data['kind'] = source.kind
48+
data["apiVersion"] = source.apiVersion
49+
data["kind"] = source.kind
5050
r.resource.update(data)
5151
case structpb.Struct():
5252
# TODO(negz): Use struct_to_dict and update to match other semantics?

crossplane/function/runtime.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ def load_credentials(tls_certs_dir: str) -> grpc.ServerCredentials:
6666
)
6767

6868

69-
async def _stop(server, timeout): # noqa: ASYNC109
70-
await server.stop(grace=timeout)
69+
async def _stop(server, grace=GRACE_PERIOD):
70+
await server.stop(grace=grace)
7171

7272

7373
def serve(
@@ -96,8 +96,9 @@ def serve(
9696

9797
server = grpc.aio.server()
9898

99-
loop.add_signal_handler(
100-
signal.SIGTERM, lambda: asyncio.create_task(_stop(server, timeout=GRACE_PERIOD))
99+
signal.signal(
100+
signal.SIGTERM,
101+
lambda _, __: asyncio.create_task(_stop(server)),
101102
)
102103

103104
grpcv1.add_FunctionRunnerServiceServicer_to_server(function, server)
@@ -126,7 +127,8 @@ async def start():
126127
try:
127128
loop.run_until_complete(start())
128129
finally:
129-
loop.run_until_complete(server.stop(grace=GRACE_PERIOD))
130+
if server._server.is_running():
131+
loop.run_until_complete(server.stop(grace=GRACE_PERIOD))
130132
loop.close()
131133

132134

tests/test_runtime.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import asyncio
1516
import dataclasses
17+
import os
18+
import signal
1619
import unittest
1720

1821
import grpc
@@ -52,6 +55,25 @@ class TestCase:
5255

5356
self.assertEqual(rsp, case.want, "-want, +got")
5457

58+
async def test_sigterm_handling(self) -> None:
59+
async def mock_server():
60+
await server.start()
61+
await asyncio.sleep(1)
62+
self.assertTrue(server._server.is_running(), "Server should be running")
63+
os.kill(os.getpid(), signal.SIGTERM)
64+
await server.wait_for_termination()
65+
self.assertFalse(
66+
server._server.is_running(),
67+
"Server should have been stopped on SIGTERM",
68+
)
69+
70+
server = grpc.aio.server()
71+
signal.signal(
72+
signal.SIGTERM,
73+
lambda _, __: asyncio.create_task(runtime._stop(server)),
74+
)
75+
await mock_server()
76+
5577

5678
class EchoRunner(grpcv1.FunctionRunnerService):
5779
def __init__(self):

0 commit comments

Comments
 (0)