  • 2022-10-20
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Implements the ``@task_group`` function decorator.

When the decorated function is called, a task group will be created to represent
a collection of closely related tasks on the same DAG that should be grouped
together when the DAG is displayed graphically.

from __future__ import annotations

import functools
import inspect
import warnings
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Mapping, Sequence, TypeVar, overload

import attr
from sqlalchemy.orm import Session

from airflow.decorators.base import ExpandableFactory
from airflow.models.expandinput import (
from airflow.models.taskmixin import DAGNode
from airflow.models.xcom_arg import XComArg
from airflow.typing_compat import ParamSpec
from airflow.utils.context import Context
from airflow.utils.helpers import prevent_duplicates
from airflow.utils.mixins import ResolveMixin
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.task_group import MappedTaskGroup, TaskGroup

    from airflow.models.dag import DAG

FParams = ParamSpec("FParams")
FReturn = TypeVar("FReturn", None, DAGNode)

task_group_sig = inspect.signature(TaskGroup.__init__)

class _MappedArgument(ResolveMixin):
    _input: ExpandInput
    _key: str

    def resolve(self, context: Context, *, session: Session = NEW_SESSION) -> Any:
        data, _ = self._input.resolve(context, session=session)
        return data[self._key]

class _TaskGroupFactory(ExpandableFactory, Generic[FParams, FReturn]):
    function: Callable[FParams, FReturn] = attr.ib(validator=attr.validators.is_callable())
    tg_kwargs: dict[str, Any] = attr.ib(factory=dict)  # Parameters forwarded to TaskGroup.
    partial_kwargs: dict[str, Any] = attr.ib(factory=dict)  # Parameters forwarded to 'function'.

    _task_group_created: bool = attr.ib(False, init=False)

    tg_class: ClassVar[type[TaskGroup]] = TaskGroup

    def _validate(self, _, kwargs):

    def __attrs_post_init__(self):
        self.tg_kwargs.setdefault("group_id", self.function.__name__)

    def __del__(self):
        if self.partial_kwargs and not self._task_group_created:
                group_id = repr(self.tg_kwargs["group_id"])
            except KeyError:
                group_id = f"at {hex(id(self))}"
            warnings.warn(f"Partial task group {group_id} was never mapped!")

    def __call__(self, *args: FParams.args, **kwargs: FParams.kwargs) -> DAGNode:
        """Instantiate the task group.

        This uses the wrapped function to create a task group. Depending on the
        return type of the wrapped function, this either returns the last task
        in the group, or the group itself, to support task chaining.
        return self._create_task_group(TaskGroup, *args, **kwargs)

    def _create_task_group(self, tg_factory: Callable[..., TaskGroup], *args: Any, **kwargs: Any) -> DAGNode:
        with tg_factory(add_suffix_on_collision=True, **self.tg_kwargs) as task_group:
            if self.function.__doc__ and not task_group.tooltip:
                task_group.tooltip = self.function.__doc__

            # Invoke function to run Tasks inside the TaskGroup
            retval = self.function(*args, **kwargs)

        self._task_group_created = True

        # If the task-creating function returns a task, forward the return value
        # so dependencies bind to it. This is equivalent to
        #   with TaskGroup(...) as tg:
        #       t2 = task_2(task_1())
        #   start >> t2 >> end
        if retval is not None:
            return retval

        # Otherwise return the task group as a whole, equivalent to
        #   with TaskGroup(...) as tg:
        #       task_1()
        #       task_2()
        #   start >> tg >> end
        return task_group

    def override(self, **kwargs: Any) -> _TaskGroupFactory[FParams, FReturn]:
        return attr.evolve(self, tg_kwargs={**self.tg_kwargs, **kwargs})

    def partial(self, **kwargs: Any) -> _TaskGroupFactory[FParams, FReturn]:
        self._validate_arg_names("partial", kwargs)
        prevent_duplicates(self.partial_kwargs, kwargs, fail_reason="duplicate partial")
        return attr.evolve(self, partial_kwargs=kwargs)

    def expand(self, **kwargs: OperatorExpandArgument) -> DAGNode:
        if not kwargs:
            raise TypeError("no arguments to expand against")
        self._validate_arg_names("expand", kwargs)
        prevent_duplicates(self.partial_kwargs, kwargs, fail_reason="mapping already partial")
        expand_input = DictOfListsExpandInput(kwargs)
        return self._create_task_group(
            functools.partial(MappedTaskGroup, expand_input=expand_input),
            **{k: _MappedArgument(input=expand_input, key=k) for k in kwargs},

    def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument) -> DAGNode:
        if isinstance(kwargs, Sequence):
            for item in kwargs:
                if not isinstance(item, (XComArg, Mapping)):
                    raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
        elif not isinstance(kwargs, XComArg):
            raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")

        # It's impossible to build a dict of stubs as keyword arguments if the
        # function uses * or ** wildcard arguments.
        function_has_vararg = any(
            v.kind == inspect.Parameter.VAR_POSITIONAL or v.kind == inspect.Parameter.VAR_KEYWORD
            for v in self.function_signature.parameters.values()
        if function_has_vararg:
            raise TypeError("calling expand_kwargs() on task group function with * or ** is not supported")

        # We can't be sure how each argument is used in the function (well
        # technically we can with AST but let's not), so we have to create stubs
        # for every argument, including those with default values.
        map_kwargs = (k for k in self.function_signature.parameters if k not in self.partial_kwargs)

        expand_input = ListOfDictsExpandInput(kwargs)
        return self._create_task_group(
            functools.partial(MappedTaskGroup, expand_input=expand_input),
            **{k: _MappedArgument(input=expand_input, key=k) for k in map_kwargs},

# This covers the @task_group() case. Annotations are copied from the TaskGroup
# class, only providing a default to 'group_id' (this is optional for the
# decorator and defaults to the decorated function's name). Please keep them in
# sync with TaskGroup when you can! Note that since this is an overload, these
# argument defaults aren't actually used at runtime--the real implementation
# does not use them, and simply rely on TaskGroup's defaults, so it's not
# disastrous if they go out of sync with TaskGroup.
def task_group(
    group_id: str | None = None,
    prefix_group_id: bool = True,
    parent_group: TaskGroup | None = None,
    dag: DAG | None = None,
    default_args: dict[str, Any] | None = None,
    tooltip: str = "",
    ui_color: str = "CornflowerBlue",
    ui_fgcolor: str = "#000",
    add_suffix_on_collision: bool = False,
) -> Callable[[Callable[FParams, FReturn]], _TaskGroupFactory[FParams, FReturn]]:

# This covers the @task_group case (no parentheses).
def task_group(python_callable: Callable[FParams, FReturn]) -> _TaskGroupFactory[FParams, FReturn]:

def task_group(python_callable=None, **tg_kwargs):
    """Python TaskGroup decorator.

    This wraps a function into an Airflow TaskGroup. When used as the
    ``@task_group()`` form, all arguments are forwarded to the underlying
    TaskGroup class. Can be used to parametrize TaskGroup.

    :param python_callable: Function to decorate.
    :param tg_kwargs: Keyword arguments for the TaskGroup object.
    if callable(python_callable) and not tg_kwargs:
        return _TaskGroupFactory(function=python_callable, tg_kwargs=tg_kwargs)
    return functools.partial(_TaskGroupFactory, tg_kwargs=tg_kwargs)


