Как визуализировать сеть в Pytorch?

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.models as models
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.autograd import Variable
from torchvision.models.vgg import model_urls
from torchviz import make_dot

batch_size = 3
learning_rate =0.0002
epoch = 50

resnet = models.resnet50(pretrained=True)
print resnet
make_dot(resnet)

Я хочу визуализировать resnet из моделей pytorch. Как мне это сделать? Я пытался использовать torchviz но это дает ошибку:

'ResNet' object has no attribute 'grad_fn'

8 ответов

Решение

make_dot ожидает переменную (т. е. тензор с grad_fn), а не сама модель.
пытаться:

x = torch.zeros(1, 3, 224, 224, dtype=torch.float, requires_grad=False)
out = resnet(x)
make_dot(out)  # plot graph of variable, not of a nn.Module

Вот три различных визуализации графиков с использованием разных инструментов.

Чтобы создать примеры визуализаций, я буду использовать простую RNN для анализа настроений, взятую из онлайн-руководства:

class RNN(nn.Module):

    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):

        super().__init__()
        self.embedding  = nn.Embedding(input_dim, embedding_dim)
        self.rnn        = nn.RNN(embedding_dim, hidden_dim)
        self.fc         = nn.Linear(hidden_dim, output_dim)

    def forward(self, text):

        embedding       = self.embedding(text)
        output, hidden  = self.rnn(embedding)

        return self.fc(hidden.squeeze(0))

Вот результат, если вы print() модель.

RNN(
  (embedding): Embedding(25002, 100)
  (rnn): RNN(100, 256)
  (fc): Linear(in_features=256, out_features=1, bias=True)
)

Ниже приведены результаты трех различных инструментов визуализации.

Для всех из них у вас должен быть фиктивный ввод, который может проходить через модель forward()метод. Простой способ получить этот ввод - получить пакет из вашего Dataloader, например:

batch = next(iter(dataloader_train))
yhat = model(batch.text) # Give dummy batch to forward().

Торчвиз

https://github.com/szagoruyko/pytorchviz

Я считаю, что этот инструмент генерирует свой график с помощью обратного прохода, поэтому все блоки используют компоненты PyTorch для обратного распространения.

from torchviz import make_dot

make_dot(yhat, params=dict(list(model.named_parameters()))).render("rnn_torchviz", format="png")

Этот инструмент создает следующий выходной файл:

Это единственный вывод, в котором четко упоминаются три слоя в моей модели, embedding, rnn, а также fc. Имена операторов взяты из обратного прохода, поэтому некоторые из них трудно понять.

Скрытый слой

https://github.com/waleedka/hiddenlayer

Я считаю, что этот инструмент использует прямой проход.

import hiddenlayer as hl

transforms = [ hl.transforms.Prune('Constant') ] # Removes Constant nodes from graph.

graph = hl.build_graph(model, batch.text, transforms=transforms)
graph.theme = hl.graph.THEMES['blue'].copy()
graph.save('rnn_hiddenlayer', format='png')

Вот результат. Мне нравится оттенок синего.

Я считаю, что вывод содержит слишком много деталей и запутывает мою архитектуру. Например, почемуunsqueeze упоминалось так много раз?

Нетрон

https://github.com/lutzroeder/netron

Этот инструмент представляет собой настольное приложение для Mac, Windows и Linux. Он основан на том, что модель сначала экспортируется в формат ONNX. Затем приложение считывает файл ONNX и отображает его. Затем есть возможность экспортировать модель в файл изображения.

input_names = ['Sentence']
output_names = ['yhat']
torch.onnx.export(model, batch.text, 'rnn.onnx', input_names=input_names, output_names=output_names)

Вот как модель выглядит в приложении. Я думаю, что этот инструмент довольно ловкий: вы можете масштабировать и панорамировать, а также можете детализировать слои и операторов. Единственный минус, который я обнаружил, - это то, что он делает только вертикальные макеты.

Это может быть поздний ответ. Но, особенно с__torch_function__развита, можно получить лучшую визуализацию. Вы можете попробовать мой проект здесь, torchview

Для вашего примера resnet50 вы можете проверить блокнот colab, здесь я демонстрирую визуализацию модели resnet18. Образ resnet18 создается следующим кодом

      import torchvision
from torchview import draw_graph

model_graph = draw_graph(resnet18(), input_size=(1,3,224,224), expand_nested=True)
model_graph.visual_graph

Он также принимает широкий спектр типов вывода/ввода (например, список, словарь).

Вы можете взглянуть на PyTorchViz ( https://github.com/szagoruyko/pytorchviz), "Небольшой пакет для создания визуализаций графиков и трассировок PyTorch".

Вот как вы это делаете с torchviz если вы хотите сохранить изображение:

# http://www.bnikolic.co.uk/blog/pytorch-detach.html

import torch
from torchviz import make_dot

x=torch.ones(10, requires_grad=True)
weights = {'x':x}

y=x**2
z=x**3
r=(y+z).sum()

make_dot(r).render("attached", format="png")

снимок экрана полученного изображения:

ht tps:https://stackru.com/images/6cc36b0aa93aaacdab344e80abe06a098c6b90d8.png

источник: ht tp://www.bnikolic.co.uk/blog/pytorch-detach.html

Если я могу беззастенчиво подключиться, я написал пакет TorchLens, который может визуализировать график модели PyTorch всего за одну строку кода (он должен работать для любой произвольной модели PyTorch, но дайте мне знать, если это не сработает для вашей модели).

Вы можете использовать TensorBoard для визуализации. TensorBoard теперь полностью поддерживается в PyTorch версии 1.2.0. Дополнительная информация: https://pytorch.org/docs/stable/tensorboard.html

Вы также можете использовать эту библиотеку tenorboarX, которая позволит вам использовать тензорную доску для множества полезных вещей наряду с визуализацией графиков.

Другие вопросы по тегам