Understanding Pandas Groupby
The groupby
function in Pandas is used to group data by certain categories and perform basic statistical operations, complex aggregations, transformations, and filtering on each group.
Common use cases include:
- Summary Statistics: Quickly calculate sums, averages, etc., for each group.
- Data Comparison: Compare performance differences between groups.
- Outlier Detection: Identify abnormal data points within groups.
- Feature Engineering: Create new features for machine learning, such as group-level means or counts.
- Data Cleaning: Correct or fill missing values using group-level statistics.
Basic Statistics
Consider an e-commerce dataset of fruit sales. We can perform group-wise statistics in various ways.
Example 1: Summing Sales by Category
import pandas as pd
data = {
'Date': ['2023-01-01', '2023-01-01', '2023-01-02', '2023-01-02',
'2023-01-03', '2023-01-03', '2023-01-04', '2023-01-04'],
'Category': ['Fruit', 'Vegetable', 'Fruit', 'Beverage', 'Vegetable', 'Beverage', 'Fruit', 'Vegetable'],
'Sell': [200, 150, 220, 180, 160, 190, 210, 170]
}
df = pd.DataFrame(data)
grouped = df.groupby('Category')['Sell'].sum()
print(grouped)
# Output:
# Category
# Fruit 630
# Vegetable 480
# Beverage 370
# Name: Sell, dtype: int64
Example 2: Grouping by Two Columns
Here, sales are grouped by both date and category:
grouped = df.groupby(['Date', 'Category'])['Sell'].sum()
print(grouped)
# Output:
# Date Category
# 2023-01-01 Fruit 200
# Vegetable 150
# 2023-01-02 Fruit 220
# Beverage 180
# 2023-01-03 Vegetable 160
# Beverage 190
# 2023-01-04 Fruit 210
# Vegetable 170
# Name: Sell, dtype: int64
Advanced Operations
Advanced operations commonly involve aggregation, transformation, and filtering.
Aggregation
Aggregation applies one or more statistical functions (e.g., sum
, mean
, max
, min
, count
) to each group. Use the agg
method for this.
grouped = df.groupby(['Category'])['Sell'].agg(['sum', 'max', 'min'])
print(grouped)
# Output:
# sum max min
# Category
# Fruit 630 220 200
# Vegetable 480 170 150
# Beverage 370 190 180
Transformation
Transformations apply a function to group data and return a result with the same size as the original data. Use the transform
method for this.
grouped = df.groupby(['Category'])['Sell'].transform(lambda x: x + 1)
print(grouped)
# Output:
# 0 201
# 1 151
# 2 221
# 3 181
# 4 161
# 5 191
# 6 211
# 7 171
Filtering
Filtering removes groups that do not meet certain conditions, returning only data from the remaining groups. Use the filter
method for this.
For example, exclude categories with total sales less than or equal to 400:
grouped = df.groupby(['Category']).filter(lambda x: x['Sell'].sum() > 400)
print(grouped)
# Output:
# Date Category Sell
# 0 2023-01-01 Fruit 200
# 1 2023-01-01 Vegetable 150
# 2 2023-01-02 Fruit 220
# 4 2023-01-03 Vegetable 160
# 6 2023-01-04 Fruit 210
# 7 2023-01-04 Vegetable 170
Original dataset for comparison:
print(df)
# Output:
# Date Category Sell
# 0 2023-01-01 Fruit 200
# 1 2023-01-01 Vegetable 150
# 2 2023-01-02 Fruit 220
# 3 2023-01-02 Beverage 180
# 4 2023-01-03 Vegetable 160
# 5 2023-01-03 Beverage 190
# 6 2023-01-04 Fruit 210
# 7 2023-01-04 Vegetable 170