Как визуализировать сеть в Pytorch?
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.models as models
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.autograd import Variable
from torchvision.models.vgg import model_urls
from torchviz import make_dot
batch_size = 3
learning_rate =0.0002
epoch = 50
resnet = models.resnet50(pretrained=True)
print resnet
make_dot(resnet)
Я хочу визуализировать resnet
из моделей pytorch. Как мне это сделать? Я пытался использовать torchviz
но это дает ошибку:
'ResNet' object has no attribute 'grad_fn'
8 ответов
make_dot
ожидает переменную (т. е. тензор с grad_fn
), а не сама модель.
пытаться:
x = torch.zeros(1, 3, 224, 224, dtype=torch.float, requires_grad=False)
out = resnet(x)
make_dot(out) # plot graph of variable, not of a nn.Module
Вот три различных визуализации графиков с использованием разных инструментов.
Чтобы создать примеры визуализаций, я буду использовать простую RNN для анализа настроений, взятую из онлайн-руководства:
class RNN(nn.Module):
def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):
super().__init__()
self.embedding = nn.Embedding(input_dim, embedding_dim)
self.rnn = nn.RNN(embedding_dim, hidden_dim)
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, text):
embedding = self.embedding(text)
output, hidden = self.rnn(embedding)
return self.fc(hidden.squeeze(0))
Вот результат, если вы print()
модель.
RNN(
(embedding): Embedding(25002, 100)
(rnn): RNN(100, 256)
(fc): Linear(in_features=256, out_features=1, bias=True)
)
Ниже приведены результаты трех различных инструментов визуализации.
Для всех из них у вас должен быть фиктивный ввод, который может проходить через модель forward()
метод. Простой способ получить этот ввод - получить пакет из вашего Dataloader, например:
batch = next(iter(dataloader_train))
yhat = model(batch.text) # Give dummy batch to forward().
Торчвиз
https://github.com/szagoruyko/pytorchviz
Я считаю, что этот инструмент генерирует свой график с помощью обратного прохода, поэтому все блоки используют компоненты PyTorch для обратного распространения.
from torchviz import make_dot
make_dot(yhat, params=dict(list(model.named_parameters()))).render("rnn_torchviz", format="png")
Этот инструмент создает следующий выходной файл:
Это единственный вывод, в котором четко упоминаются три слоя в моей модели, embedding
, rnn
, а также fc
. Имена операторов взяты из обратного прохода, поэтому некоторые из них трудно понять.
Скрытый слой
https://github.com/waleedka/hiddenlayer
Я считаю, что этот инструмент использует прямой проход.
import hiddenlayer as hl
transforms = [ hl.transforms.Prune('Constant') ] # Removes Constant nodes from graph.
graph = hl.build_graph(model, batch.text, transforms=transforms)
graph.theme = hl.graph.THEMES['blue'].copy()
graph.save('rnn_hiddenlayer', format='png')
Вот результат. Мне нравится оттенок синего.
Я считаю, что вывод содержит слишком много деталей и запутывает мою архитектуру. Например, почемуunsqueeze
упоминалось так много раз?
Нетрон
https://github.com/lutzroeder/netron
Этот инструмент представляет собой настольное приложение для Mac, Windows и Linux. Он основан на том, что модель сначала экспортируется в формат ONNX. Затем приложение считывает файл ONNX и отображает его. Затем есть возможность экспортировать модель в файл изображения.
input_names = ['Sentence']
output_names = ['yhat']
torch.onnx.export(model, batch.text, 'rnn.onnx', input_names=input_names, output_names=output_names)
Вот как модель выглядит в приложении. Я думаю, что этот инструмент довольно ловкий: вы можете масштабировать и панорамировать, а также можете детализировать слои и операторов. Единственный минус, который я обнаружил, - это то, что он делает только вертикальные макеты.
Это может быть поздний ответ. Но, особенно с__torch_function__
развита, можно получить лучшую визуализацию. Вы можете попробовать мой проект здесь, torchview
Для вашего примера resnet50 вы можете проверить блокнот colab, здесь я демонстрирую визуализацию модели resnet18. Образ resnet18 создается следующим кодом
import torchvision
from torchview import draw_graph
model_graph = draw_graph(resnet18(), input_size=(1,3,224,224), expand_nested=True)
model_graph.visual_graph
Он также принимает широкий спектр типов вывода/ввода (например, список, словарь).
Вы можете взглянуть на PyTorchViz ( https://github.com/szagoruyko/pytorchviz), "Небольшой пакет для создания визуализаций графиков и трассировок PyTorch".
Вот как вы это делаете с torchviz
если вы хотите сохранить изображение:
# http://www.bnikolic.co.uk/blog/pytorch-detach.html
import torch
from torchviz import make_dot
x=torch.ones(10, requires_grad=True)
weights = {'x':x}
y=x**2
z=x**3
r=(y+z).sum()
make_dot(r).render("attached", format="png")
снимок экрана полученного изображения:
ht tps:https://stackru.com/images/6cc36b0aa93aaacdab344e80abe06a098c6b90d8.png
источник: ht tp://www.bnikolic.co.uk/blog/pytorch-detach.html
Если я могу беззастенчиво подключиться, я написал пакет TorchLens, который может визуализировать график модели PyTorch всего за одну строку кода (он должен работать для любой произвольной модели PyTorch, но дайте мне знать, если это не сработает для вашей модели).
Вы можете использовать TensorBoard для визуализации. TensorBoard теперь полностью поддерживается в PyTorch версии 1.2.0. Дополнительная информация: https://pytorch.org/docs/stable/tensorboard.html
Вы также можете использовать эту библиотеку tenorboarX, которая позволит вам использовать тензорную доску для множества полезных вещей наряду с визуализацией графиков.