Pip安装:
pip install graphviz
pip install tochviz
(或pip install git+https://github.com/szagoruyko/pytorchviz)
Graphviz 是 AT&T 开发的一款开源的图形可视化软件,可以根据dot脚本语言中绘制的无向图(显示了对象间最简单的关系)画出直观的树形图。
Graphviz在Windows中的安装需要下载Release包,并配置环境变量,否则会报错:
下载:
https://graphviz.gitlab.io/_pages/Download/Download_windows.html
安装:建议文件夹--Anaconda下新建graphviz
设置环境变量:path----..\\graphviz\bin
验证:cmd—dot –V
使用graphviz
用法:
源码:
# Rodrigo Caye Daudt
# https://rcdaudt.github.io/
# Daudt, R. C., Le Saux, B., & Boulch, A. "Fully convolutional siamese networks for change detection". In 2018 25th IEEE International Conference on Image Processing (ICIP) (pp. 4063-4067). IEEE.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.padding import ReplicationPad2d
class Unet(nn.Module):
"""EF segmentation network."""
def __init__(self, input_nbr, label_nbr):
super(Unet, self).__init__()
self.input_nbr = input_nbr
self.conv11 = nn.Conv2d(input_nbr, 16, kernel_size=3, padding=1)
self.bn11 = nn.BatchNorm2d(16)
self.do11 = nn.Dropout2d(p=0.2)
self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1)
self.bn12 = nn.BatchNorm2d(16)
self.do12 = nn.Dropout2d(p=0.2)
self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.bn21 = nn.BatchNorm2d(32)
self.do21 = nn.Dropout2d(p=0.2)
self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
self.bn22 = nn.BatchNorm2d(32)
self.do22 = nn.Dropout2d(p=0.2)
self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.bn31 = nn.BatchNorm2d(64)
self.do31 = nn.Dropout2d(p=0.2)
self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
self.bn32 = nn.BatchNorm2d(64)
self.do32 = nn.Dropout2d(p=0.2)
self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
self.bn33 = nn.BatchNorm2d(64)
self.do33 = nn.Dropout2d(p=0.2)
self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.bn41 = nn.BatchNorm2d(128)
self.do41 = nn.Dropout2d(p=0.2)
self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
self.bn42 = nn.BatchNorm2d(128)
self.do42 = nn.Dropout2d(p=0.2)
self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
self.bn43 = nn.BatchNorm2d(128)
self.do43 = nn.Dropout2d(p=0.2)
self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1)
self.conv43d = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1)
self.bn43d = nn.BatchNorm2d(128)
self.do43d = nn.Dropout2d(p=0.2)
self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1)
self.bn42d = nn.BatchNorm2d(128)
self.do42d = nn.Dropout2d(p=0.2)
self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1)
self.bn41d = nn.BatchNorm2d(64)
self.do41d = nn.Dropout2d(p=0.2)
self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1)
self.conv33d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1)
self.bn33d = nn.BatchNorm2d(64)
self.do33d = nn.Dropout2d(p=0.2)
self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1)
self.bn32d = nn.BatchNorm2d(64)
self.do32d = nn.Dropout2d(p=0.2)
self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1)
self.bn31d = nn.BatchNorm2d(32)
self.do31d = nn.Dropout2d(p=0.2)
self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1)
self.conv22d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1)
self.bn22d = nn.BatchNorm2d(32)
self.do22d = nn.Dropout2d(p=0.2)
self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1)
self.bn21d = nn.BatchNorm2d(16)
self.do21d = nn.Dropout2d(p=0.2)
self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1)
self.conv12d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1)
self.bn12d = nn.BatchNorm2d(16)
self.do12d = nn.Dropout2d(p=0.2)
self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1)
self.sm = nn.LogSoftmax(dim=1)
def forward(self, x1, x2):
x = torch.cat((x1, x2), 1)
"""Forward method."""
# Stage 1
x11 = self.do11(F.relu(self.bn11(self.conv11(x))))
x12 = self.do12(F.relu(self.bn12(self.conv12(x11))))
x1p = F.max_pool2d(x12, kernel_size=2, stride=2)
# Stage 2
x21 = self.do21(F.relu(self.bn21(self.conv21(x1p))))
x22 = self.do22(F.relu(self.bn22(self.conv22(x21))))
x2p = F.max_pool2d(x22, kernel_size=2, stride=2)
# Stage 3
x31 = self.do31(F.relu(self.bn31(self.conv31(x2p))))
x32 = self.do32(F.relu(self.bn32(self.conv32(x31))))
x33 = self.do33(F.relu(self.bn33(self.conv33(x32))))
x3p = F.max_pool2d(x33, kernel_size=2, stride=2)
# Stage 4
x41 = self.do41(F.relu(self.bn41(self.conv41(x3p))))
x42 = self.do42(F.relu(self.bn42(self.conv42(x41))))
x43 = self.do43(F.relu(self.bn43(self.conv43(x42))))
x4p = F.max_pool2d(x43, kernel_size=2, stride=2)
# Stage 4d
x4d = self.upconv4(x4p)
pad4 = ReplicationPad2d((0, x43.size(3) - x4d.size(3), 0, x43.size(2) - x4d.size(2)))
x4d = torch.cat((pad4(x4d), x43), 1)
x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d))))
x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d))))
x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d))))
# Stage 3d
x3d = self.upconv3(x41d)
pad3 = ReplicationPad2d((0, x33.size(3) - x3d.size(3), 0, x33.size(2) - x3d.size(2)))
x3d = torch.cat((pad3(x3d), x33), 1)
x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d))))
x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d))))
x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d))))
# Stage 2d
x2d = self.upconv2(x31d)
pad2 = ReplicationPad2d((0, x22.size(3) - x2d.size(3), 0, x22.size(2) - x2d.size(2)))
x2d = torch.cat((pad2(x2d), x22), 1)
x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d))))
x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d))))
# Stage 1d
x1d = self.upconv1(x21d)
pad1 = ReplicationPad2d((0, x12.size(3) - x1d.size(3), 0, x12.size(2) - x1d.size(2)))
x1d = torch.cat((pad1(x1d), x12), 1)
x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d))))
x11d = self.conv11d(x12d)
return self.sm(x11d)
from torchviz import make_dot
from torch.autograd import Variable
def main():
#input_nbr, label_nbr
x1 = Variable(torch.randn(1,3,256,256))
x2 = Variable(torch.randn(1,3,256,256))
model = Unet(6,2)
y = model(x1,x2)
vis_graph = make_dot(y.mean(),params = dict(model.named_parameters()))
vis_graph.view()
if __name__ == '__main__':
main()
作者:zuo668