Результат rpart является рутом, но данные показывают информационный прирост
У меня есть набор данных с частотой событий менее 3% (то есть около 700 записей с классом 1 и 27000 записей с классом 0).
ID V1 V2 V3 V5 V6 Target
SDataID3 161 ONE 1 FOUR 0 0
SDataID4 11 TWO 2 THREE 2 1
SDataID5 32 TWO 2 FOUR 2 0
SDataID7 13 ONE 1 THREE 2 0
SDataID8 194 TWO 2 FOUR 0 0
SDataID10 63 THREE 3 FOUR 0 1
SDataID11 89 ONE 1 FOUR 0 0
SDataID13 78 TWO 2 FOUR 0 0
SDataID14 87 TWO 2 THREE 1 0
SDataID15 81 ONE 1 THREE 0 0
SDataID16 63 ONE 3 FOUR 0 0
SDataID17 198 ONE 3 THREE 0 0
SDataID18 9 TWO 3 THREE 0 0
SDataID19 196 ONE 2 THREE 2 0
SDataID20 189 TWO 2 ONE 1 0
SDataID21 116 THREE 3 TWO 0 0
SDataID24 104 ONE 1 FOUR 0 0
SDataID25 5 ONE 2 ONE 3 0
SDataID28 173 TWO 3 FOUR 0 0
SDataID29 5 ONE 3 ONE 3 0
SDataID31 87 ONE 3 FOUR 3 0
SDataID32 5 ONE 2 THREE 1 0
SDataID34 45 ONE 1 FOUR 0 0
SDataID35 19 TWO 2 THREE 0 0
SDataID37 133 TWO 2 FOUR 0 0
SDataID38 8 ONE 1 THREE 0 0
SDataID39 42 ONE 1 THREE 0 0
SDataID43 45 ONE 1 THREE 1 0
SDataID44 45 ONE 1 FOUR 0 0
SDataID45 176 ONE 1 FOUR 0 0
SDataID46 63 ONE 1 THREE 3 0
Я пытаюсь найти раскол, используя дерево решений. Но результатом дерева является только 1 корень.
> library(rpart)
> tree <- rpart(Target ~ ., data=subset(train, select=c( -Record.ID) ),method="class")
> printcp(tree)
Classification tree:
rpart(formula = Target ~ ., data = subset(train, select = c(-Record.ID)), method = "class")
Variables actually used in tree construction:
character(0)
Root node error: 749/18239 = 0.041066
n= 18239
CP nsplit rel error xerror xstd
1 0 0 1 0 0
После прочтения большинства ресурсов в Stackru я ослабил / подправил параметры управления, которые дали мне желаемое дерево решений.
> tree <- rpart(Target ~ ., data=subset(train, select=c( -Record.ID) ),method="class" ,control =rpart.control(minsplit = 1,minbucket=2, cp=0.00002))
> printcp(tree)
Classification tree:
rpart(formula = Target ~ ., data = subset(train, select = c(-Record.ID)),
method = "class", control = rpart.control(minsplit = 1, minbucket = 2,
cp = 2e-05))
Variables actually used in tree construction:
[1] V5 V2 V1
[4] V3 V6
Root node error: 749/18239 = 0.041066
n= 18239
CP nsplit rel error xerror xstd
1 0.00024275 0 1.00000 1.0000 0.035781
2 0.00019073 20 0.99466 1.0267 0.036235
3 0.00016689 34 0.99199 1.0307 0.036302
4 0.00014835 54 0.98798 1.0334 0.036347
5 0.00002000 63 0.98665 1.0427 0.036504
Когда я сократил дерево, это привело к дереву с единственным узлом.
> pruned.tree <- prune(tree, cp = tree$cptable[which.min(tree$cptable[,"xerror"]),"CP"])
> printcp(pruned.tree)
Classification tree:
rpart(formula = Target ~ ., data = subset(train, select = c(-Record.ID)),
method = "class", control = rpart.control(minsplit = 1, minbucket = 2,
cp = 2e-05))
Variables actually used in tree construction:
character(0)
Root node error: 749/18239 = 0.041066
n= 18239
CP nsplit rel error xerror xstd
1 0.00024275 0 1 1 0.035781
Дерево не должно выдавать только корневой узел, потому что математически на данном узле (пример представлен) мы получаем информационный прирост. Я не знаю, делаю ли я ошибку, обрезая или есть проблема с rpart в обработке набора данных с низкой частотой событий?
NODE p 1-p Entropy Weights Ent*Weight # Obs
Node 1 0.032 0.968 0.204324671 0.351398601 0.071799404 10653
Node 2 0.05 0.95 0.286396957 0.648601399 0.185757467 19663
Sum(Ent*wght) 0.257556871
Information gain 0.742443129
1 ответ
Предоставленные вами данные не отражают соотношение двух целевых классов, поэтому я настроил данные, чтобы лучше отразить это (см. Раздел "Данные"):
> prop.table(table(train$Target))
0 1
0.96707581 0.03292419
> 700/27700
[1] 0.02527076
Соотношения сейчас относительно близки...
library(rpart)
tree <- rpart(Target ~ ., data=train, method="class")
printcp(tree)
Результаты в:
Classification tree:
rpart(formula = Target ~ ., data = train, method = "class")
Variables actually used in tree construction:
character(0)
Root node error: 912/27700 = 0.032924
n= 27700
CP nsplit rel error xerror xstd
1 0 0 1 0 0
Теперь причина того, что вы видите только корневой узел для своей первой модели, возможно, связана с тем, что у вас чрезвычайно несбалансированные целевые классы, и поэтому ваши независимые переменные не могли предоставить достаточно информации для роста дерева. Мой пример данных имеет 3,3% случаев, но у вас только около 2,5%!
Как вы уже упоминали, есть способ заставить rpart
вырастить дерево. То есть переопределить параметр сложности по умолчанию (cp
). Мера сложности - это сочетание размера дерева и того, насколько хорошо дерево разделяет целевые классы. От ?rpart.control
"Любое разделение, которое не уменьшает общее отсутствие соответствия фактором cp, не предпринимается". Это означает, что ваша модель на данный момент не имеет разделения за корневым узлом, что снижает уровень сложности, достаточный для rpart
принимать во внимание. Мы можем ослабить этот порог того, что считается "достаточным", установив минимум или минус cp
(отрицательный cp
в основном заставляет дерево расти до своего полного размера).
tree <- rpart(Target ~ ., data=train, method="class" ,parms = list(split = 'information'),
control =rpart.control(minsplit = 1,minbucket=2, cp=0.00002))
printcp(tree)
Результаты в:
Classification tree:
rpart(formula = Target ~ ., data = train, method = "class", parms = list(split = "information"),
control = rpart.control(minsplit = 1, minbucket = 2, cp = 2e-05))
Variables actually used in tree construction:
[1] ID V1 V2 V3 V5 V6
Root node error: 912/27700 = 0.032924
n= 27700
CP nsplit rel error xerror xstd
1 4.1118e-04 0 1.00000 1.0000 0.032564
2 3.6550e-04 30 0.98355 1.0285 0.033009
3 3.2489e-04 45 0.97807 1.0702 0.033647
4 3.1328e-04 106 0.95504 1.0877 0.033911
5 2.7412e-04 116 0.95175 1.1031 0.034141
6 2.5304e-04 132 0.94737 1.1217 0.034417
7 2.1930e-04 149 0.94298 1.1458 0.034771
8 1.9936e-04 159 0.94079 1.1502 0.034835
9 1.8275e-04 181 0.93640 1.1645 0.035041
10 1.6447e-04 193 0.93421 1.1864 0.035356
11 1.5664e-04 233 0.92654 1.1853 0.035341
12 1.3706e-04 320 0.91228 1.2083 0.035668
13 1.2183e-04 344 0.90899 1.2127 0.035730
14 9.9681e-05 353 0.90789 1.2237 0.035885
15 2.0000e-05 364 0.90680 1.2259 0.035915
Как вы можете видеть, дерево выросло до размера, который снижает уровень сложности на минимум cp
, Следует отметить две вещи:
- В нуле
nsplit
,CP
уже настолько низок, как 0,0004, где по умолчаниюcp
вrpart
установлен на 0,01. - Начиная с
nsplit == 0
ошибка перекрестной проверки (xerror
) увеличивается при увеличении количества разбиений.
Оба из них указывают, что ваша модель соответствует данным на nsplit == 0
и далее, поскольку добавление большего количества независимых переменных в вашу модель не добавляет достаточного количества информации (недостаточное снижение CP), чтобы уменьшить ошибку перекрестной проверки. С учетом вышесказанного, модель корневого узла в этом случае является лучшей моделью, что объясняет, почему ваша исходная модель имеет только корневой узел.
pruned.tree <- prune(tree, cp = tree$cptable[which.min(tree$cptable[,"xerror"]),"CP"])
printcp(pruned.tree)
Результаты в:
Classification tree:
rpart(formula = Target ~ ., data = train, method = "class", parms = list(split = "information"),
control = rpart.control(minsplit = 1, minbucket = 2, cp = 2e-05))
Variables actually used in tree construction:
character(0)
Root node error: 912/27700 = 0.032924
n= 27700
CP nsplit rel error xerror xstd
1 0.00041118 0 1 1 0.032564
Что касается части сокращения, то теперь стало понятнее, почему ваше сокращенное дерево является деревом корневых узлов, поскольку дерево, которое выходит за пределы 0 разбиений, имеет растущую ошибку перекрестной проверки. Взять дерево с минимумом xerror
оставит вас с корневым узлом дерева, как и ожидалось.
Получение информации в основном говорит вам, сколько "информации" добавляется для каждого разделения. Технически, каждое разделение имеет некоторую степень получения информации, поскольку вы добавляете больше переменных в свою модель (получение информации всегда неотрицательно). Вам следует подумать о том, уменьшает ли это дополнительное усиление (или отсутствие усиления) ошибки, чтобы вы могли оправдать более сложную модель. Следовательно, компромисс между смещением и дисперсией.
В этом случае вам не имеет смысла уменьшать cp
и позже обрежьте полученное дерево. с тех пор, установив низкий cp
вы говорите rpart
делать расщепления, даже если они перезаписываются, при этом обрезка "режет" все узлы, которые перезаписываются.
Данные:
Обратите внимание, что я перетасовываю строки для каждого столбца и выборки вместо выборки индексов строк. Это связано с тем, что предоставленные вами данные, вероятно, не являются случайной выборкой вашего исходного набора данных (вероятно, смещенной), поэтому я в основном случайным образом создаю новые наблюдения с комбинациями ваших существующих строк, которые, мы надеемся, уменьшат это смещение.
init_train = structure(list(ID = structure(c(16L, 24L, 29L, 30L, 31L, 1L,
2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 13L, 14L, 15L,
17L, 18L, 19L, 20L, 21L, 22L, 23L, 25L, 26L, 27L, 28L), .Label = c("SDataID10",
"SDataID11", "SDataID13", "SDataID14", "SDataID15", "SDataID16",
"SDataID17", "SDataID18", "SDataID19", "SDataID20", "SDataID21",
"SDataID24", "SDataID25", "SDataID28", "SDataID29", "SDataID3",
"SDataID31", "SDataID32", "SDataID34", "SDataID35", "SDataID37",
"SDataID38", "SDataID39", "SDataID4", "SDataID43", "SDataID44",
"SDataID45", "SDataID46", "SDataID5", "SDataID7", "SDataID8"), class = "factor"),
V1 = c(161L, 11L, 32L, 13L, 194L, 63L, 89L, 78L, 87L, 81L,
63L, 198L, 9L, 196L, 189L, 116L, 104L, 5L, 173L, 5L, 87L,
5L, 45L, 19L, 133L, 8L, 42L, 45L, 45L, 176L, 63L), V2 = structure(c(1L,
3L, 3L, 1L, 3L, 2L, 1L, 3L, 3L, 1L, 1L, 1L, 3L, 1L, 3L, 2L,
1L, 1L, 3L, 1L, 1L, 1L, 1L, 3L, 3L, 1L, 1L, 1L, 1L, 1L, 1L
), .Label = c("ONE", "THREE", "TWO"), class = "factor"),
V3 = c(1L, 2L, 2L, 1L, 2L, 3L, 1L, 2L, 2L, 1L, 3L, 3L, 3L,
2L, 2L, 3L, 1L, 2L, 3L, 3L, 3L, 2L, 1L, 2L, 2L, 1L, 1L, 1L,
1L, 1L, 1L), V5 = structure(c(1L, 3L, 1L, 3L, 1L, 1L, 1L,
1L, 3L, 3L, 1L, 3L, 3L, 3L, 2L, 4L, 1L, 2L, 1L, 2L, 1L, 3L,
1L, 3L, 1L, 3L, 3L, 3L, 1L, 1L, 3L), .Label = c("FOUR", "ONE",
"THREE", "TWO"), class = "factor"), V6 = c(0L, 2L, 2L, 2L,
0L, 0L, 0L, 0L, 1L, 0L, 0L, 0L, 0L, 2L, 1L, 0L, 0L, 3L, 0L,
3L, 3L, 1L, 0L, 0L, 0L, 0L, 0L, 1L, 0L, 0L, 3L), Target = c(0L,
1L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L,
0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L
)), .Names = c("ID", "V1", "V2", "V3", "V5", "V6", "Target"
), class = "data.frame", row.names = c(NA, -31L))
set.seed(1000)
train = as.data.frame(lapply(init_train, function(x) sample(x, 27700, replace = TRUE)))