Поведение train_test_split() от Scikit-learn

Мне любопытно, как метод train_test_split() Scikit-learn будет вести себя в следующем сценарии:

Мнимый набор данных:

id, count, size
1, 4, 8
2, 5, 9
3, 6, 0

скажем, я бы разделил его на два отдельных набора, как это (сохраняя "id" в обоих):

id, count      |       id, size
1, 4           |       1, 8
2, 5           |       2, 9
3, 6           |       3, 0

И разделить их обоих с train_test_split() с тем же random_state из 0, Будет ли порядок обоих одинаковым с 'id' в качестве ссылки? (поскольку вы перетасовываете один и тот же набор данных, но с разными частями)

Мне интересно, как это работает, потому что у меня есть две модели. Первый обучается с набором данных и добавляет его результаты в набор данных, часть которого затем используется для обучения второй модели.

При этом важно, чтобы при тестировании обобщения второй модели не использовались точки данных, которые также использовались для обучения первой модели. Это потому, что данные были "видны раньше", и модель будет знать, что с ними делать, поэтому вы больше не будете проверять обобщение на новые данные.

Было бы здорово, если train_test_split() будет перетасовывать его так же, так как тогда не нужно будет отслеживать, какие данные использовались для обучения первого алгоритма, чтобы предотвратить загрязнение результатов теста.

1 ответ

Решение

Они должны иметь одинаковые результирующие индексы, если вы используете те же random_state параметр в каждом вызове.

Однако вы также можете просто изменить порядок действий. Вызовите тест / разделение поезда на родительский набор данных, а затем создайте два подмножества как из набора тестов, так и из набора поездов.

Пример:

print(df)
   id  count  size
0   1      4     8
1   2      5     9
2   3      6     0

from sklearn.model_selection import train_test_split
dfa = df[['id', 'count']].copy()
dfb = df[['id', 'size']].copy()
rstate = 123
traina, testa = train_test_split(dfa, random_state=123)
trainb, testb = train_test_split(dfb, random_state=123)
assert traina.index.equals(trainb.index)
# True
Другие вопросы по тегам