Skip to content

Commit c05733a

Browse files
Handle some cases during infer schema from dataclass (#37855)
* Handle some cases during infer schema from dataclass * For backward compatibility, only infer schema for frozen dataclasses when it's registered with row coder * Make sure Beam schema ID does not inherit * Fix IndexOutofBoundError trying to infer type from custom Iterable without type hint * Fix #37862: fixed named tuple and effectively fails dataclass inside union typehint * Allow non-frozen dataclass register with other coders as a backup for backward compatibility; add tests * Add upgrade compatibility check for potential coder change * Update CHANGES.md Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Test case for update_compatibility_version --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent b3d2b06 commit c05733a

File tree

7 files changed

+222
-15
lines changed

7 files changed

+222
-15
lines changed

CHANGES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
## New Features / Improvements
7171

7272
* Added support for large pipeline options via a file (Python) ([#37370](https://github.com/apache/beam/issues/37370)).
73+
* Supported infer schema from dataclass (Python) ([#22085](https://github.com/apache/beam/issues/22085)). Default coder for typehint-ed (or set with_output_type) for non-frozen dataclasses changed to RowCoder. To preserve the old behavior (fast primitive coder), explicitly register the type with FastPrimitiveCoder.
7374

7475
## Breaking Changes
7576

sdks/python/apache_beam/coders/coder_impl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ def encode_special_deterministic(self, value, stream):
493493
stream.write_byte(PROTO_TYPE)
494494
self.encode_type(type(value), stream)
495495
stream.write(value.SerializePartialToString(deterministic=True), True)
496-
elif dataclasses and dataclasses.is_dataclass(value):
496+
elif dataclasses.is_dataclass(value):
497497
if not type(value).__dataclass_params__.frozen:
498498
raise TypeError(
499499
"Unable to deterministically encode non-frozen '%s' of type '%s' "

sdks/python/apache_beam/typehints/native_type_compatibility.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,50 @@ def match_is_named_tuple(user_type):
176176
hasattr(user_type, '__annotations__') and hasattr(user_type, '_fields'))
177177

178178

179-
def match_is_dataclass(user_type):
180-
return dataclasses.is_dataclass(user_type) and isinstance(user_type, type)
179+
def match_dataclass_for_row(user_type):
180+
"""Match whether the type is a dataclass handled by row coder.
181+
182+
For frozen dataclasses, only true when explicitly registered with row coder:
183+
184+
beam.coders.typecoders.registry.register_coder(
185+
MyDataClass, beam.coders.RowCoder)
186+
187+
(for backward-compatibility reason).
188+
189+
For non-frozen dataclasses, default to true otherwise explicitly registered
190+
with a coder other than the row coder.
191+
"""
192+
193+
if not dataclasses.is_dataclass(user_type):
194+
return False
195+
196+
# pylint: disable=wrong-import-position
197+
try:
198+
from apache_beam.options.pipeline_options_context import get_pipeline_options # pylint: disable=line-too-long
199+
except AttributeError:
200+
pass
201+
else:
202+
opts = get_pipeline_options()
203+
if opts and opts.is_compat_version_prior_to("2.73.0"):
204+
return False
205+
206+
is_frozen = user_type.__dataclass_params__.frozen
207+
# avoid circular import
208+
try:
209+
from apache_beam.coders.typecoders import registry as coders_registry
210+
from apache_beam.coders import RowCoder
211+
except AttributeError:
212+
# coder registery not yet initialized so it must be absent
213+
return not is_frozen
214+
215+
if is_frozen:
216+
return (
217+
user_type in coders_registry._coders and
218+
coders_registry._coders[user_type] == RowCoder)
219+
else:
220+
return (
221+
user_type not in coders_registry._coders or
222+
coders_registry._coders[user_type] == RowCoder)
181223

182224

183225
def _match_is_optional(user_type):

sdks/python/apache_beam/typehints/native_type_compatibility_test.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,17 @@
2020
# pytype: skip-file
2121

2222
import collections.abc
23+
import dataclasses
2324
import enum
2425
import re
2526
import typing
2627
import unittest
2728

29+
from parameterized import param
30+
from parameterized import parameterized
31+
32+
from apache_beam.options.pipeline_options import PipelineOptions
33+
from apache_beam.options.pipeline_options_context import scoped_pipeline_options
2834
from apache_beam.typehints import typehints
2935
from apache_beam.typehints.native_type_compatibility import convert_builtin_to_typing
3036
from apache_beam.typehints.native_type_compatibility import convert_to_beam_type
@@ -33,6 +39,7 @@
3339
from apache_beam.typehints.native_type_compatibility import convert_to_python_types
3440
from apache_beam.typehints.native_type_compatibility import convert_typing_to_builtin
3541
from apache_beam.typehints.native_type_compatibility import is_any
42+
from apache_beam.typehints.native_type_compatibility import match_dataclass_for_row
3643

3744
_TestNamedTuple = typing.NamedTuple(
3845
'_TestNamedTuple', [('age', int), ('name', bytes)])
@@ -509,6 +516,58 @@ def test_type_alias_type_unwrapped(self):
509516
self.assertEqual(
510517
typehints.Tuple[int, ...], convert_to_beam_type(AliasTuple))
511518

519+
def test_dataclass_default(self):
520+
@dataclasses.dataclass(frozen=True)
521+
class FrozenDC:
522+
foo: int
523+
524+
@dataclasses.dataclass
525+
class NonFrozenDC:
526+
foo: int
527+
528+
self.assertFalse(match_dataclass_for_row(FrozenDC))
529+
self.assertTrue(match_dataclass_for_row(NonFrozenDC))
530+
531+
def test_dataclass_registered(self):
532+
@dataclasses.dataclass(frozen=True)
533+
class FrozenRegisteredDC:
534+
foo: int
535+
536+
@dataclasses.dataclass
537+
class NonFrozenRegisteredDC:
538+
foo: int
539+
540+
# pylint: disable=wrong-import-position
541+
from apache_beam.coders import RowCoder
542+
from apache_beam.coders import typecoders
543+
from apache_beam.coders.coders import FastPrimitivesCoder
544+
545+
typecoders.registry.register_coder(FrozenRegisteredDC, RowCoder)
546+
typecoders.registry.register_coder(
547+
NonFrozenRegisteredDC, FastPrimitivesCoder)
548+
549+
self.assertTrue(match_dataclass_for_row(FrozenRegisteredDC))
550+
self.assertFalse(match_dataclass_for_row(NonFrozenRegisteredDC))
551+
552+
@parameterized.expand([
553+
param(compat_version="2.72.0"),
554+
param(compat_version="2.73.0"),
555+
])
556+
def test_dataclass_update_compatibility(self, compat_version):
557+
@dataclasses.dataclass(frozen=True)
558+
class FrozenDC:
559+
foo: int
560+
561+
@dataclasses.dataclass
562+
class NonFrozenDC:
563+
foo: int
564+
565+
with scoped_pipeline_options(
566+
PipelineOptions(update_compatibility_version=compat_version)):
567+
self.assertFalse(match_dataclass_for_row(FrozenDC))
568+
self.assertEqual(
569+
compat_version == "2.73.0", match_dataclass_for_row(NonFrozenDC))
570+
512571

513572
if __name__ == '__main__':
514573
unittest.main()

sdks/python/apache_beam/typehints/row_type.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from typing import Tuple
2828

2929
from apache_beam.typehints import typehints
30-
from apache_beam.typehints.native_type_compatibility import match_is_dataclass
30+
from apache_beam.typehints.native_type_compatibility import match_dataclass_for_row
3131
from apache_beam.typehints.native_type_compatibility import match_is_named_tuple
3232
from apache_beam.typehints.schema_registry import SchemaTypeRegistry
3333

@@ -91,6 +91,9 @@ def __init__(
9191
# Currently registration happens when converting to schema protos, in
9292
# apache_beam.typehints.schemas
9393
self._schema_id = getattr(self._user_type, _BEAM_SCHEMA_ID, None)
94+
if self._schema_id and _BEAM_SCHEMA_ID not in self._user_type.__dict__:
95+
# schema id does not inherit. Unset if schema id is from base class
96+
self._schema_id = None
9497

9598
self._schema_options = schema_options or []
9699
self._field_options = field_options or {}
@@ -105,7 +108,7 @@ def from_user_type(
105108
if match_is_named_tuple(user_type):
106109
fields = [(name, user_type.__annotations__[name])
107110
for name in user_type._fields]
108-
elif match_is_dataclass(user_type):
111+
elif match_dataclass_for_row(user_type):
109112
fields = [(field.name, field.type)
110113
for field in dataclasses.fields(user_type)]
111114
else:

sdks/python/apache_beam/typehints/row_type_test.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from apache_beam.testing.util import assert_that
2727
from apache_beam.testing.util import equal_to
2828
from apache_beam.typehints import row_type
29+
from apache_beam.typehints import schemas
2930

3031

3132
class RowTypeTest(unittest.TestCase):
@@ -85,6 +86,94 @@ def generate(num: int):
8586
| 'Count Elements' >> beam.Map(self._check_key_type_and_count))
8687
assert_that(result, equal_to([10] * 100))
8788

89+
def test_group_by_key_namedtuple_union(self):
90+
Tuple1 = typing.NamedTuple("Tuple1", [("id", int)])
91+
92+
Tuple2 = typing.NamedTuple("Tuple2", [("id", int), ("name", str)])
93+
94+
def generate(num: int):
95+
for i in range(2):
96+
yield (Tuple1(i), num)
97+
yield (Tuple2(i, 'a'), num)
98+
99+
pipeline = TestPipeline(is_integration_test=False)
100+
101+
with pipeline as p:
102+
result = (
103+
p
104+
| 'Create' >> beam.Create([i for i in range(2)])
105+
| 'Generate' >> beam.ParDo(generate).with_output_types(
106+
tuple[(Tuple1 | Tuple2), int])
107+
| 'GBK' >> beam.GroupByKey()
108+
| 'Count' >> beam.Map(lambda x: len(x[1])))
109+
assert_that(result, equal_to([2] * 4))
110+
111+
# Union of dataclasses as type hint currently result in FastPrimitiveCoder
112+
# fails at GBK
113+
@unittest.skip("https://github.com/apache/beam/issues/22085")
114+
def test_group_by_key_inherited_dataclass_union(self):
115+
@dataclass
116+
class DataClassInt:
117+
id: int
118+
119+
@dataclass
120+
class DataClassStr(DataClassInt):
121+
name: str
122+
123+
beam.coders.typecoders.registry.register_coder(
124+
DataClassInt, beam.coders.RowCoder)
125+
beam.coders.typecoders.registry.register_coder(
126+
DataClassStr, beam.coders.RowCoder)
127+
128+
def generate(num: int):
129+
for i in range(10):
130+
yield (DataClassInt(i), num)
131+
yield (DataClassStr(i, 'a'), num)
132+
133+
pipeline = TestPipeline(is_integration_test=False)
134+
135+
with pipeline as p:
136+
result = (
137+
p
138+
| 'Create' >> beam.Create([i for i in range(2)])
139+
| 'Generate' >> beam.ParDo(generate).with_output_types(
140+
tuple[(DataClassInt | DataClassStr), int])
141+
| 'GBK' >> beam.GroupByKey()
142+
| 'Count Elements' >> beam.Map(self._check_key_type_and_count))
143+
assert_that(result, equal_to([2] * 4))
144+
145+
def test_derived_dataclass_schema_id(self):
146+
@dataclass
147+
class BaseDataClass:
148+
id: int
149+
150+
@dataclass
151+
class DerivedDataClass(BaseDataClass):
152+
name: str
153+
154+
self.assertFalse(hasattr(BaseDataClass, row_type._BEAM_SCHEMA_ID))
155+
schema_for_base = schemas.schema_from_element_type(BaseDataClass)
156+
self.assertTrue(hasattr(BaseDataClass, row_type._BEAM_SCHEMA_ID))
157+
self.assertEqual(
158+
schema_for_base.id, getattr(BaseDataClass, row_type._BEAM_SCHEMA_ID))
159+
160+
# Getting the schema for BaseDataClass sets the _beam_schema_id
161+
schemas.typing_to_runner_api(
162+
BaseDataClass, schema_registry=schemas.SchemaTypeRegistry())
163+
164+
# We create a RowTypeConstraint from DerivedDataClass.
165+
# It should not inherit the _beam_schema_id from BaseDataClass!
166+
derived_row_type = row_type.RowTypeConstraint.from_user_type(
167+
DerivedDataClass)
168+
self.assertIsNone(derived_row_type._schema_id)
169+
170+
schema_for_derived = schemas.schema_from_element_type(DerivedDataClass)
171+
self.assertTrue(hasattr(DerivedDataClass, row_type._BEAM_SCHEMA_ID))
172+
self.assertEqual(
173+
schema_for_derived.id,
174+
getattr(DerivedDataClass, row_type._BEAM_SCHEMA_ID))
175+
self.assertNotEqual(schema_for_derived.id, schema_for_base.id)
176+
88177

89178
if __name__ == '__main__':
90179
unittest.main()

sdks/python/apache_beam/typehints/schemas.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@
9696
from apache_beam.typehints.native_type_compatibility import _safe_issubclass
9797
from apache_beam.typehints.native_type_compatibility import convert_to_python_type
9898
from apache_beam.typehints.native_type_compatibility import extract_optional_type
99-
from apache_beam.typehints.native_type_compatibility import match_is_dataclass
99+
from apache_beam.typehints.native_type_compatibility import match_dataclass_for_row
100100
from apache_beam.typehints.native_type_compatibility import match_is_named_tuple
101101
from apache_beam.typehints.schema_registry import SCHEMA_REGISTRY
102102
from apache_beam.typehints.schema_registry import SchemaTypeRegistry
@@ -335,19 +335,23 @@ def typing_to_runner_api(self, type_: type) -> schema_pb2.FieldType:
335335
atomic_type=PRIMITIVE_TO_ATOMIC_TYPE[int])))
336336

337337
elif _safe_issubclass(type_, Sequence) and not _safe_issubclass(type_, str):
338-
element_type = self.typing_to_runner_api(_get_args(type_)[0])
339-
return schema_pb2.FieldType(
340-
array_type=schema_pb2.ArrayType(element_type=element_type))
338+
arg_types = _get_args(type_)
339+
if len(arg_types) > 0:
340+
element_type = self.typing_to_runner_api(arg_types[0])
341+
return schema_pb2.FieldType(
342+
array_type=schema_pb2.ArrayType(element_type=element_type))
341343

342344
elif _safe_issubclass(type_, Mapping):
343345
key_type, value_type = map(self.typing_to_runner_api, _get_args(type_))
344346
return schema_pb2.FieldType(
345347
map_type=schema_pb2.MapType(key_type=key_type, value_type=value_type))
346348

347349
elif _safe_issubclass(type_, Iterable) and not _safe_issubclass(type_, str):
348-
element_type = self.typing_to_runner_api(_get_args(type_)[0])
349-
return schema_pb2.FieldType(
350-
array_type=schema_pb2.ArrayType(element_type=element_type))
350+
arg_types = _get_args(type_)
351+
if len(arg_types) > 0:
352+
element_type = self.typing_to_runner_api(arg_types[0])
353+
return schema_pb2.FieldType(
354+
array_type=schema_pb2.ArrayType(element_type=element_type))
351355

352356
try:
353357
if LogicalType.is_known_logical_type(type_):
@@ -630,8 +634,10 @@ def schema_from_element_type(element_type: type) -> schema_pb2.Schema:
630634
Returns schema as a list of (name, python_type) tuples"""
631635
if isinstance(element_type, row_type.RowTypeConstraint):
632636
return named_fields_to_schema(element_type._fields)
633-
elif match_is_named_tuple(element_type) or match_is_dataclass(element_type):
634-
if hasattr(element_type, row_type._BEAM_SCHEMA_ID):
637+
elif match_is_named_tuple(element_type) or match_dataclass_for_row(
638+
element_type):
639+
# schema id does not inherit from base classes
640+
if row_type._BEAM_SCHEMA_ID in element_type.__dict__:
635641
# if the named tuple's schema is in registry, we just use it instead of
636642
# regenerating one.
637643
schema_id = getattr(element_type, row_type._BEAM_SCHEMA_ID)
@@ -657,8 +663,15 @@ def union_schema_type(element_types):
657663
element_types must be a set of schema-aware types whose fields have the
658664
same naming and ordering.
659665
"""
666+
named_fields_and_types = []
667+
for t in element_types:
668+
n = named_fields_from_element_type(t)
669+
if named_fields_and_types and len(named_fields_and_types[-1]) != len(n):
670+
raise TypeError("element types has different number of fields")
671+
named_fields_and_types.append(n)
672+
660673
union_fields_and_types = []
661-
for field in zip(*[named_fields_from_element_type(t) for t in element_types]):
674+
for field in zip(*named_fields_and_types):
662675
names, types = zip(*field)
663676
name_set = set(names)
664677
if len(name_set) != 1:

0 commit comments

Comments
 (0)