airflow serialized_objects 源码

  • 2022-10-20
  • 浏览 (659)

airflow serialized_objects 代码


# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Serialized DAG and BaseOperator"""
from __future__ import annotations

import datetime
import enum
import logging
import warnings
import weakref
from dataclasses import dataclass
from inspect import Parameter, signature
from typing import TYPE_CHECKING, Any, Collection, Iterable, Mapping, NamedTuple, Type, Union

import cattr
import lazy_object_proxy
import pendulum
from dateutil import relativedelta
from import FixedTimezone, Timezone

from airflow.compat.functools import cache
from airflow.configuration import conf
from airflow.datasets import Dataset
from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, SerializationError
from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
from airflow.models.connection import Connection
from airflow.models.dag import DAG, create_timetable
from airflow.models.expandinput import EXPAND_INPUT_EMPTY, ExpandInput, create_expand_input, get_map_type_key
from airflow.models.mappedoperator import MappedOperator
from airflow.models.operator import Operator
from airflow.models.param import Param, ParamsDict
from airflow.models.taskmixin import DAGNode
from airflow.models.xcom_arg import XComArg, deserialize_xcom_arg, serialize_xcom_arg
from airflow.providers_manager import ProvidersManager
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
from airflow.serialization.helpers import serialize_template_field
from airflow.serialization.json_schema import Validator, load_dag_schema
from airflow.settings import DAGS_FOLDER, json
from airflow.timetables.base import Timetable
from airflow.utils.code_utils import get_python_source
from import get_docs_url
from airflow.utils.module_loading import as_importable_string, import_string
from airflow.utils.operator_resources import Resources
from airflow.utils.task_group import TaskGroup

    from airflow.ti_deps.deps.base_ti_dep import BaseTIDep

        from kubernetes.client import models as k8s

        from airflow.kubernetes.pod_generator import PodGenerator
    except ImportError:

log = logging.getLogger(__name__)

    # Deprecated names, so that existing serialized dags load straight away.

def get_operator_extra_links() -> set[str]:
    """Get the operator extra links.

    This includes both the built-in ones, and those come from the providers.

def _get_default_mapped_partial() -> dict[str, Any]:
    """Get default partial kwargs in a mapped operator.

    This is used to simplify a serialized mapped operator by excluding default
    values supplied in the implementation from the serialized dict. Since those
    are defaults, they are automatically supplied on de-serialization, so we
    don't need to store them.
    # Use the private _expand() method to avoid the empty kwargs check.
    default = BaseOperator.partial(task_id="_")._expand(EXPAND_INPUT_EMPTY, strict=False).partial_kwargs
    return BaseSerialization.serialize(default)[Encoding.VAR]

def encode_relativedelta(var: relativedelta.relativedelta) -> dict[str, Any]:
    encoded = {k: v for k, v in var.__dict__.items() if not k.startswith("_") and v}
    if var.weekday and var.weekday.n:
        # Every n'th Friday for example
        encoded['weekday'] = [var.weekday.weekday, var.weekday.n]
    elif var.weekday:
        encoded['weekday'] = [var.weekday.weekday]
    return encoded

def decode_relativedelta(var: dict[str, Any]) -> relativedelta.relativedelta:
    if 'weekday' in var:
        var['weekday'] = relativedelta.weekday(*var['weekday'])  # type: ignore
    return relativedelta.relativedelta(**var)

def encode_timezone(var: Timezone) -> str | int:
    """Encode a Pendulum Timezone for serialization.

    Airflow only supports timezone objects that implements Pendulum's Timezone
    interface. We try to keep as much information as possible to make conversion
    round-tripping possible (see ``decode_timezone``). We need to special-case
    UTC; Pendulum implements it as a FixedTimezone (i.e. it gets encoded as
    0 without the special case), but passing 0 into ``pendulum.timezone`` does
    not give us UTC (but ``+00:00``).
    if isinstance(var, FixedTimezone):
        if var.offset == 0:
            return "UTC"
        return var.offset
    if isinstance(var, Timezone):
    raise ValueError(
        f"DAG timezone should be a, not {var!r}. "
        f"See {get_docs_url('timezone.html#time-zone-aware-dags')}"

def decode_timezone(var: str | int) -> Timezone:
    """Decode a previously serialized Pendulum Timezone."""

def _get_registered_timetable(importable_string: str) -> type[Timetable] | None:
    from airflow import plugins_manager

    if importable_string.startswith("airflow.timetables."):
        return import_string(importable_string)
    if plugins_manager.timetable_classes:
        return plugins_manager.timetable_classes.get(importable_string)
        return None

class _TimetableNotRegistered(ValueError):
    def __init__(self, type_string: str) -> None:
        self.type_string = type_string

    def __str__(self) -> str:
        return (
            f"Timetable class {self.type_string!r} is not registered or "
            "you have a top level database access that disrupted the session. "
            "Please check the airflow best practices documentation."

def _encode_timetable(var: Timetable) -> dict[str, Any]:
    """Encode a timetable instance.

    This delegates most of the serialization work to the type, so the behavior
    can be completely controlled by a custom subclass.
    timetable_class = type(var)
    importable_string = as_importable_string(timetable_class)
    if _get_registered_timetable(importable_string) is None:
        raise _TimetableNotRegistered(importable_string)
    return {Encoding.TYPE: importable_string, Encoding.VAR: var.serialize()}

def _decode_timetable(var: dict[str, Any]) -> Timetable:
    """Decode a previously serialized timetable.

    Most of the deserialization logic is delegated to the actual type, which
    we import from string.
    importable_string = var[Encoding.TYPE]
    timetable_class = _get_registered_timetable(importable_string)
    if timetable_class is None:
        raise _TimetableNotRegistered(importable_string)
    return timetable_class.deserialize(var[Encoding.VAR])

class _XComRef(NamedTuple):
    """Used to store info needed to create XComArg.

    We can't turn it in to a XComArg until we've loaded _all_ the tasks, so when
    deserializing an operator, we need to create something in its place, and
    post-process it in ``deserialize_dag``.

    data: dict

    def deref(self, dag: DAG) -> XComArg:
        return deserialize_xcom_arg(, dag)

# These two should be kept in sync. Note that these are intentionally not using
# the type declarations in so we always remember to update
# serialization logic when adding new ExpandInput variants. If you add things to
# the unions, be sure to update _ExpandInputRef to match.
_ExpandInputOriginalValue = Union[
    # For .expand(**kwargs).
    Mapping[str, Any],
    # For expand_kwargs(arg).
    Collection[Union[XComArg, Mapping[str, Any]]],
_ExpandInputSerializedValue = Union[
    # For .expand(**kwargs).
    Mapping[str, Any],
    # For expand_kwargs(arg).
    Collection[Union[_XComRef, Mapping[str, Any]]],

class _ExpandInputRef(NamedTuple):
    """Used to store info needed to create a mapped operator's expand input.

    This references a ``ExpandInput`` type, but replaces ``XComArg`` objects
    with ``_XComRef`` (see documentation on the latter type for reasoning).

    key: str
    value: _ExpandInputSerializedValue

    def validate_expand_input_value(cls, value: _ExpandInputOriginalValue) -> None:
        """Validate we've covered all ``ExpandInput.value`` types.

        This function does not actually do anything, but is called during
        serialization so Mypy will *statically* check we have handled all
        possible ExpandInput cases.

    def deref(self, dag: DAG) -> ExpandInput:
        """De-reference into a concrete ExpandInput object.

        If you add more cases here, be sure to update _ExpandInputOriginalValue
        and _ExpandInputSerializedValue to match the logic.
        if isinstance(self.value, _XComRef):
            value: Any = self.value.deref(dag)
        elif isinstance(self.value,
            value = {k: v.deref(dag) if isinstance(v, _XComRef) else v for k, v in self.value.items()}
            value = [v.deref(dag) if isinstance(v, _XComRef) else v for v in self.value]
        return create_expand_input(self.key, value)

class BaseSerialization:
    """BaseSerialization provides utils for serialization."""

    # JSON primitive types.
    _primitive_types = (int, bool, float, str)

    # Time types.
    # and datetime.time are converted to strings.
    _datetime_types = (datetime.datetime,)

    # Object types that are always excluded in serialization.
    _excluded_types = (logging.Logger, Connection, type)

    _json_schema: Validator | None = None

    # Should the extra operator link be loaded via plugins when
    # de-serializing the DAG? This flag is set to False in Scheduler so that Extra Operator links
    # are not loaded to not run User code in Scheduler.
    _load_operator_extra_links = True

    _CONSTRUCTOR_PARAMS: dict[str, Parameter] = {}


    def to_json(cls, var: DAG | BaseOperator | dict | list | set | tuple) -> str:
        """Stringifies DAGs and operators contained by var and returns a JSON string of var."""
        return json.dumps(cls.to_dict(var), ensure_ascii=True)

    def to_dict(cls, var: DAG | BaseOperator | dict | list | set | tuple) -> dict:
        """Stringifies DAGs and operators contained by var and returns a dict of var."""
        # Don't call on this class directly - only SerializedDAG or
        # SerializedBaseOperator should be used as the "entrypoint"
        raise NotImplementedError()

    def from_json(cls, serialized_obj: str) -> BaseSerialization | dict | list | set | tuple:
        """Deserializes json_str and reconstructs all DAGs and operators it contains."""
        return cls.from_dict(json.loads(serialized_obj))

    def from_dict(cls, serialized_obj: dict[Encoding, Any]) -> BaseSerialization | dict | list | set | tuple:
        """Deserializes a python dict stored with type decorators and
        reconstructs all DAGs and operators it contains.
        return cls.deserialize(serialized_obj)

    def validate_schema(cls, serialized_obj: str | dict) -> None:
        """Validate serialized_obj satisfies JSON schema."""
        if cls._json_schema is None:
            raise AirflowException(f'JSON schema of {cls.__name__:s} is not set.')

        if isinstance(serialized_obj, dict):
        elif isinstance(serialized_obj, str):
            raise TypeError("Invalid type: Only dict and str are supported.")

    def _encode(x: Any, type_: Any) -> dict[Encoding, Any]:
        """Encode data by a JSON dict."""
        return {Encoding.VAR: x, Encoding.TYPE: type_}

    def _is_primitive(cls, var: Any) -> bool:
        """Primitive types."""
        return var is None or isinstance(var, cls._primitive_types)

    def _is_excluded(cls, var: Any, attrname: str, instance: Any) -> bool:
        """Types excluded from serialization."""
        if var is None:
            if not cls._is_constructor_param(attrname, instance):
                # Any instance attribute, that is not a constructor argument, we exclude None as the default
                return True

            return cls._value_is_hardcoded_default(attrname, var, instance)
        return isinstance(var, cls._excluded_types) or cls._value_is_hardcoded_default(
            attrname, var, instance

    def serialize_to_json(
        cls, object_to_serialize: BaseOperator | MappedOperator | DAG, decorated_fields: set
    ) -> dict[str, Any]:
        """Serializes an object to json"""
        serialized_object: dict[str, Any] = {}
        keys_to_serialize = object_to_serialize.get_serialized_fields()
        for key in keys_to_serialize:
            # None is ignored in serialized form and is added back in deserialization.
            value = getattr(object_to_serialize, key, None)
            if cls._is_excluded(value, key, object_to_serialize):

            if key == '_operator_name':
                # when operator_name matches task_type, we can remove
                # it to reduce the JSON payload
                task_type = getattr(object_to_serialize, '_task_type', None)
                if value != task_type:
                    serialized_object[key] = cls.serialize(value)
            elif key in decorated_fields:
                serialized_object[key] = cls.serialize(value)
            elif key == "timetable" and value is not None:
                serialized_object[key] = _encode_timetable(value)
                value = cls.serialize(value)
                if isinstance(value, dict) and Encoding.TYPE in value:
                    value = value[Encoding.VAR]
                serialized_object[key] = value
        return serialized_object

    def serialize(cls, var: Any) -> Any:  # Unfortunately there is no support for recursive types in mypy
        """Helper function of depth first search for serialization.

        The serialization protocol is:

        (1) keeping JSON supported types: primitives, dict, list;
        (2) encoding other types as ``{TYPE: 'foo', VAR: 'bar'}``, the deserialization
            step decode VAR according to TYPE;
        (3) Operator has a special field CLASS to record the original class
            name for displaying in UI.

        :meta private:
        if cls._is_primitive(var):
            # enum.IntEnum is an int instance, it causes json dumps error so we use its value.
            if isinstance(var, enum.Enum):
                return var.value
            return var
        elif isinstance(var, dict):
            return cls._encode({str(k): cls.serialize(v) for k, v in var.items()}, type_=DAT.DICT)
        elif isinstance(var, list):
            return [cls.serialize(v) for v in var]
        elif var.__class__.__name__ == 'V1Pod' and _has_kubernetes() and isinstance(var, k8s.V1Pod):
            json_pod = PodGenerator.serialize_pod(var)
            return cls._encode(json_pod, type_=DAT.POD)
        elif isinstance(var, DAG):
            return SerializedDAG.serialize_dag(var)
        elif isinstance(var, Resources):
            return var.to_dict()
        elif isinstance(var, MappedOperator):
            return SerializedBaseOperator.serialize_mapped_operator(var)
        elif isinstance(var, BaseOperator):
            return SerializedBaseOperator.serialize_operator(var)
        elif isinstance(var, cls._datetime_types):
            return cls._encode(var.timestamp(), type_=DAT.DATETIME)
        elif isinstance(var, datetime.timedelta):
            return cls._encode(var.total_seconds(), type_=DAT.TIMEDELTA)
        elif isinstance(var, Timezone):
            return cls._encode(encode_timezone(var), type_=DAT.TIMEZONE)
        elif isinstance(var, relativedelta.relativedelta):
            return cls._encode(encode_relativedelta(var), type_=DAT.RELATIVEDELTA)
        elif callable(var):
            return str(get_python_source(var))
        elif isinstance(var, set):
            # FIXME: casts set to list in customized serialization in future.
                return cls._encode(sorted(cls.serialize(v) for v in var), type_=DAT.SET)
            except TypeError:
                return cls._encode([cls.serialize(v) for v in var], type_=DAT.SET)
        elif isinstance(var, tuple):
            # FIXME: casts tuple to list in customized serialization in future.
            return cls._encode([cls.serialize(v) for v in var], type_=DAT.TUPLE)
        elif isinstance(var, TaskGroup):
            return SerializedTaskGroup.serialize_task_group(var)
        elif isinstance(var, Param):
            return cls._encode(cls._serialize_param(var), type_=DAT.PARAM)
        elif isinstance(var, XComArg):
            return cls._encode(serialize_xcom_arg(var), type_=DAT.XCOM_REF)
        elif isinstance(var, Dataset):
            return cls._encode(dict(uri=var.uri, extra=var.extra), type_=DAT.DATASET)
            log.debug('Cast type %s to str in serialization.', type(var))
            return str(var)

    def deserialize(cls, encoded_var: Any) -> Any:
        """Helper function of depth first search for deserialization.

        :meta private:
        # JSON primitives (except for dict) are not encoded.
        if cls._is_primitive(encoded_var):
            return encoded_var
        elif isinstance(encoded_var, list):
            return [cls.deserialize(v) for v in encoded_var]

        if not isinstance(encoded_var, dict):
            raise ValueError(f"The encoded_var should be dict and is {type(encoded_var)}")
        var = encoded_var[Encoding.VAR]
        type_ = encoded_var[Encoding.TYPE]

        if type_ == DAT.DICT:
            return {k: cls.deserialize(v) for k, v in var.items()}
        elif type_ == DAT.DAG:
            return SerializedDAG.deserialize_dag(var)
        elif type_ == DAT.OP:
            return SerializedBaseOperator.deserialize_operator(var)
        elif type_ == DAT.DATETIME:
            return pendulum.from_timestamp(var)
        elif type_ == DAT.POD:
            if not _has_kubernetes():
                raise RuntimeError("Cannot deserialize POD objects without kubernetes libraries installed!")
            pod = PodGenerator.deserialize_model_dict(var)
            return pod
        elif type_ == DAT.TIMEDELTA:
            return datetime.timedelta(seconds=var)
        elif type_ == DAT.TIMEZONE:
            return decode_timezone(var)
        elif type_ == DAT.RELATIVEDELTA:
            return decode_relativedelta(var)
        elif type_ == DAT.SET:
            return {cls.deserialize(v) for v in var}
        elif type_ == DAT.TUPLE:
            return tuple(cls.deserialize(v) for v in var)
        elif type_ == DAT.PARAM:
            return cls._deserialize_param(var)
        elif type_ == DAT.XCOM_REF:
            return _XComRef(var)  # Delay deserializing XComArg objects until we have the entire DAG.
        elif type_ == DAT.DATASET:
            return Dataset(**var)
            raise TypeError(f'Invalid type {type_!s} in deserialization.')

    _deserialize_datetime = pendulum.from_timestamp
    _deserialize_timezone =

    def _deserialize_timedelta(cls, seconds: int) -> datetime.timedelta:
        return datetime.timedelta(seconds=seconds)

    def _is_constructor_param(cls, attrname: str, instance: Any) -> bool:

        return attrname in cls._CONSTRUCTOR_PARAMS

    def _value_is_hardcoded_default(cls, attrname: str, value: Any, instance: Any) -> bool:
        Return true if ``value`` is the hard-coded default for the given attribute.

        This takes in to account cases where the ``max_active_tasks`` parameter is
        stored in the ``_max_active_tasks`` attribute.

        And by using `is` here only and not `==` this copes with the case a
        user explicitly specifies an attribute with the same "value" as the
        default. (This is because ``"default" is "default"`` will be False as
        they are different strings with the same characters.)

        Also returns True if the value is an empty list or empty dict. This is done
        to account for the case where the default value of the field is None but has the
        ``field = field or {}`` set.
        if attrname in cls._CONSTRUCTOR_PARAMS and (
            cls._CONSTRUCTOR_PARAMS[attrname] is value or (value in [{}, []])
            return True
        return False

    def _serialize_param(cls, param: Param):
        return dict(

    def _deserialize_param(cls, param_dict: dict):
        In 2.2.0, Param attrs were assumed to be json-serializable and were not run through
        this class's ``serialize`` method.  So before running through ``deserialize``,
        we first verify that it's necessary to do.
        class_name = param_dict['__class']
        class_ = import_string(class_name)  # type: Type[Param]
        attrs = ('default', 'description', 'schema')
        kwargs = {}
        for attr in attrs:
            if attr not in param_dict:
            val = param_dict[attr]
            is_serialized = isinstance(val, dict) and '__type' in val
            if is_serialized:
                deserialized_val = cls.deserialize(param_dict[attr])
                kwargs[attr] = deserialized_val
                kwargs[attr] = val
        return class_(**kwargs)

    def _serialize_params_dict(cls, params: ParamsDict | dict):
        """Serialize Params dict for a DAG/Task"""
        serialized_params = {}
        for k, v in params.items():
            # TODO: As of now, we would allow serialization of params which are of type Param only.
                class_identity = f"{v.__module__}.{v.__class__.__name__}"
            except AttributeError:
                class_identity = ""
            if class_identity == "airflow.models.param.Param":
                serialized_params[k] = cls._serialize_param(v)
                raise ValueError(
                    f"Params to a DAG or a Task can be only of type airflow.models.param.Param, "
                    f"but param {k!r} is {v.__class__}"
        return serialized_params

    def _deserialize_params_dict(cls, encoded_params: dict) -> ParamsDict:
        """Deserialize a DAG's Params dict"""
        op_params = {}
        for k, v in encoded_params.items():
            if isinstance(v, dict) and "__class" in v:
                op_params[k] = cls._deserialize_param(v)
                # Old style params, convert it
                op_params[k] = Param(v)

        return ParamsDict(op_params)

class DependencyDetector:
    Detects dependencies between DAGs.

    :meta private:

    def detect_task_dependencies(task: Operator) -> list[DagDependency]:
        from airflow.operators.trigger_dagrun import TriggerDagRunOperator
        from airflow.sensors.external_task import ExternalTaskSensor

        """Detects dependencies caused by tasks"""
        deps = []
        if isinstance(task, TriggerDagRunOperator):
                    target=getattr(task, "trigger_dag_id"),
        elif isinstance(task, ExternalTaskSensor):
                    source=getattr(task, "external_dag_id"),
        for obj in task.outlets or []:
            if isinstance(obj, Dataset):
        return deps

    def detect_dag_dependencies(dag: DAG | None) -> Iterable[DagDependency]:
        """Detects dependencies set directly on the DAG object."""
        if not dag:
        for x in dag.dataset_triggers:
            yield DagDependency(

class SerializedBaseOperator(BaseOperator, BaseSerialization):
    """A JSON serializable representation of operator.

    All operators are casted to SerializedBaseOperator after deserialization.
    Class specific attributes used by UI are move to object attributes.

    _decorated_fields = {'executor_config'}

        k: v.default
        for k, v in signature(BaseOperator.__init__).parameters.items()
        if v.default is not v.empty

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # task_type is used by UI to display the correct class type, because UI only
        # receives BaseOperator from deserialized DAGs.
        self._task_type = 'BaseOperator'
        # Move class attributes into object attributes.
        self.ui_color = BaseOperator.ui_color
        self.ui_fgcolor = BaseOperator.ui_fgcolor
        self.template_ext = BaseOperator.template_ext
        self.template_fields = BaseOperator.template_fields
        self.operator_extra_links = BaseOperator.operator_extra_links

    def task_type(self) -> str:
        # Overwrites task_type of BaseOperator to use _task_type instead of
        # __class__.__name__.
        return self._task_type

    def task_type(self, task_type: str):
        self._task_type = task_type

    def operator_name(self) -> str:
        # Overwrites operator_name of BaseOperator to use _operator_name instead of
        # __class__.operator_name.
        return self._operator_name

    def operator_name(self, operator_name: str):
        self._operator_name = operator_name

    def serialize_mapped_operator(cls, op: MappedOperator) -> dict[str, Any]:
        serialized_op = cls._serialize_node(op, include_deps=op.deps != MappedOperator.deps_for(BaseOperator))
        # Handle expand_input and op_kwargs_expand_input.
        expansion_kwargs = op._get_specified_expand_input()
        if TYPE_CHECKING:  # Let Mypy check the input type for us!
        serialized_op[op._expand_input_attr] = {
            "type": get_map_type_key(expansion_kwargs),
            "value": cls.serialize(expansion_kwargs.value),

        # Simplify partial_kwargs by comparing it to the most barebone object.
        # Remove all entries that are simply default values.
        serialized_partial = serialized_op["partial_kwargs"]
        for k, default in _get_default_mapped_partial().items():
                v = serialized_partial[k]
            except KeyError:
            if v == default:
                del serialized_partial[k]

        serialized_op["_is_mapped"] = True
        return serialized_op

    def serialize_operator(cls, op: BaseOperator) -> dict[str, Any]:
        return cls._serialize_node(op, include_deps=op.deps is not BaseOperator.deps)

    def _serialize_node(cls, op: BaseOperator | MappedOperator, include_deps: bool) -> dict[str, Any]:
        """Serializes operator into a JSON object."""
        serialize_op = cls.serialize_to_json(op, cls._decorated_fields)
        serialize_op['_task_type'] = getattr(op, "_task_type", type(op).__name__)
        serialize_op['_task_module'] = getattr(op, "_task_module", type(op).__module__)
        if op.operator_name != serialize_op['_task_type']:
            serialize_op['_operator_name'] = op.operator_name

        # Used to determine if an Operator is inherited from EmptyOperator
        serialize_op['_is_empty'] = op.inherits_from_empty_operator

        if op.operator_extra_links:
            serialize_op['_operator_extra_links'] = cls._serialize_operator_extra_links(

        if include_deps:
            serialize_op['deps'] = cls._serialize_deps(op.deps)

        # Store all template_fields as they are if there are JSON Serializable
        # If not, store them as strings
        if op.template_fields:
            for template_field in op.template_fields:
                value = getattr(op, template_field, None)
                if not cls._is_excluded(value, template_field, op):
                    serialize_op[template_field] = serialize_template_field(value)

        if op.params:
            serialize_op['params'] = cls._serialize_params_dict(op.params)

        return serialize_op

    def _serialize_deps(cls, op_deps: Iterable[BaseTIDep]) -> list[str]:
        from airflow import plugins_manager

        if plugins_manager.registered_ti_dep_classes is None:
            raise AirflowException("Can not load plugins")

        deps = []
        for dep in op_deps:
            klass = type(dep)
            module_name = klass.__module__
            qualname = f'{module_name}.{klass.__name__}'
            if (
                not qualname.startswith("airflow.ti_deps.deps.")
                and qualname not in plugins_manager.registered_ti_dep_classes
                raise SerializationError(
                    f"Custom dep class {qualname} not serialized, please register it through plugins."
        # deps needs to be sorted here, because op_deps is a set, which is unstable when traversing,
        # and the same call may get different results.
        # When calling json.dumps(, sort_keys=True) to generate dag_hash, misjudgment will occur
        return sorted(deps)

    def populate_operator(cls, op: Operator, encoded_op: dict[str, Any]) -> None:
        if "label" not in encoded_op:
            # Handle deserialization of old data before the introduction of TaskGroup
            encoded_op["label"] = encoded_op["task_id"]

        # Extra Operator Links defined in Plugins
        op_extra_links_from_plugin = {}

        if "_operator_name" not in encoded_op:
            encoded_op["_operator_name"] = encoded_op["_task_type"]

        # We don't want to load Extra Operator links in Scheduler
        if cls._load_operator_extra_links:
            from airflow import plugins_manager


            if plugins_manager.operator_extra_links is None:
                raise AirflowException("Can not load plugins")

            for ope in plugins_manager.operator_extra_links:
                for operator in ope.operators:
                    if (
                        operator.__name__ == encoded_op["_task_type"]
                        and operator.__module__ == encoded_op["_task_module"]
                        op_extra_links_from_plugin.update({ ope})

            # If OperatorLinks are defined in Plugins but not in the Operator that is being Serialized
            # set the Operator links attribute
            # The case for "If OperatorLinks are defined in the operator that is being Serialized"
            # is handled in the deserialization loop where it matches k == "_operator_extra_links"
            if op_extra_links_from_plugin and "_operator_extra_links" not in encoded_op:
                setattr(op, "operator_extra_links", list(op_extra_links_from_plugin.values()))

        for k, v in encoded_op.items():
            # Todo: TODO: Remove in Airflow 3.0 when dummy operator is removed
            if k == "_is_dummy":
                k = "_is_empty"

            if k in ("_outlets", "_inlets"):
                # `_outlets` -> `outlets`
                k = k[1:]
            if k == "_downstream_task_ids":
                # Upgrade from old format/name
                k = "downstream_task_ids"
            if k == "label":
                # Label shouldn't be set anymore --  it's computed from task_id now
            elif k == "downstream_task_ids":
                v = set(v)
            elif k == "subdag":
                v = SerializedDAG.deserialize_dag(v)
            elif k in {"retry_delay", "execution_timeout", "sla", "max_retry_delay"}:
                v = cls._deserialize_timedelta(v)
            elif k in encoded_op["template_fields"]:
            elif k == "resources":
                v = Resources.from_dict(v)
            elif k.endswith("_date"):
                v = cls._deserialize_datetime(v)
            elif k == "_operator_extra_links":
                if cls._load_operator_extra_links:
                    op_predefined_extra_links = cls._deserialize_operator_extra_links(v)

                    # If OperatorLinks with the same name exists, Links via Plugin have higher precedence
                    op_predefined_extra_links = {}

                v = list(op_predefined_extra_links.values())
                k = "operator_extra_links"

            elif k == "deps":
                v = cls._deserialize_deps(v)
            elif k == "params":
                v = cls._deserialize_params_dict(v)
                if op.params:  # Merge existing params if needed.
                    v, new = op.params, v
            elif k == "partial_kwargs":
                v = {arg: cls.deserialize(value) for arg, value in v.items()}
            elif k in {"expand_input", "op_kwargs_expand_input"}:
                v = _ExpandInputRef(v["type"], cls.deserialize(v["value"]))
            elif k in cls._decorated_fields or k not in op.get_serialized_fields():
                v = cls.deserialize(v)
            elif k in ("outlets", "inlets"):
                v = cls.deserialize(v)

            # else use v as it is

            setattr(op, k, v)

        for k in op.get_serialized_fields() - encoded_op.keys() - cls._CONSTRUCTOR_PARAMS.keys():
            # TODO: refactor deserialization of BaseOperator and MappedOperaotr (split it out), then check
            # could go away.
            if not hasattr(op, k):
                setattr(op, k, None)

        # Set all the template_field to None that were not present in Serialized JSON
        for field in op.template_fields:
            if not hasattr(op, field):
                setattr(op, field, None)

        # Used to determine if an Operator is inherited from EmptyOperator
        setattr(op, "_is_empty", bool(encoded_op.get("_is_empty", False)))

    def deserialize_operator(cls, encoded_op: dict[str, Any]) -> Operator:
        """Deserializes an operator from a JSON object."""
        op: Operator
        if encoded_op.get("_is_mapped", False):
            # Most of these will be loaded later, these are just some stand-ins.
            op_data = {k: v for k, v in encoded_op.items() if k in BaseOperator.get_serialized_fields()}
                operator_name = encoded_op["_operator_name"]
            except KeyError:
                operator_name = encoded_op["_task_type"]
            op = MappedOperator(
            op = SerializedBaseOperator(task_id=encoded_op['task_id'])

        cls.populate_operator(op, encoded_op)
        return op

    def detect_dependencies(cls, op: Operator) -> set[DagDependency]:
        """Detects between DAG dependencies for the operator."""

        def get_custom_dep() -> list[DagDependency]:
            If custom dependency detector is configured, use it.

            TODO: Remove this logic in 3.0.
            custom_dependency_detector_cls = conf.getimport('scheduler', 'dependency_detector', fallback=None)
            if not (
                custom_dependency_detector_cls is None or custom_dependency_detector_cls is DependencyDetector
                    "Use of a custom dependency detector is deprecated. "
                    "Support will be removed in a future release.",
                dep = custom_dependency_detector_cls().detect_task_dependencies(op)
                if type(dep) is DagDependency:
                    return [dep]
            return []

        dependency_detector = DependencyDetector()
        deps = set(dependency_detector.detect_task_dependencies(op))
        deps.update(get_custom_dep())  # todo: remove in 3.0
        return deps

    def _is_excluded(cls, var: Any, attrname: str, op: DAGNode):
        if var is not None and op.has_dag() and attrname.endswith("_date"):
            # If this date is the same as the matching field in the dag, then
            # don't store it again at the task level.
            dag_date = getattr(op.dag, attrname, None)
            if var is dag_date or var == dag_date:
                return True
        return super()._is_excluded(var, attrname, op)

    def _deserialize_deps(cls, deps: list[str]) -> set[BaseTIDep]:
        from airflow import plugins_manager

        if plugins_manager.registered_ti_dep_classes is None:
            raise AirflowException("Can not load plugins")

        instances = set()
        for qualname in set(deps):
            if (
                not qualname.startswith("airflow.ti_deps.deps.")
                and qualname not in plugins_manager.registered_ti_dep_classes
                raise SerializationError(
                    f"Custom dep class {qualname} not deserialized, please register it through plugins."

            except ImportError:
                log.warning("Error importing dep %r", qualname, exc_info=True)
        return instances

    def _deserialize_operator_extra_links(cls, encoded_op_links: list) -> dict[str, BaseOperatorLink]:
        Deserialize Operator Links if the Classes are registered in Airflow Plugins.
        Error is raised if the OperatorLink is not found in Plugins too.

        :param encoded_op_links: Serialized Operator Link
        :return: De-Serialized Operator Link
        from airflow import plugins_manager


        if plugins_manager.registered_operator_link_classes is None:
            raise AirflowException("Can't load plugins")
        op_predefined_extra_links = {}

        for _operator_links_source in encoded_op_links:
            # Get the key, value pair as Tuple where key is OperatorLink ClassName
            # and value is the dictionary containing the arguments passed to the OperatorLink
            # Example of a single iteration:
            #   _operator_links_source =
            #   {
            #       '': {
            #           'index': 0
            #       }
            #   },
            #   list(_operator_links_source.items()) =
            #   [
            #       (
            #           '',
            #           {'index': 0}
            #       )
            #   ]
            #   list(_operator_links_source.items())[0] =
            #   (
            #       '',
            #       {
            #           'index': 0
            #       }
            #   )

            _operator_link_class_path, data = list(_operator_links_source.items())[0]
            if _operator_link_class_path in get_operator_extra_links():
                single_op_link_class = import_string(_operator_link_class_path)
            elif _operator_link_class_path in plugins_manager.registered_operator_link_classes:
                single_op_link_class = plugins_manager.registered_operator_link_classes[
                log.error("Operator Link class %r not registered", _operator_link_class_path)
                return {}

            op_predefined_extra_link: BaseOperatorLink = cattr.structure(data, single_op_link_class)

            op_predefined_extra_links.update({ op_predefined_extra_link})

        return op_predefined_extra_links

    def _serialize_operator_extra_links(cls, operator_extra_links: Iterable[BaseOperatorLink]):
        Serialize Operator Links. Store the import path of the OperatorLink and the arguments
        passed to it. Example
        ``[{'': {}}]``

        :param operator_extra_links: Operator Link
        :return: Serialized Operator Link
        serialize_operator_extra_links = []
        for operator_extra_link in operator_extra_links:
            op_link_arguments = cattr.unstructure(operator_extra_link)
            if not isinstance(op_link_arguments, dict):
                op_link_arguments = {}

            module_path = (
            serialize_operator_extra_links.append({module_path: op_link_arguments})

        return serialize_operator_extra_links

class SerializedDAG(DAG, BaseSerialization):
    A JSON serializable representation of DAG.

    A stringified DAG can only be used in the scope of scheduler and webserver, because fields
    that are not serializable, such as functions and customer defined classes, are casted to

    Compared with SimpleDAG: SerializedDAG contains all information for webserver.
    Compared with DagPickle: DagPickle contains all information for worker, but some DAGs are
    not pickle-able. SerializedDAG works for all DAGs.

    _decorated_fields = {'schedule_interval', 'default_args', '_access_control'}

    def __get_constructor_defaults():
        param_to_attr = {
            'max_active_tasks': '_max_active_tasks',
            'description': '_description',
            'default_view': '_default_view',
            'access_control': '_access_control',
        return {
            param_to_attr.get(k, k): v.default
            for k, v in signature(DAG.__init__).parameters.items()
            if v.default is not v.empty

    _CONSTRUCTOR_PARAMS = __get_constructor_defaults.__func__()  # type: ignore
    del __get_constructor_defaults

    _json_schema = lazy_object_proxy.Proxy(load_dag_schema)

    def serialize_dag(cls, dag: DAG) -> dict:
        """Serializes a DAG into a JSON object."""
            serialized_dag = cls.serialize_to_json(dag, cls._decorated_fields)

            serialized_dag['_processor_dags_folder'] = DAGS_FOLDER

            # If schedule_interval is backed by timetable, serialize only
            # timetable; vice versa for a timetable backed by schedule_interval.
            if dag.timetable.summary == dag.schedule_interval:
                del serialized_dag["schedule_interval"]
                del serialized_dag["timetable"]

            serialized_dag["tasks"] = [cls.serialize(task) for _, task in dag.task_dict.items()]
            dag_deps = {
                for task in dag.task_dict.values()
                for dep in SerializedBaseOperator.detect_dependencies(task)
            serialized_dag["dag_dependencies"] = [x.__dict__ for x in dag_deps]
            serialized_dag['_task_group'] = SerializedTaskGroup.serialize_task_group(dag.task_group)

            # Edge info in the JSON exactly matches our internal structure
            serialized_dag["edge_info"] = dag.edge_info
            serialized_dag["params"] = cls._serialize_params_dict(dag.params)

            # has_on_*_callback are only stored if the value is True, as the default is False
            if dag.has_on_success_callback:
                serialized_dag['has_on_success_callback'] = True
            if dag.has_on_failure_callback:
                serialized_dag['has_on_failure_callback'] = True
            return serialized_dag
        except SerializationError:
        except Exception as e:
            raise SerializationError(f'Failed to serialize DAG {dag.dag_id!r}: {e}')

    def deserialize_dag(cls, encoded_dag: dict[str, Any]) -> SerializedDAG:
        """Deserializes a DAG from a JSON object."""
        dag = SerializedDAG(dag_id=encoded_dag['_dag_id'])

        for k, v in encoded_dag.items():
            if k == "_downstream_task_ids":
                v = set(v)
            elif k == "tasks":

                SerializedBaseOperator._load_operator_extra_links = cls._load_operator_extra_links

                v = {task["task_id"]: SerializedBaseOperator.deserialize_operator(task) for task in v}
                k = "task_dict"
            elif k == "timezone":
                v = cls._deserialize_timezone(v)
            elif k == "dagrun_timeout":
                v = cls._deserialize_timedelta(v)
            elif k.endswith("_date"):
                v = cls._deserialize_datetime(v)
            elif k == "edge_info":
                # Value structure matches exactly
            elif k == "timetable":
                v = _decode_timetable(v)
            elif k in cls._decorated_fields:
                v = cls.deserialize(v)
            elif k == "params":
                v = cls._deserialize_params_dict(v)
            elif k == "dataset_triggers":
                v = cls.deserialize(v)
            # else use v as it is

            setattr(dag, k, v)

        # A DAG is always serialized with only one of schedule_interval and
        # timetable. This back-populates the other to ensure the two attributes
        # line up correctly on the DAG instance.
        if "timetable" in encoded_dag:
            dag.schedule_interval = dag.timetable.summary
            dag.timetable = create_timetable(dag.schedule_interval, dag.timezone)

        # Set _task_group
        if "_task_group" in encoded_dag:
            dag._task_group = SerializedTaskGroup.deserialize_task_group(
                encoded_dag["_task_group"], None, dag.task_dict, dag
            # This must be old data that had no task_group. Create a root TaskGroup and add
            # all tasks to it.
            dag._task_group = TaskGroup.create_root(dag)
            for task in dag.tasks:

        # Set has_on_*_callbacks to True if they exist in Serialized blob as False is the default
        if "has_on_success_callback" in encoded_dag:
            dag.has_on_success_callback = True
        if "has_on_failure_callback" in encoded_dag:
            dag.has_on_failure_callback = True

        keys_to_set_none = dag.get_serialized_fields() - encoded_dag.keys() - cls._CONSTRUCTOR_PARAMS.keys()
        for k in keys_to_set_none:
            setattr(dag, k, None)

        for task in dag.task_dict.values():
            task.dag = dag

            for date_attr in ["start_date", "end_date"]:
                if getattr(task, date_attr) is None:
                    setattr(task, date_attr, getattr(dag, date_attr))

            if task.subdag is not None:
                setattr(task.subdag, 'parent_dag', dag)

            # Dereference expand_input and op_kwargs_expand_input.
            for k in ("expand_input", "op_kwargs_expand_input"):
                kwargs_ref = getattr(task, k, None)
                if isinstance(kwargs_ref, _ExpandInputRef):
                    setattr(task, k, kwargs_ref.deref(dag))

            for task_id in task.downstream_task_ids:
                # Bypass set_upstream etc here - it does more than we want

        return dag

    def to_dict(cls, var: Any) -> dict:
        """Stringifies DAGs and operators contained by var and returns a dict of var."""
        json_dict = {"__version": cls.SERIALIZER_VERSION, "dag": cls.serialize_dag(var)}

        # Validate Serialized DAG with Json Schema. Raises Error if it mismatches
        return json_dict

    def from_dict(cls, serialized_obj: dict) -> SerializedDAG:
        """Deserializes a python dict in to the DAG and operators it contains."""
        ver = serialized_obj.get('__version', '<not present>')
        if ver != cls.SERIALIZER_VERSION:
            raise ValueError(f"Unsure how to deserialize version {ver!r}")
        return cls.deserialize_dag(serialized_obj['dag'])

class SerializedTaskGroup(TaskGroup, BaseSerialization):
    """A JSON serializable representation of TaskGroup."""

    def serialize_task_group(cls, task_group: TaskGroup) -> dict[str, Any] | None:
        """Serializes TaskGroup into a JSON object."""
        if not task_group:
            return None

        # task_group.xxx_ids needs to be sorted here, because task_group.xxx_ids is a set,
        # when converting set to list, the order is uncertain.
        # When calling json.dumps(, sort_keys=True) to generate dag_hash, misjudgment will occur
        serialize_group = {
            "_group_id": task_group._group_id,
            "prefix_group_id": task_group.prefix_group_id,
            "tooltip": task_group.tooltip,
            "ui_color": task_group.ui_color,
            "ui_fgcolor": task_group.ui_fgcolor,
            "children": {
                label: child.serialize_for_task_group() for label, child in task_group.children.items()
            "upstream_group_ids": cls.serialize(sorted(task_group.upstream_group_ids)),
            "downstream_group_ids": cls.serialize(sorted(task_group.downstream_group_ids)),
            "upstream_task_ids": cls.serialize(sorted(task_group.upstream_task_ids)),
            "downstream_task_ids": cls.serialize(sorted(task_group.downstream_task_ids)),

        return serialize_group

    def deserialize_task_group(
        encoded_group: dict[str, Any],
        parent_group: TaskGroup | None,
        task_dict: dict[str, Operator],
        dag: SerializedDAG,
    ) -> TaskGroup:
        """Deserializes a TaskGroup from a JSON object."""
        group_id = cls.deserialize(encoded_group["_group_id"])
        kwargs = {
            key: cls.deserialize(encoded_group[key])
            for key in ["prefix_group_id", "tooltip", "ui_color", "ui_fgcolor"]
        group = SerializedTaskGroup(group_id=group_id, parent_group=parent_group, dag=dag, **kwargs)

        def set_ref(task: Operator) -> Operator:
            task.task_group = weakref.proxy(group)
            return task

        group.children = {
            label: set_ref(task_dict[val])  # type: ignore
            if _type == DAT.OP  # type: ignore
            else SerializedTaskGroup.deserialize_task_group(val, group, task_dict, dag=dag)
            for label, (_type, val) in encoded_group["children"].items()
        return group

class DagDependency:
    """Dataclass for representing dependencies between DAGs.
    These are calculated during serialization and attached to serialized DAGs.

    source: str
    target: str
    dependency_type: str
    dependency_id: str | None = None

    def node_id(self):
        """Node ID for graph rendering"""
        val = f"{self.dependency_type}"
        if not self.dependency_type == 'dataset':
            val += f":{self.source}:{}"
        if self.dependency_id:
            val += f":{self.dependency_id}"
        return val

def _has_kubernetes() -> bool:
    if "HAS_KUBERNETES" in globals():
        return HAS_KUBERNETES

    # Loading kube modules is expensive, so delay it until the last moment

        from kubernetes.client import models as k8s

        from airflow.kubernetes.pod_generator import PodGenerator

        globals()['k8s'] = k8s
        globals()['PodGenerator'] = PodGenerator

        # isort: on
        HAS_KUBERNETES = True
    except ImportError:
        HAS_KUBERNETES = False


airflow 源码目录


airflow init 源码

airflow enums 源码

airflow helpers 源码

airflow json_schema 源码

0  赞