Calmcode - pytest tricks: combine

Combining Fixtures and Parametrized Decorators

1 2 3 4 5 6 7 8 9 10

Let's now consider a much more elaborate example. This example will test two functions; normalize and threshold. Both functions have overlapping behavior (the shape needs to remain the same in both cases) but they also have behavior that needs to be tested seperately. That means that we might be able to do something clever by combining fixtures and parametrized decorators.

import pytest
import numpy as np

def normalize(X):
    return (X - X.min())/(X.max() - X.min())

def threshold(X, min_val=-1, max_val=1):
    result = np.where(X <= min_val, min_val, X)
    return np.where(result >= max_val, max_val, result)

@pytest.fixture(params=[(1,1), (2,2), (3,3), (4,4)], ids=lambda d: f"rows: {d[0]} cols: {d[1]}")
def random_numpy_array(request):
    return np.random.normal(request.param)

# This test checks behavior that overlaps between both functions.
# We use the fixture to generate the random numpy arrays, but we
# use the parametrize decorator to loop over the functions to check.
@pytest.mark.parametrize("func", [normalize, threshold], ids=lambda d: d.__name__)
def test_shape_same(func, random_numpy_array):
    X_norm = func(random_numpy_array)
    assert random_numpy_array.shape == X_norm.shape

# This test only checks specific behavior for the `normalize` function.
def test_min_max_normalise(random_numpy_array):
    X_norm = normalize(random_numpy_array)
    assert X_norm.min() == 0.0
    assert X_norm.max() == 1.0

# This test only checks specific behavior for the `threshold` function.
# Again, nothing is stopping us from using paramtrize to build a grid
# of values to check. In practice, be a bit careful that you don't make
# grids that are too big for comfort and take a long time to test.
@pytest.mark.parametrize("min_val", [-3, -2, -1], ids=lambda x: f"min_val:{x}")
@pytest.mark.parametrize("max_val", [3, 2, 1], ids=lambda x: f"max_val:{x}")
def test_min_max_threshold(random_numpy_array, min_val, max_val):
    X_norm = threshold(random_numpy_array, min_val, max_val)
    assert X_norm.min() >= min_val
    assert X_norm.max() <= max_val