Skip to content

Tasks

databricks_bundle_decorators.decorators.task(fn=None, *, io_manager=None, depends_on=None, all_partitions=False, partition_by=None, **kwargs)

task(
    *,
    io_manager: IoManager | None = ...,
    depends_on: TaskProxy | list[TaskProxy] | None = ...,
    all_partitions: bool = ...,
    partition_by: str | list[str] | None = ...,
    **kwargs: Unpack[TaskConfig],
) -> _TaskDecorator
task(fn: types.FunctionType) -> Callable[..., Any]

Register a function as a Databricks task.

When used inside a @job body, the decorated function is registered under a qualified key (job_name.task_name) and calling it returns a TaskProxy that wires up the DAG.

When used outside a @job body (e.g. at module level), the function is registered under its short name for use in tests or standalone execution. Duplicate names at module level raise DuplicateResourceError.

Parameters:

Name Type Description Default
io_manager IoManager | None

An IoManager instance that controls how the task's return value is persisted and loaded by downstream tasks. When None, no automatic data transfer takes place (use set_task_value for small scalars).

None
depends_on TaskProxy | list[TaskProxy] | None

One or more TaskProxy objects returned by calling other @task-decorated functions inside a @job body. Creates control-flow-only dependencies: the current task will run after the specified upstream tasks complete, but no data is transferred via IoManager. Use this when a task must wait for another to finish without consuming its output. For data dependencies, pass TaskProxy objects as regular function arguments instead.

None
all_partitions bool

When True, all upstream data dependencies read the entire dataset (all partitions) instead of filtering to the current backfill_key. For fine-grained control, use the all_partitions function to wrap individual TaskProxy arguments instead.

False
**kwargs Unpack[TaskConfig]

Any additional SDK-native Task fields (e.g. max_retries, timeout_seconds, retry_on_timeout). These are forwarded directly to the databricks.bundles.jobs.Task constructor at deploy time. See TaskConfig for the full list of supported fields.

{}
Notes

Dependency edges are detected only for TaskProxy objects passed as direct positional or keyword arguments. Proxies nested inside lists, dicts, or other container types are not inspected and will not register dependency edges.

Source code in src/databricks_bundle_decorators/decorators.py
def task(
    fn: types.FunctionType | None = None,
    *,
    io_manager: IoManager | None = None,
    depends_on: TaskProxy | list[TaskProxy] | None = None,
    all_partitions: bool = False,
    partition_by: str | list[str] | None = None,
    **kwargs: Unpack[TaskConfig],
):
    """Register a function as a Databricks task.

    When used **inside** a ``@job`` body, the decorated function is
    registered under a qualified key (``job_name.task_name``) and
    calling it returns a `TaskProxy` that wires up the DAG.

    When used **outside** a ``@job`` body (e.g. at module level), the
    function is registered under its short name for use in tests or
    standalone execution.  Duplicate names at module level raise
    `DuplicateResourceError`.

    Parameters
    ----------
    io_manager:
        An `IoManager` instance that controls
        how the task's return value is persisted and loaded by downstream
        tasks.  When ``None``, no automatic data transfer takes place (use
        `set_task_value` for small scalars).
    depends_on:
        One or more `TaskProxy` objects returned by calling other
        ``@task``-decorated functions inside a ``@job`` body.  Creates
        **control-flow-only** dependencies: the current task will run
        after the specified upstream tasks complete, but no data is
        transferred via `IoManager`.  Use this when a task must wait
        for another to finish without consuming its output.  For
        data dependencies, pass `TaskProxy` objects as regular function
        arguments instead.
    all_partitions:
        When ``True``, **all** upstream data dependencies read the
        entire dataset (all partitions) instead of filtering to the
        current ``backfill_key``.  For fine-grained control, use the
        `all_partitions` function to wrap individual `TaskProxy`
        arguments instead.
    **kwargs:
        Any additional SDK-native ``Task`` fields (e.g. ``max_retries``,
        ``timeout_seconds``, ``retry_on_timeout``).  These are forwarded
        directly to the ``databricks.bundles.jobs.Task`` constructor at
        deploy time.  See `TaskConfig`
        for the full list of supported fields.
    Notes
    -----
    Dependency edges are detected only for `TaskProxy` objects passed as
    **direct** positional or keyword arguments.  Proxies nested inside
    lists, dicts, or other container types are **not** inspected and will
    not register dependency edges."""

    def decorator(fn: types.FunctionType) -> Callable[..., Any]:
        task_key = fn.__name__

        # Normalize and validate depends_on
        depends_on_keys: list[str] = []
        if depends_on is not None:
            deps = depends_on if isinstance(depends_on, list) else [depends_on]
            for dep in deps:
                if not isinstance(dep, TaskProxy):
                    raise TypeError(
                        f"@task(depends_on=...) expects TaskProxy objects "
                        f"returned by calling @task-decorated functions "
                        f"inside a @job body, got {type(dep).__name__!r}."
                    )
                depends_on_keys.append(dep.task_key)

        meta = TaskMeta(
            fn=fn,
            task_key=task_key,
            io_manager=io_manager,
            partition_by=_normalize_partition_by(partition_by),
            sdk_config=dict(kwargs),
            depends_on=depends_on_keys,
        )

        if _current_job_name is not None:
            # Inside a @job body – register under qualified key and
            # store in the job-local tracker so the wrapper can build
            # the DAG.
            qualified_key = f"{_current_job_name}.{task_key}"
            _register_unique(_TASK_REGISTRY, qualified_key, meta, "task")
            # Also stash in a job-scoped dict so @job can iterate.
            _current_job_tasks[task_key] = meta
        else:
            # Module-level definition (standalone / test usage).
            _register_unique(_TASK_REGISTRY, task_key, meta, "task")

        @functools.wraps(fn)
        def wrapper(*args, **kwargs):
            if _current_job_name is not None:
                # We're being *called* inside a @job body – return a
                # TaskProxy and record DAG edges from any proxy args.
                if task_key in _current_job_dag:
                    raise DuplicateResourceError(
                        f"Task '{task_key}' is called more than once in job "
                        f"'{_current_job_name}'. Each @task may only be invoked "
                        "once per @job body. Use a unique function name for "
                        "each logical step."
                    )

                upstream_deps: list[str] = list(meta.depends_on)
                edge_map: dict[str, str] = {}
                ap_params: set[str] = set()

                param_names = list(inspect.signature(fn).parameters.keys())

                for idx, arg in enumerate(args):
                    if isinstance(arg, (_AllPartitionsProxy, TaskProxy)):
                        upstream_deps.append(arg.task_key)
                        p_name = (
                            param_names[idx] if idx < len(param_names) else f"arg{idx}"
                        )
                        edge_map[p_name] = arg.task_key
                        if isinstance(arg, _AllPartitionsProxy) or all_partitions:
                            ap_params.add(p_name)
                    elif arg is not None:
                        p_name = (
                            param_names[idx] if idx < len(param_names) else f"arg{idx}"
                        )
                        warnings.warn(
                            f"Task '{task_key}' in job '{_current_job_name}' "
                            f"received a non-TaskProxy argument "
                            f"({type(arg).__name__!r}) for parameter "
                            f"'{p_name}'. Inside a @job body, task calls "
                            f"only build the DAG — arguments that are not "
                            f"TaskProxy values returned by other @task "
                            f"calls are silently discarded at runtime. "
                            f"Move data-producing code inside a @task "
                            f"function.",
                            UserWarning,
                            stacklevel=2,
                        )

                for kw_name, kw_val in kwargs.items():
                    if isinstance(kw_val, (_AllPartitionsProxy, TaskProxy)):
                        upstream_deps.append(kw_val.task_key)
                        edge_map[kw_name] = kw_val.task_key
                        if isinstance(kw_val, _AllPartitionsProxy) or all_partitions:
                            ap_params.add(kw_name)
                    elif kw_val is not None:
                        warnings.warn(
                            f"Task '{task_key}' in job '{_current_job_name}' "
                            f"received a non-TaskProxy argument "
                            f"({type(kw_val).__name__!r}) for parameter "
                            f"'{kw_name}'. Inside a @job body, task calls "
                            f"only build the DAG — arguments that are not "
                            f"TaskProxy values returned by other @task "
                            f"calls are silently discarded at runtime. "
                            f"Move data-producing code inside a @task "
                            f"function.",
                            UserWarning,
                            stacklevel=2,
                        )

                # Deduplicate while preserving order
                upstream_deps = list(dict.fromkeys(upstream_deps))
                _current_job_dag[task_key] = upstream_deps
                _current_job_edges[task_key] = edge_map
                if ap_params:
                    _current_job_all_partitions[task_key] = ap_params

                return TaskProxy(task_key)
            else:
                # Normal execution (runtime / tests).
                return fn(*args, **kwargs)

        wrapper._task_meta = meta  # type: ignore[attr-defined]
        return wrapper

    if fn is not None:
        return decorator(fn)
    return decorator

databricks_bundle_decorators.decorators.for_each_task(*, inputs, concurrency=None, io_manager=None, depends_on=None, **kwargs)

Register a function as a Databricks for-each task.

A for-each task iterates over a list of inputs and executes the decorated function once per element. The iteration list is specified via the inputs decorator argument — either a TaskValueRef created by task_value (referencing a specific upstream task-value) or a static Python list.

The decorated function must have a parameter named inputs. At runtime the framework injects the current element from the iteration list into that parameter.

Inside a @job body the function must be called to add it to the DAG — just like @task. Call arguments wire IoManager data dependencies.

Parameters:

Name Type Description Default
inputs TaskValueRef | list[Any]

The iteration source. Use task_value(upstream_task, "key") to iterate over a task-value published by an upstream task via set_task_value. Pass a plain Python list (must be JSON-serialisable) for static iteration.

required
concurrency int | None

Maximum number of parallel iterations. Maps to the ForEachTask.concurrency field in the Databricks SDK.

None
io_manager IoManager | None

An IoManager instance for persisting the task's return value, identical in behaviour to @task(io_manager=...).

None
depends_on _TaskRef | list[_TaskRef] | None

Control-flow-only dependencies, identical to @task(depends_on=...). Accepts @task-decorated functions or TaskProxy objects.

None
**kwargs Unpack[TaskConfig]

SDK-native Task fields forwarded to the inner task (e.g. max_retries, timeout_seconds). See TaskConfig.

{}

Examples:

Dynamic inputs from an upstream task with an IoManager data dependency::

@job
def my_pipeline():
    @task
    def get_files():
        set_task_value("files", ["a.csv", "b.csv", "c.csv"])

    @task(io_manager=staging_io)
    def load_data():
        return pl.read_parquet("s3://bucket/data.parquet")

    data = load_data()

    @for_each_task(inputs=task_value(get_files, "files"), concurrency=5)
    def process(inputs: str, data):
        subset = data.filter(pl.col("file") == inputs)
        print(f"Processing {inputs}: {len(subset)} rows")

    process(data=data)

Static inputs::

@job
def static_pipeline():
    @for_each_task(inputs=["us-east-1", "eu-west-1"])
    def ingest(inputs: str):
        print(f"Ingesting {inputs}")

    ingest()
Source code in src/databricks_bundle_decorators/decorators.py
def for_each_task(
    *,
    inputs: TaskValueRef | list[Any],
    concurrency: int | None = None,
    io_manager: IoManager | None = None,
    depends_on: _TaskRef | list[_TaskRef] | None = None,
    **kwargs: Unpack[TaskConfig],
) -> _TaskDecorator:
    """Register a function as a Databricks **for-each** task.

    A for-each task iterates over a list of inputs and executes the
    decorated function once per element.  The iteration list is
    specified via the ``inputs`` decorator argument — either a
    `TaskValueRef` created by `task_value` (referencing a specific
    upstream task-value) or a static Python list.

    The decorated function **must** have a parameter named ``inputs``.
    At runtime the framework injects the current element from the
    iteration list into that parameter.

    Inside a ``@job`` body the function must be **called** to add it
    to the DAG — just like ``@task``.  Call arguments wire `IoManager`
    data dependencies.

    Parameters
    ----------
    inputs:
        The iteration source.  Use ``task_value(upstream_task, "key")``
        to iterate over a task-value published by an upstream task via
        `set_task_value`.  Pass a plain Python list (must be
        JSON-serialisable) for static iteration.
    concurrency:
        Maximum number of parallel iterations.  Maps to the
        ``ForEachTask.concurrency`` field in the Databricks SDK.
    io_manager:
        An `IoManager` instance for persisting the task's return value,
        identical in behaviour to ``@task(io_manager=...)``.
    depends_on:
        Control-flow-only dependencies, identical to
        ``@task(depends_on=...)``.  Accepts ``@task``-decorated
        functions or `TaskProxy` objects.
    **kwargs:
        SDK-native ``Task`` fields forwarded to the **inner** task
        (e.g. ``max_retries``, ``timeout_seconds``).  See `TaskConfig`.

    Examples
    --------
    Dynamic inputs from an upstream task with an IoManager data dependency::

        @job
        def my_pipeline():
            @task
            def get_files():
                set_task_value("files", ["a.csv", "b.csv", "c.csv"])

            @task(io_manager=staging_io)
            def load_data():
                return pl.read_parquet("s3://bucket/data.parquet")

            data = load_data()

            @for_each_task(inputs=task_value(get_files, "files"), concurrency=5)
            def process(inputs: str, data):
                subset = data.filter(pl.col("file") == inputs)
                print(f"Processing {inputs}: {len(subset)} rows")

            process(data=data)

    Static inputs::

        @job
        def static_pipeline():
            @for_each_task(inputs=["us-east-1", "eu-west-1"])
            def ingest(inputs: str):
                print(f"Ingesting {inputs}")

            ingest()
    """

    if _current_job_name is None:
        raise RuntimeError("@for_each_task can only be used inside a @job body.")

    # --- resolve inputs ---------------------------------------------------
    inputs_task_key: str | None = None
    inputs_value_key: str | None = None
    static_inputs: list[Any] | None = None
    inputs_dep_key: str | None = None

    if isinstance(inputs, list):
        try:
            json.dumps(inputs)
        except (TypeError, ValueError) as exc:
            raise TypeError(
                f"@for_each_task: static inputs must be "
                f"JSON-serialisable, got error: {exc}"
            ) from exc
        static_inputs = inputs
    elif isinstance(inputs, TaskValueRef):
        inputs_task_key = inputs.task_key
        inputs_value_key = inputs.key
        inputs_dep_key = inputs.task_key
    else:
        raise TypeError(
            f"@for_each_task(inputs=...) expects a TaskValueRef from "
            f"task_value() or a static list, got {type(inputs).__name__!r}. "
            f"Use task_value(upstream_task, 'key_name') to reference "
            f"a task-value from an upstream task."
        )

    # --- resolve depends_on -----------------------------------------------
    depends_on_keys: list[str] = []
    if depends_on is not None:
        deps = depends_on if isinstance(depends_on, list) else [depends_on]
        for dep in deps:
            depends_on_keys.append(
                _resolve_task_ref(dep, "@for_each_task(depends_on=...)")
            )

    # Merge inputs dep into depends_on list
    all_dep_keys = list(depends_on_keys)
    if inputs_dep_key is not None:
        all_dep_keys.append(inputs_dep_key)
    # Deduplicate while preserving order
    all_dep_keys = list(dict.fromkeys(all_dep_keys))

    def decorator(fn: types.FunctionType) -> Callable[..., Any]:
        task_key = fn.__name__

        # Validate that the function has an 'inputs' parameter
        sig = inspect.signature(fn)
        if "inputs" not in sig.parameters:
            raise ValueError(
                f"@for_each_task: function '{task_key}' must have a "
                f"parameter named 'inputs' to receive each element "
                f"from the iteration list. "
                f"Parameters: {list(sig.parameters.keys())}."
            )

        meta = TaskMeta(
            fn=fn,
            task_key=task_key,
            io_manager=io_manager,
            sdk_config=dict(kwargs),
            depends_on=all_dep_keys,
        )

        assert _current_job_name is not None  # guaranteed by outer check

        qualified_key = f"{_current_job_name}.{task_key}"
        _register_unique(_TASK_REGISTRY, qualified_key, meta, "task")
        _current_job_tasks[task_key] = meta

        # Record ForEachMeta immediately — no call required.
        _current_job_for_each[task_key] = ForEachMeta(
            inputs_task_key=inputs_task_key,
            inputs_value_key=inputs_value_key,
            static_inputs=static_inputs,
            concurrency=concurrency,
        )

        @functools.wraps(fn)
        def wrapper(*args, **call_kwargs):
            if _current_job_name is None:
                # Normal execution (runtime / tests) — call directly.
                return fn(*args, **call_kwargs)

            # Inside a @job body — wire data-dependency edges.
            if task_key in _current_job_dag:
                raise DuplicateResourceError(
                    f"Task '{task_key}' is called more than once in job "
                    f"'{_current_job_name}'. Each @task / @for_each_task "
                    "may only be invoked once per @job body."
                )

            # Map positional args to parameter names (skip 'inputs')
            param_names = [p for p in sig.parameters.keys() if p != "inputs"]
            all_call_kwargs: dict[str, Any] = {}
            for idx, arg in enumerate(args):
                p_name = param_names[idx] if idx < len(param_names) else f"arg{idx}"
                all_call_kwargs[p_name] = arg
            all_call_kwargs.update(call_kwargs)

            upstream_deps: list[str] = list(all_dep_keys)
            edge_map: dict[str, str] = {}

            # Process call args as data deps (same as @task)
            for kw_name, kw_val in all_call_kwargs.items():
                if isinstance(kw_val, TaskProxy):
                    upstream_deps.append(kw_val.task_key)
                    edge_map[kw_name] = kw_val.task_key
                elif kw_val is not None:
                    warnings.warn(
                        f"for_each_task '{task_key}' in job "
                        f"'{_current_job_name}' received a non-TaskProxy "
                        f"argument ({type(kw_val).__name__!r}) for "
                        f"parameter '{kw_name}'. Inside a @job body, "
                        f"task calls only build the DAG.",
                        UserWarning,
                        stacklevel=2,
                    )

            upstream_deps = list(dict.fromkeys(upstream_deps))
            _current_job_dag[task_key] = upstream_deps
            _current_job_edges[task_key] = edge_map

            return TaskProxy(task_key)

        wrapper._task_meta = meta  # type: ignore[attr-defined]
        return wrapper

    return decorator

databricks_bundle_decorators.decorators.task_value(task_ref, key)

Create a reference to a specific task-value from an upstream task.

Use this with @for_each_task(inputs=...) to specify which upstream task-value provides the iteration list.

Parameters:

Name Type Description Default
task_ref _TaskRef

A @task-decorated function or a TaskProxy returned by calling one inside a @job body.

required
key str

The task-value key name — the key argument passed to set_task_value in the upstream task.

required

Returns:

Type Description
`TaskValueRef`

An object that can be passed to @for_each_task(inputs=...).

Examples:

::

@job
def my_pipeline():
    @task
    def discover():
        set_task_value("countries", ["US", "UK", "DE"])

    @for_each_task(inputs=task_value(discover, "countries"))
    def process(inputs: str):
        print(f"Processing {inputs}")
Source code in src/databricks_bundle_decorators/decorators.py
def task_value(task_ref: _TaskRef, key: str) -> TaskValueRef:
    """Create a reference to a specific task-value from an upstream task.

    Use this with ``@for_each_task(inputs=...)`` to specify which
    upstream task-value provides the iteration list.

    Parameters
    ----------
    task_ref:
        A ``@task``-decorated function or a `TaskProxy` returned by
        calling one inside a ``@job`` body.
    key:
        The task-value key name — the ``key`` argument passed to
        `set_task_value` in the upstream task.

    Returns
    -------
    `TaskValueRef`
        An object that can be passed to ``@for_each_task(inputs=...)``.

    Examples
    --------
    ::

        @job
        def my_pipeline():
            @task
            def discover():
                set_task_value("countries", ["US", "UK", "DE"])

            @for_each_task(inputs=task_value(discover, "countries"))
            def process(inputs: str):
                print(f"Processing {inputs}")
    """
    resolved_key = _resolve_task_ref(task_ref, "task_value()")
    return TaskValueRef(task_key=resolved_key, key=key)

Per-Task Compute Override

By default every task inherits the shared job cluster defined via @job(cluster=...). If a specific task needs different compute you can pass any of the compute-related TaskConfig fields directly to the decorator:

@task(existing_cluster_id="0123-456789-abcdef01")
def special_task():
    ...

@task(environment_key="my-serverless-env")
def serverless_task():
    ...

These fields are forwarded to the Databricks SDK Task constructor and take precedence over the job-level cluster. See TaskConfig for all available fields.