superset hive 源码

  • 2022-10-20
superset hive 代码


# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
import os
import re
import tempfile
import time
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from urllib import parse

import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from flask import current_app, g
from sqlalchemy import Column, text
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import ColumnClause, Select

from superset.common.db_query_status import QueryStatus
from superset.databases.utils import make_url_safe
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.presto import PrestoEngineSpec
from superset.exceptions import SupersetException
from superset.extensions import cache_manager
from superset.models.sql_lab import Query
from superset.sql_parse import ParsedQuery, Table
from superset.utils import core as utils

    # prevent circular imports
    from superset.models.core import Database

logger = logging.getLogger(__name__)

def upload_to_s3(filename: str, upload_prefix: str, table: Table) -> str:
    Upload the file to S3.

    :param filename: The file to upload
    :param upload_prefix: The S3 prefix
    :param table: The table that will be created
    :returns: The S3 location of the table

    # pylint: disable=import-outside-toplevel
    import boto3
    from boto3.s3.transfer import TransferConfig

    bucket_path = current_app.config["CSV_TO_HIVE_UPLOAD_S3_BUCKET"]

    if not bucket_path:"No upload bucket specified")
        raise Exception(
            "No upload bucket specified. You can specify one in the config file."

    s3 = boto3.client("s3")
    location = os.path.join("s3a://", bucket_path, upload_prefix, table.table)
        os.path.join(upload_prefix, table.table, os.path.basename(filename)),
        Config=TransferConfig(use_threads=False),  # Threading is broken in Python 3.9.
    return location

class HiveEngineSpec(PrestoEngineSpec):
    """Reuses PrestoEngineSpec functionality."""

    engine = "hive"
    engine_name = "Apache Hive"
    max_column_name_length = 767
    allows_alias_to_source_column = True
    allows_hidden_ordeby_agg = False

    # When running `SHOW FUNCTIONS`, what is the name of the column with the
    # function names?
    _show_functions_column = "tab_name"

    # pylint: disable=line-too-long
    _time_grain_expressions = {
        None: "{col}",
        "PT1S": "from_unixtime(unix_timestamp({col}), 'yyyy-MM-dd HH:mm:ss')",
        "PT1M": "from_unixtime(unix_timestamp({col}), 'yyyy-MM-dd HH:mm:00')",
        "PT1H": "from_unixtime(unix_timestamp({col}), 'yyyy-MM-dd HH:00:00')",
        "P1D": "from_unixtime(unix_timestamp({col}), 'yyyy-MM-dd 00:00:00')",
        "P1W": "date_format(date_sub({col}, CAST(7-from_unixtime(unix_timestamp({col}),'u') as int)), 'yyyy-MM-dd 00:00:00')",
        "P1M": "from_unixtime(unix_timestamp({col}), 'yyyy-MM-01 00:00:00')",
        "P3M": "date_format(add_months(trunc({col}, 'MM'), -(month({col})-1)%3), 'yyyy-MM-dd 00:00:00')",
        "P1Y": "from_unixtime(unix_timestamp({col}), 'yyyy-01-01 00:00:00')",
        "P1W/1970-01-03T00:00:00Z": "date_format(date_add({col}, INT(6-from_unixtime(unix_timestamp({col}), 'u'))), 'yyyy-MM-dd 00:00:00')",
        "1969-12-28T00:00:00Z/P1W": "date_format(date_add({col}, -INT(from_unixtime(unix_timestamp({col}), 'u'))), 'yyyy-MM-dd 00:00:00')",

    # Scoping regex at class level to avoid recompiling
    # 17/02/07 19:36:38 INFO ql.Driver: Total jobs = 5
    jobs_stats_r = re.compile(r".*INFO.*Total jobs = (?P<max_jobs>[0-9]+)")
    # 17/02/07 19:37:08 INFO ql.Driver: Launching Job 2 out of 5
    launching_job_r = re.compile(
        ".*INFO.*Launching Job (?P<job_number>[0-9]+) out of " "(?P<max_jobs>[0-9]+)"
    # 17/02/07 19:36:58 INFO exec.Task: 2017-02-07 19:36:58,152 Stage-18
    # map = 0%,  reduce = 0%
    stage_progress_r = re.compile(
        r"map = (?P<map_progress>[0-9]+)%.*"
        r"reduce = (?P<reduce_progress>[0-9]+)%.*"

    def patch(cls) -> None:
        # pylint: disable=import-outside-toplevel
        from pyhive import hive
        from TCLIService import (
            constants as patched_constants,
            TCLIService as patched_TCLIService,
            ttypes as patched_ttypes,

        from superset.db_engines import hive as patched_hive

        hive.TCLIService = patched_TCLIService
        hive.constants = patched_constants
        hive.ttypes = patched_ttypes
        hive.Cursor.fetch_logs = patched_hive.fetch_logs

    def fetch_data(
        cls, cursor: Any, limit: Optional[int] = None
    ) -> List[Tuple[Any, ...]]:
        # pylint: disable=import-outside-toplevel
        import pyhive
        from TCLIService import ttypes

        state = cursor.poll()
        if state.operationState == ttypes.TOperationState.ERROR_STATE:
            raise Exception("Query error", state.errorMessage)
            return super().fetch_data(cursor, limit)
        except pyhive.exc.ProgrammingError:
            return []

    def df_to_sql(
        database: "Database",
        table: Table,
        df: pd.DataFrame,
        to_sql_kwargs: Dict[str, Any],
    ) -> None:
        Upload data from a Pandas DataFrame to a database.

        The data is stored via the binary Parquet format which is both less problematic
        and more performant than a text file. More specifically storing a table as a
        CSV text file has severe limitations including the fact that the Hive CSV SerDe
        does not support multiline fields.

        Note this method does not create metadata for the table.

        :param database: The database to upload the data to
        :param: table The table to upload the data to
        :param df: The dataframe with data to be uploaded
        :param to_sql_kwargs: The kwargs to be passed to pandas.DataFrame.to_sql` method

        engine = cls.get_engine(database)

        if to_sql_kwargs["if_exists"] == "append":
            raise SupersetException("Append operation not currently supported")

        if to_sql_kwargs["if_exists"] == "fail":

            # Ensure table doesn't already exist.
            if table.schema:
                table_exists = not database.get_df(
                    f"SHOW TABLES IN {table.schema} LIKE '{table.table}'"
                table_exists = not database.get_df(
                    f"SHOW TABLES LIKE '{table.table}'"

            if table_exists:
                raise SupersetException("Table already exists")
        elif to_sql_kwargs["if_exists"] == "replace":
            engine.execute(f"DROP TABLE IF EXISTS {str(table)}")

        def _get_hive_type(dtype: np.dtype) -> str:
            hive_type_by_dtype = {
                np.dtype("bool"): "BOOLEAN",
                np.dtype("float64"): "DOUBLE",
                np.dtype("int64"): "BIGINT",
                np.dtype("object"): "STRING",

            return hive_type_by_dtype.get(dtype, "STRING")

        schema_definition = ", ".join(
            f"`{name}` {_get_hive_type(dtype)}" for name, dtype in df.dtypes.items()

        with tempfile.NamedTemporaryFile(
            dir=current_app.config["UPLOAD_FOLDER"], suffix=".parquet"
        ) as file:

                    CREATE TABLE {str(table)} ({schema_definition})
                    STORED AS PARQUET
                    LOCATION :location
                    ](database, g.user, table.schema),

    def convert_dttm(
        cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
    ) -> Optional[str]:
        tt = target_type.upper()
        if tt == utils.TemporalType.DATE:
            return f"CAST('{}' AS DATE)"
        if tt == utils.TemporalType.TIMESTAMP:
            return f"""CAST('{dttm
                .isoformat(sep=" ", timespec="microseconds")}' AS TIMESTAMP)"""
        return None

    def adjust_database_uri(
        cls, uri: URL, selected_schema: Optional[str] = None
    ) -> URL:
        if selected_schema:
            uri = uri.set(database=parse.quote(selected_schema, safe=""))

        return uri

    def _extract_error_message(cls, ex: Exception) -> str:
        msg = str(ex)
        match ='errorMessage="(.*?)(?<!\\)"', msg)
        if match:
            msg =
        return msg

    def progress(cls, log_lines: List[str]) -> int:
        total_jobs = 1  # assuming there's at least 1 job
        current_job = 1
        stages: Dict[int, float] = {}
        for line in log_lines:
            match = cls.jobs_stats_r.match(line)
            if match:
                total_jobs = int(match.groupdict()["max_jobs"]) or 1
            match = cls.launching_job_r.match(line)
            if match:
                current_job = int(match.groupdict()["job_number"])
                total_jobs = int(match.groupdict()["max_jobs"]) or 1
                stages = {}
            match = cls.stage_progress_r.match(line)
            if match:
                stage_number = int(match.groupdict()["stage_number"])
                map_progress = int(match.groupdict()["map_progress"])
                reduce_progress = int(match.groupdict()["reduce_progress"])
                stages[stage_number] = (map_progress + reduce_progress) / 2
            "Progress detail: {}, "  # pylint: disable=logging-format-interpolation
            "current job {}, "
            "total jobs: {}".format(stages, current_job, total_jobs)

        stage_progress = sum(stages.values()) / len(stages.values()) if stages else 0

        progress = 100 * (current_job - 1) / total_jobs + stage_progress / total_jobs
        return int(progress)

    def get_tracking_url_from_logs(cls, log_lines: List[str]) -> Optional[str]:
        lkp = "Tracking URL = "
        for line in log_lines:
            if lkp in line:
                return line.split(lkp)[1]
        return None

    def handle_cursor(  # pylint: disable=too-many-locals
        cls, cursor: Any, query: Query, session: Session
    ) -> None:
        """Updates progress information"""
        # pylint: disable=import-outside-toplevel
        from pyhive import hive

        unfinished_states = (
        polled = cursor.poll()
        last_log_line = 0
        tracking_url = None
        job_id = None
        query_id =
        while polled.operationState in unfinished_states:
            # Queries don't terminate when user clicks the STOP button on SQL LAB.
            # Refresh session so that the `query.status` modified in stop_query in
            # views/ is reflected here.
            query = session.query(type(query)).filter_by(id=query_id).one()
            if query.status == QueryStatus.STOPPED:

                log = cursor.fetch_logs() or ""
            except Exception:  # pylint: disable=broad-except
                logger.warning("Call to GetLog() failed")
                log = ""

            if log:
                log_lines = log.splitlines()
                progress = cls.progress(log_lines)
                    "Query %s: Progress total: %s", str(query_id), str(progress)
                needs_commit = False
                if progress > query.progress:
                    query.progress = progress
                    needs_commit = True
                if not tracking_url:
                    tracking_url = cls.get_tracking_url_from_logs(log_lines)
                    if tracking_url:
                        job_id = tracking_url.split("/")[-2]
                            "Query %s: Found the tracking url: %s",
                        query.tracking_url = tracking_url
              "Query %s: Job id: %s", str(query_id), str(job_id))
                        needs_commit = True
                if job_id and len(log_lines) > last_log_line:
                    # Wait for job id before logging things out
                    # this allows for prefixing all log lines and becoming
                    # searchable in something like Kibana
                    for l in log_lines[last_log_line:]:
              "Query %s: [%s] %s", str(query_id), str(job_id), l)
                    last_log_line = len(log_lines)
                if needs_commit:
            polled = cursor.poll()

    def get_columns(
        cls, inspector: Inspector, table_name: str, schema: Optional[str]
    ) -> List[Dict[str, Any]]:
        return inspector.get_columns(table_name, schema)

    def where_latest_partition(  # pylint: disable=too-many-arguments
        table_name: str,
        schema: Optional[str],
        database: "Database",
        query: Select,
        columns: Optional[List[Dict[str, str]]] = None,
    ) -> Optional[Select]:
            col_names, values = cls.latest_partition(
                table_name, schema, database, show_first=True
        except Exception:  # pylint: disable=broad-except
            # table is not partitioned
            return None
        if values is not None and columns is not None:
            for col_name, value in zip(col_names, values):
                for clm in columns:
                    if clm.get("name") == col_name:
                        query = query.where(Column(col_name) == value)

            return query
        return None

    def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[ColumnClause]:
        return BaseEngineSpec._get_fields(cols)  # pylint: disable=protected-access

    def latest_sub_partition(
        cls, table_name: str, schema: Optional[str], database: "Database", **kwargs: Any
    ) -> str:
        # TODO(bogdan): implement`

    def _latest_partition_from_df(cls, df: pd.DataFrame) -> Optional[List[str]]:
        """Hive partitions look like ds={partition name}/ds={partition name}"""
        if not df.empty:
            return [
                for partition_str in df.iloc[:, 0].max().split("/")
        return None

    def _partition_query(  # pylint: disable=too-many-arguments
        table_name: str,
        database: "Database",
        limit: int = 0,
        order_by: Optional[List[Tuple[str, bool]]] = None,
        filters: Optional[Dict[Any, Any]] = None,
    ) -> str:
        return f"SHOW PARTITIONS {table_name}"

    def select_star(  # pylint: disable=too-many-arguments
        database: "Database",
        table_name: str,
        engine: Engine,
        schema: Optional[str] = None,
        limit: int = 100,
        show_cols: bool = False,
        indent: bool = True,
        latest_partition: bool = True,
        cols: Optional[List[Dict[str, Any]]] = None,
    ) -> str:
        return super(  # pylint: disable=bad-super-call
            PrestoEngineSpec, cls

    def get_url_for_impersonation(
        cls, url: URL, impersonate_user: bool, username: Optional[str]
    ) -> URL:
        Return a modified URL with the username set.

        :param url: SQLAlchemy URL object
        :param impersonate_user: Flag indicating if impersonation is enabled
        :param username: Effective username
        # Do nothing in the URL object since instead this should modify
        # the configuraiton dictionary. See get_configuration_for_impersonation
        return url

    def update_impersonation_config(
        connect_args: Dict[str, Any],
        uri: str,
        username: Optional[str],
    ) -> None:
        Update a configuration dictionary
        that can set the correct properties for impersonating users
        :param connect_args:
        :param uri: URI string
        :param impersonate_user: Flag indicating if impersonation is enabled
        :param username: Effective username
        :return: None
        url = make_url_safe(uri)
        backend_name = url.get_backend_name()

        # Must be Hive connection, enable impersonation, and set optional param
        # auth=LDAP|KERBEROS
        # this will set hive.server2.proxy.user=$effective_username on connect_args['configuration']
        if backend_name == "hive" and username is not None:
            configuration = connect_args.get("configuration", {})
            configuration["hive.server2.proxy.user"] = username
            connect_args["configuration"] = configuration

    def execute(  # type: ignore
        cursor, query: str, async_: bool = False
    ):  # pylint: disable=arguments-differ
        kwargs = {"async": async_}
        cursor.execute(query, **kwargs)

    def get_function_names(cls, database: "Database") -> List[str]:
        Get a list of function names that are able to be called on the database.
        Used for SQL Lab autocomplete.

        :param database: The database to get functions for
        :return: A list of function names useable in the database
        df = database.get_df("SHOW FUNCTIONS")
        if cls._show_functions_column in df:
            return df[cls._show_functions_column].tolist()

        columns = df.columns.values.tolist()
            "Payload from `SHOW FUNCTIONS` has the incorrect format. "
            "Expected column `%s`, found: %s.",
            ", ".join(columns),
        # if the results have a single column, use that
        if len(columns) == 1:
            return df[columns[0]].tolist()

        # otherwise, return no function names to prevent errors
        return []

    def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool:
        """Pessimistic readonly, 100% sure statement won't mutate anything"""
        return (
            or parsed_query.is_set()
            or parsed_query.is_show()

    def has_implicit_cancel(cls) -> bool:
        Return True if the live cursor handles the implicit cancelation of the query,
        False otherise.

        :return: Whether the live cursor implicitly cancels the query
        :see: handle_cursor

        return True


