airflow expandinput 源码
airflow expandinput 代码
文件路径:/airflow/models/expandinput.py
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import collections.abc
import functools
import operator
from typing import TYPE_CHECKING, Any, Dict, Iterable, Mapping, NamedTuple, Sequence, Sized, Union
from airflow.compat.functools import cache
from airflow.utils.context import Context
if TYPE_CHECKING:
from sqlalchemy.orm import Session
from airflow.models.xcom_arg import XComArg
ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"]
# Each keyword argument to expand() can be an XComArg, sequence, or dict (not
# any mapping since we need the value to be ordered).
OperatorExpandArgument = Union["XComArg", Sequence, Dict[str, Any]]
# The single argument of expand_kwargs() can be an XComArg, or a list with each
# element being either an XComArg or a dict.
OperatorExpandKwargsArgument = Union["XComArg", Sequence[Union["XComArg", Mapping[str, Any]]]]
# For isinstance() check.
@cache
def get_mappable_types() -> tuple[type, ...]:
from airflow.models.xcom_arg import XComArg
return (XComArg, list, tuple, dict)
class NotFullyPopulated(RuntimeError):
"""Raise when ``get_map_lengths`` cannot populate all mapping metadata.
This is generally due to not all upstream tasks have finished when the
function is called.
"""
def __init__(self, missing: set[str]) -> None:
self.missing = missing
def __str__(self) -> str:
keys = ", ".join(repr(k) for k in sorted(self.missing))
return f"Failed to populate all mapping metadata; missing: {keys}"
class DictOfListsExpandInput(NamedTuple):
"""Storage type of a mapped operator's mapped kwargs.
This is created from ``expand(**kwargs)``.
"""
value: dict[str, OperatorExpandArgument]
def _iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]:
"""Generate kwargs with values available on parse-time."""
from airflow.models.xcom_arg import XComArg
return ((k, v) for k, v in self.value.items() if not isinstance(v, XComArg))
def get_parse_time_mapped_ti_count(self) -> int | None:
if not self.value:
return 0
literal_values = [len(v) for _, v in self._iter_parse_time_resolved_kwargs()]
if len(literal_values) != len(self.value):
return None # None-literal type encountered, so give up.
return functools.reduce(operator.mul, literal_values, 1)
def _get_map_lengths(self, run_id: str, *, session: Session) -> dict[str, int]:
"""Return dict of argument name to map length.
If any arguments are not known right now (upstream task not finished),
they will not be present in the dict.
"""
from airflow.models.xcom_arg import XComArg
# TODO: This initiates one database call for each XComArg. Would it be
# more efficient to do one single db call and unpack the value here?
map_lengths_iterator = (
(k, (v.get_task_map_length(run_id, session=session) if isinstance(v, XComArg) else len(v)))
for k, v in self.value.items()
)
map_lengths = {k: v for k, v in map_lengths_iterator if v is not None}
if len(map_lengths) < len(self.value):
raise NotFullyPopulated(set(self.value).difference(map_lengths))
return map_lengths
def get_total_map_length(self, run_id: str, *, session: Session) -> int:
if not self.value:
return 0
lengths = self._get_map_lengths(run_id, session=session)
return functools.reduce(operator.mul, (lengths[name] for name in self.value), 1)
def _expand_mapped_field(self, key: str, value: Any, context: Context, *, session: Session) -> Any:
from airflow.models.xcom_arg import XComArg
if isinstance(value, XComArg):
value = value.resolve(context, session=session)
map_index = context["ti"].map_index
if map_index < 0:
raise RuntimeError("can't resolve task-mapping argument without expanding")
all_lengths = self._get_map_lengths(context["run_id"], session=session)
def _find_index_for_this_field(index: int) -> int:
# Need to use the original user input to retain argument order.
for mapped_key in reversed(list(self.value)):
mapped_length = all_lengths[mapped_key]
if mapped_length < 1:
raise RuntimeError(f"cannot expand field mapped to length {mapped_length!r}")
if mapped_key == key:
return index % mapped_length
index //= mapped_length
return -1
found_index = _find_index_for_this_field(map_index)
if found_index < 0:
return value
if isinstance(value, collections.abc.Sequence):
return value[found_index]
if not isinstance(value, dict):
raise TypeError(f"can't map over value of type {type(value)}")
for i, (k, v) in enumerate(value.items()):
if i == found_index:
return k, v
raise IndexError(f"index {map_index} is over mapped length")
def resolve(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]:
data = {k: self._expand_mapped_field(k, v, context, session=session) for k, v in self.value.items()}
literal_keys = {k for k, _ in self._iter_parse_time_resolved_kwargs()}
resolved_oids = {id(v) for k, v in data.items() if k not in literal_keys}
return data, resolved_oids
def _describe_type(value: Any) -> str:
if value is None:
return "None"
return type(value).__name__
class ListOfDictsExpandInput(NamedTuple):
"""Storage type of a mapped operator's mapped kwargs.
This is created from ``expand_kwargs(xcom_arg)``.
"""
value: OperatorExpandKwargsArgument
def get_parse_time_mapped_ti_count(self) -> int | None:
if isinstance(self.value, collections.abc.Sized):
return len(self.value)
return None
def get_total_map_length(self, run_id: str, *, session: Session) -> int:
if isinstance(self.value, collections.abc.Sized):
return len(self.value)
length = self.value.get_task_map_length(run_id, session=session)
if length is None:
raise NotFullyPopulated({"expand_kwargs() argument"})
return length
def resolve(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]:
map_index = context["ti"].map_index
if map_index < 0:
raise RuntimeError("can't resolve task-mapping argument without expanding")
mapping: Any
if isinstance(self.value, collections.abc.Sized):
mapping = self.value[map_index]
if not isinstance(mapping, collections.abc.Mapping):
mapping = mapping.resolve(context, session)
else:
mappings = self.value.resolve(context, session)
if not isinstance(mappings, collections.abc.Sequence):
raise ValueError(f"expand_kwargs() expects a list[dict], not {_describe_type(mappings)}")
mapping = mappings[map_index]
if not isinstance(mapping, collections.abc.Mapping):
raise ValueError(f"expand_kwargs() expects a list[dict], not list[{_describe_type(mapping)}]")
for key in mapping:
if not isinstance(key, str):
raise ValueError(
f"expand_kwargs() input dict keys must all be str, "
f"but {key!r} is of type {_describe_type(key)}"
)
return mapping, {id(v) for v in mapping.values()}
EXPAND_INPUT_EMPTY = DictOfListsExpandInput({}) # Sentinel value.
_EXPAND_INPUT_TYPES = {
"dict-of-lists": DictOfListsExpandInput,
"list-of-dicts": ListOfDictsExpandInput,
}
def get_map_type_key(expand_input: ExpandInput) -> str:
return next(k for k, v in _EXPAND_INPUT_TYPES.items() if v == type(expand_input))
def create_expand_input(kind: str, value: Any) -> ExpandInput:
return _EXPAND_INPUT_TYPES[kind](value)
相关信息
相关文章
0
赞
热门推荐
-
2、 - 优质文章
-
3、 gate.io
-
8、 golang
-
9、 openharmony
-
10、 Vue中input框自动聚焦