-
Notifications
You must be signed in to change notification settings - Fork 29.1k
[WIP][SPARK-56248][PYTHON][SS] Optimize python TWS stateful processor serialization calls #55039
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
jiateoh
wants to merge
3
commits into
apache:master
Choose a base branch
from
jiateoh:tws_python_serialization_improvements
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+81
−35
Open
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,16 +23,69 @@ | |
| from pyspark.serializers import write_int, read_int, UTF8Deserializer | ||
| from pyspark.sql.pandas.serializers import ArrowStreamSerializer | ||
| from pyspark.sql.types import ( | ||
| AtomicType, | ||
| StructType, | ||
| Row, | ||
| ) | ||
| from pyspark.sql.pandas.types import convert_pandas_using_numpy_type | ||
| from pyspark.sql.utils import has_numpy | ||
| from pyspark.serializers import CPickleSerializer | ||
| from pyspark.errors import PySparkRuntimeError | ||
| import uuid | ||
|
|
||
| __all__ = ["StatefulProcessorApiClient", "StatefulProcessorHandleState"] | ||
|
|
||
| if has_numpy: | ||
| import numpy as np | ||
|
|
||
| def _normalize_value(v: Any) -> Any: | ||
| # Convert NumPy types to Python primitive types. | ||
| if isinstance(v, np.generic): | ||
| return v.tolist() | ||
| # Named tuples (collections.namedtuple or typing.NamedTuple) and Row both | ||
| # require positional arguments and cannot be instantiated | ||
| # with a generator expression. | ||
| if isinstance(v, Row) or (isinstance(v, tuple) and hasattr(v, "_fields")): | ||
| return type(v)(*[_normalize_value(e) for e in v]) | ||
| # List / tuple: recursively normalize each element | ||
| if isinstance(v, (list, tuple)): | ||
| return type(v)(_normalize_value(e) for e in v) | ||
| # Dict: normalize both keys and values | ||
| if isinstance(v, dict): | ||
| return {_normalize_value(k): _normalize_value(val) for k, val in v.items()} | ||
| # Address a couple of pandas dtypes too. | ||
| elif hasattr(v, "to_pytimedelta"): | ||
| return v.to_pytimedelta() | ||
| elif hasattr(v, "to_pydatetime"): | ||
| return v.to_pydatetime() | ||
| return v | ||
|
|
||
| def _normalize_value_simple(v: Any) -> Any: | ||
| """Fast path for flat schemas: check for numpy scalars and pandas dtypes only.""" | ||
| if isinstance(v, np.generic): | ||
| return v.tolist() | ||
| if hasattr(v, "to_pytimedelta"): | ||
| return v.to_pytimedelta() | ||
| if hasattr(v, "to_pydatetime"): | ||
| return v.to_pydatetime() | ||
| return v | ||
|
|
||
| def _normalize_tuple(data: Tuple) -> Tuple: | ||
| return tuple(_normalize_value(v) for v in data) | ||
|
|
||
| def _normalize_tuple_simple(data: Tuple) -> Tuple: | ||
| return tuple(_normalize_value_simple(v) for v in data) | ||
| else: | ||
| def _normalize_tuple(data: Tuple) -> Tuple: | ||
| return data # toInternal handles tuples natively | ||
|
|
||
| _normalize_tuple_simple = _normalize_tuple | ||
|
|
||
|
|
||
| def _is_simple_schema(schema: StructType) -> bool: | ||
| """True if every field is an atomic type (no nested structs, arrays, or maps).""" | ||
| return all(isinstance(f.dataType, AtomicType) for f in schema.fields) | ||
|
|
||
|
|
||
| class StatefulProcessorHandleState(Enum): | ||
| PRE_INIT = 0 | ||
|
|
@@ -81,6 +134,10 @@ def __init__( | |
| self.list_timer_iterator_cursors: Dict[str, Tuple[Any, int, bool]] = {} | ||
| self.expiry_timer_iterator_cursors: Dict[str, Tuple[Any, int, bool]] = {} | ||
|
|
||
| # Cache of schema-id -> fast-serialize callable, so we avoid | ||
| # repeated attribute lookups on every _serialize_to_bytes call. | ||
| self._serializer_cache: Dict[int, Any] = {} | ||
|
|
||
| # statefulProcessorApiClient is initialized per batch per partition, | ||
| # so we will have new timestamps for a new batch | ||
| self._batch_timestamp = -1 | ||
|
|
@@ -487,43 +544,24 @@ def _receive_proto_message_with_timers(self) -> Tuple[int, str, Any, bool]: | |
| def _receive_str(self) -> str: | ||
| return self.utf8_deserializer.loads(self.sockfile) | ||
|
|
||
| def _serialize_to_bytes(self, schema: StructType, data: Tuple) -> bytes: | ||
| from pyspark.testing.utils import have_numpy | ||
|
|
||
| if have_numpy: | ||
| import numpy as np | ||
|
|
||
| def normalize_value(v: Any) -> Any: | ||
| # Convert NumPy types to Python primitive types. | ||
| if isinstance(v, np.generic): | ||
| return v.tolist() | ||
| # Named tuples (collections.namedtuple or typing.NamedTuple) and Row both | ||
| # require positional arguments and cannot be instantiated | ||
| # with a generator expression. | ||
| if isinstance(v, Row) or (isinstance(v, tuple) and hasattr(v, "_fields")): | ||
| return type(v)(*[normalize_value(e) for e in v]) | ||
| # List / tuple: recursively normalize each element | ||
| if isinstance(v, (list, tuple)): | ||
| return type(v)(normalize_value(e) for e in v) | ||
| # Dict: normalize both keys and values | ||
| if isinstance(v, dict): | ||
| return {normalize_value(k): normalize_value(val) for k, val in v.items()} | ||
| # Address a couple of pandas dtypes too. | ||
| elif hasattr(v, "to_pytimedelta"): | ||
| return v.to_pytimedelta() | ||
| elif hasattr(v, "to_pydatetime"): | ||
| return v.to_pydatetime() | ||
| else: | ||
| return v | ||
|
|
||
| converted = [normalize_value(v) for v in data] | ||
| else: | ||
| converted = list(data) | ||
| def _get_serializer(self, schema: StructType) -> Any: | ||
| schema_id = id(schema) | ||
| serializer = self._serializer_cache.get(schema_id) | ||
| if serializer is not None: | ||
| return serializer | ||
|
|
||
| field_names = [f.name for f in schema.fields] | ||
| row_value = Row(**dict(zip(field_names, converted))) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Copying from the PR description for reference on why this field_name/Row creation is no longer necessary:
|
||
| to_internal = schema.toInternal | ||
| dumps = self.pickleSer.dumps | ||
| normalize = _normalize_tuple_simple if _is_simple_schema(schema) else _normalize_tuple | ||
|
|
||
| return self.pickleSer.dumps(schema.toInternal(row_value)) | ||
| def _fast_serialize(data: Tuple) -> bytes: | ||
| return dumps(to_internal(normalize(data))) | ||
|
|
||
| self._serializer_cache[schema_id] = _fast_serialize | ||
| return _fast_serialize | ||
|
|
||
| def _serialize_to_bytes(self, schema: StructType, data: Tuple) -> bytes: | ||
| return self._get_serializer(schema)(data) | ||
|
|
||
| def _deserialize_from_bytes(self, value: bytes) -> Any: | ||
| return self.pickleSer.loads(value) | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[edit: this prototype is now in the diff, latter discussion is still relevant though]
Another fastpath option being considered: prevalidate the schema, if it's all non-nested we can skip some checks. Prototype outlined below but with more analysis/tradeoffs in the follow up comment:
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This gets a bit tricky with the last two timestamp checks though:
Schema.toInternalwould properly handle them today without the explicit conversions in the normal path. However removing these checks does introduce an expectation that pandas types will always subclass the datetime classes (this is indeed the current expectation though).Supporting this backwards compatibility increases cost from
1 isinstancecheck to1 isinstance + 2 hasattrchecks though.For point 2, I ran a microbenchmark with [timestamp (datetime) + long (numeric)] struct on each of the normalize function variants. Averaging over 2 million iterations on an input
(datetime(2024, 1, 1), 42):The broader speedup on a full TWS workload will of course vary depending on schemas and row values used though.