Introduction

In my experience, developing testable Spark applications code is not an easy task for data practitioners. I am not going to discuss the underlying reasons.

In this post, I present my reasoning while developing a testable batch Spark Application. The text is presented in two sections. The first section , TDD - Developing code from the tests, I show an example of how to develop a code that is modular, readable, comprehensive, testable, and easy to maintain. In the last section More than producing pretty code - it’s about building organizational knowledge, I emphasize the benefits of using TDD in data projects based on my experience and on other sources that may help you to understand this methodology.

This post is not a tutorial. Yet, it provides some guidance to help professionals to develop testable code for Apache Spark applications. Details of implementation, like preparing the environment to run tests locally and how to use PyTest, are not presented here.


TDD - Developing code from the tests

Framing the Problem

Let’s suppose we have a task to develop a code that reads CSV files, adapt the column names to comply with internal guidelines as well as choosing proper data types. We don’t have much more information, like expected workload, value ranges for numerical columns, or expected values for categorical columns. The Product Owner, however, provided us a sample and guarantee that all files are going to have the same column names. Let’s start.


Thinking about the repo and test files structure

The repo structure we can use is presented below:

spark-unit-test-example
 ┣ notebooks
 ┃ ┗ ...
 ┣ tests
 ┃ ┗ ...
 ┗ README.md

The repo spark-unit-test-example has two folders and one README.md file. The notebook folder will contain the source code for the Spark application we need to develop. The tests folder will contain the source code to test functions of our application.

Regarding the Python test files, we can start with the following structure:

In code, this would be something like the snippet below:

import sys
import pytest
from pyspark.sql import SparkSession
from pyspark_test import assert_pyspark_df_equal

@pytest.fixture(scope="module")
def spark():
    ...

@pytest.fixture(scope="module")
def mock_base(spark):
    ...

@pytest.fixture(scope="module")
def mock_expected_change_column_names(spark):
    ...

@pytest.fixture(scope="module")
def mock_expected_change_column_types(spark):
    ...

def test_mock_base(mock_base):
    assert mock_base.count() > 0

@pytest.mark.skip(reason="TBD")
def test_change_column_names(mock_base, mock_expected_change_column_names):
    sys.append('./notebooks')
    from ... import change_column_names
    ...

@pytest.mark.skip(reason="TBD")
def test_change_column_types(mock_base, mock_expected_change_column_types):
    sys.append('./notebooks')
    from ... import change_column_types
    ...

The code above is written in the file hop_raw2bronze_test.py, which is just a starting point and has the following content:

  • General imports:

    • sys: access Python Runtime variables. In particular, we are interested in appending the notebook folder path to the Python module search path.
    • pytest: creates data that can be used through the tests in a reliable way. We can also use pytest decorators in order to skip tests that we didn’t build yet.
    • SparkSession: simulates a connection to a Spark Cluster and creates mocked Spark DataFrames.
    • assert_pyspark_df_equal: tests if two provided Spark DataFrames are equal.
  • Mocks and fixtures:

    • spark: creates the Spark mock connection.
    • mock_base: creates the initial mock data. Replaces the result of the DataFrameReader when reading the CSV files in the application code.
    • mock_expected: each mock_expected creates the expected result of the corresponding transformation.
  • Tests:

    • test_mock_base: evaluates if the initial DataFrame was created and is not empty.
    • other tests: each test is going to evaluate its corresponding transformation function.

The repo structure now is:

spark-unit-test-example
 ┣ notebooks
 ┃ ┗ ...
 ┣ tests
 ┃ ┗ hop_raw2bronze_test.py
 ┗ README.md

Since we have an initial structure, we can start to develop and implement tests.

Implementing the spark and mock_base fixtures

First, we need to implement the spark and mock_base fixtures. The spark fixture simulates the SparkSession for our application. The mock_base replaces the result of reading the CSV files in the application code. The snippet below shows the implementations:

...
@pytest.fixture(scope="module")
def spark():
    spark_session = SparkSession.builder \
        .appName("Spark Unit Test") \
        .getOrCreate()
    return spark_session


@pytest.fixture(scope="module")
def mock_base(spark):
    mock_data = [
       {
        "Customer ID": "1",
        "Age": "20",
        "Gender": "Male",
        "Item Purchased": "Blouse",
        "Category": "Clothing",
        "Purchase Amount (USD)": "53",
        "Location": "Montana",
        "Size": "L",
        "Color":"Blue",
        "Season": "Spring",
        "Review Rating": "3.1",
        "Subscription Status": "Yes",
        "Payment Method": "Credit Card",
        "Shipping Type": "Free Shipping",
        "Discount Applied": "No",
        "Promo Code Used": "No",
        "Previous Purchases": "5",
        "Preferred Payment Method": "Credit Card",
        "Frequency of Purchases": "Weekly"
       },
    ]

    schema = """
         `Customer ID` string,
        `Age` string,
        `Gender` string,
        `Item Purchased` string,
        `Category` string,
        `Purchase Amount (USD)` string,
        `Location` string,
        `Size` string,
        `Color` string,
        `Season` string,
        `Review Rating` string,
        `Subscription Status` string,
        `Payment Method` string,
        `Shipping Type` string,
        `Discount Applied` string,
        `Promo Code Used` string,
        `Previous Purchases` string,
        `Preferred Payment Method` string,
        `Frequency of Purchases` string
    """

    df = spark.createDataFrame(mock_data, schema=schema)
    df.show()
    return df
  ...

  def test_mock_base(mock_base):
    assert mock_base.count() == 1

We are mocking just one line of the data source, but we can add new ones as we need. Since we expect to read CSV files, all data types would be String. Now we can start running tests with Pytest. We are going to get something similar to the output below:

...
+-----------+---+------+--------------+--------+---------------------+--------+----+-----+------+-------------+-------------------+--------------+-------------+----------------+---------------+------------------+------------------------+----------------------+
|Customer ID|Age|Gender|Item Purchased|Category|Purchase Amount (USD)|Location|Size|Color|Season|Review Rating|Subscription Status|Payment Method|Shipping Type|Discount Applied|Promo Code Used|Previous Purchases|Preferred Payment Method|Frequency of Purchases|
+-----------+---+------+--------------+--------+---------------------+--------+----+-----+------+-------------+-------------------+--------------+-------------+----------------+---------------+------------------+------------------------+----------------------+
|          1| 20|  Male|        Blouse|Clothing|                   53| Montana|   L| Blue|Spring|          3.1|                Yes|   Credit Card|Free Shipping|              No|             No|                 5|             Credit Card|                Weekly|
+-----------+---+------+--------------+--------+---------------------+--------+----+-----+------+-------------+-------------------+--------------+-------------+----------------+---------------+------------------+------------------------+----------------------+

PASSED

=============================== warnings summary ===============================
...

After that, we can remove the df.show() call in the mock_base fixture and proceed to the next step.

Implementing the change_column_names function

In order to implement the change_column_names function, we need a fixture to mock the expected result and develop the test. Both are presented below:

@pytest.fixture(scope="module")
def mock_expected_change_column_names(spark):
    mock_data = [
       {
        "customer_id": "1",
        "age": "20",
        "gender": "Male",
        "item_purchased": "Blouse",
        "category": "Clothing",
        "purchase_amount_usd": "53",
        "location": "Montana",
        "size": "L",
        "color":"Blue",
        "season": "Spring",
        "review_rating": "3.1",
        "subscription_status": "Yes",
        "payment_method": "Credit Card",
        "shipping_type": "Free Shipping",
        "discount_applied": "No",
        "promo_code_used": "No",
        "previous_purchases": "5",
        "preferred_payment_method": "Credit Card",
        "frequency_of_purchases": "Weekly"
       },
    ]

    schema = """
        customer_id string,
        age string,
        gender string,
        item_purchased string,
        category string,
        purchase_amount_usd string,
        location string,
        size string,
        color string,
        season string,
        review_rating string,
        subscription_status string,
        payment_method string,
        shipping_type string,
        discount_applied string,
        promo_code_used string,
        previous_purchases string,
        preferred_payment_method string,
        frequency_of_purchases string
    """

    df = spark.createDataFrame(mock_data, schema=schema)
    return df
  ...

  def test_change_column_names(mock_base, mock_expected_change_column_names):
    sys.path.append('./notebooks')
    from hop_raw2bronze import change_column_names

    df_transformed = change_column_names(mock_base)
    df_transformed.show()

    assert_pyspark_df_equal(df_transformed, mock_expected_change_column_names)

This is the first application function test we are building. We call the sys.path.appendmethod in order to provide the application files location to the Python Runtime. After that, we create a df_transformed dataset as a result of the change_column_name function with the mock_base DataFrame as its argument. Finally, we check if the transformed data df_transformed is equal to the expected result mock_expected_change_column_names.

Now that we have the test implemented, we can start writing the application code in the hop_raw2bronze file inside the notebooks folder. The updated file structure is presented below:

spark-unit-test-example
 ┣ notebooks
 ┃ ┗ hop_raw2bronze.py
 ┣ tests
 ┃ ┗ hop_raw2bronze_test.py
 ┗ README.md

We can write anything in the hop_raw2bronze.py file since it results in passing the test_change_column_names. There are many ways to achieve this result. One of them is presented below:

from pyspark.sql import DataFrame

def change_column_names(df: DataFrame) -> DataFrame:
    return df.selectExpr(
        "`Customer ID` as customer_id",
        "`Age` as age",
        "`Gender` as gender",
        "`Item Purchased` as item_purchased",
        "`Category` as category",
        "`Purchase Amount (USD)` as purchase_amount_usd",
        "`Location` as location",
        "`Size` as size",
        "`Color` as color",
        "`Season` as season",
        "`Review Rating` as review_rating",
        "`Subscription Status` as subscription_status",
        "`Payment Method` as payment_method",
        "`Shipping Type` as shipping_type",
        "`Discount Applied` as discount_applied",
        "`Promo Code Used` as promo_code_used",
        "`Previous Purchases` as previous_purchases",
        "`Preferred Payment Method` as preferred_payment_method",
        "`Frequency of Purchases` as frequency_of_purchases"
        )

The next step is to run Pytest. In real world, we may end up failing the test the first time, and then we keep fixing the application function until we pass the test. The terminal output now looks something like this:

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
...
PASSED
spark-unit-test-example/tests/hop_raw2bronze_test.py::test_change_column_names +-----------+---+------+--------------+--------+-------------------+--------+----+-----+------+-------------+-------------------+--------------+-------------+----------------+---------------+------------------+------------------------+----------------------+
|customer_id|age|gender|item_purchased|category|purchase_amount_usd|location|size|color|season|review_rating|subscription_status|payment_method|shipping_type|discount_applied|promo_code_used|previous_purchases|preferred_payment_method|frequency_of_purchases|
+-----------+---+------+--------------+--------+-------------------+--------+----+-----+------+-------------+-------------------+--------------+-------------+----------------+---------------+------------------+------------------------+----------------------+
|          1| 20|  Male|        Blouse|Clothing|                 53| Montana|   L| Blue|Spring|          3.1|                Yes|   Credit Card|Free Shipping|              No|             No|                 5|             Credit Card|                Weekly|
+-----------+---+------+--------------+--------+-------------------+--------+----+-----+------+-------------+-------------------+--------------+-------------+----------------+---------------+------------------+------------------------+----------------------+

PASSED
...```

We can remove the df_transformed.show() in the test and move on to the next function.

Implementing the change_column_types function

As we did in the previous section, we start by developing the fixture to mock the expected data and then implement the test. But we need better definitions in order to build an application that makes better usage of the computing resources and doesn’t compromise the data quality.

First, the column age could be a tinyint type, since the human age is a two-digit number and this type uses 1 byte to represent integer numbers from -128 to 127. For a String representation of a two-digit age, we would use 42 bytes, given the following formula:

byte_size = 2 * (num_of_characteres) + 38

Except for the remaining columns which values are texts, we can’t choose the proper data type without better requirements. In addition to the problem regarding the data size in memory, data quality may be compromised if we keep their types as String. Encoding problems could happen in numerical strings like customer_id, purchase_amount_usd and review_rating, and we could propagate these problems to other processes that consumes data from the table we are feeding. Moreover, we couldn’t inform the responsible team about this problem and improve the data pipeline.

We often choose the data type in the bronze table as the same as is in its origin. However, this may not be applied to the field purchase_amount_usd. We don’t usualy have aggregations in transactional systems, so the data type for this column could be some kind of floating point. But aggregations are commom in analytical systems. If we have a lot of lines in this column and sum its value, the output will vary considerably due to the floating point arithmetic. We should define a type with fixed precision and the accepted precision.

Let’s pretend we have hammered out the details with the business team, leading us to the following:

...
@pytest.fixture(scope="module")
def mock_expected_change_column_types(spark):
    from decimal import Decimal
    mock_data = [
       {
        "customer_id": 1,
        "age": 20,
        "gender": "Male",
        "item_purchased": "Blouse",
        "category": "Clothing",
        "purchase_amount_usd": Decimal(53.0),
        "location": "Montana",
        "size": "L",
        "color":"Blue",
        "season": "Spring",
        "review_rating": 3.1,
        "subscription_status": "Yes",
        "payment_method": "Credit Card",
        "shipping_type": "Free Shipping",
        "discount_applied": "No",
        "promo_code_used": "No",
        "previous_purchases": 5,
        "preferred_payment_method": "Credit Card",
        "frequency_of_purchases": "Weekly"
       },
    ]

    schema = """
        customer_id bigint,
        age int,
        gender string,
        item_purchased string,
        category string,
        purchase_amount_usd decimal(10,5),
        location string,
        size string,
        color string,
        season string,
        review_rating double,
        subscription_status string,
        payment_method string,
        shipping_type string,
        discount_applied string,
        promo_code_used string,
        previous_purchases int,
        preferred_payment_method string,
        frequency_of_purchases string
    """

    df = spark.createDataFrame(mock_data, schema=schema)
    return df
...
def test_change_column_types(mock_base, mock_expected_change_column_types):
    sys.path.append('app/spark-unit-test-example/notebooks')
    from hop_raw2bronze import change_column_names, change_column_types

    df_names = change_column_names(mock_base)
    df_transformed = change_column_types(df_names)

    assert_pyspark_df_equal(df_transformed, mock_expected_change_column_types)

Now, we can implement the function change_column_types in the module hop_raw2bronze. One of the possibilities is presented below:

...
def change_column_types(df: DataFrame) -> DataFrame:
    return df.selectExpr(
        "cast(customer_id as bigint) customer_id",
        "cast(age as int) age",
        "cast(gender as string) gender",
        "cast(item_purchased as string) item_purchased",
        "cast(category as string) category",
        "cast(purchase_amount_usd as decimal(10,5)) purchase_amount_usd",
        "cast(location as string) location",
        "cast(size as string) size",
        "cast(color as string) color",
        "cast(season as string) season",
        "cast(review_rating as double) review_rating",
        "cast(subscription_status as string) subscription_status",
        "cast(payment_method as string) payment_method",
        "cast(shipping_type as string) shipping_type",
        "cast(discount_applied as string) discount_applied",
        "cast(promo_code_used as string) promo_code_used",
        "cast(previous_purchases as int) previous_purchases",
        "cast(preferred_payment_method as string) preferred_payment_method",
        "cast(frequency_of_purchases as string) frequency_of_purchases",
    )
...

Executing the Pytest, we are going to get the following output:

============================= test session starts ==============================
...
app/spark-unit-test-example/tests/hop_raw2bronze_test.py::test_mock_base ps: unrecognized option: p
...
PASSED [ 33%]
app/spark-unit-test-example/tests/hop_raw2bronze_test.py::test_change_column_names PASSED [ 66%]
app/spark-unit-test-example/tests/hop_raw2bronze_test.py::test_change_column_types PASSED [100%]

Implementing the check_constraints function

In order to implement the check_constraints function, we need to create the mock fixtures and the test. To do that, we need to define the data quality constraints and the expected behavior in cases they are violated. There are some possibilities:

  • When a constraint is violated, the value is removed from the DataFrame and we proceed with the application execution. Engineering and business teams have agreed that the amount of bad data is too small and has no impact over the use case of this table.
  • When a constraint is violated, the process is interrupted and an exception is raised, showing the total amount of violations occurrences, as well as a small sample of them. Engineering and business teams have agreed that if we have a problem, we should provide enough information to the responsible team in order to identify the root cause of the problem.
  • When a constraint is violated, we raise a warning and proceed, without filtering. Engineering and business teams have agreed that It is expected to have some bad data in the table, and we have a data quality tool monitoring it and that allow us to block using this table if the data quality is below some defined threshold.

Whatever decision is taken, it should be aligned with the business. Here, we are going to proceed with the second option and consider the following constraints:

  • The customer_id field can’t be null.
  • The review_rating and previous_purchases fields can’t be negative or null.

In order to check the function behavior when those rules are violated, we need to create a mock that contains data that violates at least one of each rule, like the snippet below:

@pytest.fixture(scope="module")
def mock_check_constraints_fail(spark):
    from decimal import Decimal
    mock_data = [
       {
        # customer_id is Null
        "customer_id": None,
        "age": 20,
        "gender": "Male",
        "item_purchased": "Blouse",
        "category": "Clothing",
        "purchase_amount_usd": Decimal(53.0),
        "location": "Montana",
        "size": "L",
        "color":"Blue",
        "season": "Spring",
        "review_rating": 3.1,
        "subscription_status": "Yes",
        "payment_method": "Credit Card",
        "shipping_type": "Free Shipping",
        "discount_applied": "No",
        "promo_code_used": "No",
        "previous_purchases": 5,
        "preferred_payment_method": "Credit Card",
        "frequency_of_purchases": "Weekly"
       },
        {
        # review_rating is Negative
        "customer_id": 1,
        "age": 20,
        "gender": "Male",
        "item_purchased": "Blouse",
        "category": "Clothing",
        "purchase_amount_usd": Decimal(53.0),
        "location": "Montana",
        "size": "L",
        "color":"Blue",
        "season": "Spring",
        "review_rating": -0.1,
        "subscription_status": "Yes",
        "payment_method": "Credit Card",
        "shipping_type": "Free Shipping",
        "discount_applied": "No",
        "promo_code_used": "No",
        "previous_purchases": 5,
        "preferred_payment_method": "Credit Card",
        "frequency_of_purchases": "Weekly"
       },
        {
        # review_rating is Null
        "customer_id": 1,
        "age": 20,
        "gender": "Male",
        "item_purchased": "Blouse",
        "category": "Clothing",
        "purchase_amount_usd": Decimal(53.0),
        "location": "Montana",
        "size": "L",
        "color":"Blue",
        "season": "Spring",
        "review_rating": None,
        "subscription_status": "Yes",
        "payment_method": "Credit Card",
        "shipping_type": "Free Shipping",
        "discount_applied": "No",
        "promo_code_used": "No",
        "previous_purchases": 5,
        "preferred_payment_method": "Credit Card",
        "frequency_of_purchases": "Weekly"
       },
        {
        # previous_purchases is Negative
        "customer_id": 1,
        "age": 20,
        "gender": "Male",
        "item_purchased": "Blouse",
        "category": "Clothing",
        "purchase_amount_usd": Decimal(53.0),
        "location": "Montana",
        "size": "L",
        "color":"Blue",
        "season": "Spring",
        "review_rating": 3.1,
        "subscription_status": "Yes",
        "payment_method": "Credit Card",
        "shipping_type": "Free Shipping",
        "discount_applied": "No",
        "promo_code_used": "No",
        "previous_purchases": -5,
        "preferred_payment_method": "Credit Card",
        "frequency_of_purchases": "Weekly"
       },
            {
        # previous_purchases is Null
        "customer_id": 1,
        "age": 20,
        "gender": "Male",
        "item_purchased": "Blouse",
        "category": "Clothing",
        "purchase_amount_usd": Decimal(53.0),
        "location": "Montana",
        "size": "L",
        "color":"Blue",
        "season": "Spring",
        "review_rating": 3.1,
        "subscription_status": "Yes",
        "payment_method": "Credit Card",
        "shipping_type": "Free Shipping",
        "discount_applied": "No",
        "promo_code_used": "No",
        "previous_purchases": None,
        "preferred_payment_method": "Credit Card",
        "frequency_of_purchases": "Weekly"
       },
    ]

    schema = """
        customer_id bigint,
        age int,
        gender string,
        item_purchased string,
        category string,
        purchase_amount_usd decimal(10,5),
        location string,
        size string,
        color string,
        season string,
        review_rating double,
        subscription_status string,
        payment_method string,
        shipping_type string,
        discount_applied string,
        promo_code_used string,
        previous_purchases int,
        preferred_payment_method string,
        frequency_of_purchases string
    """

    df = spark.createDataFrame(mock_data, schema=schema)
    return df

In the mocked data above, we create a 5-row DataFrame that violates one of each constraints. Now, we can proceed with the tests:

...
def test_check_contraints_ok(mock_expected_change_column_types):
    sys.path.append('./notebooks')
    from hop_raw2bronze import check_constraints

    check_ok = check_constraints(mock_expected_change_column_types)
    assert check_ok


def test_check_constraints_fail(mock_check_constraints_fail):
    sys.path.append('./notebooks')
    from hop_raw2bronze import check_constraints

    with pytest.raises(ValueError) as info:
        check_constraints(mock_check_constraints_fail)

    bad_rows = info.value.args[1]
    assert bad_rows == 5
    ...

In the snippet above we are testing two scenarios. In the first one, we test the function behavior where no constraints are violated. We use the same mock we build to check the change_column_types to make it simple, but we can change if needed. In the second test, the message from the raised ValueError exceptions is captured. The second argument in this message has the number of rows in the DataFrame that violates at least one of the constraints. Now, we can implement the check_constraints function in the hop_raw2bronze module. One of the possibilities is presented in the snippet below:

...
def check_constraints(df: DataFrame) -> DataFrame:
    df_check = df.filter(
        """
        (customer_id is null)
            or (review_rating < 0)
            or (review_rating is null)
            or (previous_purchases < 0)
            or (previous_purchases is null)
        """
    )

    check = df_check.isEmpty()

    if check:
        return True

    df_string = df.limit(10).toPandas().to_string()
    df_count = df.count()
    error_msg = f"""
    [VALUE ERROR] Constraints were violated.
        - customer_id must not be null.
        - review_rating must be equal or greater than zero.
        - previous_purchases must be equal or greater than zero.

    Showing up to 10 rows that violates the rules above:

    {df_string}
    """
    print(df_string)
    raise ValueError("Total Lines:", df_count, error_msg)
...

After running the tests with Pytest, something similar will be shown in the terminal:

============================= test session starts ==============================
...
app/spark-unit-test-example/tests/hop_raw2bronze_test.py::test_change_column_names ps: unrecognized option: p
...
WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
PASSED [ 25%]
app/spark-unit-test-example/tests/hop_raw2bronze_test.py::test_change_column_types PASSED [ 50%]
app/spark-unit-test-example/tests/hop_raw2bronze_test.py::test_check_contraints_ok PASSED [ 75%]
app/spark-unit-test-example/tests/hop_raw2bronze_test.py::test_check_constraints_fail    customer_id  age gender item_purchased  category purchase_amount_usd location size color  season  review_rating subscription_status payment_method  shipping_type discount_applied promo_code_used  previous_purchases preferred_payment_method frequency_of_purchases
0          NaN   20   Male         Blouse  Clothing            53.00000  Montana    L  Blue  Spring            3.1                 Yes    Credit Card  Free Shipping               No              No                 5.0              Credit Card                 Weekly
1          1.0   20   Male         Blouse  Clothing            53.00000  Montana    L  Blue  Spring           -0.1                 Yes    Credit Card  Free Shipping               No              No                 5.0              Credit Card                 Weekly
2          1.0   20   Male         Blouse  Clothing            53.00000  Montana    L  Blue  Spring            NaN                 Yes    Credit Card  Free Shipping               No              No                 5.0              Credit Card                 Weekly
3          1.0   20   Male         Blouse  Clothing            53.00000  Montana    L  Blue  Spring            3.1                 Yes    Credit Card  Free Shipping               No              No                -5.0              Credit Card                 Weekly
4          1.0   20   Male         Blouse  Clothing            53.00000  Montana    L  Blue  Spring            3.1                 Yes    Credit Card  Free Shipping               No              No                 NaN              Credit Card                 Weekly
PASSED [100%]

=============================== warnings summary ===============================
...
======================== 4 passed, 2 warnings in 20.56s ========================

As we need to add new constraints, we add the violation example in the mock base and addapt both tests and application function. The test file is working as a documentation to the business logic. We know what it’s expected, and we can even document new cases that can happen.

We can still write one more function in our application that calls every other transformations in a defined order.

Implementing the transformations function

The transformation function is expected to do the following:

  • Receives the raw DataFrame created after reading the CSV files.
  • Passes the raw DataFrame to the change_column_names function, leading to the df_change_names DataFrame.
  • After that, it passes the df_change_names to the change_column_types function. This result in the final DataFrame df_result.
  • Then, we provide the df_result to the check_constraint function, which will return True if df_result doesn’t violate the constraints, or raise a ValueError exception otherwise.

As we did before, we have a happy path and an exception path. For the happy path, we can use the mock_base fixture, and for the exception path, we create the mock_base_fail fixture presented below:

@pytest.fixture(scope="module")
def mock_base_fail(spark):
    mock_data = [
       {
        "Customer ID": "1",
        "Age": "20",
        "Gender": "Male",
        "Item Purchased": "Blouse",
        "Category": "Clothing",
        "Purchase Amount (USD)": "53",
        "Location": "Montana",
        "Size": "L",
        "Color":"Blue",
        "Season": "Spring",
        "Review Rating": "-3.1",
        "Subscription Status": "Yes",
        "Payment Method": "Credit Card",
        "Shipping Type": "Free Shipping",
        "Discount Applied": "No",
        "Promo Code Used": "No",
        "Previous Purchases": "5",
        "Preferred Payment Method": "Credit Card",
        "Frequency of Purchases": "Weekly"
       },
    ]

    schema = """
         `Customer ID` string,
        `Age` string,
        `Gender` string,
        `Item Purchased` string,
        `Category` string,
        `Purchase Amount (USD)` string,
        `Location` string,
        `Size` string,
        `Color` string,
        `Season` string,
        `Review Rating` string,
        `Subscription Status` string,
        `Payment Method` string,
        `Shipping Type` string,
        `Discount Applied` string,
        `Promo Code Used` string,
        `Previous Purchases` string,
        `Preferred Payment Method` string,
        `Frequency of Purchases` string
    """

    df = spark.createDataFrame(mock_data, schema=schema)
    return df

Now, we can create the tests to evaluate the mock creation and to check the transformations function behavior:

...
def test_mock_base_fail(mock_base_fail):
    assert mock_base_fail.count() == 1

def test_transformations_ok(mock_base, mock_expected_change_column_types):
    sys.path.append('./notebooks')
    from hop_raw2bronze import transformations

    df_transformed = transformations(mock_base)

    assert_pyspark_df_equal(df_transformed, mock_expected_change_column_types)


def test_transformations_fail(mock_base_fail):
    sys.path.append('./notebooks')
    from hop_raw2bronze import transformations

    with pytest.raises(ValueError) as info:
        transformations(mock_base_fail)

    bad_rows = info.value.args[1]
    assert bad_rows == 1

Next, we can implement the transformations application function in the hop_raw2bronze module. One of the possible implementations is presented below:

...
def transformations(df: DataFrame) -> DataFrame:
    df_change_names = change_column_names(df)
    df_result= change_column_types(df_change_names)
    check = check_constraints(df_result)

    if check:
        return df_result
...

After running the Pytest, we should see something like the output below:

============================= test session starts ==============================
...
WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
PASSED [ 14%]
app/spark-unit-test-example/tests/hop_raw2bronze_test.py::test_change_column_names PASSED [ 28%]
app/spark-unit-test-example/tests/hop_raw2bronze_test.py::test_change_column_types PASSED [ 42%]
app/spark-unit-test-example/tests/hop_raw2bronze_test.py::test_check_contraints_ok PASSED [ 57%]
app/spark-unit-test-example/tests/hop_raw2bronze_test.py::test_check_constraints_fail    customer_id  age gender item_purchased  category purchase_amount_usd location size color  season  review_rating subscription_status payment_method  shipping_type discount_applied promo_code_used  previous_purchases preferred_payment_method frequency_of_purchases
0          NaN   20   Male         Blouse  Clothing            53.00000  Montana    L  Blue  Spring            3.1                 Yes    Credit Card  Free Shipping               No              No                 5.0              Credit Card                 Weekly
1          1.0   20   Male         Blouse  Clothing            53.00000  Montana    L  Blue  Spring           -0.1                 Yes    Credit Card  Free Shipping               No              No                 5.0              Credit Card                 Weekly
2          1.0   20   Male         Blouse  Clothing            53.00000  Montana    L  Blue  Spring            NaN                 Yes    Credit Card  Free Shipping               No              No                 5.0              Credit Card                 Weekly
3          1.0   20   Male         Blouse  Clothing            53.00000  Montana    L  Blue  Spring            3.1                 Yes    Credit Card  Free Shipping               No              No                -5.0              Credit Card                 Weekly
4          1.0   20   Male         Blouse  Clothing            53.00000  Montana    L  Blue  Spring            3.1                 Yes    Credit Card  Free Shipping               No              No                 NaN              Credit Card                 Weekly
PASSED [ 71%]
app/spark-unit-test-example/tests/hop_raw2bronze_test.py::test_transformations_ok PASSED [ 85%]
app/spark-unit-test-example/tests/hop_raw2bronze_test.py::test_transformations_fail    customer_id  age gender item_purchased  category purchase_amount_usd location size color  season  review_rating subscription_status payment_method  shipping_type discount_applied promo_code_used  previous_purchases preferred_payment_method frequency_of_purchases
0            1   20   Male         Blouse  Clothing            53.00000  Montana    L  Blue  Spring           -3.1                 Yes    Credit Card  Free Shipping               No              No                   5              Credit Card                 Weekly
PASSED [100%]

=============================== warnings summary ===============================
...
======================== 7 passed, 4 warnings in 22.43s ========================

The whole hop_raw2bronze module application code

In the snippet below, we can see the code for the hop_raw2bronze module:

from pyspark.sql import DataFrame

APP_NAME = "[BATCH] RAW2BRONZE - customer_shopping"

def change_column_names(df: DataFrame) -> DataFrame:
    return df.selectExpr(
        "`Customer ID` as customer_id",
        "`Age` as age",
        "`Gender` as gender",
        "`Item Purchased` as item_purchased",
        "`Category` as category",
        "`Purchase Amount (USD)`as purchase_amount_usd",
        "`Location` as location",
        "`Size` as size",
        "`Color` as color",
        "`Season` as season",
        "`Review Rating` as review_rating",
        "`Subscription Status` as subscription_status",
        "`Payment Method` as payment_method",
        "`Shipping Type` as shipping_type",
        "`Discount Applied` as discount_applied",
        "`Promo Code Used` as promo_code_used",
        "`Previous Purchases` as previous_purchases",
        "`Preferred Payment Method` as preferred_payment_method",
        "`Frequency of Purchases` as frequency_of_purchases"
        )


def change_column_types(df: DataFrame) -> DataFrame:
    return df.selectExpr(
        "cast(customer_id as bigint) customer_id",
        "cast(age as int) age",
        "cast(gender as string) gender",
        "cast(item_purchased as string) item_purchased",
        "cast(category as string) category",
        "cast(purchase_amount_usd as decimal(10,5)) purchase_amount_usd",
        "cast(location as string) location",
        "cast(size as string) size",
        "cast(color as string) color",
        "cast(season as string) season",
        "cast(review_rating as double) review_rating",
        "cast(subscription_status as string) subscription_status",
        "cast(payment_method as string) payment_method",
        "cast(shipping_type as string) shipping_type",
        "cast(discount_applied as string) discount_applied",
        "cast(promo_code_used as string) promo_code_used",
        "cast(previous_purchases as int) previous_purchases",
        "cast(preferred_payment_method as string) preferred_payment_method",
        "cast(frequency_of_purchases as string) frequency_of_purchases",
    )


def check_constraints(df: DataFrame) -> DataFrame:
    df_check = df.filter(
        """
        (customer_id is null)
            or (review_rating < 0)
            or (review_rating is null)
            or (previous_purchases < 0)
            or (previous_purchases is null)
        """
    )

    check = df_check.isEmpty()

    if check:
        return True

    df_string = df.limit(10).toPandas().to_string()
    df_count = df.count()
    error_msg = f"""
    [VALUE ERROR] Constraints were violated.
        - customer_id must not be null.
        - review_rating must be equal or greater than zero.
        - previous_purchases must be equal or greater than zero.

    Showing up to 10 rows that violates the rules above:

    {df_string}
    """
    print(df_string)
    raise ValueError("Total Lines:", df_count, error_msg)

def transformations(df: DataFrame) -> DataFrame:
    df_change_names = change_column_names(df)
    df_result= change_column_types(df_change_names)
    check = check_constraints(df_result)

    if check:
        return df_result


def main():
    spark = ...
    df_raw = spark.read.csv('/path/to/csv/files')

    df_final = transformations(df_raw)

    if check_constraints(df_final):
        spark.sparkContext.setJobGroup(APP_NAME, "Hop Raw to Bronze")
        df_final.mode("append").saveAsTable("bronze.customer_shopping")


if __name__ == "__main__":
    main()

Noticed that the code has a well defined structure. The main function injects the df_raw into lower-level application functions. If, at some point, we need to add more functions to this application, it’s easy to split the module hop_raw2bronze and refactor the tests. For example, we could create a hop_raw2bronze_utils module containing all transformations and checks we need. In this scenario hop_raw2bronze, would look something like the snippet below:

from pyspark.sql import DataFrame
from hop_raw2bronze_utils import transformations, checks

APP_NAME = "[BATCH] RAW2BRONZE - customer_shopping"

def main():
    spark = ...
    df_raw = spark.read.csv('/path/to/csv/files')

    df_final = transformations(df_raw)

    if check_constraints(df_final):
        spark.sparkContext.setJobGroup(APP_NAME, "Hop Raw to Bronze")
        df_final.mode("append").saveAsTable("bronze.customer_shopping")


if __name__ == "__main__":
    main()

And for every test we would replace the imports from hop_raw2bronze to hop_raw2bronze_utils. But the biggest advantage, in my opinion, is not that we have produced a pretty code. We are building a business-knowledge base embeeded in the code.


More than a method to produce pretty code - it’s about building organizational knowledge

In the example above, we have seen how to produce code that is easy to read and to maintain. But the benefits are beyond that. In my opinion, Test Driven Development main benefits are:

  • It makes the developer to think about the business requirements. It’s not uncommon to have loosely defined requirements given by business people. If developers don’t take their part in understanding the business in order to produce code, the results can be harmful. The developer may need to work overtime in order to fix a problem in their code caused by a loosely defined problem description in a Jira card. Using this method, programmers can make better questions to the business team in order to get better requirements. In the business perspective, having well defined requirements at the beginning of the project reduce time to market and the risks of delivering a bad product to a client.
  • It embeeds documentation in the process of developing code. We often develop the code first, and if we have enough time, we take care of the documentation. When developing tests, we are also documenting the business requirements, expectations on the application behavior, and even how our data sources look like. In a managerial perspective, this means that developers will learn faster about the context they are into and start earlier to execute their tasks.

Multiple companies and tech professionals advocate for the implementation of unit tests and TDD, in both technical and business perspective. Here are some examples:

  • Microsoft says that unit tests improve processes and the quality of their outcomes.
  • As Harsh Patel, Aditi Rajnish, and Raj Pathak have posted in AWS Machine Learning Blog, ".. TDD facilitates collaboration and knowledge sharing among teams, because tests serve as living documentation and a shared understanding of the expected behavior and constraints"
  • Google also mentions the importance of the TDD approach in order to produce better code.

And there are countless other sources that mentions benefits of TDD, such as the books Test Driven Development: By Example or The DevOps Handbook.

In my opinion, applying TDD is beneficial to everyone involved in the Spark Application development. The business team can deliver better products faster, and make it easy to developers to produce, understand and maintain the codebase. This can lead to faster financial outcomes, as well as lessening the possibilites of engineers having to work firefighting.


References