spark EnsureRequirements 源码
spark EnsureRequirements 代码
文件路径:/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.exchange
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.internal.SQLConf
/**
* Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]]
* of input data meets the
* [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for
* each operator by inserting [[ShuffleExchangeExec]] Operators where required. Also ensure that
* the input partition ordering requirements are met.
*
* @param optimizeOutRepartition A flag to indicate that if this rule should optimize out
* user-specified repartition shuffles or not. This is mostly true,
* but can be false in AQE when AQE optimization may change the plan
* output partitioning and need to retain the user-specified
* repartition shuffles in the plan.
* @param requiredDistribution The root required distribution we should ensure. This value is used
* in AQE in case we change final stage output partitioning.
*/
case class EnsureRequirements(
optimizeOutRepartition: Boolean = true,
requiredDistribution: Option[Distribution] = None)
extends Rule[SparkPlan] {
private def ensureDistributionAndOrdering(
originalChildren: Seq[SparkPlan],
requiredChildDistributions: Seq[Distribution],
requiredChildOrderings: Seq[Seq[SortOrder]],
shuffleOrigin: ShuffleOrigin): Seq[SparkPlan] = {
assert(requiredChildDistributions.length == originalChildren.length)
assert(requiredChildOrderings.length == originalChildren.length)
// Ensure that the operator's children satisfy their output distribution requirements.
var children = originalChildren.zip(requiredChildDistributions).map {
case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
child
case (child, BroadcastDistribution(mode)) =>
BroadcastExchangeExec(mode, child)
case (child, distribution) =>
val numPartitions = distribution.requiredNumPartitions
.getOrElse(conf.numShufflePartitions)
ShuffleExchangeExec(distribution.createPartitioning(numPartitions), child, shuffleOrigin)
}
// Get the indexes of children which have specified distribution requirements and need to be
// co-partitioned.
val childrenIndexes = requiredChildDistributions.zipWithIndex.filter {
case (_: ClusteredDistribution, _) => true
case _ => false
}.map(_._2)
// Special case: if all sides of the join are single partition
val allSinglePartition =
childrenIndexes.forall(children(_).outputPartitioning == SinglePartition)
// If there are more than one children, we'll need to check partitioning & distribution of them
// and see if extra shuffles are necessary.
if (childrenIndexes.length > 1 && !allSinglePartition) {
val specs = childrenIndexes.map(i => {
val requiredDist = requiredChildDistributions(i)
assert(requiredDist.isInstanceOf[ClusteredDistribution],
s"Expected ClusteredDistribution but found ${requiredDist.getClass.getSimpleName}")
i -> children(i).outputPartitioning.createShuffleSpec(
requiredDist.asInstanceOf[ClusteredDistribution])
}).toMap
// Find out the shuffle spec that gives better parallelism. Currently this is done by
// picking the spec with the largest number of partitions.
//
// NOTE: this is not optimal for the case when there are more than 2 children. Consider:
// (10, 10, 11)
// where the number represent the number of partitions for each child, it's better to pick 10
// here since we only need to shuffle one side - we'd need to shuffle two sides if we pick 11.
//
// However this should be sufficient for now since in Spark nodes with multiple children
// always have exactly 2 children.
// Whether we should consider `spark.sql.shuffle.partitions` and ensure enough parallelism
// during shuffle. To achieve a good trade-off between parallelism and shuffle cost, we only
// consider the minimum parallelism iff ALL children need to be re-shuffled.
//
// A child needs to be re-shuffled iff either one of below is true:
// 1. It can't create partitioning by itself, i.e., `canCreatePartitioning` returns false
// (as for the case of `RangePartitioning`), therefore it needs to be re-shuffled
// according to other shuffle spec.
// 2. It already has `ShuffleExchangeLike`, so we can re-use existing shuffle without
// introducing extra shuffle.
//
// On the other hand, in scenarios such as:
// HashPartitioning(5) <-> HashPartitioning(6)
// while `spark.sql.shuffle.partitions` is 10, we'll only re-shuffle the left side and make it
// HashPartitioning(6).
val shouldConsiderMinParallelism = specs.forall(p =>
!p._2.canCreatePartitioning || children(p._1).isInstanceOf[ShuffleExchangeLike]
)
// Choose all the specs that can be used to shuffle other children
val candidateSpecs = specs
.filter(_._2.canCreatePartitioning)
.filter(p => !shouldConsiderMinParallelism ||
children(p._1).outputPartitioning.numPartitions >= conf.defaultNumShufflePartitions)
val bestSpecOpt = if (candidateSpecs.isEmpty) {
None
} else {
// When choosing specs, we should consider those children with no `ShuffleExchangeLike` node
// first. For instance, if we have:
// A: (No_Exchange, 100) <---> B: (Exchange, 120)
// it's better to pick A and change B to (Exchange, 100) instead of picking B and insert a
// new shuffle for A.
val candidateSpecsWithoutShuffle = candidateSpecs.filter { case (k, _) =>
!children(k).isInstanceOf[ShuffleExchangeLike]
}
val finalCandidateSpecs = if (candidateSpecsWithoutShuffle.nonEmpty) {
candidateSpecsWithoutShuffle
} else {
candidateSpecs
}
// Pick the spec with the best parallelism
Some(finalCandidateSpecs.values.maxBy(_.numPartitions))
}
// Check if 1) all children are of `KeyGroupedPartitioning` and 2) they are all compatible
// with each other. If both are true, skip shuffle.
val allCompatible = childrenIndexes.sliding(2).forall {
case Seq(a, b) =>
checkKeyGroupedSpec(specs(a)) && checkKeyGroupedSpec(specs(b)) &&
specs(a).isCompatibleWith(specs(b))
}
children = children.zip(requiredChildDistributions).zipWithIndex.map {
case ((child, _), idx) if allCompatible || !childrenIndexes.contains(idx) =>
child
case ((child, dist), idx) =>
if (bestSpecOpt.isDefined && bestSpecOpt.get.isCompatibleWith(specs(idx))) {
child
} else {
val newPartitioning = bestSpecOpt.map { bestSpec =>
// Use the best spec to create a new partitioning to re-shuffle this child
val clustering = dist.asInstanceOf[ClusteredDistribution].clustering
bestSpec.createPartitioning(clustering)
}.getOrElse {
// No best spec available, so we create default partitioning from the required
// distribution
val numPartitions = dist.requiredNumPartitions
.getOrElse(conf.numShufflePartitions)
dist.createPartitioning(numPartitions)
}
child match {
case ShuffleExchangeExec(_, c, so) => ShuffleExchangeExec(newPartitioning, c, so)
case _ => ShuffleExchangeExec(newPartitioning, child)
}
}
}
}
// Now that we've performed any necessary shuffles, add sorts to guarantee output orderings:
children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) =>
// If child.outputOrdering already satisfies the requiredOrdering, we do not need to sort.
if (SortOrder.orderingSatisfies(child.outputOrdering, requiredOrdering)) {
child
} else {
SortExec(requiredOrdering, global = false, child = child)
}
}
children
}
private def checkKeyGroupedSpec(shuffleSpec: ShuffleSpec): Boolean = {
def check(spec: KeyGroupedShuffleSpec): Boolean = {
val attributes = spec.partitioning.expressions.flatMap(_.collectLeaves())
val clustering = spec.distribution.clustering
if (SQLConf.get.getConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION)) {
attributes.length == clustering.length && attributes.zip(clustering).forall {
case (l, r) => l.semanticEquals(r)
}
} else {
true // already validated in `KeyGroupedPartitioning.satisfies`
}
}
shuffleSpec match {
case spec: KeyGroupedShuffleSpec => check(spec)
case ShuffleSpecCollection(specs) => specs.exists(checkKeyGroupedSpec)
case _ => false
}
}
private def reorder(
leftKeys: IndexedSeq[Expression],
rightKeys: IndexedSeq[Expression],
expectedOrderOfKeys: Seq[Expression],
currentOrderOfKeys: Seq[Expression]): Option[(Seq[Expression], Seq[Expression])] = {
if (expectedOrderOfKeys.size != currentOrderOfKeys.size) {
return None
}
// Check if the current order already satisfies the expected order.
if (expectedOrderOfKeys.zip(currentOrderOfKeys).forall(p => p._1.semanticEquals(p._2))) {
return Some(leftKeys, rightKeys)
}
// Build a lookup between an expression and the positions its holds in the current key seq.
val keyToIndexMap = mutable.Map.empty[Expression, mutable.BitSet]
currentOrderOfKeys.zipWithIndex.foreach {
case (key, index) =>
keyToIndexMap.getOrElseUpdate(key.canonicalized, mutable.BitSet.empty).add(index)
}
// Reorder the keys.
val leftKeysBuffer = new ArrayBuffer[Expression](leftKeys.size)
val rightKeysBuffer = new ArrayBuffer[Expression](rightKeys.size)
val iterator = expectedOrderOfKeys.iterator
while (iterator.hasNext) {
// Lookup the current index of this key.
keyToIndexMap.get(iterator.next().canonicalized) match {
case Some(indices) if indices.nonEmpty =>
// Take the first available index from the map.
val index = indices.firstKey
indices.remove(index)
// Add the keys for that index to the reordered keys.
leftKeysBuffer += leftKeys(index)
rightKeysBuffer += rightKeys(index)
case _ =>
// The expression cannot be found, or we have exhausted all indices for that expression.
return None
}
}
Some(leftKeysBuffer.toSeq, rightKeysBuffer.toSeq)
}
private def reorderJoinKeys(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
leftPartitioning: Partitioning,
rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = {
if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) {
reorderJoinKeysRecursively(
leftKeys,
rightKeys,
Some(leftPartitioning),
Some(rightPartitioning))
.getOrElse((leftKeys, rightKeys))
} else {
(leftKeys, rightKeys)
}
}
/**
* Recursively reorders the join keys based on partitioning. It starts reordering the
* join keys to match HashPartitioning on either side, followed by PartitioningCollection.
*/
private def reorderJoinKeysRecursively(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
leftPartitioning: Option[Partitioning],
rightPartitioning: Option[Partitioning]): Option[(Seq[Expression], Seq[Expression])] = {
(leftPartitioning, rightPartitioning) match {
case (Some(HashPartitioning(leftExpressions, _)), _) =>
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leftExpressions, leftKeys)
.orElse(reorderJoinKeysRecursively(
leftKeys, rightKeys, None, rightPartitioning))
case (_, Some(HashPartitioning(rightExpressions, _))) =>
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys)
.orElse(reorderJoinKeysRecursively(
leftKeys, rightKeys, leftPartitioning, None))
case (Some(KeyGroupedPartitioning(clustering, _, _)), _) =>
val leafExprs = clustering.flatMap(_.collectLeaves())
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, leftKeys)
.orElse(reorderJoinKeysRecursively(
leftKeys, rightKeys, None, rightPartitioning))
case (_, Some(KeyGroupedPartitioning(clustering, _, _))) =>
val leafExprs = clustering.flatMap(_.collectLeaves())
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, rightKeys)
.orElse(reorderJoinKeysRecursively(
leftKeys, rightKeys, leftPartitioning, None))
case (Some(PartitioningCollection(partitionings)), _) =>
partitionings.foldLeft(Option.empty[(Seq[Expression], Seq[Expression])]) { (res, p) =>
res.orElse(reorderJoinKeysRecursively(leftKeys, rightKeys, Some(p), rightPartitioning))
}.orElse(reorderJoinKeysRecursively(leftKeys, rightKeys, None, rightPartitioning))
case (_, Some(PartitioningCollection(partitionings))) =>
partitionings.foldLeft(Option.empty[(Seq[Expression], Seq[Expression])]) { (res, p) =>
res.orElse(reorderJoinKeysRecursively(leftKeys, rightKeys, leftPartitioning, Some(p)))
}.orElse(None)
case _ =>
None
}
}
/**
* When the physical operators are created for JOIN, the ordering of join keys is based on order
* in which the join keys appear in the user query. That might not match with the output
* partitioning of the join node's children (thus leading to extra sort / shuffle being
* introduced). This rule will change the ordering of the join keys to match with the
* partitioning of the join nodes' children.
*/
private def reorderJoinPredicates(plan: SparkPlan): SparkPlan = {
plan match {
case ShuffledHashJoinExec(
leftKeys, rightKeys, joinType, buildSide, condition, left, right, isSkew) =>
val (reorderedLeftKeys, reorderedRightKeys) =
reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning)
ShuffledHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition,
left, right, isSkew)
case SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right, isSkew) =>
val (reorderedLeftKeys, reorderedRightKeys) =
reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning)
SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition,
left, right, isSkew)
case other => other
}
}
def apply(plan: SparkPlan): SparkPlan = {
val newPlan = plan.transformUp {
case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, shuffleOrigin)
if optimizeOutRepartition &&
(shuffleOrigin == REPARTITION_BY_COL || shuffleOrigin == REPARTITION_BY_NUM) =>
def hasSemanticEqualPartitioning(partitioning: Partitioning): Boolean = {
partitioning match {
case lower: HashPartitioning if upper.semanticEquals(lower) => true
case lower: PartitioningCollection =>
lower.partitionings.exists(hasSemanticEqualPartitioning)
case _ => false
}
}
if (hasSemanticEqualPartitioning(child.outputPartitioning)) {
child
} else {
operator
}
case operator: SparkPlan =>
val reordered = reorderJoinPredicates(operator)
val newChildren = ensureDistributionAndOrdering(
reordered.children,
reordered.requiredChildDistribution,
reordered.requiredChildOrdering,
ENSURE_REQUIREMENTS)
reordered.withNewChildren(newChildren)
}
if (requiredDistribution.isDefined) {
val shuffleOrigin = if (requiredDistribution.get.requiredNumPartitions.isDefined) {
REPARTITION_BY_NUM
} else {
REPARTITION_BY_COL
}
val finalPlan = ensureDistributionAndOrdering(
newPlan :: Nil,
requiredDistribution.get :: Nil,
Seq(Nil),
shuffleOrigin)
assert(finalPlan.size == 1)
finalPlan.head
} else {
newPlan
}
}
}
相关信息
相关文章
0
赞
- 所属分类: 前端技术
- 本文标签:
热门推荐
-
2、 - 优质文章
-
3、 gate.io
-
8、 golang
-
9、 openharmony
-
10、 Vue中input框自动聚焦