spark KafkaOffsetReaderConsumer 源码

  • 2022-10-20
  • 浏览 (341)

spark KafkaOffsetReaderConsumer 代码

文件路径:/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderConsumer.scala

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.kafka010

import java.{util => ju}

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal

import org.apache.kafka.clients.consumer.{Consumer, ConsumerConfig, OffsetAndTimestamp}
import org.apache.kafka.common.TopicPartition

import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.ExecutorCacheTaskLocation
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.kafka010.KafkaSourceProvider.StrategyOnNoMatchStartingOffset
import org.apache.spark.util.{UninterruptibleThread, UninterruptibleThreadRunner}

/**
 * This class uses Kafka's own [[org.apache.kafka.clients.consumer.KafkaConsumer]] API to
 * read data offsets from Kafka.
 * The [[ConsumerStrategy]] class defines which Kafka topics and partitions should be read
 * by this source. These strategies directly correspond to the different consumption options
 * in. This class is designed to return a configured
 * [[org.apache.kafka.clients.consumer.KafkaConsumer]] that is used by the
 * [[KafkaSource]] to query for the offsets. See the docs on
 * [[org.apache.spark.sql.kafka010.ConsumerStrategy]]
 * for more details.
 *
 * Note: This class is not ThreadSafe
 */
private[kafka010] class KafkaOffsetReaderConsumer(
    consumerStrategy: ConsumerStrategy,
    override val driverKafkaParams: ju.Map[String, Object],
    readerOptions: CaseInsensitiveMap[String],
    driverGroupIdPrefix: String) extends KafkaOffsetReader with Logging {

  /**
   * [[UninterruptibleThreadRunner]] ensures that all
   * [[org.apache.kafka.clients.consumer.KafkaConsumer]] communication called in an
   * [[UninterruptibleThread]]. In the case of streaming queries, we are already running in an
   * [[UninterruptibleThread]], however for batch mode this is not the case.
   */
  val uninterruptibleThreadRunner = new UninterruptibleThreadRunner("Kafka Offset Reader")

  /**
   * Place [[groupId]] and [[nextId]] here so that they are initialized before any consumer is
   * created -- see SPARK-19564.
   */
  private var groupId: String = null
  private var nextId = 0

  /**
   * A KafkaConsumer used in the driver to query the latest Kafka offsets. This only queries the
   * offsets and never commits them.
   */
  @volatile protected var _consumer: Consumer[Array[Byte], Array[Byte]] = null

  protected def consumer: Consumer[Array[Byte], Array[Byte]] = synchronized {
    assert(Thread.currentThread().isInstanceOf[UninterruptibleThread])
    if (_consumer == null) {
      val newKafkaParams = new ju.HashMap[String, Object](driverKafkaParams)
      if (driverKafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG) == null) {
        newKafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, nextGroupId())
      }
      _consumer = consumerStrategy.createConsumer(newKafkaParams)
    }
    _consumer
  }

  private[kafka010] val maxOffsetFetchAttempts =
    readerOptions.getOrElse(KafkaSourceProvider.FETCH_OFFSET_NUM_RETRY, "3").toInt

  /**
   * Number of partitions to read from Kafka. If this value is greater than the number of Kafka
   * topicPartitions, we will split up  the read tasks of the skewed partitions to multiple Spark
   * tasks. The number of Spark tasks will be *approximately* `numPartitions`. It can be less or
   * more depending on rounding errors or Kafka partitions that didn't receive any new data.
   */
  private val minPartitions =
    readerOptions.get(KafkaSourceProvider.MIN_PARTITIONS_OPTION_KEY).map(_.toInt)

  private val rangeCalculator = new KafkaOffsetRangeCalculator(minPartitions)

  private[kafka010] val offsetFetchAttemptIntervalMs =
    readerOptions.getOrElse(KafkaSourceProvider.FETCH_OFFSET_RETRY_INTERVAL_MS, "1000").toLong

  /**
   * Whether we should divide Kafka TopicPartitions with a lot of data into smaller Spark tasks.
   */
  private def shouldDivvyUpLargePartitions(numTopicPartitions: Int): Boolean = {
    minPartitions.map(_ > numTopicPartitions).getOrElse(false)
  }

  private def nextGroupId(): String = {
    groupId = driverGroupIdPrefix + "-" + nextId
    nextId += 1
    groupId
  }

  override def toString(): String = consumerStrategy.toString

  override def close(): Unit = {
    if (_consumer != null) uninterruptibleThreadRunner.runUninterruptibly { stopConsumer() }
    uninterruptibleThreadRunner.shutdown()
  }

  /**
   * @return The Set of TopicPartitions for a given topic
   */
  private def fetchTopicPartitions(): Set[TopicPartition] =
    uninterruptibleThreadRunner.runUninterruptibly {
      assert(Thread.currentThread().isInstanceOf[UninterruptibleThread])
      // Poll to get the latest assigned partitions
      consumer.poll(0)
      val partitions = consumer.assignment()
      consumer.pause(partitions)
      partitions.asScala.toSet
  }

  override def fetchPartitionOffsets(
      offsetRangeLimit: KafkaOffsetRangeLimit,
      isStartingOffsets: Boolean): Map[TopicPartition, Long] = {
    def validateTopicPartitions(partitions: Set[TopicPartition],
      partitionOffsets: Map[TopicPartition, Long]): Map[TopicPartition, Long] = {
      assert(partitions == partitionOffsets.keySet,
        "If startingOffsets contains specific offsets, you must specify all TopicPartitions.\n" +
          "Use -1 for latest, -2 for earliest.\n" +
          s"Specified: ${partitionOffsets.keySet} Assigned: ${partitions}")
      logDebug(s"Partitions assigned to consumer: $partitions. Seeking to $partitionOffsets")
      partitionOffsets
    }
    val partitions = fetchTopicPartitions()
    // Obtain TopicPartition offsets with late binding support
    offsetRangeLimit match {
      case EarliestOffsetRangeLimit => partitions.map {
        case tp => tp -> KafkaOffsetRangeLimit.EARLIEST
      }.toMap
      case LatestOffsetRangeLimit => partitions.map {
        case tp => tp -> KafkaOffsetRangeLimit.LATEST
      }.toMap
      case SpecificOffsetRangeLimit(partitionOffsets) =>
        validateTopicPartitions(partitions, partitionOffsets)
      case SpecificTimestampRangeLimit(partitionTimestamps, strategy) =>
        fetchSpecificTimestampBasedOffsets(partitionTimestamps,
          isStartingOffsets, strategy).partitionToOffsets
      case GlobalTimestampRangeLimit(timestamp, strategy) =>
        fetchGlobalTimestampBasedOffsets(timestamp,
          isStartingOffsets, strategy).partitionToOffsets
    }
  }

  override def fetchSpecificOffsets(
      partitionOffsets: Map[TopicPartition, Long],
      reportDataLoss: String => Unit): KafkaSourceOffset = {
    val fnAssertParametersWithPartitions: ju.Set[TopicPartition] => Unit = { partitions =>
      assert(partitions.asScala == partitionOffsets.keySet,
        "If startingOffsets contains specific offsets, you must specify all TopicPartitions.\n" +
          "Use -1 for latest, -2 for earliest, if you don't care.\n" +
          s"Specified: ${partitionOffsets.keySet} Assigned: ${partitions.asScala}")
      logDebug(s"Partitions assigned to consumer: $partitions. Seeking to $partitionOffsets")
    }

    val fnRetrievePartitionOffsets: ju.Set[TopicPartition] => Map[TopicPartition, Long] = { _ =>
      partitionOffsets
    }

    val fnAssertFetchedOffsets: Map[TopicPartition, Long] => Unit = { fetched =>
      partitionOffsets.foreach {
        case (tp, off) if off != KafkaOffsetRangeLimit.LATEST &&
          off != KafkaOffsetRangeLimit.EARLIEST =>
          if (fetched(tp) != off) {
            reportDataLoss(
              s"startingOffsets for $tp was $off but consumer reset to ${fetched(tp)}")
          }
        case _ =>
        // no real way to check that beginning or end is reasonable
      }
    }

    fetchSpecificOffsets0(fnAssertParametersWithPartitions, fnRetrievePartitionOffsets,
      fnAssertFetchedOffsets)
  }

  override def fetchSpecificTimestampBasedOffsets(
      partitionTimestamps: Map[TopicPartition, Long],
      isStartingOffsets: Boolean,
      strategyOnNoMatchStartingOffset: StrategyOnNoMatchStartingOffset.Value)
    : KafkaSourceOffset = {

    val fnAssertParametersWithPartitions: ju.Set[TopicPartition] => Unit = { partitions =>
      assert(partitions.asScala == partitionTimestamps.keySet,
        "If starting/endingOffsetsByTimestamp contains specific offsets, you must specify all " +
          s"topics. Specified: ${partitionTimestamps.keySet} Assigned: ${partitions.asScala}")
      logDebug(s"Partitions assigned to consumer: $partitions. Seeking to $partitionTimestamps")
    }

    val fnRetrievePartitionOffsets: ju.Set[TopicPartition] => Map[TopicPartition, Long] = { _ =>
      val converted = partitionTimestamps.map { case (tp, timestamp) =>
        tp -> java.lang.Long.valueOf(timestamp)
      }.asJava

      val offsetForTime: ju.Map[TopicPartition, OffsetAndTimestamp] =
        consumer.offsetsForTimes(converted)

      readTimestampOffsets(
        offsetForTime.asScala.toMap,
        isStartingOffsets,
        strategyOnNoMatchStartingOffset,
        partitionTimestamps)
    }

    val fnAssertFetchedOffsets: Map[TopicPartition, Long] => Unit = { _ => }

    fetchSpecificOffsets0(fnAssertParametersWithPartitions, fnRetrievePartitionOffsets,
      fnAssertFetchedOffsets)
  }

  override def fetchGlobalTimestampBasedOffsets(
      timestamp: Long,
      isStartingOffsets: Boolean,
      strategyOnNoMatchStartingOffset: StrategyOnNoMatchStartingOffset.Value)
    : KafkaSourceOffset = {

    val fnAssertParametersWithPartitions: ju.Set[TopicPartition] => Unit = { partitions =>
      logDebug(s"Partitions assigned to consumer: $partitions. Seeking to $timestamp")
    }

    val fnRetrievePartitionOffsets: ju.Set[TopicPartition] => Map[TopicPartition, Long] = { tps =>
      val converted = tps.asScala.map(_ -> java.lang.Long.valueOf(timestamp)).toMap.asJava

      val offsetForTime: ju.Map[TopicPartition, OffsetAndTimestamp] =
        consumer.offsetsForTimes(converted)

      readTimestampOffsets(
        offsetForTime.asScala.toMap,
        isStartingOffsets,
        strategyOnNoMatchStartingOffset,
        _ => timestamp)
    }

    val fnAssertFetchedOffsets: Map[TopicPartition, Long] => Unit = { _ => }

    fetchSpecificOffsets0(fnAssertParametersWithPartitions, fnRetrievePartitionOffsets,
      fnAssertFetchedOffsets)
  }

  private def readTimestampOffsets(
      tpToOffsetMap: Map[TopicPartition, OffsetAndTimestamp],
      isStartingOffsets: Boolean,
      strategyOnNoMatchStartingOffset: StrategyOnNoMatchStartingOffset.Value,
      partitionTimestampFn: TopicPartition => Long): Map[TopicPartition, Long] = {

    tpToOffsetMap.map { case (tp, offsetSpec) =>
      val offset = if (offsetSpec == null) {
        if (isStartingOffsets) {
          strategyOnNoMatchStartingOffset match {
            case StrategyOnNoMatchStartingOffset.ERROR =>
              // This is to match the old behavior - we used assert to check the condition.
              // scalastyle:off throwerror
              throw new AssertionError("No offset " +
                s"matched from request of topic-partition $tp and timestamp " +
                s"${partitionTimestampFn(tp)}.")
              // scalastyle:on throwerror

            case StrategyOnNoMatchStartingOffset.LATEST =>
              KafkaOffsetRangeLimit.LATEST
          }
        } else {
          KafkaOffsetRangeLimit.LATEST
        }
      } else {
        offsetSpec.offset()
      }

      tp -> offset
    }.toMap
  }

  private def fetchSpecificOffsets0(
      fnAssertParametersWithPartitions: ju.Set[TopicPartition] => Unit,
      fnRetrievePartitionOffsets: ju.Set[TopicPartition] => Map[TopicPartition, Long],
      fnAssertFetchedOffsets: Map[TopicPartition, Long] => Unit): KafkaSourceOffset = {
    val fetched = partitionsAssignedToConsumer {
      partitions => {
        fnAssertParametersWithPartitions(partitions)

        val partitionOffsets = fnRetrievePartitionOffsets(partitions)

        partitionOffsets.foreach {
          case (tp, KafkaOffsetRangeLimit.LATEST) =>
            consumer.seekToEnd(ju.Arrays.asList(tp))
          case (tp, KafkaOffsetRangeLimit.EARLIEST) =>
            consumer.seekToBeginning(ju.Arrays.asList(tp))
          case (tp, off) => consumer.seek(tp, off)
        }

        partitionOffsets.map {
          case (tp, _) => tp -> consumer.position(tp)
        }
      }
    }

    fnAssertFetchedOffsets(fetched)

    KafkaSourceOffset(fetched)
  }

  override def fetchEarliestOffsets(): Map[TopicPartition, Long] = partitionsAssignedToConsumer(
    partitions => {
      logDebug("Seeking to the beginning")

      consumer.seekToBeginning(partitions)
      val partitionOffsets = partitions.asScala.map(p => p -> consumer.position(p)).toMap
      logDebug(s"Got earliest offsets for partition : $partitionOffsets")
      partitionOffsets
    }, fetchingEarliestOffset = true)

  /**
   * Specific to `KafkaOffsetReaderConsumer`:
   * Kafka may return earliest offsets when we are requesting latest offsets if `poll` is called
   * right before `seekToEnd` (KAFKA-7703). As a workaround, we will call `position` right after
   * `poll` to wait until the potential offset request triggered by `poll(0)` is done.
   */
  override def fetchLatestOffsets(
      knownOffsets: Option[PartitionOffsetMap]): PartitionOffsetMap =
    partitionsAssignedToConsumer { partitions => {
      logDebug("Seeking to the end.")

      if (knownOffsets.isEmpty) {
        consumer.seekToEnd(partitions)
        partitions.asScala.map(p => p -> consumer.position(p)).toMap
      } else {
        var partitionOffsets: PartitionOffsetMap = Map.empty

        /**
         * Compare `knownOffsets` and `partitionOffsets`. Returns all partitions that have incorrect
         * latest offset (offset in `knownOffsets` is great than the one in `partitionOffsets`).
         */
        def findIncorrectOffsets(): Seq[(TopicPartition, Long, Long)] = {
          val incorrectOffsets = ArrayBuffer[(TopicPartition, Long, Long)]()
          partitionOffsets.foreach { case (tp, offset) =>
            knownOffsets.foreach(_.get(tp).foreach { knownOffset =>
              if (knownOffset > offset) {
                val incorrectOffset = (tp, knownOffset, offset)
                incorrectOffsets += incorrectOffset
              }
            })
          }
          incorrectOffsets.toSeq
        }

        // Retry to fetch latest offsets when detecting incorrect offsets. We don't use
        // `withRetriesWithoutInterrupt` to retry because:
        //
        // - `withRetriesWithoutInterrupt` will reset the consumer for each attempt but a fresh
        //    consumer has a much bigger chance to hit KAFKA-7703.
        // - Avoid calling `consumer.poll(0)` which may cause KAFKA-7703.
        var incorrectOffsets: Seq[(TopicPartition, Long, Long)] = Nil
        var attempt = 0
        do {
          consumer.seekToEnd(partitions)
          partitionOffsets = partitions.asScala.map(p => p -> consumer.position(p)).toMap
          attempt += 1

          incorrectOffsets = findIncorrectOffsets()
          if (incorrectOffsets.nonEmpty) {
            logWarning("Found incorrect offsets in some partitions " +
              s"(partition, previous offset, fetched offset): $incorrectOffsets")
            if (attempt < maxOffsetFetchAttempts) {
              logWarning("Retrying to fetch latest offsets because of incorrect offsets")
              Thread.sleep(offsetFetchAttemptIntervalMs)
            }
          }
        } while (incorrectOffsets.nonEmpty && attempt < maxOffsetFetchAttempts)

        logDebug(s"Got latest offsets for partition : $partitionOffsets")
        partitionOffsets
      }
    }
  }

  override def fetchEarliestOffsets(
      newPartitions: Seq[TopicPartition]): Map[TopicPartition, Long] = {
    if (newPartitions.isEmpty) {
      Map.empty[TopicPartition, Long]
    } else {
      partitionsAssignedToConsumer(partitions => {
        // Get the earliest offset of each partition
        consumer.seekToBeginning(partitions)
        val partitionOffsets = newPartitions.filter { p =>
          // When deleting topics happen at the same time, some partitions may not be in
          // `partitions`. So we need to ignore them
          partitions.contains(p)
        }.map(p => p -> consumer.position(p)).toMap
        logDebug(s"Got earliest offsets for new partitions: $partitionOffsets")
        partitionOffsets
      }, fetchingEarliestOffset = true)
    }
  }

  override def getOffsetRangesFromUnresolvedOffsets(
      startingOffsets: KafkaOffsetRangeLimit,
      endingOffsets: KafkaOffsetRangeLimit): Seq[KafkaOffsetRange] = {
    val fromPartitionOffsets = fetchPartitionOffsets(startingOffsets, isStartingOffsets = true)
    val untilPartitionOffsets = fetchPartitionOffsets(endingOffsets, isStartingOffsets = false)

    // Obtain topicPartitions in both from and until partition offset, ignoring
    // topic partitions that were added and/or deleted between the two above calls.
    if (fromPartitionOffsets.keySet != untilPartitionOffsets.keySet) {
      implicit val topicOrdering: Ordering[TopicPartition] = Ordering.by(t => t.topic())
      val fromTopics = fromPartitionOffsets.keySet.toList.sorted.mkString(",")
      val untilTopics = untilPartitionOffsets.keySet.toList.sorted.mkString(",")
      throw new IllegalStateException("different topic partitions " +
        s"for starting offsets topics[${fromTopics}] and " +
        s"ending offsets topics[${untilTopics}]")
    }

    // Calculate offset ranges
    val offsetRangesBase = untilPartitionOffsets.keySet.map { tp =>
      val fromOffset = fromPartitionOffsets.get(tp).getOrElse {
        // This should not happen since topicPartitions contains all partitions not in
        // fromPartitionOffsets
        throw new IllegalStateException(s"$tp doesn't have a from offset")
      }
      val untilOffset = untilPartitionOffsets(tp)
      KafkaOffsetRange(tp, fromOffset, untilOffset, None)
    }.toSeq

    if (shouldDivvyUpLargePartitions(offsetRangesBase.size)) {
      val fromOffsetsMap =
        offsetRangesBase.map(range => (range.topicPartition, range.fromOffset)).toMap
      val untilOffsetsMap =
        offsetRangesBase.map(range => (range.topicPartition, range.untilOffset)).toMap

      // No need to report data loss here
      val resolvedFromOffsets = fetchSpecificOffsets(fromOffsetsMap, _ => ()).partitionToOffsets
      val resolvedUntilOffsets = fetchSpecificOffsets(untilOffsetsMap, _ => ()).partitionToOffsets
      val ranges = offsetRangesBase.map(_.topicPartition).map { tp =>
        KafkaOffsetRange(tp, resolvedFromOffsets(tp), resolvedUntilOffsets(tp), preferredLoc = None)
      }
      val divvied = rangeCalculator.getRanges(ranges).groupBy(_.topicPartition)
      divvied.flatMap { case (tp, splitOffsetRanges) =>
        if (splitOffsetRanges.length == 1) {
          Seq(KafkaOffsetRange(tp, fromOffsetsMap(tp), untilOffsetsMap(tp), None))
        } else {
          // the list can't be empty
          val first = splitOffsetRanges.head.copy(fromOffset = fromOffsetsMap(tp))
          val end = splitOffsetRanges.last.copy(untilOffset = untilOffsetsMap(tp))
          Seq(first) ++ splitOffsetRanges.drop(1).dropRight(1) :+ end
        }
      }.toArray.toSeq
    } else {
      offsetRangesBase
    }
  }

  private def getSortedExecutorList(): Array[String] = {
    def compare(a: ExecutorCacheTaskLocation, b: ExecutorCacheTaskLocation): Boolean = {
      if (a.host == b.host) {
        a.executorId > b.executorId
      } else {
        a.host > b.host
      }
    }

    val bm = SparkEnv.get.blockManager
    bm.master.getPeers(bm.blockManagerId).toArray
      .map(x => ExecutorCacheTaskLocation(x.host, x.executorId))
      .sortWith(compare)
      .map(_.toString)
  }

  override def getOffsetRangesFromResolvedOffsets(
      fromPartitionOffsets: PartitionOffsetMap,
      untilPartitionOffsets: PartitionOffsetMap,
      reportDataLoss: String => Unit): Seq[KafkaOffsetRange] = {
    // Find the new partitions, and get their earliest offsets
    val newPartitions = untilPartitionOffsets.keySet.diff(fromPartitionOffsets.keySet)
    val newPartitionInitialOffsets = fetchEarliestOffsets(newPartitions.toSeq)
    if (newPartitionInitialOffsets.keySet != newPartitions) {
      // We cannot get from offsets for some partitions. It means they got deleted.
      val deletedPartitions = newPartitions.diff(newPartitionInitialOffsets.keySet)
      reportDataLoss(
        s"Cannot find earliest offsets of ${deletedPartitions}. Some data may have been missed")
    }
    logInfo(s"Partitions added: $newPartitionInitialOffsets")
    newPartitionInitialOffsets.filter(_._2 != 0).foreach { case (p, o) =>
      reportDataLoss(
        s"Added partition $p starts from $o instead of 0. Some data may have been missed")
    }

    val deletedPartitions = fromPartitionOffsets.keySet.diff(untilPartitionOffsets.keySet)
    if (deletedPartitions.nonEmpty) {
      val message = if (driverKafkaParams.containsKey(ConsumerConfig.GROUP_ID_CONFIG)) {
        s"$deletedPartitions are gone. ${KafkaSourceProvider.CUSTOM_GROUP_ID_ERROR_MESSAGE}"
      } else {
        s"$deletedPartitions are gone. Some data may have been missed."
      }
      reportDataLoss(message)
    }

    // Use the until partitions to calculate offset ranges to ignore partitions that have
    // been deleted
    val topicPartitions = untilPartitionOffsets.keySet.filter { tp =>
      // Ignore partitions that we don't know the from offsets.
      newPartitionInitialOffsets.contains(tp) || fromPartitionOffsets.contains(tp)
    }.toSeq
    logDebug("TopicPartitions: " + topicPartitions.mkString(", "))

    val fromOffsets = fromPartitionOffsets ++ newPartitionInitialOffsets
    val untilOffsets = untilPartitionOffsets
    val ranges = topicPartitions.map { tp =>
      val fromOffset = fromOffsets(tp)
      val untilOffset = untilOffsets(tp)
      if (untilOffset < fromOffset) {
        reportDataLoss(s"Partition $tp's offset was changed from " +
          s"$fromOffset to $untilOffset, some data may have been missed")
      }
      KafkaOffsetRange(tp, fromOffset, untilOffset, preferredLoc = None)
    }
    rangeCalculator.getRanges(ranges, getSortedExecutorList)
  }

  private def partitionsAssignedToConsumer(
      body: ju.Set[TopicPartition] => Map[TopicPartition, Long],
      fetchingEarliestOffset: Boolean = false)
    : Map[TopicPartition, Long] = uninterruptibleThreadRunner.runUninterruptibly {

    withRetriesWithoutInterrupt {
      // Poll to get the latest assigned partitions
      consumer.poll(0)
      val partitions = consumer.assignment()

      if (!fetchingEarliestOffset) {
        // Call `position` to wait until the potential offset request triggered by `poll(0)` is
        // done. This is a workaround for KAFKA-7703, which an async `seekToBeginning` triggered by
        // `poll(0)` may reset offsets that should have been set by another request.
        partitions.asScala.map(p => p -> consumer.position(p)).foreach(_ => {})
      }

      consumer.pause(partitions)
      logDebug(s"Partitions assigned to consumer: $partitions.")
      body(partitions)
    }
  }

  /**
   * Helper function that does multiple retries on a body of code that returns offsets.
   * Retries are needed to handle transient failures. For e.g. race conditions between getting
   * assignment and getting position while topics/partitions are deleted can cause NPEs.
   *
   * This method also makes sure `body` won't be interrupted to workaround a potential issue in
   * `KafkaConsumer.poll`. (KAFKA-1894)
   */
  private def withRetriesWithoutInterrupt(
      body: => Map[TopicPartition, Long]): Map[TopicPartition, Long] = {
    // Make sure `KafkaConsumer.poll` won't be interrupted (KAFKA-1894)
    assert(Thread.currentThread().isInstanceOf[UninterruptibleThread])

    synchronized {
      var result: Option[Map[TopicPartition, Long]] = None
      var attempt = 1
      var lastException: Throwable = null
      while (result.isEmpty && attempt <= maxOffsetFetchAttempts
        && !Thread.currentThread().isInterrupted) {
        Thread.currentThread match {
          case ut: UninterruptibleThread =>
            // "KafkaConsumer.poll" may hang forever if the thread is interrupted (E.g., the query
            // is stopped)(KAFKA-1894). Hence, we just make sure we don't interrupt it.
            //
            // If the broker addresses are wrong, or Kafka cluster is down, "KafkaConsumer.poll" may
            // hang forever as well. This cannot be resolved in KafkaSource until Kafka fixes the
            // issue.
            ut.runUninterruptibly {
              try {
                result = Some(body)
              } catch {
                case NonFatal(e) =>
                  lastException = e
                  logWarning(s"Error in attempt $attempt getting Kafka offsets: ", e)
                  attempt += 1
                  Thread.sleep(offsetFetchAttemptIntervalMs)
                  resetConsumer()
              }
            }
          case _ =>
            throw new IllegalStateException(
              "Kafka APIs must be executed on a o.a.spark.util.UninterruptibleThread")
        }
      }
      if (Thread.interrupted()) {
        throw new InterruptedException()
      }
      if (result.isEmpty) {
        assert(attempt > maxOffsetFetchAttempts)
        assert(lastException != null)
        throw lastException
      }
      result.get
    }
  }

  private def stopConsumer(): Unit = synchronized {
    assert(Thread.currentThread().isInstanceOf[UninterruptibleThread])
    if (_consumer != null) _consumer.close()
  }

  private def resetConsumer(): Unit = synchronized {
    stopConsumer()
    _consumer = null  // will automatically get reinitialized again
  }
}

相关信息

spark 源码目录

相关文章

spark ConsumerStrategy 源码

spark JsonUtils 源码

spark KafkaBatch 源码

spark KafkaBatchPartitionReader 源码

spark KafkaBatchWrite 源码

spark KafkaContinuousStream 源码

spark KafkaDataWriter 源码

spark KafkaMicroBatchStream 源码

spark KafkaOffsetRangeCalculator 源码

spark KafkaOffsetRangeLimit 源码

0  赞