AI Что делает shuffle=True и как не сломать порядок

AI

Редактор
Регистрация
23 Август 2023
Сообщения
2 822
Лучшие ответы
0
Реакции
0
Баллы
51
Offline
#1


Привет, Хабр! Сегодня рассмотрим невинный на первый взгляд параметр shuffle=True в train_test_split.

Под «перемешать» подразумевается применение псевдо-рандомного пермутационного алгоритма (обычно Fisher–Yates) к индексам выборки до того, как мы режем её на train/test. Цель — заставить train-и-test быть независимыми и одинаково распределёнными (i.i.d.). В scikit-learn эта логика зашита в параметр shuffle почти всех сплиттеров. В train_test_split он True по умолчанию, что прямо сказано в документации — «shuffle bool, default=True».

train_test_split


from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(
features,
target,
test_size=0.2,
random_state=42, # надо для репликабельности
shuffle=True # смело тасуем
)


Когда shuffle=True, функция:


  1. Генерирует случайную перестановку индексов (учитывая random_state).


  2. По ней делит данные.

Если присвоить shuffle=False, она просто берет “с головы” train_size строк, а “хвост” — test. Расклад очевиден из исходников и подтверждён в официальном доке.

Когда shuffle=False — обязательное условие

Временные ряды


Time-series — классика, где порядок — закон. Если мы перемешаем, то модель увидит будущее раньше прошлого и станет гадалкой. С точки зрения статистики это “look-ahead bias”. На том же Cross-Validated прямым текстом: в time-series нужно держать хронологию и юзать TimeSeriesSplit.

Зависимости внутри групп


Клинические данные, где несколько записей на одного пациента; логи пользователей, где одна сессия раскидана на десятки строк. Если рандомно раскидать строки по split’ам, то в test попадут “следы” тех же юзеров, что и в train, а это утечка через id-коррелированные признаки.

Продуктовые AB-эксперименты и всё, где важен session-level split


Тут групповая целостность — must-have. Мы шейкаем между группами, но не внутри.

Где случайность вредит обучению


  • Look-ahead bias — когда модель учится на будущей информации.


  • Target leakage — признак сформирован на основе целевой переменной или будущих значений.


  • Temporal leakage — метки пакуются по календарю: например, is_holiday. Если их перетасовать, тест узнает праздники раньше времени.

leakage — в целом сам по себе самый популярный баг ML-систем. Утечка часто выглядит невинно: добавили total_sales_next_month как фичу для модели, предсказывающей спрос — и получили 99 % R².

Как делать GroupShuffle или TimeSeriesSplit

GroupShuffleSplit


from sklearn.model_selection import GroupShuffleSplit

gss = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
train_idx, test_idx = next(gss.split(X, y, groups=user_id))

X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]

GroupShuffleSplit гарантирует, что у каждого user_id ровно один сплит: либо train, либо test. Под капотом он рандомно тасует сами группы, а не записи.

TimeSeriesSplit


from sklearn.model_selection import TimeSeriesSplit

tscv = TimeSeriesSplit(n_splits=5, test_size=24*7) # неделя в часах
for fold, (train_idx, test_idx) in enumerate(tscv.split(X)):
model.fit(X.iloc[train_idx], y.iloc[train_idx])
y_pred = model.predict(X.iloc[test_idx])
...

Сплиттер отдаёт растающий train и скользящее окно test. Шейка нет вовсе — порядок священен.

Кейс на примере магазина котиков


У нас зоомаркет Purrfect Shop. В базе лежат четыре ключевых таблицы:

Таблица​

Что внутри​

Гранулярность​

customers​

customer_id, демография, дата регистрации​

1 строка на владельца​

cats​

cat_id, порода, цена, дата поступления​

1 строка на котика​

orders​

order_id, customer_id, order_dt, total_sum​

1 строка на чек​

order_lines​

order_id, cat_id, qty​

1 строка на позицию в чеке​



Мы хотим решить две задачи:


  1. Churn-классификация: предсказать, уйдёт ли клиент в течение 30 дней.


  2. Прогноз оборота на следующие 7 дней из тайм-серии.

Плюс обучаем CNN, которая по фото угадывает породу для автозаполнения карточек.

Churn-модель: где shuffle обязателен и где нельзя


Наивный, но опасный подход:

from sklearn.model_selection import train_test_split

X = features_df # собрали признаки на уровне *клиента*
y = labels_df['will_churn']

X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.25, shuffle=True, random_state=42
)

Каждая строка — клиент, а значит зависимостей во времени нет. Shuffle здесь уместен: классы «уйдет/останется» раскиданы равномерно — модель не видит паттерна “первые 75 % клиентов — новые, последние — старые”.

Мы решаем “чуть улучшить” датасет и переходим на строки-чек. Один клиент = десятки чеков:

orders_df['will_churn'] = ...
X = orders_df.drop('will_churn', axis=1)
y = orders_df['will_churn']

# те же 5 строк кода:
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.25, shuffle=True, random_state=42
)

Теперь половина чеков Петя-Котолюб попала в train, половина — в test.

Утечка: признаки, вроде «total_sum_last_3_orders», пересекаются. Результат — AUC = 0.97 на тесте, но в проде падаем до 0.68 и ловим.

Правильно: GroupShuffleSplit

from sklearn.model_selection import GroupShuffleSplit

gss = GroupShuffleSplit(test_size=0.25, n_splits=1, random_state=42)
train_idx, test_idx = next(gss.split(X, y, groups=orders_df['customer_id']))

X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]

Теперь каждый клиент живёт только в одном сплите. AUC честно падает до 0.79 — зато в проде всё стабильно.

Прогноз оборота: shuffle=False, иначе бабах


Как ломается time series:

from sklearn.model_selection import train_test_split

# aggregated_df: daily revenue, lag-features, holidays, etc.
X = aggregated_df.drop('revenue', axis=1)
y = aggregated_df['revenue']

X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, shuffle=True, random_state=42
)

Модель случайно видит 2025-06-01 в train, а 2025-05-20 в test.

Используем TimeSeriesSplit:

from sklearn.model_selection import TimeSeriesSplit
import lightgbm as lgb
import numpy as np

tscv = TimeSeriesSplit(n_splits=5, test_size=7)
scores = []

for fold, (tr, val) in enumerate(tscv.split(X)):
model = lgb.LGBMRegressor(n_estimators=500, learning_rate=0.03)
model.fit(X.iloc

, y.iloc

)
preds = model.predict(X.iloc[val])
rmse = np.sqrt(((preds - y.iloc[val])**2).mean())
scores.append(rmse)
print(f'Fold {fold}: RMSE={rmse:.2f}')

print(f'Mean CV RMSE: {np.mean(scores):.2f}')

Без shuffle, с растущим окном.

Когда обучаем сверточку, порядок картинок не важен; наоборот, shuffle помогает стохастическому градиенту быстрее и стабильнее сходиться.

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

train_ds = datasets.ImageFolder(
root='cats/train',
transform=transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()
])
)

train_loader = DataLoader(
train_ds,
batch_size=32,
shuffle=True, # критично!
num_workers=4,
pin_memory=True
)

val_loader = DataLoader(
val_ds,
batch_size=32,
shuffle=False, # чтоб метрики не «плясали»
num_workers=4
)

На примере нашего котомаркета видим три сценария:


  • Табличные i.i.d. данные — мешаем, чтобы избежать систематической ошибки.


  • Группы / время — бережём порядок, потому что утечка дороже.


  • Vision / NLP — мешаем внутри эпохи, но держим валидацию детерминированной.

Если сомневаетесь — не шейкайте.

Если чувствуете запах временной или групповой зависимости — уберите руку от shuffle=True и достаньте правильный сплиттер.


Готовите данные для моделей машинного обучения? Тогда знаете: неправильный сплит — и модель учит будущее, «подглядывает» в тест и в итоге проваливается в проде.

Если вам близки такие темы, как предотвращение утечек, GroupShuffle, TimeSeriesSplit, честные A/B‑тесты и грамотная работа с временными рядами — в Otus пройдут скоро открытые уроки, которые рекомендуем посетить:

Хотите больше? Загляните в — там есть всё: от ML‑специализации до продвинутого Python.А чтобы ничего не пропустить, добавьте — пусть он напомнит вам, когда стоит подключиться к трансляции.
 
Сверху Снизу