Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 68 additions & 2 deletions common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@

package org.apache.comet.vector

import java.nio.channels.ReadableByteChannel

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.arrow.c.{ArrowArray, ArrowImporter, ArrowSchema, CDataDictionaryProvider, Data}
import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.dictionary.DictionaryProvider
import org.apache.arrow.vector.{FieldVector, VectorSchemaRoot}
import org.apache.arrow.vector.dictionary.{Dictionary, DictionaryProvider}
import org.apache.arrow.vector.ipc.ArrowStreamReader
import org.apache.spark.SparkException
import org.apache.spark.sql.comet.util.Utils
import org.apache.spark.sql.vectorized.ColumnarBatch
Expand Down Expand Up @@ -310,4 +314,66 @@ object NativeUtil {
}
new ColumnarBatch(vectors.toArray, arrowRoot.getRowCount)
}

/**
* Reads all Arrow IPC batches from a channel and returns them as an array of independent
* ColumnarBatch objects. Each batch's vectors are transferred to new allocations so they don't
* reference the stream reader's internal state and can be safely cached and reused.
*/
def materializeBatches(channel: ReadableByteChannel): Array[ColumnarBatch] = {
val allocator = CometArrowAllocator
val allBatches = new ArrayBuffer[ColumnarBatch]()
val arrowReader = new ArrowStreamReader(channel, allocator)

// Lazily copy dictionaries on first batch load
var dictProvider: DictionaryProvider = null

while (arrowReader.loadNextBatch()) {
if (dictProvider == null) {
dictProvider = copyDictionaries(arrowReader)
}

val root = arrowReader.getVectorSchemaRoot
val numRows = root.getRowCount

// Transfer each field vector to an independent allocation.
// transfer() moves buffer ownership, leaving the originals empty for reuse
// by the next loadNextBatch() call.
val fieldVectors = new java.util.ArrayList[FieldVector]()
root.getFieldVectors.forEach { vec =>
val tp = vec.getTransferPair(allocator)
tp.transfer()
fieldVectors.add(tp.getTo.asInstanceOf[FieldVector])
}

val newRoot = new VectorSchemaRoot(fieldVectors)
newRoot.setRowCount(numRows)
allBatches += rootAsBatch(newRoot, dictProvider)
}

arrowReader.close()
allBatches.toArray
}

/**
* Copy dictionary vectors from the reader's provider so they persist after the reader is
* closed. If there are no dictionaries, returns the reader itself as a no-op provider.
*/
private def copyDictionaries(reader: ArrowStreamReader): DictionaryProvider = {
val dictIds = reader.getDictionaryIds
if (dictIds.isEmpty) {
return reader
}

val allocator = CometArrowAllocator
val copiedDicts = new java.util.ArrayList[Dictionary]()
dictIds.forEach { id =>
val dict = reader.lookup(id)
val tp = dict.getVector.getTransferPair(allocator)
tp.transfer()
copiedDicts.add(new Dictionary(tp.getTo.asInstanceOf[FieldVector], dict.getEncoding))
}
new DictionaryProvider.MapDictionaryProvider(
copiedDicts.toArray(new Array[Dictionary](0)): _*)
}
}
1 change: 1 addition & 0 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3905,6 +3905,7 @@ mod tests {
join_type: 0,
condition: None,
build_side: 0,
is_broadcast: false,
})),
};

Expand Down
1 change: 1 addition & 0 deletions native/proto/src/proto/operator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ message HashJoin {
JoinType join_type = 3;
optional spark.spark_expression.Expr condition = 4;
BuildSide build_side = 5;
bool is_broadcast = 6;
}

message SortMergeJoin {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1674,6 +1674,7 @@ trait CometHashJoin {
.addAllRightJoinKeys(rightKeys.map(_.get).asJava)
.setBuildSide(if (join.buildSide == BuildLeft) OperatorOuterClass.BuildSide.BuildLeft
else OperatorOuterClass.BuildSide.BuildRight)
.setIsBroadcast(join.isInstanceOf[BroadcastHashJoinExec])
condition.foreach(joinBuilder.setCondition)
Some(builder.setHashJoin(joinBuilder).build())
} else {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.spark.sql.benchmark

import org.apache.spark.SparkConf
import org.apache.spark.benchmark.Benchmark
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.internal.SQLConf

import org.apache.comet.{CometConf, CometSparkSessionExtensions}

/**
* Benchmark to measure Comet broadcast hash join performance. To run this benchmark:
* `SPARK_GENERATE_BENCHMARK_FILES=1 make
* benchmark-org.apache.spark.sql.benchmark.CometBroadcastHashJoinBenchmark` Results will be
* written to "spark/benchmarks/CometBroadcastHashJoinBenchmark-**results.txt".
*/
object CometBroadcastHashJoinBenchmark extends CometBenchmarkBase {
override def getSparkSession: SparkSession = {
val conf = new SparkConf()
.setAppName("CometBroadcastHashJoinBenchmark")
.set("spark.master", "local[5]")
.setIfMissing("spark.driver.memory", "3g")
.setIfMissing("spark.executor.memory", "3g")
.set("spark.executor.memoryOverhead", "10g")

val sparkSession = SparkSession.builder
.config(conf)
.withExtensions(new CometSparkSessionExtensions)
.getOrCreate()

sparkSession.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true")
sparkSession.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true")
sparkSession.conf.set(CometConf.COMET_ENABLED.key, "false")
sparkSession.conf.set(CometConf.COMET_EXEC_ENABLED.key, "false")
sparkSession.conf.set("parquet.enable.dictionary", "false")

sparkSession
}

def broadcastHashJoinBenchmark(
streamedRows: Int,
broadcastRows: Int,
joinType: String): Unit = {
val benchmark = new Benchmark(
s"Broadcast Hash Join ($joinType, stream=$streamedRows, broadcast=$broadcastRows)",
streamedRows,
output = output)

withTempPath { dir =>
import spark.implicits._

// Create streamed (large) table
val streamedDir = dir.getCanonicalPath + "/streamed"
spark
.range(streamedRows)
.select(($"id" % broadcastRows).as("key"), $"id".as("value"))
.write
.mode("overwrite")
.parquet(streamedDir)

// Create broadcast (small) table
val broadcastDir = dir.getCanonicalPath + "/broadcast"
spark
.range(broadcastRows)
.select($"id".as("key"), ($"id" * 10).as("payload"))
.write
.mode("overwrite")
.parquet(broadcastDir)

spark.read.parquet(streamedDir).createOrReplaceTempView("streamed")
spark.read.parquet(broadcastDir).createOrReplaceTempView("broadcast")

val query =
"SELECT /*+ BROADCAST(broadcast) */ s.value, b.payload " +
s"FROM streamed s $joinType JOIN broadcast b ON s.key = b.key"

withTempTable("streamed", "broadcast") {
benchmark.addCase("Spark") { _ =>
withSQLConf(
CometConf.COMET_ENABLED.key -> "false",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> s"${256 * 1024 * 1024}") {
spark.sql(query).noop()
}
}

benchmark.addCase("Comet (Scan + Exec)") { _ =>
withSQLConf(
CometConf.COMET_ENABLED.key -> "true",
CometConf.COMET_EXEC_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> s"${256 * 1024 * 1024}") {
spark.sql(query).noop()
}
}

benchmark.run()
}
}
}

override def runCometBenchmark(mainArgs: Array[String]): Unit = {
val streamedRows = 2 * 1024 * 1024
val broadcastRows = 1000

for (joinType <- Seq("INNER", "LEFT", "RIGHT")) {
broadcastHashJoinBenchmark(streamedRows, broadcastRows, joinType)
}

// Test with larger broadcast table
broadcastHashJoinBenchmark(streamedRows, 10000, "INNER")
}
}
Loading