diff --git a/python/pyspark/sql/streaming/stateful_processor_api_client.py b/python/pyspark/sql/streaming/stateful_processor_api_client.py index eab91e0c3f840..ce0c0df8c7d15 100644 --- a/python/pyspark/sql/streaming/stateful_processor_api_client.py +++ b/python/pyspark/sql/streaming/stateful_processor_api_client.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from datetime import datetime from enum import Enum import json import os @@ -27,12 +28,43 @@ Row, ) from pyspark.sql.pandas.types import convert_pandas_using_numpy_type -from pyspark.serializers import CPickleSerializer +from pyspark.serializers import PickleSerializer from pyspark.errors import PySparkRuntimeError import uuid __all__ = ["StatefulProcessorApiClient", "StatefulProcessorHandleState"] +try: + import numpy as np + + has_numpy = True + SCALAR_TYPES = frozenset((bool, int, float, str, bytes, datetime, type(None))) + + def _normalize_state_value(v: Any) -> Any: + if type(v) in SCALAR_TYPES: # Fast path for common scalar values. + return v + # Convert NumPy scalar values to Python primitive values. + if isinstance(v, np.generic): + return v.tolist() + # Named tuples (collections.namedtuple or typing.NamedTuple) and Row both + # require positional arguments and cannot be instantiated with a generator expression. + if isinstance(v, Row) or (isinstance(v, tuple) and hasattr(v, "_fields")): + return type(v)(*map(_normalize_state_value, v)) + # List / tuple: recursively normalize each element. + if isinstance(v, (list, tuple)): + return type(v)(map(_normalize_state_value, v)) + # Dict: normalize both keys and values. + if isinstance(v, dict): + return {_normalize_state_value(k): _normalize_state_value(val) for k, val in v.items()} + # Address a couple of pandas dtypes too. + if hasattr(v, "to_pytimedelta"): + return v.to_pytimedelta() + if hasattr(v, "to_pydatetime"): + return v.to_pydatetime() + return v +except ImportError: + has_numpy = False + class StatefulProcessorHandleState(Enum): PRE_INIT = 0 @@ -74,7 +106,7 @@ def __init__( else: self.handle_state = StatefulProcessorHandleState.CREATED self.utf8_deserializer = UTF8Deserializer() - self.pickleSer = CPickleSerializer() + self.pickleSer = PickleSerializer() self.serializer = ArrowStreamSerializer() # Dictionaries to store the mapping between iterator id and a tuple of data batch # and the index of the last row that was read. @@ -488,35 +520,8 @@ def _receive_str(self) -> str: return self.utf8_deserializer.loads(self.sockfile) def _serialize_to_bytes(self, schema: StructType, data: Tuple) -> bytes: - from pyspark.testing.utils import have_numpy - - if have_numpy: - import numpy as np - - def normalize_value(v: Any) -> Any: - # Convert NumPy types to Python primitive types. - if isinstance(v, np.generic): - return v.tolist() - # Named tuples (collections.namedtuple or typing.NamedTuple) and Row both - # require positional arguments and cannot be instantiated - # with a generator expression. - if isinstance(v, Row) or (isinstance(v, tuple) and hasattr(v, "_fields")): - return type(v)(*[normalize_value(e) for e in v]) - # List / tuple: recursively normalize each element - if isinstance(v, (list, tuple)): - return type(v)(normalize_value(e) for e in v) - # Dict: normalize both keys and values - if isinstance(v, dict): - return {normalize_value(k): normalize_value(val) for k, val in v.items()} - # Address a couple of pandas dtypes too. - elif hasattr(v, "to_pytimedelta"): - return v.to_pytimedelta() - elif hasattr(v, "to_pydatetime"): - return v.to_pydatetime() - else: - return v - - converted = tuple(normalize_value(v) for v in data) + if has_numpy: + converted = tuple(map(_normalize_state_value, data)) else: converted = data