spark RowQueue 源码

  • 2022-10-20
  • 浏览 (196)

spark RowQueue 代码

文件路径:/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.execution.python

import java.io._

import com.google.common.io.Closeables

import org.apache.spark.SparkEnv
import org.apache.spark.io.NioBufferedFileInputStream
import org.apache.spark.memory.{MemoryConsumer, SparkOutOfMemoryError, TaskMemoryManager}
import org.apache.spark.serializer.SerializerManager
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.memory.MemoryBlock

/**
 * A RowQueue is an FIFO queue for UnsafeRow.
 *
 * This RowQueue is ONLY designed and used for Python UDF, which has only one writer and only one
 * reader, the reader ALWAYS ran behind the writer. See the doc of class [[BatchEvalPythonExec]]
 * on how it works.
 */
private[python] trait RowQueue {

  /**
   * Add a row to the end of it, returns true iff the row has been added to the queue.
   */
  def add(row: UnsafeRow): Boolean

  /**
   * Retrieve and remove the first row, returns null if it's empty.
   *
   * It can only be called after add is called, otherwise it will fail (NPE).
   */
  def remove(): UnsafeRow

  /**
   * Cleanup all the resources.
   */
  def close(): Unit
}

/**
 * A RowQueue that is based on in-memory page. UnsafeRows are appended into it until it's full.
 * Another thread could read from it at the same time (behind the writer).
 *
 * The format of UnsafeRow in page:
 * [4 bytes to hold length of record (N)] [N bytes to hold record] [...]
 *
 * -1 length means end of page.
 */
private[python] abstract class InMemoryRowQueue(val page: MemoryBlock, numFields: Int)
  extends RowQueue {
  private val base: AnyRef = page.getBaseObject
  private val endOfPage: Long = page.getBaseOffset + page.size
  // the first location where a new row would be written
  private var writeOffset = page.getBaseOffset
  // points to the start of the next row to read
  private var readOffset = page.getBaseOffset
  private val resultRow = new UnsafeRow(numFields)

  def add(row: UnsafeRow): Boolean = synchronized {
    val size = row.getSizeInBytes
    if (writeOffset + 4 + size > endOfPage) {
      // if there is not enough space in this page to hold the new record
      if (writeOffset + 4 <= endOfPage) {
        // if there's extra space at the end of the page, store a special "end-of-page" length (-1)
        Platform.putInt(base, writeOffset, -1)
      }
      false
    } else {
      Platform.putInt(base, writeOffset, size)
      Platform.copyMemory(row.getBaseObject, row.getBaseOffset, base, writeOffset + 4, size)
      writeOffset += 4 + size
      true
    }
  }

  def remove(): UnsafeRow = synchronized {
    assert(readOffset <= writeOffset, "reader should not go beyond writer")
    if (readOffset + 4 > endOfPage || Platform.getInt(base, readOffset) < 0) {
      null
    } else {
      val size = Platform.getInt(base, readOffset)
      resultRow.pointTo(base, readOffset + 4, size)
      readOffset += 4 + size
      resultRow
    }
  }
}

/**
 * A RowQueue that is backed by a file on disk. This queue will stop accepting new rows once any
 * reader has begun reading from the queue.
 */
private[python] case class DiskRowQueue(
    file: File,
    fields: Int,
    serMgr: SerializerManager) extends RowQueue {

  private var out = new DataOutputStream(serMgr.wrapForEncryption(
    new BufferedOutputStream(new FileOutputStream(file.toString))))
  private var unreadBytes = 0L

  private var in: DataInputStream = _
  private val resultRow = new UnsafeRow(fields)

  def add(row: UnsafeRow): Boolean = synchronized {
    if (out == null) {
      // Another thread is reading, stop writing this one
      return false
    }
    out.writeInt(row.getSizeInBytes)
    out.write(row.getBytes)
    unreadBytes += 4 + row.getSizeInBytes
    true
  }

  def remove(): UnsafeRow = synchronized {
    if (out != null) {
      out.close()
      out = null
      in = new DataInputStream(serMgr.wrapForEncryption(
        new NioBufferedFileInputStream(file)))
    }

    if (unreadBytes > 0) {
      val size = in.readInt()
      val bytes = new Array[Byte](size)
      in.readFully(bytes)
      unreadBytes -= 4 + size
      resultRow.pointTo(bytes, size)
      resultRow
    } else {
      null
    }
  }

  def close(): Unit = synchronized {
    Closeables.close(out, true)
    out = null
    Closeables.close(in, true)
    in = null
    if (file.exists()) {
      file.delete()
    }
  }
}

/**
 * A RowQueue that has a list of RowQueues, which could be in memory or disk.
 *
 * HybridRowQueue could be safely appended in one thread, and pulled in another thread in the same
 * time.
 */
private[python] case class HybridRowQueue(
    memManager: TaskMemoryManager,
    tempDir: File,
    numFields: Int,
    serMgr: SerializerManager)
  extends MemoryConsumer(memManager, memManager.getTungstenMemoryMode) with RowQueue {

  // Each buffer should have at least one row
  private var queues = new java.util.LinkedList[RowQueue]()

  private var writing: RowQueue = _
  private var reading: RowQueue = _

  // exposed for testing
  private[python] def numQueues(): Int = queues.size()

  def spill(size: Long, trigger: MemoryConsumer): Long = {
    if (trigger == this) {
      // When it's triggered by itself, it should write upcoming rows into disk instead of copying
      // the rows already in the queue.
      return 0L
    }
    var released = 0L
    synchronized {
      // poll out all the buffers and add them back in the same order to make sure that the rows
      // are in correct order.
      val newQueues = new java.util.LinkedList[RowQueue]()
      while (!queues.isEmpty) {
        val queue = queues.remove()
        val newQueue = if (!queues.isEmpty && queue.isInstanceOf[InMemoryRowQueue]) {
          val diskQueue = createDiskQueue()
          var row = queue.remove()
          while (row != null) {
            diskQueue.add(row)
            row = queue.remove()
          }
          released += queue.asInstanceOf[InMemoryRowQueue].page.size()
          queue.close()
          diskQueue
        } else {
          queue
        }
        newQueues.add(newQueue)
      }
      queues = newQueues
    }
    released
  }

  private def createDiskQueue(): RowQueue = {
    DiskRowQueue(File.createTempFile("buffer", "", tempDir), numFields, serMgr)
  }

  private def createNewQueue(required: Long): RowQueue = {
    val page = try {
      allocatePage(required)
    } catch {
      case _: SparkOutOfMemoryError =>
        null
    }
    val buffer = if (page != null) {
      new InMemoryRowQueue(page, numFields) {
        override def close(): Unit = {
          freePage(page)
        }
      }
    } else {
      createDiskQueue()
    }

    synchronized {
      queues.add(buffer)
    }
    buffer
  }

  def add(row: UnsafeRow): Boolean = {
    if (writing == null || !writing.add(row)) {
      writing = createNewQueue(4 + row.getSizeInBytes)
      if (!writing.add(row)) {
        throw QueryExecutionErrors.failedToPushRowIntoRowQueueError(writing.toString)
      }
    }
    true
  }

  def remove(): UnsafeRow = {
    var row: UnsafeRow = null
    if (reading != null) {
      row = reading.remove()
    }
    if (row == null) {
      if (reading != null) {
        reading.close()
      }
      synchronized {
        reading = queues.remove()
      }
      assert(reading != null, s"queue should not be empty")
      row = reading.remove()
      assert(row != null, s"$reading should have at least one row")
    }
    row
  }

  def close(): Unit = {
    if (reading != null) {
      reading.close()
      reading = null
    }
    synchronized {
      while (!queues.isEmpty) {
        queues.remove().close()
      }
    }
  }
}

private[python] object HybridRowQueue {
  def apply(taskMemoryMgr: TaskMemoryManager, file: File, fields: Int): HybridRowQueue = {
    HybridRowQueue(taskMemoryMgr, file, fields, SparkEnv.get.serializerManager)
  }
}

相关信息

spark 源码目录

相关文章

spark AggregateInPandasExec 源码

spark ApplyInPandasWithStatePythonRunner 源码

spark ApplyInPandasWithStateWriter 源码

spark ArrowEvalPythonExec 源码

spark ArrowPythonRunner 源码

spark AttachDistributedSequenceExec 源码

spark BatchEvalPythonExec 源码

spark CoGroupedArrowPythonRunner 源码

spark EvalPythonExec 源码

spark EvaluatePython 源码

0  赞