Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 73 additions & 35 deletions python/pyspark/sql/streaming/stateful_processor_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,69 @@
from pyspark.serializers import write_int, read_int, UTF8Deserializer
from pyspark.sql.pandas.serializers import ArrowStreamSerializer
from pyspark.sql.types import (
AtomicType,
StructType,
Row,
)
from pyspark.sql.pandas.types import convert_pandas_using_numpy_type
from pyspark.sql.utils import has_numpy
from pyspark.serializers import CPickleSerializer
from pyspark.errors import PySparkRuntimeError
import uuid

__all__ = ["StatefulProcessorApiClient", "StatefulProcessorHandleState"]

if has_numpy:
import numpy as np

def _normalize_value(v: Any) -> Any:
# Convert NumPy types to Python primitive types.
if isinstance(v, np.generic):
return v.tolist()
# Named tuples (collections.namedtuple or typing.NamedTuple) and Row both
# require positional arguments and cannot be instantiated
# with a generator expression.
if isinstance(v, Row) or (isinstance(v, tuple) and hasattr(v, "_fields")):
return type(v)(*[_normalize_value(e) for e in v])
# List / tuple: recursively normalize each element
if isinstance(v, (list, tuple)):
return type(v)(_normalize_value(e) for e in v)
# Dict: normalize both keys and values
if isinstance(v, dict):
return {_normalize_value(k): _normalize_value(val) for k, val in v.items()}
# Address a couple of pandas dtypes too.
elif hasattr(v, "to_pytimedelta"):
return v.to_pytimedelta()
elif hasattr(v, "to_pydatetime"):
return v.to_pydatetime()
return v

def _normalize_value_simple(v: Any) -> Any:
"""Fast path for flat schemas: check for numpy scalars and pandas dtypes only."""
if isinstance(v, np.generic):
return v.tolist()
if hasattr(v, "to_pytimedelta"):
return v.to_pytimedelta()
if hasattr(v, "to_pydatetime"):
return v.to_pydatetime()
return v

def _normalize_tuple(data: Tuple) -> Tuple:
return tuple(_normalize_value(v) for v in data)
Copy link
Copy Markdown
Contributor Author

@jiateoh jiateoh Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[edit: this prototype is now in the diff, latter discussion is still relevant though]

Another fastpath option being considered: prevalidate the schema, if it's all non-nested we can skip some checks. Prototype outlined below but with more analysis/tradeoffs in the follow up comment:

    """True if schema has no nested structs, arrays, or maps."""
    return all(
        not isinstance(f.dataType, (StructType, ArrayType, MapType))
        for f in schema.fields
    )

   def _normalize_tuple_simple(data: Tuple) -> Tuple:
        """Fast path: 1 isinstance per element instead of 7 (no nested types)."""
        return tuple(v.tolist() if isinstance(v, np.generic) else v for v in data)

Copy link
Copy Markdown
Contributor Author

@jiateoh jiateoh Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gets a bit tricky with the last two timestamp checks though:

  1. As indicated in the existing comment "Address a couple of pandas dtypes too.", the target case is pandas objects. These already subclass datetime.datetime/timedelta so Schema.toInternal would properly handle them today without the explicit conversions in the normal path. However removing these checks does introduce an expectation that pandas types will always subclass the datetime classes (this is indeed the current expectation though).
  2. The existing API inadvertently supports any python values so long as they can be converted with to_pytimedelta/to_pydatetime. So introducing a fastpath without these checks could be a breaking change for custom datetime values. It's worth pointing out this is already within a branch where we have checked that numpy is present though.
    Supporting this backwards compatibility increases cost from 1 isinstance check to 1 isinstance + 2 hasattr checks though.

For point 2, I ran a microbenchmark with [timestamp (datetime) + long (numeric)] struct on each of the normalize function variants. Averaging over 2 million iterations on an input (datetime(2024, 1, 1), 42):

  1. Full normalize (existing): 0.55 us/row
  2. Fastpath (1-check): 0.26 us/row [52% faster]
  3. Fastpath (3-check): 0.34 us/row [38% faster]

The broader speedup on a full TWS workload will of course vary depending on schemas and row values used though.


def _normalize_tuple_simple(data: Tuple) -> Tuple:
return tuple(_normalize_value_simple(v) for v in data)
else:
def _normalize_tuple(data: Tuple) -> Tuple:
return data # toInternal handles tuples natively

_normalize_tuple_simple = _normalize_tuple


def _is_simple_schema(schema: StructType) -> bool:
"""True if every field is an atomic type (no nested structs, arrays, or maps)."""
return all(isinstance(f.dataType, AtomicType) for f in schema.fields)


class StatefulProcessorHandleState(Enum):
PRE_INIT = 0
Expand Down Expand Up @@ -81,6 +134,10 @@ def __init__(
self.list_timer_iterator_cursors: Dict[str, Tuple[Any, int, bool]] = {}
self.expiry_timer_iterator_cursors: Dict[str, Tuple[Any, int, bool]] = {}

# Cache of schema-id -> fast-serialize callable, so we avoid
# repeated attribute lookups on every _serialize_to_bytes call.
self._serializer_cache: Dict[int, Any] = {}

# statefulProcessorApiClient is initialized per batch per partition,
# so we will have new timestamps for a new batch
self._batch_timestamp = -1
Expand Down Expand Up @@ -487,43 +544,24 @@ def _receive_proto_message_with_timers(self) -> Tuple[int, str, Any, bool]:
def _receive_str(self) -> str:
return self.utf8_deserializer.loads(self.sockfile)

def _serialize_to_bytes(self, schema: StructType, data: Tuple) -> bytes:
from pyspark.testing.utils import have_numpy

if have_numpy:
import numpy as np

def normalize_value(v: Any) -> Any:
# Convert NumPy types to Python primitive types.
if isinstance(v, np.generic):
return v.tolist()
# Named tuples (collections.namedtuple or typing.NamedTuple) and Row both
# require positional arguments and cannot be instantiated
# with a generator expression.
if isinstance(v, Row) or (isinstance(v, tuple) and hasattr(v, "_fields")):
return type(v)(*[normalize_value(e) for e in v])
# List / tuple: recursively normalize each element
if isinstance(v, (list, tuple)):
return type(v)(normalize_value(e) for e in v)
# Dict: normalize both keys and values
if isinstance(v, dict):
return {normalize_value(k): normalize_value(val) for k, val in v.items()}
# Address a couple of pandas dtypes too.
elif hasattr(v, "to_pytimedelta"):
return v.to_pytimedelta()
elif hasattr(v, "to_pydatetime"):
return v.to_pydatetime()
else:
return v

converted = [normalize_value(v) for v in data]
else:
converted = list(data)
def _get_serializer(self, schema: StructType) -> Any:
schema_id = id(schema)
serializer = self._serializer_cache.get(schema_id)
if serializer is not None:
return serializer

field_names = [f.name for f in schema.fields]
row_value = Row(**dict(zip(field_names, converted)))
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copying from the PR description for reference on why this field_name/Row creation is no longer necessary:

  • StructType.toInternal dispatches on type: for dict it looks up by field name, for tuple/list it zips by position. So functionally there is no need to convert the tuple to list. (L521 deletion)
  • Row is a tuple subclass, so it always hit the positional branch.
  • Since 3.0.0 (types.py change notes), Row field names are insertion ordered. Python dictionaries (as of 3.7+) are also insertion ordered.
  • dict(zip(field_names, converted)) → Row(**...) ends up adding extra hops to (1) fetch field names, (2) zip them with row values, (3) create an insertion-ordered dictionary of those field names, and (4) create an insertion-ordered row (dropping the field names which are no longer used). With the end result being a Row (tuple subclass) which uses same positional branch of Schema.toInternal as the original input tuple would.

to_internal = schema.toInternal
dumps = self.pickleSer.dumps
normalize = _normalize_tuple_simple if _is_simple_schema(schema) else _normalize_tuple

return self.pickleSer.dumps(schema.toInternal(row_value))
def _fast_serialize(data: Tuple) -> bytes:
return dumps(to_internal(normalize(data)))

self._serializer_cache[schema_id] = _fast_serialize
return _fast_serialize

def _serialize_to_bytes(self, schema: StructType, data: Tuple) -> bytes:
return self._get_serializer(schema)(data)

def _deserialize_from_bytes(self, value: bytes) -> Any:
return self.pickleSer.loads(value)
Expand Down
8 changes: 8 additions & 0 deletions python/pyspark/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@
from pyspark.errors.exceptions.captured import CapturedException # noqa: F401
from pyspark.find_spark_home import _find_spark_home

has_numpy: bool = False
try:
import numpy as np # noqa: F401

has_numpy = True
except ImportError:
pass

if TYPE_CHECKING:
from py4j.java_collections import JavaArray
from py4j.java_gateway import (
Expand Down