Словарь Python - объединение листовых узлов ниже порога
Ниже приведен упрощенный пример дерева решений (dict()
) что я тренировался на Python:
tree= {'Age': {'> 55': 0.4, '< 18': {'Income': {'high': 0, 'low': 0.2}},
'18-35': 0.25, '36-55': {'Marital_Status': {'single': {'Income':
{'high': 0, 'low': 0.1}}, 'married': 0.05}}}}
Числа в конечных узлах (прямоугольниках) представляют вероятность появления метки класса (например, ИСТИНА) в этом узле. Визуально дерево выглядит так:
Я пытаюсь закодировать общий алгоритм после сокращения, который объединяет узлы, значения которых меньше 0.3
к их родительским узлам. Таким образом, полученное дерево с 0.3
порог будет выглядеть следующим образом при построении графика:
На втором рисунке обратите внимание, что Income
узел в Age<18
теперь объединены с корневым узлом Age
, И Age=36-55, Marital_Staus
был объединен в Age
поскольку сумма всех его листовых узлов (на нескольких уровнях) меньше 0,3.
Это неполный псевдокод, который я придумал (пока):
def post_prune (dictionary, threshold):
for k in dictionary.keys():
if isinstance(dictionary[k], dict): # interim node
post_prune(dictionary[k], threshold)
else: # leaf node
if dictionary[k]> threshold:
pass
else:
to_do = 'delete this node'
Хотел опубликовать вопрос, так как я чувствую, что это должно было быть решено много раз.
Спасибо.
PS: я не собираюсь использовать конечный результат для классификации, поэтому обрезка таким образом (косметически) работает.
1 ответ
Вы можете попробовать что-то вроде этого:
def simplify(tree, threshold):
# simplify tree bottom-up
for key, child in tree.items():
if isinstance(child, dict):
tree[key] = simplify(child, threshold)
# all child-nodes are leafs and smaller than threshold -> return max
if all(isinstance(child, str) and float(child) <= threshold
for child in tree.values()):
return max(tree.values(), key=float)
# else return tree itself
return tree
Пример:
>>> tree= {'Age': {'> 55': '0.4', '18-35': '0', \
'< 18': {'Income': {'high': '0', 'low': '0.2'}}, \
'36-55': {'Marital_Status': {'single': {'Income': {'high': '0', 'low': '0.1'}}, \
'married': '0.3'}}}}
>>> simplify(tree, 0.2)
{'Age': {'> 55': '0.4', '< 18': '0.2', '18-35': '0',
'36-55': {'Marital_Status': {'single': '0.1', 'married': '0.3'}}}}
Обновление: Похоже, я неправильно понял ваш вопрос: вы хотите, чтобы упрощенное дерево содержало суммы листьев, если их сумма меньше порога! Ваше предлагаемое редактирование было слегка отклонено. Попробуй это:
def simplify(tree, threshold):
# simplify tree bottom-up
for key, child in tree.items():
if isinstance(child, dict):
tree[key] = simplify(child, threshold)
# all child-nodes are leafs and sum smaller than threshold -> return sum
if all(isinstance(child, str) for child in tree.values()) \
and sum(map(float, tree.values())) <= threshold:
return str(sum(map(float, tree.values())))
# else return tree itself
return tree