diff --git a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala index 45245121a0..3c8d2caaa0 100644 --- a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala +++ b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala @@ -19,11 +19,15 @@ package org.apache.comet.vector +import java.nio.channels.ReadableByteChannel + import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import org.apache.arrow.c.{ArrowArray, ArrowImporter, ArrowSchema, CDataDictionaryProvider, Data} -import org.apache.arrow.vector.VectorSchemaRoot -import org.apache.arrow.vector.dictionary.DictionaryProvider +import org.apache.arrow.vector.{FieldVector, VectorSchemaRoot} +import org.apache.arrow.vector.dictionary.{Dictionary, DictionaryProvider} +import org.apache.arrow.vector.ipc.ArrowStreamReader import org.apache.spark.SparkException import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.vectorized.ColumnarBatch @@ -310,4 +314,66 @@ object NativeUtil { } new ColumnarBatch(vectors.toArray, arrowRoot.getRowCount) } + + /** + * Reads all Arrow IPC batches from a channel and returns them as an array of independent + * ColumnarBatch objects. Each batch's vectors are transferred to new allocations so they don't + * reference the stream reader's internal state and can be safely cached and reused. + */ + def materializeBatches(channel: ReadableByteChannel): Array[ColumnarBatch] = { + val allocator = CometArrowAllocator + val allBatches = new ArrayBuffer[ColumnarBatch]() + val arrowReader = new ArrowStreamReader(channel, allocator) + + // Lazily copy dictionaries on first batch load + var dictProvider: DictionaryProvider = null + + while (arrowReader.loadNextBatch()) { + if (dictProvider == null) { + dictProvider = copyDictionaries(arrowReader) + } + + val root = arrowReader.getVectorSchemaRoot + val numRows = root.getRowCount + + // Transfer each field vector to an independent allocation. + // transfer() moves buffer ownership, leaving the originals empty for reuse + // by the next loadNextBatch() call. + val fieldVectors = new java.util.ArrayList[FieldVector]() + root.getFieldVectors.forEach { vec => + val tp = vec.getTransferPair(allocator) + tp.transfer() + fieldVectors.add(tp.getTo.asInstanceOf[FieldVector]) + } + + val newRoot = new VectorSchemaRoot(fieldVectors) + newRoot.setRowCount(numRows) + allBatches += rootAsBatch(newRoot, dictProvider) + } + + arrowReader.close() + allBatches.toArray + } + + /** + * Copy dictionary vectors from the reader's provider so they persist after the reader is + * closed. If there are no dictionaries, returns the reader itself as a no-op provider. + */ + private def copyDictionaries(reader: ArrowStreamReader): DictionaryProvider = { + val dictIds = reader.getDictionaryIds + if (dictIds.isEmpty) { + return reader + } + + val allocator = CometArrowAllocator + val copiedDicts = new java.util.ArrayList[Dictionary]() + dictIds.forEach { id => + val dict = reader.lookup(id) + val tp = dict.getVector.getTransferPair(allocator) + tp.transfer() + copiedDicts.add(new Dictionary(tp.getTo.asInstanceOf[FieldVector], dict.getEncoding)) + } + new DictionaryProvider.MapDictionaryProvider( + copiedDicts.toArray(new Array[Dictionary](0)): _*) + } } diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 15bbabe883..5fd913bf93 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -3905,6 +3905,7 @@ mod tests { join_type: 0, condition: None, build_side: 0, + is_broadcast: false, })), }; diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index 4afc1fefb7..63a2be2c1d 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -334,6 +334,7 @@ message HashJoin { JoinType join_type = 3; optional spark.spark_expression.Expr condition = 4; BuildSide build_side = 5; + bool is_broadcast = 6; } message SortMergeJoin { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index da2ae21a95..ba70521d61 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -1674,6 +1674,7 @@ trait CometHashJoin { .addAllRightJoinKeys(rightKeys.map(_.get).asJava) .setBuildSide(if (join.buildSide == BuildLeft) OperatorOuterClass.BuildSide.BuildLeft else OperatorOuterClass.BuildSide.BuildRight) + .setIsBroadcast(join.isInstanceOf[BroadcastHashJoinExec]) condition.foreach(joinBuilder.setCondition) Some(builder.setHashJoin(joinBuilder).build()) } else { diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometBroadcastHashJoinBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometBroadcastHashJoinBenchmark.scala new file mode 100644 index 0000000000..339e3e785a --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometBroadcastHashJoinBenchmark.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.benchmark + +import org.apache.spark.SparkConf +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.internal.SQLConf + +import org.apache.comet.{CometConf, CometSparkSessionExtensions} + +/** + * Benchmark to measure Comet broadcast hash join performance. To run this benchmark: + * `SPARK_GENERATE_BENCHMARK_FILES=1 make + * benchmark-org.apache.spark.sql.benchmark.CometBroadcastHashJoinBenchmark` Results will be + * written to "spark/benchmarks/CometBroadcastHashJoinBenchmark-**results.txt". + */ +object CometBroadcastHashJoinBenchmark extends CometBenchmarkBase { + override def getSparkSession: SparkSession = { + val conf = new SparkConf() + .setAppName("CometBroadcastHashJoinBenchmark") + .set("spark.master", "local[5]") + .setIfMissing("spark.driver.memory", "3g") + .setIfMissing("spark.executor.memory", "3g") + .set("spark.executor.memoryOverhead", "10g") + + val sparkSession = SparkSession.builder + .config(conf) + .withExtensions(new CometSparkSessionExtensions) + .getOrCreate() + + sparkSession.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") + sparkSession.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") + sparkSession.conf.set(CometConf.COMET_ENABLED.key, "false") + sparkSession.conf.set(CometConf.COMET_EXEC_ENABLED.key, "false") + sparkSession.conf.set("parquet.enable.dictionary", "false") + + sparkSession + } + + def broadcastHashJoinBenchmark( + streamedRows: Int, + broadcastRows: Int, + joinType: String): Unit = { + val benchmark = new Benchmark( + s"Broadcast Hash Join ($joinType, stream=$streamedRows, broadcast=$broadcastRows)", + streamedRows, + output = output) + + withTempPath { dir => + import spark.implicits._ + + // Create streamed (large) table + val streamedDir = dir.getCanonicalPath + "/streamed" + spark + .range(streamedRows) + .select(($"id" % broadcastRows).as("key"), $"id".as("value")) + .write + .mode("overwrite") + .parquet(streamedDir) + + // Create broadcast (small) table + val broadcastDir = dir.getCanonicalPath + "/broadcast" + spark + .range(broadcastRows) + .select($"id".as("key"), ($"id" * 10).as("payload")) + .write + .mode("overwrite") + .parquet(broadcastDir) + + spark.read.parquet(streamedDir).createOrReplaceTempView("streamed") + spark.read.parquet(broadcastDir).createOrReplaceTempView("broadcast") + + val query = + "SELECT /*+ BROADCAST(broadcast) */ s.value, b.payload " + + s"FROM streamed s $joinType JOIN broadcast b ON s.key = b.key" + + withTempTable("streamed", "broadcast") { + benchmark.addCase("Spark") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> s"${256 * 1024 * 1024}") { + spark.sql(query).noop() + } + } + + benchmark.addCase("Comet (Scan + Exec)") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> s"${256 * 1024 * 1024}") { + spark.sql(query).noop() + } + } + + benchmark.run() + } + } + } + + override def runCometBenchmark(mainArgs: Array[String]): Unit = { + val streamedRows = 2 * 1024 * 1024 + val broadcastRows = 1000 + + for (joinType <- Seq("INNER", "LEFT", "RIGHT")) { + broadcastHashJoinBenchmark(streamedRows, broadcastRows, joinType) + } + + // Test with larger broadcast table + broadcastHashJoinBenchmark(streamedRows, 10000, "INNER") + } +}