Rolling Means per Group
Sofar we've only been calculating a rolling mean on a "single" series. But what
should we do if we're interested in calculating a smoothed line for every state
in our dataset? In that case we'd like our rolling mean to respect the boundaries
that we'd assign with a
.group_by. So how would we do that?
.transform() verb is what you need here. Let's give a small example
of how to use it. Let's start by grabbing a subset dataframe that has every
state in it.
import pandas as pd df = pd.read_csv("https://calmcode.io/datasets/birthdays.csv") subset_df = (df .assign(date=lambda d: pd.to_datetime(d['date'], format="%Y-%m-%d")) [['state', 'date', 'births']])
Next, we'll combine
(subset_df .set_index('date') .groupby('state')['births'] .transform(lambda d: d.rolling('20D', min_periods=1).mean()))
Here's what each line does.
- We add a date index with
- Next we group our dataset with
.groupby(). Each grouped set will have an index attached and we're getting a grouped-series object because we're only selecting the
- We're calling
.transform(). Usually you may have been used to calling
.aggregate()here. The main difference is that
.agg()will reduce the groups into a single row with calculated statistics. The
.transform()method will return an array that's as long as the grouped set going in. This way we're able to calculate a rolling mean that remains within a group.
The output is nice, but we'd like to add a column to our original dataframe. Let's refactor the code a little first though becauase it's an excellent opportunity to add a helper function.
def calc_rolling_mean(dataf, column=None, setting='30D'): return (dataf .groupby('state')[column] .transform(lambda d: d.rolling(setting, min_periods=1).mean()))
We now have a convenient
calc_rolling_mean function at our disposal. This
function will keep the
state group in mind as we're calculating rolling means.
A dataframe goes into the function and an array of equal length comes out. That
means that we can use it in an
(subset_df .set_index('date') .assign(rolling_births=lambda d: calc_rolling_mean(d, column='births')))
If you're interested in checking that this works as expected, you can sort the data and take a subset of a single state to confirm nothing broke.
(subset_df .set_index('date') .assign(rolling_births=lambda d: calc_rolling_mean(d, column='births')) .reset_index() .sort_values(["state", "date"]) .loc[lambda d: d['state'] == 'CA'])