Table of Contents

Lucas Wu

Understanding Pandas Groupby

The groupby function in Pandas is used to group data by certain categories and perform basic statistical operations, complex aggregations, transformations, and filtering on each group.

Common use cases include:

Basic Statistics

Consider an e-commerce dataset of fruit sales. We can perform group-wise statistics in various ways.

Example 1: Summing Sales by Category

import pandas as pd

data = {
    'Date': ['2023-01-01', '2023-01-01', '2023-01-02', '2023-01-02',
             '2023-01-03', '2023-01-03', '2023-01-04', '2023-01-04'],
    'Category': ['Fruit', 'Vegetable', 'Fruit', 'Beverage', 'Vegetable', 'Beverage', 'Fruit', 'Vegetable'],
    'Sell': [200, 150, 220, 180, 160, 190, 210, 170]
}

df = pd.DataFrame(data)

grouped = df.groupby('Category')['Sell'].sum()

print(grouped)
# Output:
# Category
# Fruit        630
# Vegetable    480
# Beverage     370
# Name: Sell, dtype: int64

Example 2: Grouping by Two Columns

Here, sales are grouped by both date and category:

grouped = df.groupby(['Date', 'Category'])['Sell'].sum()

print(grouped)
# Output:
# Date        Category
# 2023-01-01  Fruit          200
#             Vegetable      150
# 2023-01-02  Fruit          220
#             Beverage       180
# 2023-01-03  Vegetable      160
#             Beverage       190
# 2023-01-04  Fruit          210
#             Vegetable      170
# Name: Sell, dtype: int64

Advanced Operations

Advanced operations commonly involve aggregation, transformation, and filtering.

Aggregation

Aggregation applies one or more statistical functions (e.g., sum, mean, max, min, count) to each group. Use the agg method for this.

grouped = df.groupby(['Category'])['Sell'].agg(['sum', 'max', 'min'])

print(grouped)
# Output:
#           sum  max  min
# Category
# Fruit      630  220  200
# Vegetable  480  170  150
# Beverage   370  190  180

Transformation

Transformations apply a function to group data and return a result with the same size as the original data. Use the transform method for this.

grouped = df.groupby(['Category'])['Sell'].transform(lambda x: x + 1)

print(grouped)
# Output:
# 0    201
# 1    151
# 2    221
# 3    181
# 4    161
# 5    191
# 6    211
# 7    171

Filtering

Filtering removes groups that do not meet certain conditions, returning only data from the remaining groups. Use the filter method for this.

For example, exclude categories with total sales less than or equal to 400:

grouped = df.groupby(['Category']).filter(lambda x: x['Sell'].sum() > 400)

print(grouped)
# Output:
#          Date    Category  Sell
# 0  2023-01-01       Fruit   200
# 1  2023-01-01   Vegetable   150
# 2  2023-01-02       Fruit   220
# 4  2023-01-03   Vegetable   160
# 6  2023-01-04       Fruit   210
# 7  2023-01-04   Vegetable   170

Original dataset for comparison:

print(df)
# Output:
#          Date    Category  Sell
# 0  2023-01-01       Fruit   200
# 1  2023-01-01   Vegetable   150
# 2  2023-01-02       Fruit   220
# 3  2023-01-02    Beverage   180
# 4  2023-01-03   Vegetable   160
# 5  2023-01-03    Beverage   190
# 6  2023-01-04       Fruit   210
# 7  2023-01-04   Vegetable   170
(完)