Skip to content

Commit ef23652

Browse files
committed
Improve dataclass serialization
1 parent c658a52 commit ef23652

File tree

2 files changed

+347
-5
lines changed

2 files changed

+347
-5
lines changed

durabletask/internal/shared.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,15 +103,18 @@ def encode(self, obj: Any) -> str:
103103
return super().encode(obj)
104104

105105
def default(self, obj):
106-
if dataclasses.is_dataclass(obj):
106+
if dataclasses.is_dataclass(obj) and not isinstance(obj, type):
107107
# Dataclasses are not serializable by default, so we convert them to a dict and mark them for
108-
# automatic deserialization by the receiver
109-
d = dataclasses.asdict(obj) # type: ignore
108+
# automatic deserialization by the receiver. We use a shallow field extraction instead of
109+
# dataclasses.asdict() so that nested dataclass values are re-processed by the encoder
110+
# individually (each receiving their own AUTO_SERIALIZED marker).
111+
d = {f.name: getattr(obj, f.name) for f in dataclasses.fields(obj)}
110112
d[AUTO_SERIALIZED] = True
111113
return d
112114
elif isinstance(obj, SimpleNamespace):
113-
# Most commonly used for serializing custom objects that were previously serialized using our encoder
114-
d = vars(obj)
115+
# Most commonly used for serializing custom objects that were previously serialized using our encoder.
116+
# Copy the dict to avoid mutating the original object.
117+
d = dict(vars(obj))
115118
d[AUTO_SERIALIZED] = True
116119
return d
117120
# This will typically raise a TypeError
Lines changed: 339 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,339 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
import dataclasses
5+
from collections import namedtuple
6+
from types import SimpleNamespace
7+
8+
import pytest
9+
10+
from durabletask.internal.shared import (
11+
AUTO_SERIALIZED,
12+
from_json,
13+
to_json,
14+
)
15+
16+
17+
# --- Dataclass fixtures ---
18+
19+
@dataclasses.dataclass
20+
class SimpleData:
21+
x: int
22+
y: str
23+
24+
25+
@dataclasses.dataclass
26+
class InnerData:
27+
value: int
28+
29+
30+
@dataclasses.dataclass
31+
class OuterData:
32+
inner: InnerData
33+
label: str
34+
35+
36+
@dataclasses.dataclass
37+
class DeeplyNested:
38+
outer: OuterData
39+
flag: bool
40+
41+
42+
# --- Namedtuple fixtures ---
43+
44+
Point = namedtuple("Point", ["x", "y"])
45+
46+
47+
class TestDataclassSerialization:
48+
"""Tests for dataclass serialization/deserialization via to_json/from_json."""
49+
50+
def test_simple_dataclass_round_trip(self):
51+
"""A simple dataclass should serialize and deserialize to a SimpleNamespace."""
52+
obj = SimpleData(x=1, y="hello")
53+
json_str = to_json(obj)
54+
result = from_json(json_str)
55+
56+
assert isinstance(result, SimpleNamespace)
57+
assert result.x == 1
58+
assert result.y == "hello"
59+
60+
def test_simple_dataclass_json_contains_auto_serialized_marker(self):
61+
"""The JSON output should contain the AUTO_SERIALIZED marker."""
62+
obj = SimpleData(x=1, y="hello")
63+
json_str = to_json(obj)
64+
65+
assert AUTO_SERIALIZED in json_str
66+
67+
def test_nested_dataclass_round_trip(self):
68+
"""Nested dataclasses should all deserialize to SimpleNamespace, not dict."""
69+
obj = OuterData(inner=InnerData(value=42), label="test")
70+
json_str = to_json(obj)
71+
result = from_json(json_str)
72+
73+
assert isinstance(result, SimpleNamespace)
74+
assert isinstance(result.inner, SimpleNamespace), (
75+
"Inner dataclass should deserialize to SimpleNamespace, not dict"
76+
)
77+
assert result.inner.value == 42
78+
assert result.label == "test"
79+
80+
def test_deeply_nested_dataclass_round_trip(self):
81+
"""Deeply nested dataclasses should all deserialize to SimpleNamespace."""
82+
obj = DeeplyNested(
83+
outer=OuterData(inner=InnerData(value=7), label="deep"),
84+
flag=True,
85+
)
86+
json_str = to_json(obj)
87+
result = from_json(json_str)
88+
89+
assert isinstance(result, SimpleNamespace)
90+
assert isinstance(result.outer, SimpleNamespace)
91+
assert isinstance(result.outer.inner, SimpleNamespace)
92+
assert result.outer.inner.value == 7
93+
assert result.outer.label == "deep"
94+
assert result.flag is True
95+
96+
def test_dataclass_inside_dict(self):
97+
"""A dataclass value inside a dict should round-trip as SimpleNamespace."""
98+
obj = {"data": SimpleData(x=10, y="world")}
99+
json_str = to_json(obj)
100+
result = from_json(json_str)
101+
102+
assert isinstance(result, dict)
103+
assert isinstance(result["data"], SimpleNamespace)
104+
assert result["data"].x == 10
105+
assert result["data"].y == "world"
106+
107+
def test_dataclass_inside_list(self):
108+
"""Dataclass items inside a list should round-trip as SimpleNamespace."""
109+
items = [SimpleData(x=1, y="a"), SimpleData(x=2, y="b")]
110+
json_str = to_json(items)
111+
result = from_json(json_str)
112+
113+
assert isinstance(result, list)
114+
assert len(result) == 2
115+
for item in result:
116+
assert isinstance(item, SimpleNamespace)
117+
assert result[0].x == 1
118+
assert result[1].y == "b"
119+
120+
def test_array_of_nested_dataclasses(self):
121+
"""An array of dataclasses with nested dataclass fields should fully round-trip."""
122+
items = [
123+
OuterData(inner=InnerData(value=1), label="first"),
124+
OuterData(inner=InnerData(value=2), label="second"),
125+
]
126+
json_str = to_json(items)
127+
result = from_json(json_str)
128+
129+
assert isinstance(result, list)
130+
assert len(result) == 2
131+
for item in result:
132+
assert isinstance(item, SimpleNamespace)
133+
assert isinstance(item.inner, SimpleNamespace)
134+
assert result[0].inner.value == 1
135+
assert result[0].label == "first"
136+
assert result[1].inner.value == 2
137+
assert result[1].label == "second"
138+
139+
def test_nested_array_of_dataclasses(self):
140+
"""An array nested inside another array of dataclasses should round-trip."""
141+
items = [
142+
[SimpleData(x=1, y="a"), SimpleData(x=2, y="b")],
143+
[SimpleData(x=3, y="c")],
144+
]
145+
json_str = to_json(items)
146+
result = from_json(json_str)
147+
148+
assert isinstance(result, list)
149+
assert len(result) == 2
150+
assert isinstance(result[0], list)
151+
assert len(result[0]) == 2
152+
assert isinstance(result[1], list)
153+
assert len(result[1]) == 1
154+
for sublist in result:
155+
for item in sublist:
156+
assert isinstance(item, SimpleNamespace)
157+
assert result[0][0].x == 1
158+
assert result[0][1].y == "b"
159+
assert result[1][0].x == 3
160+
161+
def test_dict_with_nested_dataclass_values(self):
162+
"""Dict values that are nested dataclasses should fully round-trip."""
163+
obj = {"item": OuterData(inner=InnerData(value=99), label="nested")}
164+
json_str = to_json(obj)
165+
result = from_json(json_str)
166+
167+
assert isinstance(result, dict)
168+
assert isinstance(result["item"], SimpleNamespace)
169+
assert isinstance(result["item"].inner, SimpleNamespace)
170+
assert result["item"].inner.value == 99
171+
assert result["item"].label == "nested"
172+
173+
def test_dict_with_multiple_dataclass_values(self):
174+
"""A dict with several dataclass values should all round-trip."""
175+
obj = {
176+
"a": SimpleData(x=1, y="one"),
177+
"b": SimpleData(x=2, y="two"),
178+
}
179+
json_str = to_json(obj)
180+
result = from_json(json_str)
181+
182+
assert isinstance(result, dict)
183+
for key in ("a", "b"):
184+
assert isinstance(result[key], SimpleNamespace)
185+
assert result["a"].x == 1
186+
assert result["b"].y == "two"
187+
188+
def test_dict_with_array_of_dataclasses(self):
189+
"""A dict whose value is a list of dataclasses should round-trip."""
190+
obj = {"items": [SimpleData(x=1, y="a"), SimpleData(x=2, y="b")]}
191+
json_str = to_json(obj)
192+
result = from_json(json_str)
193+
194+
assert isinstance(result, dict)
195+
assert isinstance(result["items"], list)
196+
assert len(result["items"]) == 2
197+
for item in result["items"]:
198+
assert isinstance(item, SimpleNamespace)
199+
assert result["items"][0].x == 1
200+
assert result["items"][1].y == "b"
201+
202+
def test_dict_with_array_of_nested_dataclasses(self):
203+
"""A dict whose value is a list of nested dataclasses should fully round-trip."""
204+
obj = {
205+
"records": [
206+
OuterData(inner=InnerData(value=10), label="r1"),
207+
OuterData(inner=InnerData(value=20), label="r2"),
208+
]
209+
}
210+
json_str = to_json(obj)
211+
result = from_json(json_str)
212+
213+
assert isinstance(result, dict)
214+
assert isinstance(result["records"], list)
215+
for item in result["records"]:
216+
assert isinstance(item, SimpleNamespace)
217+
assert isinstance(item.inner, SimpleNamespace)
218+
assert result["records"][0].inner.value == 10
219+
assert result["records"][1].label == "r2"
220+
221+
def test_dataclass_with_list_of_dataclass_field(self):
222+
"""A dataclass containing a list-of-dataclass field should round-trip."""
223+
@dataclasses.dataclass
224+
class Container:
225+
items: list
226+
227+
obj = Container(items=[InnerData(value=1), InnerData(value=2)])
228+
json_str = to_json(obj)
229+
result = from_json(json_str)
230+
231+
assert isinstance(result, SimpleNamespace)
232+
assert isinstance(result.items, list)
233+
assert len(result.items) == 2
234+
for item in result.items:
235+
assert isinstance(item, SimpleNamespace)
236+
assert result.items[0].value == 1
237+
assert result.items[1].value == 2
238+
239+
def test_dataclass_with_dict_of_dataclass_field(self):
240+
"""A dataclass containing a dict-of-dataclass field should round-trip."""
241+
@dataclasses.dataclass
242+
class Mapping:
243+
entries: dict
244+
245+
obj = Mapping(entries={"a": InnerData(value=5), "b": InnerData(value=6)})
246+
json_str = to_json(obj)
247+
result = from_json(json_str)
248+
249+
assert isinstance(result, SimpleNamespace)
250+
assert isinstance(result.entries, dict)
251+
for val in result.entries.values():
252+
assert isinstance(val, SimpleNamespace)
253+
assert result.entries["a"].value == 5
254+
assert result.entries["b"].value == 6
255+
256+
257+
class TestSimpleNamespaceSerialization:
258+
"""Tests for SimpleNamespace serialization."""
259+
260+
def test_simple_namespace_round_trip(self):
261+
"""SimpleNamespace should serialize and deserialize correctly."""
262+
obj = SimpleNamespace(a=1, b="two")
263+
json_str = to_json(obj)
264+
result = from_json(json_str)
265+
266+
assert isinstance(result, SimpleNamespace)
267+
assert result.a == 1
268+
assert result.b == "two"
269+
270+
def test_simple_namespace_not_mutated(self):
271+
"""Serializing a SimpleNamespace should NOT mutate the original object."""
272+
obj = SimpleNamespace(x=1, y=2)
273+
original_attrs = set(vars(obj).keys())
274+
275+
to_json(obj)
276+
277+
current_attrs = set(vars(obj).keys())
278+
assert current_attrs == original_attrs, (
279+
f"Original SimpleNamespace was mutated: gained {current_attrs - original_attrs}"
280+
)
281+
assert not hasattr(obj, AUTO_SERIALIZED)
282+
283+
def test_nested_simple_namespace_round_trip(self):
284+
"""Nested SimpleNamespace should deserialize as SimpleNamespace."""
285+
obj = SimpleNamespace(child=SimpleNamespace(val=99))
286+
json_str = to_json(obj)
287+
result = from_json(json_str)
288+
289+
assert isinstance(result, SimpleNamespace)
290+
assert isinstance(result.child, SimpleNamespace)
291+
assert result.child.val == 99
292+
293+
294+
class TestNamedtupleSerialization:
295+
"""Tests for namedtuple serialization."""
296+
297+
def test_namedtuple_top_level_round_trip(self):
298+
"""A namedtuple at the top level should serialize with field names."""
299+
p = Point(x=3, y=4)
300+
json_str = to_json(p)
301+
result = from_json(json_str)
302+
303+
assert isinstance(result, SimpleNamespace)
304+
assert result.x == 3
305+
assert result.y == 4
306+
307+
308+
class TestPrimitiveSerialization:
309+
"""Tests for primitive/basic type round-trips."""
310+
311+
@pytest.mark.parametrize("value", [
312+
42,
313+
3.14,
314+
"hello",
315+
True,
316+
False,
317+
None,
318+
[1, 2, 3],
319+
{"key": "value"},
320+
])
321+
def test_primitive_round_trip(self, value):
322+
"""Primitive types should round-trip unchanged."""
323+
json_str = to_json(value)
324+
result = from_json(json_str)
325+
assert result == value
326+
327+
328+
class TestDataclassNotMutated:
329+
"""Ensure serialization does not mutate dataclass inputs."""
330+
331+
def test_dataclass_not_mutated(self):
332+
"""Serializing a dataclass should not add attributes to the original."""
333+
obj = SimpleData(x=1, y="test")
334+
to_json(obj)
335+
336+
# dataclass fields should be unchanged
337+
assert obj.x == 1
338+
assert obj.y == "test"
339+
assert not hasattr(obj, AUTO_SERIALIZED)

0 commit comments

Comments
 (0)