Tasks
databricks_bundle_decorators.decorators.task(fn=None, *, io_manager=None, output_name=None, depends_on=None, all_partitions=False, partition_by=None, **kwargs)
task(
*,
io_manager: IoManager | None = ...,
output_name: str | None = ...,
depends_on: TaskProxy | list[TaskProxy] | None = ...,
all_partitions: bool = ...,
partition_by: str | list[str] | None = ...,
**kwargs: Unpack[TaskConfig],
) -> _TaskDecorator
task(fn: types.FunctionType) -> Callable[..., Any]
Register a function as a Databricks task.
When used inside a @job body, the decorated function is
registered under a qualified key (job_name.task_name) and
calling it returns a TaskProxy that wires up the DAG.
When used outside a @job body (e.g. at module level), the
function is registered under its short name for use in tests or
standalone execution. Duplicate names at module level raise
DuplicateResourceError.
Parameters:
| Name |
Type |
Description |
Default |
io_manager
|
IoManager | None
|
An IoManager instance that controls
how the task's return value is persisted and loaded by downstream
tasks. When None, no automatic data transfer takes place (use
set_task_value for small scalars).
|
None
|
depends_on
|
TaskProxy | list[TaskProxy] | None
|
One or more TaskProxy objects returned by calling other
@task-decorated functions inside a @job body. Creates
control-flow-only dependencies: the current task will run
after the specified upstream tasks complete, but no data is
transferred via IoManager. Use this when a task must wait
for another to finish without consuming its output. For
data dependencies, pass TaskProxy objects as regular function
arguments instead.
|
None
|
all_partitions
|
bool
|
When True, all upstream data dependencies read the
entire dataset (all partitions) instead of filtering to the
current backfill_key. For fine-grained control, use the
all_partitions function to wrap individual TaskProxy
arguments instead.
|
False
|
**kwargs
|
Unpack[TaskConfig]
|
Any additional SDK-native Task fields (e.g. max_retries,
timeout_seconds, retry_on_timeout). These are forwarded
directly to the databricks.bundles.jobs.Task constructor at
deploy time. See TaskConfig
for the full list of supported fields.
|
{}
|
Notes
Dependency edges are detected only for TaskProxy objects passed as
direct positional or keyword arguments. Proxies nested inside
lists, dicts, or other container types are not inspected and will
not register dependency edges.
Source code in src/databricks_bundle_decorators/decorators.py
| def task(
fn: types.FunctionType | None = None,
*,
io_manager: IoManager | None = None,
output_name: str | None = None,
depends_on: TaskProxy | list[TaskProxy] | None = None,
all_partitions: bool = False,
partition_by: str | list[str] | None = None,
**kwargs: Unpack[TaskConfig],
):
"""Register a function as a Databricks task.
When used **inside** a ``@job`` body, the decorated function is
registered under a qualified key (``job_name.task_name``) and
calling it returns a `TaskProxy` that wires up the DAG.
When used **outside** a ``@job`` body (e.g. at module level), the
function is registered under its short name for use in tests or
standalone execution. Duplicate names at module level raise
`DuplicateResourceError`.
Parameters
----------
io_manager:
An `IoManager` instance that controls
how the task's return value is persisted and loaded by downstream
tasks. When ``None``, no automatic data transfer takes place (use
`set_task_value` for small scalars).
depends_on:
One or more `TaskProxy` objects returned by calling other
``@task``-decorated functions inside a ``@job`` body. Creates
**control-flow-only** dependencies: the current task will run
after the specified upstream tasks complete, but no data is
transferred via `IoManager`. Use this when a task must wait
for another to finish without consuming its output. For
data dependencies, pass `TaskProxy` objects as regular function
arguments instead.
all_partitions:
When ``True``, **all** upstream data dependencies read the
entire dataset (all partitions) instead of filtering to the
current ``backfill_key``. For fine-grained control, use the
`all_partitions` function to wrap individual `TaskProxy`
arguments instead.
**kwargs:
Any additional SDK-native ``Task`` fields (e.g. ``max_retries``,
``timeout_seconds``, ``retry_on_timeout``). These are forwarded
directly to the ``databricks.bundles.jobs.Task`` constructor at
deploy time. See `TaskConfig`
for the full list of supported fields.
Notes
-----
Dependency edges are detected only for `TaskProxy` objects passed as
**direct** positional or keyword arguments. Proxies nested inside
lists, dicts, or other container types are **not** inspected and will
not register dependency edges."""
def decorator(fn: types.FunctionType) -> Callable[..., Any]:
task_key = fn.__name__
# Normalize and validate depends_on
depends_on_keys: list[str] = []
if depends_on is not None:
deps = depends_on if isinstance(depends_on, list) else [depends_on]
for dep in deps:
if not isinstance(dep, TaskProxy):
raise TypeError(
f"@task(depends_on=...) expects TaskProxy objects "
f"returned by calling @task-decorated functions "
f"inside a @job body, got {type(dep).__name__!r}."
)
depends_on_keys.append(dep.task_key)
meta = TaskMeta(
fn=fn,
task_key=task_key,
io_manager=io_manager,
output_name=output_name,
partition_by=_normalize_partition_by(partition_by),
sdk_config={**kwargs},
depends_on=depends_on_keys,
)
if _current_job_name is not None:
# Inside a @job body - register under qualified key and
# store in the job-local tracker so the wrapper can build
# the DAG.
qualified_key = f"{_current_job_name}.{task_key}"
_register_unique(_TASK_REGISTRY, qualified_key, meta, "task")
# Also stash in a job-scoped dict so @job can iterate.
_current_job_tasks[task_key] = meta
else:
# Module-level definition (standalone / test usage).
_register_unique(_TASK_REGISTRY, task_key, meta, "task")
@functools.wraps(fn)
def wrapper(*args, **kwargs):
if _current_job_name is not None:
# We're being *called* inside a @job body - return a
# TaskProxy and record DAG edges from any proxy args.
if task_key in _current_job_dag:
raise DuplicateResourceError(
f"Task '{task_key}' is called more than once in job "
f"'{_current_job_name}'. Each @task may only be invoked "
"once per @job body. Use a unique function name for "
"each logical step."
)
upstream_deps: list[str] = list(meta.depends_on)
edge_map: dict[str, str] = {}
ap_params: set[str] = set()
param_names = list(inspect.signature(fn).parameters.keys())
for idx, arg in enumerate(args):
if isinstance(arg, (_AllPartitionsProxy, TaskProxy)):
upstream_deps.append(arg.task_key)
p_name = (
param_names[idx] if idx < len(param_names) else f"arg{idx}"
)
edge_map[p_name] = arg.task_key
if isinstance(arg, _AllPartitionsProxy) or all_partitions:
ap_params.add(p_name)
elif arg is not None:
p_name = (
param_names[idx] if idx < len(param_names) else f"arg{idx}"
)
warnings.warn(
f"Task '{task_key}' in job '{_current_job_name}' "
f"received a non-TaskProxy argument "
f"({type(arg).__name__!r}) for parameter "
f"'{p_name}'. Inside a @job body, task calls "
f"only build the DAG — arguments that are not "
f"TaskProxy values returned by other @task "
f"calls are silently discarded at runtime. "
f"Move data-producing code inside a @task "
f"function.",
UserWarning,
stacklevel=2,
)
for kw_name, kw_val in kwargs.items():
if isinstance(kw_val, (_AllPartitionsProxy, TaskProxy)):
upstream_deps.append(kw_val.task_key)
edge_map[kw_name] = kw_val.task_key
if isinstance(kw_val, _AllPartitionsProxy) or all_partitions:
ap_params.add(kw_name)
elif kw_val is not None:
warnings.warn(
f"Task '{task_key}' in job '{_current_job_name}' "
f"received a non-TaskProxy argument "
f"({type(kw_val).__name__!r}) for parameter "
f"'{kw_name}'. Inside a @job body, task calls "
f"only build the DAG — arguments that are not "
f"TaskProxy values returned by other @task "
f"calls are silently discarded at runtime. "
f"Move data-producing code inside a @task "
f"function.",
UserWarning,
stacklevel=2,
)
# Deduplicate while preserving order
upstream_deps = list(dict.fromkeys(upstream_deps))
_current_job_dag[task_key] = upstream_deps
_current_job_edges[task_key] = edge_map
if ap_params:
_current_job_all_partitions[task_key] = ap_params
return TaskProxy(task_key)
# Normal execution (runtime / tests).
return fn(*args, **kwargs)
wrapper._task_meta = meta # ty: ignore[unresolved-attribute]
return wrapper
if fn is not None:
return decorator(fn)
return decorator
|
databricks_bundle_decorators.decorators.for_each_task(*, inputs, concurrency=None, io_manager=None, depends_on=None, **kwargs)
Register a function as a Databricks for-each task.
A for-each task iterates over a list of inputs and executes the
decorated function once per element. The iteration list is
specified via the inputs decorator argument — either a
TaskValueRef created by task_value (referencing a specific
upstream task-value) or a static Python list.
The decorated function must have a parameter named inputs.
At runtime the framework injects the current element from the
iteration list into that parameter.
Inside a @job body the function must be called to add it
to the DAG — just like @task. Call arguments wire IoManager
data dependencies.
Parameters:
| Name |
Type |
Description |
Default |
inputs
|
TaskValueRef | list[Any]
|
The iteration source. Use task_value(upstream_task, "key")
to iterate over a task-value published by an upstream task via
set_task_value. Pass a plain Python list (must be
JSON-serialisable) for static iteration.
|
required
|
concurrency
|
int | None
|
Maximum number of parallel iterations. Maps to the
ForEachTask.concurrency field in the Databricks SDK.
|
None
|
io_manager
|
IoManager | None
|
An IoManager instance for persisting the task's return value,
identical in behaviour to @task(io_manager=...).
|
None
|
depends_on
|
_TaskRef | list[_TaskRef] | None
|
Control-flow-only dependencies, identical to
@task(depends_on=...). Accepts @task-decorated
functions or TaskProxy objects.
|
None
|
**kwargs
|
Unpack[TaskConfig]
|
SDK-native Task fields forwarded to the inner task
(e.g. max_retries, timeout_seconds). See TaskConfig.
|
{}
|
Examples:
Dynamic inputs from an upstream task with an IoManager data dependency::
@job
def my_pipeline():
@task
def get_files():
set_task_value("files", ["a.csv", "b.csv", "c.csv"])
@task(io_manager=staging_io)
def load_data():
return pl.read_parquet("s3://bucket/data.parquet")
data = load_data()
@for_each_task(inputs=task_value(get_files, "files"), concurrency=5)
def process(inputs: str, data):
subset = data.filter(pl.col("file") == inputs)
print(f"Processing {inputs}: {len(subset)} rows")
process(data=data)
Static inputs::
@job
def static_pipeline():
@for_each_task(inputs=["us-east-1", "eu-west-1"])
def ingest(inputs: str):
print(f"Ingesting {inputs}")
ingest()
Source code in src/databricks_bundle_decorators/decorators.py
| def for_each_task(
*,
inputs: TaskValueRef | list[Any],
concurrency: int | None = None,
io_manager: IoManager | None = None,
depends_on: _TaskRef | list[_TaskRef] | None = None,
**kwargs: Unpack[TaskConfig],
) -> _TaskDecorator:
"""Register a function as a Databricks **for-each** task.
A for-each task iterates over a list of inputs and executes the
decorated function once per element. The iteration list is
specified via the ``inputs`` decorator argument — either a
`TaskValueRef` created by `task_value` (referencing a specific
upstream task-value) or a static Python list.
The decorated function **must** have a parameter named ``inputs``.
At runtime the framework injects the current element from the
iteration list into that parameter.
Inside a ``@job`` body the function must be **called** to add it
to the DAG — just like ``@task``. Call arguments wire `IoManager`
data dependencies.
Parameters
----------
inputs:
The iteration source. Use ``task_value(upstream_task, "key")``
to iterate over a task-value published by an upstream task via
`set_task_value`. Pass a plain Python list (must be
JSON-serialisable) for static iteration.
concurrency:
Maximum number of parallel iterations. Maps to the
``ForEachTask.concurrency`` field in the Databricks SDK.
io_manager:
An `IoManager` instance for persisting the task's return value,
identical in behaviour to ``@task(io_manager=...)``.
depends_on:
Control-flow-only dependencies, identical to
``@task(depends_on=...)``. Accepts ``@task``-decorated
functions or `TaskProxy` objects.
**kwargs:
SDK-native ``Task`` fields forwarded to the **inner** task
(e.g. ``max_retries``, ``timeout_seconds``). See `TaskConfig`.
Examples
--------
Dynamic inputs from an upstream task with an IoManager data dependency::
@job
def my_pipeline():
@task
def get_files():
set_task_value("files", ["a.csv", "b.csv", "c.csv"])
@task(io_manager=staging_io)
def load_data():
return pl.read_parquet("s3://bucket/data.parquet")
data = load_data()
@for_each_task(inputs=task_value(get_files, "files"), concurrency=5)
def process(inputs: str, data):
subset = data.filter(pl.col("file") == inputs)
print(f"Processing {inputs}: {len(subset)} rows")
process(data=data)
Static inputs::
@job
def static_pipeline():
@for_each_task(inputs=["us-east-1", "eu-west-1"])
def ingest(inputs: str):
print(f"Ingesting {inputs}")
ingest()
"""
if _current_job_name is None:
raise RuntimeError("@for_each_task can only be used inside a @job body.")
# --- resolve inputs ---------------------------------------------------
inputs_task_key: str | None = None
inputs_value_key: str | None = None
static_inputs: list[Any] | None = None
inputs_dep_key: str | None = None
if isinstance(inputs, list):
try:
json.dumps(inputs)
except (TypeError, ValueError) as exc:
raise TypeError(
f"@for_each_task: static inputs must be "
f"JSON-serialisable, got error: {exc}"
) from exc
static_inputs = inputs
elif isinstance(inputs, TaskValueRef):
inputs_task_key = inputs.task_key
inputs_value_key = inputs.key
inputs_dep_key = inputs.task_key
else:
raise TypeError(
f"@for_each_task(inputs=...) expects a TaskValueRef from "
f"task_value() or a static list, got {type(inputs).__name__!r}. "
f"Use task_value(upstream_task, 'key_name') to reference "
f"a task-value from an upstream task."
)
# --- resolve depends_on -----------------------------------------------
depends_on_keys: list[str] = []
if depends_on is not None:
deps = depends_on if isinstance(depends_on, list) else [depends_on]
depends_on_keys.extend(
_resolve_task_ref(dep, "@for_each_task(depends_on=...)") for dep in deps
)
# Merge inputs dep into depends_on list
all_dep_keys = list(depends_on_keys)
if inputs_dep_key is not None:
all_dep_keys.append(inputs_dep_key)
# Deduplicate while preserving order
all_dep_keys = list(dict.fromkeys(all_dep_keys))
def decorator(fn: types.FunctionType) -> Callable[..., Any]:
task_key = fn.__name__
# Validate that the function has an 'inputs' parameter
sig = inspect.signature(fn)
if "inputs" not in sig.parameters:
raise ValueError(
f"@for_each_task: function '{task_key}' must have a "
f"parameter named 'inputs' to receive each element "
f"from the iteration list. "
f"Parameters: {list(sig.parameters.keys())}."
)
meta = TaskMeta(
fn=fn,
task_key=task_key,
io_manager=io_manager,
sdk_config={**kwargs},
depends_on=all_dep_keys,
)
assert _current_job_name is not None # guaranteed by outer check
qualified_key = f"{_current_job_name}.{task_key}"
_register_unique(_TASK_REGISTRY, qualified_key, meta, "task")
_current_job_tasks[task_key] = meta
# Record ForEachMeta immediately — no call required.
_current_job_for_each[task_key] = ForEachMeta(
inputs_task_key=inputs_task_key,
inputs_value_key=inputs_value_key,
static_inputs=static_inputs,
concurrency=concurrency,
)
@functools.wraps(fn)
def wrapper(*args, **call_kwargs):
if _current_job_name is None:
# Normal execution (runtime / tests) — call directly.
return fn(*args, **call_kwargs)
# Inside a @job body — wire data-dependency edges.
if task_key in _current_job_dag:
raise DuplicateResourceError(
f"Task '{task_key}' is called more than once in job "
f"'{_current_job_name}'. Each @task / @for_each_task "
"may only be invoked once per @job body."
)
# Map positional args to parameter names (skip 'inputs')
param_names = [p for p in sig.parameters if p != "inputs"]
all_call_kwargs: dict[str, Any] = {}
for idx, arg in enumerate(args):
p_name = param_names[idx] if idx < len(param_names) else f"arg{idx}"
all_call_kwargs[p_name] = arg
all_call_kwargs.update(call_kwargs)
upstream_deps: list[str] = list(all_dep_keys)
edge_map: dict[str, str] = {}
# Process call args as data deps (same as @task)
for kw_name, kw_val in all_call_kwargs.items():
if isinstance(kw_val, TaskProxy):
upstream_deps.append(kw_val.task_key)
edge_map[kw_name] = kw_val.task_key
elif kw_val is not None:
warnings.warn(
f"for_each_task '{task_key}' in job "
f"'{_current_job_name}' received a non-TaskProxy "
f"argument ({type(kw_val).__name__!r}) for "
f"parameter '{kw_name}'. Inside a @job body, "
f"task calls only build the DAG.",
UserWarning,
stacklevel=2,
)
upstream_deps = list(dict.fromkeys(upstream_deps))
_current_job_dag[task_key] = upstream_deps
_current_job_edges[task_key] = edge_map
return TaskProxy(task_key)
wrapper._task_meta = meta # ty: ignore[unresolved-attribute]
return wrapper
return decorator
|
databricks_bundle_decorators.decorators.task_value(task_ref, key)
Create a reference to a specific task-value from an upstream task.
Use this with @for_each_task(inputs=...) to specify which
upstream task-value provides the iteration list.
Parameters:
| Name |
Type |
Description |
Default |
task_ref
|
_TaskRef
|
A @task-decorated function or a TaskProxy returned by
calling one inside a @job body.
|
required
|
key
|
str
|
The task-value key name — the key argument passed to
set_task_value in the upstream task.
|
required
|
Returns:
| Type |
Description |
`TaskValueRef`
|
An object that can be passed to @for_each_task(inputs=...).
|
Examples:
::
@job
def my_pipeline():
@task
def discover():
set_task_value("countries", ["US", "UK", "DE"])
@for_each_task(inputs=task_value(discover, "countries"))
def process(inputs: str):
print(f"Processing {inputs}")
Source code in src/databricks_bundle_decorators/decorators.py
| def task_value(task_ref: _TaskRef, key: str) -> TaskValueRef:
"""Create a reference to a specific task-value from an upstream task.
Use this with ``@for_each_task(inputs=...)`` to specify which
upstream task-value provides the iteration list.
Parameters
----------
task_ref:
A ``@task``-decorated function or a `TaskProxy` returned by
calling one inside a ``@job`` body.
key:
The task-value key name — the ``key`` argument passed to
`set_task_value` in the upstream task.
Returns
-------
`TaskValueRef`
An object that can be passed to ``@for_each_task(inputs=...)``.
Examples
--------
::
@job
def my_pipeline():
@task
def discover():
set_task_value("countries", ["US", "UK", "DE"])
@for_each_task(inputs=task_value(discover, "countries"))
def process(inputs: str):
print(f"Processing {inputs}")
"""
resolved_key = _resolve_task_ref(task_ref, "task_value()")
return TaskValueRef(task_key=resolved_key, key=key)
|
Control-Flow Dependencies
By default, passing a TaskProxy as a function argument creates a
data dependency — the upstream task's output is loaded and passed to
the downstream task at runtime. When you only need ordering ("run A
before B") without data transfer, use depends_on:
@job
def my_job():
@task
def setup():
... # e.g. create a table, warm a cache
setup_proxy = setup()
@task(depends_on=setup_proxy)
def work():
... # runs after setup, but receives no data from it
work()
Since depends_on is a @task decorator parameter, the upstream
TaskProxy must be assigned before the @task(depends_on=...) line.
You can pass a list to wait on multiple tasks:
@job
def my_job():
@task
def step_a(): ...
@task
def step_b(): ...
a_proxy = step_a()
b_proxy = step_b()
@task(depends_on=[a_proxy, b_proxy])
def final():
...
final()
depends_on and data arguments can be mixed on the same task:
@job
def my_job():
@task
def init(): ...
@task
def produce(): ...
i = init()
p = produce()
@task(depends_on=i) # control-flow dep on init
def consume(data): ...
consume(p) # data dep on produce
See How It Works — Control-flow dependencies
for more details.
Asset Naming (output_name)
By default, task outputs are stored under the task key (function name).
Use output_name to decouple the storage name from the function name:
@task(io_manager=io, output_name="customers")
def extract_customers():
... # stored as "customers", not "extract_customers"
The resolved name is available as OutputContext.asset_name (write)
and InputContext.upstream_asset_name (read). When output_name is
None (the default), both fall back to the task key.
Per-Task Compute Override
By default every task inherits the shared job cluster defined via
@job(cluster=...). If a specific task needs different compute you
can pass any of the compute-related TaskConfig fields directly to
the decorator:
@task(existing_cluster_id="0123-456789-abcdef01")
def special_task():
...
@task(environment_key="my-serverless-env")
def serverless_task():
...
These fields are forwarded to the Databricks SDK Task constructor
and take precedence over the job-level cluster. See
TaskConfig for all available fields.