Skip to content

Commit 418cfd1

Browse files
committed
[SPARK-51690][SS] Change the protocol of ListState.put()/get()/appendList() from Arrow to simple custom protocol
### What changes were proposed in this pull request? This PR proposes to get rid of usage for Arrow on sending multiple elements of ListState and replace it with simple custom protocol. The custom protocol we are proposing is super simple and widely used already. 1. Write the size of the element (in bytes), if there is no more element, write -1 2. Write the element (as bytes) 3. Go back to 1 Note that this PR only makes change to ListState - we are aware that there are more usages of Arrow in other state types or other functionality (timer). We want to improve over time via benchmarking and addressing if it shows the latency implication. ### Why are the changes needed? For small number of elements, Arrow does not perform very well compared to the custom protocol. In the benchmark, we have three elements to exchange between Python worker and JVM, and replacing Arrow with custom protocol could cut the elapsed time on state interaction by 1/3. Given the natural performance diff between Scala version of transformWithState and PySpark version of transformWithStateInPandas, I think users must use the Scala version to handle noticeable volume of workloads. We can position PySpark version to aim for more lightweight workloads - we can revisit if we see the opposite demands. ### Does this PR introduce _any_ user-facing change? No, it's an internal change. ### How was this patch tested? Existing UT, with modification about mock expectation. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #50488 from HeartSaVioR/SPARK-51690. Authored-by: Jungtaek Lim <kabhwan.opensource@gmail.com> Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
1 parent 75d80c7 commit 418cfd1

File tree

5 files changed

+105
-48
lines changed

5 files changed

+105
-48
lines changed

python/pyspark/sql/streaming/list_state_client.py

Lines changed: 15 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,13 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17-
from typing import Dict, Iterator, List, Union, Tuple
17+
from typing import Any, Dict, Iterator, List, Union, Tuple
1818

1919
from pyspark.sql.streaming.stateful_processor_api_client import StatefulProcessorApiClient
20-
from pyspark.sql.types import StructType, TYPE_CHECKING
20+
from pyspark.sql.types import StructType
2121
from pyspark.errors import PySparkRuntimeError
2222
import uuid
2323

24-
if TYPE_CHECKING:
25-
from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike
26-
2724
__all__ = ["ListStateClient"]
2825

2926

@@ -38,9 +35,9 @@ def __init__(
3835
self.schema = self._stateful_processor_api_client._parse_string_schema(schema)
3936
else:
4037
self.schema = schema
41-
# A dictionary to store the mapping between list state name and a tuple of pandas DataFrame
38+
# A dictionary to store the mapping between list state name and a tuple of data batch
4239
# and the index of the last row that was read.
43-
self.pandas_df_dict: Dict[str, Tuple["PandasDataFrameLike", int]] = {}
40+
self.data_batch_dict: Dict[str, Tuple[Any, int]] = {}
4441

4542
def exists(self, state_name: str) -> bool:
4643
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
@@ -67,9 +64,9 @@ def exists(self, state_name: str) -> bool:
6764
def get(self, state_name: str, iterator_id: str) -> Tuple:
6865
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
6966

70-
if iterator_id in self.pandas_df_dict:
67+
if iterator_id in self.data_batch_dict:
7168
# If the state is already in the dictionary, return the next row.
72-
pandas_df, index = self.pandas_df_dict[iterator_id]
69+
data_batch, index = self.data_batch_dict[iterator_id]
7370
else:
7471
# If the state is not in the dictionary, fetch the state from the server.
7572
get_call = stateMessage.ListStateGet(iteratorId=iterator_id)
@@ -85,33 +82,20 @@ def get(self, state_name: str, iterator_id: str) -> Tuple:
8582
response_message = self._stateful_processor_api_client._receive_proto_message()
8683
status = response_message[0]
8784
if status == 0:
88-
iterator = self._stateful_processor_api_client._read_arrow_state()
89-
# We need to exhaust the iterator here to make sure all the arrow batches are read,
90-
# even though there is only one batch in the iterator. Otherwise, the stream might
91-
# block further reads since it thinks there might still be some arrow batches left.
92-
# We only need to read the first batch in the iterator because it's guaranteed that
93-
# there would only be one batch sent from the JVM side.
94-
data_batch = None
95-
for batch in iterator:
96-
if data_batch is None:
97-
data_batch = batch
98-
if data_batch is None:
99-
# TODO(SPARK-49233): Classify user facing errors.
100-
raise PySparkRuntimeError("Error getting next list state row.")
101-
pandas_df = data_batch.to_pandas()
85+
data_batch = self._stateful_processor_api_client._read_list_state()
10286
index = 0
10387
else:
10488
raise StopIteration()
10589

10690
new_index = index + 1
107-
if new_index < len(pandas_df):
91+
if new_index < len(data_batch):
10892
# Update the index in the dictionary.
109-
self.pandas_df_dict[iterator_id] = (pandas_df, new_index)
93+
self.data_batch_dict[iterator_id] = (data_batch, new_index)
11094
else:
111-
# If the index is at the end of the DataFrame, remove the state from the dictionary.
112-
self.pandas_df_dict.pop(iterator_id, None)
113-
pandas_row = pandas_df.iloc[index]
114-
return tuple(pandas_row)
95+
# If the index is at the end of the data batch, remove the state from the dictionary.
96+
self.data_batch_dict.pop(iterator_id, None)
97+
row = data_batch[index]
98+
return tuple(row)
11599

116100
def append_value(self, state_name: str, value: Tuple) -> None:
117101
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
@@ -143,7 +127,7 @@ def append_list(self, state_name: str, values: List[Tuple]) -> None:
143127

144128
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
145129

146-
self._stateful_processor_api_client._send_arrow_state(self.schema, values)
130+
self._stateful_processor_api_client._send_list_state(self.schema, values)
147131
response_message = self._stateful_processor_api_client._receive_proto_message()
148132
status = response_message[0]
149133
if status != 0:
@@ -160,7 +144,7 @@ def put(self, state_name: str, values: List[Tuple]) -> None:
160144

161145
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
162146

163-
self._stateful_processor_api_client._send_arrow_state(self.schema, values)
147+
self._stateful_processor_api_client._send_list_state(self.schema, values)
164148
response_message = self._stateful_processor_api_client._receive_proto_message()
165149
status = response_message[0]
166150
if status != 0:

python/pyspark/sql/streaming/stateful_processor_api_client.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,26 @@ def _send_arrow_state(self, schema: StructType, state: List[Tuple]) -> None:
467467
def _read_arrow_state(self) -> Any:
468468
return self.serializer.load_stream(self.sockfile)
469469

470+
def _send_list_state(self, schema: StructType, state: List[Tuple]) -> None:
471+
for value in state:
472+
bytes = self._serialize_to_bytes(schema, value)
473+
length = len(bytes)
474+
write_int(length, self.sockfile)
475+
self.sockfile.write(bytes)
476+
477+
write_int(-1, self.sockfile)
478+
self.sockfile.flush()
479+
480+
def _read_list_state(self) -> List[Any]:
481+
data_array = []
482+
while True:
483+
length = read_int(self.sockfile)
484+
if length < 0:
485+
break
486+
bytes = self.sockfile.read(length)
487+
data_array.append(self._deserialize_from_bytes(bytes))
488+
return data_array
489+
470490
# Parse a string schema into a StructType schema. This method will perform an API call to
471491
# JVM side to parse the schema string.
472492
def _parse_string_schema(self, schema: str) -> StructType:

sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasDeserializer.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.arrow.vector.ipc.ArrowStreamReader
2626

2727
import org.apache.spark.internal.Logging
2828
import org.apache.spark.sql.Row
29+
import org.apache.spark.sql.api.python.PythonSQLUtils
2930
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
3031
import org.apache.spark.sql.util.ArrowUtils
3132
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
@@ -57,4 +58,24 @@ class TransformWithStateInPandasDeserializer(deserializer: ExpressionEncoder.Des
5758
reader.close(false)
5859
rows.toSeq
5960
}
61+
62+
def readListElements(stream: DataInputStream, listStateInfo: ListStateInfo): Seq[Row] = {
63+
val rows = new scala.collection.mutable.ArrayBuffer[Row]
64+
65+
var endOfLoop = false
66+
while (!endOfLoop) {
67+
val size = stream.readInt()
68+
if (size < 0) {
69+
endOfLoop = true
70+
} else {
71+
val bytes = new Array[Byte](size)
72+
stream.read(bytes, 0, size)
73+
val newRow = PythonSQLUtils.toJVMRow(bytes, listStateInfo.schema,
74+
listStateInfo.deserializer)
75+
rows.append(newRow)
76+
}
77+
}
78+
79+
rows.toSeq
80+
}
6081
}

sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServer.scala

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,7 @@ class TransformWithStateInPandasStateServer(
475475
sendResponse(2, s"state $stateName doesn't exist")
476476
}
477477
case ListStateCall.MethodCase.LISTSTATEPUT =>
478-
val rows = deserializer.readArrowBatches(inputStream)
478+
val rows = deserializer.readListElements(inputStream, listStateInfo)
479479
listStateInfo.listState.put(rows.toArray)
480480
sendResponse(0)
481481
case ListStateCall.MethodCase.LISTSTATEGET =>
@@ -487,20 +487,18 @@ class TransformWithStateInPandasStateServer(
487487
}
488488
if (!iteratorOption.get.hasNext) {
489489
sendResponse(2, s"List state $stateName doesn't contain any value.")
490-
return
491490
} else {
492491
sendResponse(0)
492+
sendIteratorForListState(iteratorOption.get)
493493
}
494-
sendIteratorAsArrowBatches(iteratorOption.get, listStateInfo.schema,
495-
arrowStreamWriterForTest) { data => listStateInfo.serializer(data)}
496494
case ListStateCall.MethodCase.APPENDVALUE =>
497495
val byteArray = message.getAppendValue.getValue.toByteArray
498496
val newRow = PythonSQLUtils.toJVMRow(byteArray, listStateInfo.schema,
499497
listStateInfo.deserializer)
500498
listStateInfo.listState.appendValue(newRow)
501499
sendResponse(0)
502500
case ListStateCall.MethodCase.APPENDLIST =>
503-
val rows = deserializer.readArrowBatches(inputStream)
501+
val rows = deserializer.readListElements(inputStream, listStateInfo)
504502
listStateInfo.listState.appendList(rows.toArray)
505503
sendResponse(0)
506504
case ListStateCall.MethodCase.CLEAR =>
@@ -511,6 +509,28 @@ class TransformWithStateInPandasStateServer(
511509
}
512510
}
513511

512+
private def sendIteratorForListState(iter: Iterator[Row]): Unit = {
513+
// Only write a single batch in each GET request. Stops writing row if rowCount reaches
514+
// the arrowTransformWithStateInPandasMaxRecordsPerBatch limit. This is to handle a case
515+
// when there are multiple state variables, user tries to access a different state variable
516+
// while the current state variable is not exhausted yet.
517+
var rowCount = 0
518+
while (iter.hasNext && rowCount < arrowTransformWithStateInPandasMaxRecordsPerBatch) {
519+
val data = iter.next()
520+
521+
// Serialize the value row as a byte array
522+
val valueBytes = PythonSQLUtils.toPyRow(data)
523+
val lenBytes = valueBytes.length
524+
525+
outputStream.writeInt(lenBytes)
526+
outputStream.write(valueBytes)
527+
528+
rowCount += 1
529+
}
530+
outputStream.writeInt(-1)
531+
outputStream.flush()
532+
}
533+
514534
private[sql] def handleMapStateRequest(message: MapStateCall): Unit = {
515535
val stateName = message.getStateName
516536
if (!mapStates.contains(stateName)) {

sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServerSuite.scala

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo
103103
listStateMap, iteratorMap, mapStateMap, keyValueIteratorMap, expiryTimerIter, listTimerMap)
104104
when(transformWithStateInPandasDeserializer.readArrowBatches(any))
105105
.thenReturn(Seq(getIntegerRow(1)))
106+
when(transformWithStateInPandasDeserializer.readListElements(any, any))
107+
.thenReturn(Seq(getIntegerRow(1)))
106108
}
107109

108110
test("set handle state") {
@@ -260,8 +262,10 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo
260262
.setListStateGet(ListStateGet.newBuilder().setIteratorId(iteratorId).build()).build()
261263
stateServer.handleListStateRequest(message)
262264
verify(listState, times(0)).get()
263-
verify(arrowStreamWriter).writeRow(any)
264-
verify(arrowStreamWriter).finalizeCurrentArrowBatch()
265+
// 1 for row, 1 for end of the data, 1 for proto response
266+
verify(outputStream, times(3)).writeInt(any)
267+
// 1 for sending an actual row, 1 for sending proto message
268+
verify(outputStream, times(2)).write(any[Array[Byte]])
265269
}
266270

267271
test("list state get - iterator in map with multiple batches") {
@@ -278,15 +282,20 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo
278282
// First call should send 2 records.
279283
stateServer.handleListStateRequest(message)
280284
verify(listState, times(0)).get()
281-
verify(arrowStreamWriter, times(maxRecordsPerBatch)).writeRow(any)
282-
verify(arrowStreamWriter).finalizeCurrentArrowBatch()
285+
// maxRecordsPerBatch times for rows, 1 for end of the data, 1 for proto response
286+
verify(outputStream, times(maxRecordsPerBatch + 2)).writeInt(any)
287+
// maxRecordsPerBatch times for rows, 1 for sending proto message
288+
verify(outputStream, times(maxRecordsPerBatch + 1)).write(any[Array[Byte]])
283289
// Second call should send the remaining 2 records.
284290
stateServer.handleListStateRequest(message)
285291
verify(listState, times(0)).get()
286-
// Since Mockito's verify counts the total number of calls, the expected number of writeRow call
287-
// should be 2 * maxRecordsPerBatch.
288-
verify(arrowStreamWriter, times(2 * maxRecordsPerBatch)).writeRow(any)
289-
verify(arrowStreamWriter, times(2)).finalizeCurrentArrowBatch()
292+
// Since Mockito's verify counts the total number of calls, the expected number of writeInt
293+
// and write should be accumulated from the prior count; the number of calls are the same
294+
// with prior one.
295+
// maxRecordsPerBatch times for rows, 1 for end of the data, 1 for proto response
296+
verify(outputStream, times(maxRecordsPerBatch * 2 + 4)).writeInt(any)
297+
// maxRecordsPerBatch times for rows, 1 for sending proto message
298+
verify(outputStream, times(maxRecordsPerBatch * 2 + 2)).write(any[Array[Byte]])
290299
}
291300

292301
test("list state get - iterator not in map") {
@@ -302,17 +311,20 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo
302311
when(listState.get()).thenReturn(Iterator(getIntegerRow(1), getIntegerRow(2), getIntegerRow(3)))
303312
stateServer.handleListStateRequest(message)
304313
verify(listState).get()
314+
305315
// Verify that only maxRecordsPerBatch (2) rows are written to the output stream while still
306316
// having 1 row left in the iterator.
307-
verify(arrowStreamWriter, times(maxRecordsPerBatch)).writeRow(any)
308-
verify(arrowStreamWriter).finalizeCurrentArrowBatch()
317+
// maxRecordsPerBatch (2) for rows, 1 for end of the data, 1 for proto response
318+
verify(outputStream, times(maxRecordsPerBatch + 2)).writeInt(any)
319+
// 2 for rows, 1 for proto message
320+
verify(outputStream, times(maxRecordsPerBatch + 1)).write(any[Array[Byte]])
309321
}
310322

311323
test("list state put") {
312324
val message = ListStateCall.newBuilder().setStateName(stateName)
313325
.setListStatePut(ListStatePut.newBuilder().build()).build()
314326
stateServer.handleListStateRequest(message)
315-
verify(transformWithStateInPandasDeserializer).readArrowBatches(any)
327+
verify(transformWithStateInPandasDeserializer).readListElements(any, any)
316328
verify(listState).put(any)
317329
}
318330

@@ -328,7 +340,7 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo
328340
val message = ListStateCall.newBuilder().setStateName(stateName)
329341
.setAppendList(AppendList.newBuilder().build()).build()
330342
stateServer.handleListStateRequest(message)
331-
verify(transformWithStateInPandasDeserializer).readArrowBatches(any)
343+
verify(transformWithStateInPandasDeserializer).readListElements(any, any)
332344
verify(listState).appendList(any)
333345
}
334346

0 commit comments

Comments
 (0)