airflow

code
Author

Vincenzo Palazeti

Published

December 2, 2023

this dag has an expanding training window with a maximum amount.

we found that three years worth of data was perferable, and anything further would degrade performance.

also, the backtest needed to start in 2011, which was only 1 year after the training data began.

so, for the first two years the training data window would expand to a maximum of three years, then retain only the most recent three years.


what makes this difficult is airflow does not allow python in the dag.

to get around this, I wrote the functions outside the dag & used the built in airflow utilities for date manipulation & function execution.

PythonOperator grabbed the functions and macros.ds_add() manipulated the training & prediction windows

these windows were set here in the file, but I probably shoud’ve used a CLI like click

from datetime import datetime, timedelta
import pendulum
from airflow import DAG
from airflow.models import Variable
from airflow.operators.bash_operator import BashOperator
from airflow.operators.python import PythonOperator
from airflow.providers.docker.operators.docker import DockerOperator
from docker.types import Mount

default_arguments = {
    "owner": "airflow",
    "depends_on_past": True,
    "retries": 0,
}

def check_start_date(**kwargs):
    date_given = kwargs['date']
    if date_given > '2010-01-01':
        return date_given
    else:
        return '2010-01-01'

def check_end_date(**kwargs):
    date_given = kwargs['date']
    if date_given < '2020-12-31':
        return date_given
    else:
        return '2020-12-31'

with DAG(
    dag_id="backfiill data_pull",
    description="Backfill DAG for data pulls",
    schedule_interval=timedelta(days=100),
    start_date = pendulum.datetime(2011, 1, 1, tz="UTC"),
    catchup=False,
    default_args=default_arguments,
) as dag:


    format_start_date = PythonOperator(
        task_id='format_start_date',
        python_callable=check_start_date,
        op_kwargs={"date":"{{ macros.ds_add(ds, -1096) }}"},
    )
    format_end_date = PythonOperator(
        task_id='format_end_date',
        python_callable=check_end_date,
        op_kwargs={"date":"{{ macros.ds_add(ds, 99) }}"},
    )

    training_date_start =  "{{ ti.xcom_pull(task_ids='format_start_date') }}"
    training_date_end = "{{ macros.ds_add(ds, -1) }}"

    prediction_period_start = "{{ ds }}"
    prediction_period_end =  "{{ ti.xcom_pull(task_ids='format_end_date') }}"


    docker_url = Variable.get("DOCKER_URL", deserialize_json=True)
    aws_user = Variable.get("AWS_USERNAME", deserialize_json=True)

    training_data_pull = DockerOperator(
        task_id="training_data_pull",
        image="image_name:latest",
        command=f"python scripts/data_pull.py \
            --aws_user {aws_user} \
            --start_date {training_date_start} \
            --end_date {training_date_end}",
        network_mode='host',
        docker_url=docker_url,
        auto_remove=True,
        mounts=[
            Mount(target='/home/myuser/.aws', source='/home/airflow/.aws', type='bind'),
            Mount(target='/home/myuser/code/scripts', source='/home/airflow/projects', type='bind')
            ],
        dag=dag
    )

    prediction_data_pull = DockerOperator(
        task_id="prediction_data_pull",
        image="image_name:latest",
        command=f"python scripts/data_pull.py \
            --aws_user {aws_user} \
            --start_date {prediction_period_start} \
            --end_date {prediction_period_end}",
        network_mode='host',
        docker_url=docker_url,
        auto_remove=True,
        mounts=[
            Mount(target='/home/myuser/.aws', source='/home/airflow/.aws', type='bind'),
            Mount(target='/home/myuser/code/scripts', source='/home/airflow/projects', type='bind')
            ],
        dag=dag
    )

    format_start_date >> format_end_date >> training_data_pull >> prediction_data_pull