Словарь Python - объединение листовых узлов ниже порога

Ниже приведен упрощенный пример дерева решений (dict()) что я тренировался на Python:

tree= {'Age': {'> 55': 0.4, '< 18': {'Income': {'high': 0, 'low': 0.2}}, 
               '18-35': 0.25, '36-55': {'Marital_Status': {'single': {'Income': 
               {'high': 0, 'low': 0.1}}, 'married': 0.05}}}}

Числа в конечных узлах (прямоугольниках) представляют вероятность появления метки класса (например, ИСТИНА) в этом узле. Визуально дерево выглядит так:

Я пытаюсь закодировать общий алгоритм после сокращения, который объединяет узлы, значения которых меньше 0.3 к их родительским узлам. Таким образом, полученное дерево с 0.3 порог будет выглядеть следующим образом при построении графика:

На втором рисунке обратите внимание, что Income узел в Age<18 теперь объединены с корневым узлом Age, И Age=36-55, Marital_Staus был объединен в Age поскольку сумма всех его листовых узлов (на нескольких уровнях) меньше 0,3.

Это неполный псевдокод, который я придумал (пока):

def post_prune  (dictionary, threshold):

    for k in dictionary.keys():

        if isinstance(dictionary[k], dict): # interim node

            post_prune(dictionary[k], threshold)

        else: # leaf node

            if dictionary[k]> threshold:
                pass
            else:
                to_do = 'delete this node'

Хотел опубликовать вопрос, так как я чувствую, что это должно было быть решено много раз.

Спасибо.

PS: я не собираюсь использовать конечный результат для классификации, поэтому обрезка таким образом (косметически) работает.

1 ответ

Решение

Вы можете попробовать что-то вроде этого:

def simplify(tree, threshold):
    # simplify tree bottom-up
    for key, child in tree.items():
        if isinstance(child, dict):
            tree[key] = simplify(child, threshold)
    # all child-nodes are leafs and smaller than threshold -> return max
    if all(isinstance(child, str) and float(child) <= threshold 
           for child in tree.values()):
        return max(tree.values(), key=float)
    # else return tree itself
    return tree

Пример:

>>> tree= {'Age': {'> 55': '0.4', '18-35': '0', \
                   '< 18': {'Income': {'high': '0', 'low': '0.2'}}, \
                   '36-55': {'Marital_Status': {'single': {'Income': {'high': '0', 'low': '0.1'}}, \
                                                'married': '0.3'}}}}
>>> simplify(tree, 0.2)
{'Age': {'> 55': '0.4', '< 18': '0.2', '18-35': '0', 
         '36-55': {'Marital_Status': {'single': '0.1', 'married': '0.3'}}}}

Обновление: Похоже, я неправильно понял ваш вопрос: вы хотите, чтобы упрощенное дерево содержало суммы листьев, если их сумма меньше порога! Ваше предлагаемое редактирование было слегка отклонено. Попробуй это:

def simplify(tree, threshold):
    # simplify tree bottom-up
    for key, child in tree.items():
        if isinstance(child, dict):
            tree[key] = simplify(child, threshold)
    # all child-nodes are leafs and sum smaller than threshold -> return sum
    if all(isinstance(child, str) for child in tree.values()) \
           and sum(map(float, tree.values())) <= threshold:
        return str(sum(map(float, tree.values())))
    # else return tree itself
    return tree
Другие вопросы по тегам