While I have been using a simpler version in the past, it is only recently that I wrapped my head around how to simplify the creation–and really the maintenance–of upstream & downstream task dependencies in a DAG. The way I had been doing it involved writing these massive tuples with the task names over & over…
And it was even worse for DAGs that weren’t linear–in fact, from frustration I stopped maintaining that test at one point! Ultimately, it involved figuring out a way to build a list for each path through the DAG & then slice the list to create the upstream tasks & downstream tasks. And to aid in simplicity & future maintenance of the test, I used namedtuple
from collections
library to allow me to give the fields obvious names.
from collections import namedtuple
from pytest import mark, param
Task = namedtuple("Task", ["task", "upstream_task", "downstream_task", "name"])
task_paths = { # Key = path name, Value = list of tasks in that path
"first": [
"EMPTY", # This must be here for upstream task of "start" task!
"start",
"idempotency_check",
"extract_first",
"transform_first",
"load_first",
"send_slack_message_success",
"EMPTY", # This must be here for downstream_task of "send_slack_message_success" task!
],
"second": [
"EMPTY",
"start",
"extract_second_only",
"send_slack_message_success",
"EMPTY",
], # Add any additional paths below here if necessary
}
def tasks_to_test(table_type):
"""This function creates a list of tuples with the task, its upstream task,
downstream task, and name (for a pretty task identifier).
"""
task_list = task_paths.get(table_type)
tasks = task_list[1:-1] # Creates a list that doesn't start/end with "EMPTY"
upstream_tasks = task_list[:-1] # Starts with "EMPTY" but doesn't end with it
downstream_tasks = task_list[2:] # End with "EMPTY" but doesn't start with it
return [
Task(t, u, d, f"{table_type}_{t}")
for t, u, d in zip(tasks, upstream_tasks, downstream_tasks)
if t != "EMPTY"
]
@mark.parametrize(
"task_id, upstream_task_id, downstream_task_id",
[
param(task.task, task.upstream_task, task.downstream_task, id=task.name)
for task_path in task_paths
for task in tasks_to_test(task_path)
],
)
def test_task_dependencies(task_id, upstream_task_id, downstream_task_id):
"""Verify the upstream/downstream task_ids for tasks in DAG.
If there is no upstream or downstream task_id, it will be "EMPTY".
Where it is possible there may be multiple tasks that could run
simultaneously, a task_id candidate is selected randomly.
"""
dag_task_id = dag.get_task(task_id)
upstream_task_ids = [task.task_id for task in dag_task_id.upstream_list] or "EMPTY"
downstream_task_ids = [
task.task_id for task in dag_task_id.downstream_list
] or "EMPTY"
assert upstream_task_id in upstream_task_ids
assert downstream_task_id in downstream_task_ids
This will based on the the non-“EMPTY” values in both lists, generate a total of 9 tests. A test will fail if any upstream or downstream tasks are not properly connected to each task.
The real magic starts here with tasks_to_test
slicing up each list into a list of tasks, a list of upstream tasks, a list of downstream tasks, and then zipping them altogether into Task named tuples. Then the list comprehension in @mark.parametrize()
decorator of test_task_dependencies
finishes putting all the task data into the list of tuples needed for each test.