From e0c6cd506c4c865437478574f127445b8ca9bd0d Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 13 Mar 2026 15:19:07 -0600 Subject: [PATCH 1/5] perf: optimize broadcast hash join with CollectLeft mode and decompression caching Use PartitionMode::CollectLeft instead of Partitioned for broadcast hash joins so DataFusion can optimize hash table construction for the broadcast side. Also cache decompressed broadcast data at executor level to avoid repeated LZ4 decompression across tasks. --- native/core/src/execution/planner.rs | 11 +- native/proto/src/proto/operator.proto | 1 + .../comet/CometBroadcastExchangeExec.scala | 50 ++++++- .../apache/spark/sql/comet/operators.scala | 1 + .../CometBroadcastHashJoinBenchmark.scala | 129 ++++++++++++++++++ 5 files changed, 186 insertions(+), 6 deletions(-) create mode 100644 spark/src/test/scala/org/apache/spark/sql/benchmark/CometBroadcastHashJoinBenchmark.scala diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 15bbabe883..9af260df1f 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1663,6 +1663,12 @@ impl PhysicalPlanner { let left = Arc::clone(&join_params.left.native_plan); let right = Arc::clone(&join_params.right.native_plan); + let partition_mode = if join.is_broadcast { + PartitionMode::CollectLeft + } else { + PartitionMode::Partitioned + }; + let hash_join = Arc::new(HashJoinExec::try_new( left, right, @@ -1670,7 +1676,7 @@ impl PhysicalPlanner { join_params.join_filter, &join_params.join_type, None, - PartitionMode::Partitioned, + partition_mode, // null doesn't equal to null in Spark join key. If the join key is // `EqualNullSafe`, Spark will rewrite it during planning. NullEquality::NullEqualsNothing, @@ -1688,7 +1694,7 @@ impl PhysicalPlanner { )) } else { let swapped_hash_join = - hash_join.as_ref().swap_inputs(PartitionMode::Partitioned)?; + hash_join.as_ref().swap_inputs(partition_mode)?; let mut additional_native_plans = vec![]; if swapped_hash_join.as_any().is::() { @@ -3905,6 +3911,7 @@ mod tests { join_type: 0, condition: None, build_side: 0, + is_broadcast: false, })), }; diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index 4afc1fefb7..63a2be2c1d 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -334,6 +334,7 @@ message HashJoin { JoinType join_type = 3; optional spark.spark_expression.Expr condition = 4; BuildSide build_side = 5; + bool is_broadcast = 6; } message SortMergeJoin { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala index f40e05ea0c..a9fe3c127b 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala @@ -19,19 +19,23 @@ package org.apache.spark.sql.comet +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream} +import java.nio.channels.Channels import java.util.UUID -import java.util.concurrent.{Future, TimeoutException, TimeUnit} +import java.util.concurrent.{ConcurrentHashMap, Future, TimeoutException, TimeUnit} import scala.concurrent.{ExecutionContext, Promise} import scala.concurrent.duration.NANOSECONDS import scala.util.control.NonFatal -import org.apache.spark.{broadcast, Partition, SparkContext, SparkException, TaskContext} +import org.apache.spark.{broadcast, Partition, SparkContext, SparkEnv, SparkException, TaskContext} +import org.apache.spark.io.CompressionCodec import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning} +import org.apache.spark.sql.comet.execution.arrow.ArrowReaderIterator import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, SQLExecution} @@ -311,8 +315,46 @@ class CometBatchRDD( override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { val partition = split.asInstanceOf[CometBatchPartition] - partition.value.value.toIterator - .flatMap(Utils.decodeBatches(_, this.getClass.getSimpleName)) + val broadcastId = partition.value.id + val decompressedBytes = CometBatchRDD.decompressedCache.computeIfAbsent( + broadcastId, + _ => { + val codec = CompressionCodec.createCodec(SparkEnv.get.conf) + partition.value.value.map { chunkedBuffer => + val cbbis = chunkedBuffer.toInputStream() + val ins = codec.compressedInputStream(cbbis) + val baos = new ByteArrayOutputStream() + val buf = new Array[Byte](8192) + var n = ins.read(buf) + while (n != -1) { + baos.write(buf, 0, n) + n = ins.read(buf) + } + ins.close() + baos.toByteArray + } + }) + decompressedBytes.iterator.flatMap { bytes => + new ArrowReaderIterator( + Channels.newChannel(new ByteArrayInputStream(bytes)), + this.getClass.getSimpleName) + } + } +} + +object CometBatchRDD { + + /** + * Executor-level cache of decompressed broadcast data keyed by broadcast ID. This avoids + * repeated LZ4 decompression when multiple tasks on the same executor process the same + * broadcast relation. Each entry stores decompressed Arrow IPC byte arrays. + */ + private[comet] val decompressedCache = + new ConcurrentHashMap[Long, Array[Array[Byte]]]() + + /** Invalidate cached decompressed data for a broadcast. */ + def invalidateCache(broadcastId: Long): Unit = { + decompressedCache.remove(broadcastId) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index da2ae21a95..ba70521d61 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -1674,6 +1674,7 @@ trait CometHashJoin { .addAllRightJoinKeys(rightKeys.map(_.get).asJava) .setBuildSide(if (join.buildSide == BuildLeft) OperatorOuterClass.BuildSide.BuildLeft else OperatorOuterClass.BuildSide.BuildRight) + .setIsBroadcast(join.isInstanceOf[BroadcastHashJoinExec]) condition.foreach(joinBuilder.setCondition) Some(builder.setHashJoin(joinBuilder).build()) } else { diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometBroadcastHashJoinBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometBroadcastHashJoinBenchmark.scala new file mode 100644 index 0000000000..0e5933014b --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometBroadcastHashJoinBenchmark.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.benchmark + +import org.apache.spark.SparkConf +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.internal.SQLConf + +import org.apache.comet.{CometConf, CometSparkSessionExtensions} + +/** + * Benchmark to measure Comet broadcast hash join performance. To run this benchmark: + * `SPARK_GENERATE_BENCHMARK_FILES=1 make + * benchmark-org.apache.spark.sql.benchmark.CometBroadcastHashJoinBenchmark` Results will be + * written to "spark/benchmarks/CometBroadcastHashJoinBenchmark-**results.txt". + */ +object CometBroadcastHashJoinBenchmark extends CometBenchmarkBase { + override def getSparkSession: SparkSession = { + val conf = new SparkConf() + .setAppName("CometBroadcastHashJoinBenchmark") + .set("spark.master", "local[5]") + .setIfMissing("spark.driver.memory", "3g") + .setIfMissing("spark.executor.memory", "3g") + .set("spark.executor.memoryOverhead", "10g") + + val sparkSession = SparkSession.builder + .config(conf) + .withExtensions(new CometSparkSessionExtensions) + .getOrCreate() + + sparkSession.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") + sparkSession.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") + sparkSession.conf.set(CometConf.COMET_ENABLED.key, "false") + sparkSession.conf.set(CometConf.COMET_EXEC_ENABLED.key, "false") + sparkSession.conf.set("parquet.enable.dictionary", "false") + + sparkSession + } + + def broadcastHashJoinBenchmark( + streamedRows: Int, + broadcastRows: Int, + joinType: String): Unit = { + val benchmark = new Benchmark( + s"Broadcast Hash Join ($joinType, stream=$streamedRows, broadcast=$broadcastRows)", + streamedRows, + output = output) + + withTempPath { dir => + import spark.implicits._ + + // Create streamed (large) table + val streamedDir = dir.getCanonicalPath + "/streamed" + spark + .range(streamedRows) + .select(($"id" % broadcastRows).as("key"), $"id".as("value")) + .write + .mode("overwrite") + .parquet(streamedDir) + + // Create broadcast (small) table + val broadcastDir = dir.getCanonicalPath + "/broadcast" + spark + .range(broadcastRows) + .select($"id".as("key"), ($"id" * 10).as("payload")) + .write + .mode("overwrite") + .parquet(broadcastDir) + + spark.read.parquet(streamedDir).createOrReplaceTempView("streamed") + spark.read.parquet(broadcastDir).createOrReplaceTempView("broadcast") + + val query = + s"SELECT /*+ BROADCAST(broadcast) */ s.value, b.payload " + + s"FROM streamed s $joinType JOIN broadcast b ON s.key = b.key" + + withTempTable("streamed", "broadcast") { + benchmark.addCase("Spark") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> s"${256 * 1024 * 1024}") { + spark.sql(query).noop() + } + } + + benchmark.addCase("Comet (Scan + Exec)") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> s"${256 * 1024 * 1024}") { + spark.sql(query).noop() + } + } + + benchmark.run() + } + } + } + + override def runCometBenchmark(mainArgs: Array[String]): Unit = { + val streamedRows = 2 * 1024 * 1024 + val broadcastRows = 1000 + + for (joinType <- Seq("INNER", "LEFT", "RIGHT")) { + broadcastHashJoinBenchmark(streamedRows, broadcastRows, joinType) + } + + // Test with larger broadcast table + broadcastHashJoinBenchmark(streamedRows, 10000, "INNER") + } +} From de889ba1a36e910d3d6578a122e62c953ecd3986 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 13 Mar 2026 15:38:11 -0600 Subject: [PATCH 2/5] style: apply formatting --- native/core/src/execution/planner.rs | 3 +-- .../apache/spark/sql/comet/CometBroadcastExchangeExec.scala | 3 +-- .../spark/sql/benchmark/CometBroadcastHashJoinBenchmark.scala | 2 +- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 9af260df1f..bc4aeb68de 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1693,8 +1693,7 @@ impl PhysicalPlanner { )), )) } else { - let swapped_hash_join = - hash_join.as_ref().swap_inputs(partition_mode)?; + let swapped_hash_join = hash_join.as_ref().swap_inputs(partition_mode)?; let mut additional_native_plans = vec![]; if swapped_hash_join.as_any().is::() { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala index a9fe3c127b..0fc462a1b4 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.comet -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import java.nio.channels.Channels import java.util.UUID import java.util.concurrent.{ConcurrentHashMap, Future, TimeoutException, TimeUnit} @@ -36,7 +36,6 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning} import org.apache.spark.sql.comet.execution.arrow.ArrowReaderIterator -import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, SQLExecution} import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, ShuffleQueryStageExec} diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometBroadcastHashJoinBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometBroadcastHashJoinBenchmark.scala index 0e5933014b..339e3e785a 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometBroadcastHashJoinBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometBroadcastHashJoinBenchmark.scala @@ -89,7 +89,7 @@ object CometBroadcastHashJoinBenchmark extends CometBenchmarkBase { spark.read.parquet(broadcastDir).createOrReplaceTempView("broadcast") val query = - s"SELECT /*+ BROADCAST(broadcast) */ s.value, b.payload " + + "SELECT /*+ BROADCAST(broadcast) */ s.value, b.payload " + s"FROM streamed s $joinType JOIN broadcast b ON s.key = b.key" withTempTable("streamed", "broadcast") { From 072bb293b52f70bcc023e4f06be2051d0f5a1931 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 13 Mar 2026 17:43:00 -0600 Subject: [PATCH 3/5] revert: remove CollectLeft partition mode for broadcast hash joins Reverting the CollectLeft change as it causes multiple test failures in CI including ArrayIndexOutOfBoundsException in NativeUtil.exportBatch and assertion failures in native code. --- native/core/src/execution/planner.rs | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index bc4aeb68de..5fd913bf93 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1663,12 +1663,6 @@ impl PhysicalPlanner { let left = Arc::clone(&join_params.left.native_plan); let right = Arc::clone(&join_params.right.native_plan); - let partition_mode = if join.is_broadcast { - PartitionMode::CollectLeft - } else { - PartitionMode::Partitioned - }; - let hash_join = Arc::new(HashJoinExec::try_new( left, right, @@ -1676,7 +1670,7 @@ impl PhysicalPlanner { join_params.join_filter, &join_params.join_type, None, - partition_mode, + PartitionMode::Partitioned, // null doesn't equal to null in Spark join key. If the join key is // `EqualNullSafe`, Spark will rewrite it during planning. NullEquality::NullEqualsNothing, @@ -1693,7 +1687,8 @@ impl PhysicalPlanner { )), )) } else { - let swapped_hash_join = hash_join.as_ref().swap_inputs(partition_mode)?; + let swapped_hash_join = + hash_join.as_ref().swap_inputs(PartitionMode::Partitioned)?; let mut additional_native_plans = vec![]; if swapped_hash_join.as_any().is::() { From d9df44cbefa0fefc8aaf0aaa57a3048be4ff94a6 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 13 Mar 2026 18:36:18 -0600 Subject: [PATCH 4/5] perf: cache deserialized Arrow batches for broadcast hash join Instead of caching only decompressed bytes (which still required Arrow IPC deserialization per task), cache the fully materialized ColumnarBatch objects with transferred Arrow vectors. This avoids both LZ4 decompression and Arrow IPC parsing on subsequent task accesses to the same broadcast relation. Vector data is transferred to independent allocations via Arrow's TransferPair so cached batches don't reference stream reader state and can be safely reused across tasks. --- .../org/apache/comet/vector/NativeUtil.scala | 70 ++++++++++++++++++- .../comet/CometBroadcastExchangeExec.scala | 47 ++++++------- 2 files changed, 89 insertions(+), 28 deletions(-) diff --git a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala index 45245121a0..3c8d2caaa0 100644 --- a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala +++ b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala @@ -19,11 +19,15 @@ package org.apache.comet.vector +import java.nio.channels.ReadableByteChannel + import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import org.apache.arrow.c.{ArrowArray, ArrowImporter, ArrowSchema, CDataDictionaryProvider, Data} -import org.apache.arrow.vector.VectorSchemaRoot -import org.apache.arrow.vector.dictionary.DictionaryProvider +import org.apache.arrow.vector.{FieldVector, VectorSchemaRoot} +import org.apache.arrow.vector.dictionary.{Dictionary, DictionaryProvider} +import org.apache.arrow.vector.ipc.ArrowStreamReader import org.apache.spark.SparkException import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.vectorized.ColumnarBatch @@ -310,4 +314,66 @@ object NativeUtil { } new ColumnarBatch(vectors.toArray, arrowRoot.getRowCount) } + + /** + * Reads all Arrow IPC batches from a channel and returns them as an array of independent + * ColumnarBatch objects. Each batch's vectors are transferred to new allocations so they don't + * reference the stream reader's internal state and can be safely cached and reused. + */ + def materializeBatches(channel: ReadableByteChannel): Array[ColumnarBatch] = { + val allocator = CometArrowAllocator + val allBatches = new ArrayBuffer[ColumnarBatch]() + val arrowReader = new ArrowStreamReader(channel, allocator) + + // Lazily copy dictionaries on first batch load + var dictProvider: DictionaryProvider = null + + while (arrowReader.loadNextBatch()) { + if (dictProvider == null) { + dictProvider = copyDictionaries(arrowReader) + } + + val root = arrowReader.getVectorSchemaRoot + val numRows = root.getRowCount + + // Transfer each field vector to an independent allocation. + // transfer() moves buffer ownership, leaving the originals empty for reuse + // by the next loadNextBatch() call. + val fieldVectors = new java.util.ArrayList[FieldVector]() + root.getFieldVectors.forEach { vec => + val tp = vec.getTransferPair(allocator) + tp.transfer() + fieldVectors.add(tp.getTo.asInstanceOf[FieldVector]) + } + + val newRoot = new VectorSchemaRoot(fieldVectors) + newRoot.setRowCount(numRows) + allBatches += rootAsBatch(newRoot, dictProvider) + } + + arrowReader.close() + allBatches.toArray + } + + /** + * Copy dictionary vectors from the reader's provider so they persist after the reader is + * closed. If there are no dictionaries, returns the reader itself as a no-op provider. + */ + private def copyDictionaries(reader: ArrowStreamReader): DictionaryProvider = { + val dictIds = reader.getDictionaryIds + if (dictIds.isEmpty) { + return reader + } + + val allocator = CometArrowAllocator + val copiedDicts = new java.util.ArrayList[Dictionary]() + dictIds.forEach { id => + val dict = reader.lookup(id) + val tp = dict.getVector.getTransferPair(allocator) + tp.transfer() + copiedDicts.add(new Dictionary(tp.getTo.asInstanceOf[FieldVector], dict.getEncoding)) + } + new DictionaryProvider.MapDictionaryProvider( + copiedDicts.toArray(new Array[Dictionary](0)): _*) + } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala index 0fc462a1b4..ed3e8079fd 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.comet -import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import java.nio.channels.Channels import java.util.UUID import java.util.concurrent.{ConcurrentHashMap, Future, TimeoutException, TimeUnit} @@ -35,7 +34,6 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning} -import org.apache.spark.sql.comet.execution.arrow.ArrowReaderIterator import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, SQLExecution} import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, ShuffleQueryStageExec} @@ -52,6 +50,7 @@ import org.apache.comet.{CometConf, CometRuntimeException, ConfigEntry} import org.apache.comet.serde.OperatorOuterClass import org.apache.comet.serde.operator.CometSink import org.apache.comet.shims.ShimCometBroadcastExchangeExec +import org.apache.comet.vector.NativeUtil /** * A [[CometBroadcastExchangeExec]] collects, transforms and finally broadcasts the result of a @@ -315,45 +314,41 @@ class CometBatchRDD( override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { val partition = split.asInstanceOf[CometBatchPartition] val broadcastId = partition.value.id - val decompressedBytes = CometBatchRDD.decompressedCache.computeIfAbsent( + val cachedBatches = CometBatchRDD.batchCache.computeIfAbsent( broadcastId, _ => { val codec = CompressionCodec.createCodec(SparkEnv.get.conf) - partition.value.value.map { chunkedBuffer => - val cbbis = chunkedBuffer.toInputStream() - val ins = codec.compressedInputStream(cbbis) - val baos = new ByteArrayOutputStream() - val buf = new Array[Byte](8192) - var n = ins.read(buf) - while (n != -1) { - baos.write(buf, 0, n) - n = ins.read(buf) + partition.value.value.flatMap { chunkedBuffer => + if (chunkedBuffer.size > 0) { + val ins = codec.compressedInputStream(chunkedBuffer.toInputStream()) + NativeUtil.materializeBatches(Channels.newChannel(ins)) + } else { + Array.empty[ColumnarBatch] } - ins.close() - baos.toByteArray } }) - decompressedBytes.iterator.flatMap { bytes => - new ArrowReaderIterator( - Channels.newChannel(new ByteArrayInputStream(bytes)), - this.getClass.getSimpleName) - } + cachedBatches.iterator } } object CometBatchRDD { /** - * Executor-level cache of decompressed broadcast data keyed by broadcast ID. This avoids - * repeated LZ4 decompression when multiple tasks on the same executor process the same - * broadcast relation. Each entry stores decompressed Arrow IPC byte arrays. + * Executor-level cache of fully deserialized broadcast batches keyed by broadcast ID. This + * avoids repeated LZ4 decompression and Arrow IPC deserialization when multiple tasks on the + * same executor process the same broadcast relation. The cached batches contain Arrow vectors + * that are independent of the original stream readers, so they can be safely reused across + * tasks. Native code copies the data on each access via Arrow FFI export. */ - private[comet] val decompressedCache = - new ConcurrentHashMap[Long, Array[Array[Byte]]]() + private[comet] val batchCache = + new ConcurrentHashMap[Long, Array[ColumnarBatch]]() - /** Invalidate cached decompressed data for a broadcast. */ + /** Invalidate cached batch data for a broadcast, freeing Arrow memory. */ def invalidateCache(broadcastId: Long): Unit = { - decompressedCache.remove(broadcastId) + val batches = batchCache.remove(broadcastId) + if (batches != null) { + batches.foreach(_.close()) + } } } From b9dc9606e39b7cd239684fe28e368044a8fdd466 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 13 Mar 2026 20:46:35 -0600 Subject: [PATCH 5/5] revert: remove broadcast batch caching to fix CI failures The Arrow batch caching caused test failures in CI due to unsafe reuse of exported Arrow vectors via FFI. Reverting to the original Utils.decodeBatches approach which creates fresh batches per task. --- .../comet/CometBroadcastExchangeExec.scala | 46 ++----------------- 1 file changed, 5 insertions(+), 41 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala index ed3e8079fd..f40e05ea0c 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala @@ -19,21 +19,20 @@ package org.apache.spark.sql.comet -import java.nio.channels.Channels import java.util.UUID -import java.util.concurrent.{ConcurrentHashMap, Future, TimeoutException, TimeUnit} +import java.util.concurrent.{Future, TimeoutException, TimeUnit} import scala.concurrent.{ExecutionContext, Promise} import scala.concurrent.duration.NANOSECONDS import scala.util.control.NonFatal -import org.apache.spark.{broadcast, Partition, SparkContext, SparkEnv, SparkException, TaskContext} -import org.apache.spark.io.CompressionCodec +import org.apache.spark.{broadcast, Partition, SparkContext, SparkException, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning} +import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, SQLExecution} import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, ShuffleQueryStageExec} @@ -50,7 +49,6 @@ import org.apache.comet.{CometConf, CometRuntimeException, ConfigEntry} import org.apache.comet.serde.OperatorOuterClass import org.apache.comet.serde.operator.CometSink import org.apache.comet.shims.ShimCometBroadcastExchangeExec -import org.apache.comet.vector.NativeUtil /** * A [[CometBroadcastExchangeExec]] collects, transforms and finally broadcasts the result of a @@ -313,42 +311,8 @@ class CometBatchRDD( override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { val partition = split.asInstanceOf[CometBatchPartition] - val broadcastId = partition.value.id - val cachedBatches = CometBatchRDD.batchCache.computeIfAbsent( - broadcastId, - _ => { - val codec = CompressionCodec.createCodec(SparkEnv.get.conf) - partition.value.value.flatMap { chunkedBuffer => - if (chunkedBuffer.size > 0) { - val ins = codec.compressedInputStream(chunkedBuffer.toInputStream()) - NativeUtil.materializeBatches(Channels.newChannel(ins)) - } else { - Array.empty[ColumnarBatch] - } - } - }) - cachedBatches.iterator - } -} - -object CometBatchRDD { - - /** - * Executor-level cache of fully deserialized broadcast batches keyed by broadcast ID. This - * avoids repeated LZ4 decompression and Arrow IPC deserialization when multiple tasks on the - * same executor process the same broadcast relation. The cached batches contain Arrow vectors - * that are independent of the original stream readers, so they can be safely reused across - * tasks. Native code copies the data on each access via Arrow FFI export. - */ - private[comet] val batchCache = - new ConcurrentHashMap[Long, Array[ColumnarBatch]]() - - /** Invalidate cached batch data for a broadcast, freeing Arrow memory. */ - def invalidateCache(broadcastId: Long): Unit = { - val batches = batchCache.remove(broadcastId) - if (batches != null) { - batches.foreach(_.close()) - } + partition.value.value.toIterator + .flatMap(Utils.decodeBatches(_, this.getClass.getSimpleName)) } }