Testing Task Dependencies In Your DAG

While I have been using a simpler version in the past, it is only recently that I wrapped my head around how to simplify the creation–and really the maintenance–of upstream & downstream task dependencies in a DAG. The way I had been doing it involved writing these massive tuples with the task names over & over…

And it was even worse for DAGs that weren’t linear–in fact, from frustration I stopped maintaining that test at one point! Ultimately, it involved figuring out a way to build a list for each path through the DAG & then slice the list to create the upstream tasks & downstream tasks. And to aid in simplicity & future maintenance of the test, I used namedtuple from collections library to allow me to give the fields obvious names.

from collections import namedtuple
from pytest import mark, param


Task = namedtuple("Task", ["task", "upstream_task", "downstream_task", "name"])

task_paths = {  # Key = path name, Value = list of tasks in that path
    "first": [
        "EMPTY",  # This must be here for upstream task of "start" task!
        "start",
        "idempotency_check",
        "extract_first",
        "transform_first",
        "load_first",
        "send_slack_message_success",
        "EMPTY",  # This must be here for downstream_task of "send_slack_message_success" task!
    ],
    "second": [
        "EMPTY",
        "start",
        "extract_second_only",
        "send_slack_message_success",
        "EMPTY",
    ],  # Add any additional paths below here if necessary
}



def tasks_to_test(table_type):
    """This function creates a list of tuples with the task, its upstream task,
    downstream task, and name (for a pretty task identifier).
    """
    task_list = task_paths.get(table_type)
    tasks = task_list[1:-1]  # Creates a list that doesn't start/end with "EMPTY"
    upstream_tasks = task_list[:-1]  # Starts with "EMPTY" but doesn't end with it
    downstream_tasks = task_list[2:]  # End with "EMPTY" but doesn't start with it
    return [
        Task(t, u, d, f"{table_type}_{t}")
        for t, u, d in zip(tasks, upstream_tasks, downstream_tasks)
        if t != "EMPTY"
    ]


@mark.parametrize(
    "task_id, upstream_task_id, downstream_task_id",
    [
        param(task.task, task.upstream_task, task.downstream_task, id=task.name)
        for task_path in task_paths
        for task in tasks_to_test(task_path)
    ],
)
def test_task_dependencies(task_id, upstream_task_id, downstream_task_id):
    """Verify the upstream/downstream task_ids for tasks in DAG.

    If there is no upstream or downstream task_id, it will be "EMPTY".

    Where it is possible there may be multiple tasks that could run
    simultaneously, a task_id candidate is selected randomly.
    """
    dag_task_id = dag.get_task(task_id)
    upstream_task_ids = [task.task_id for task in dag_task_id.upstream_list] or "EMPTY"
    downstream_task_ids = [
        task.task_id for task in dag_task_id.downstream_list
    ] or "EMPTY"
    assert upstream_task_id in upstream_task_ids
    assert downstream_task_id in downstream_task_ids

This will based on the the non-“EMPTY” values in both lists, generate a total of 9 tests. A test will fail if any upstream or downstream tasks are not properly connected to each task.

The real magic starts here with tasks_to_test slicing up each list into a list of tasks, a list of upstream tasks, a list of downstream tasks, and then zipping them altogether into Task named tuples. Then the list comprehension in @mark.parametrize() decorator of test_task_dependencies finishes putting all the task data into the list of tuples needed for each test.

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.