graphviz-可视化pytorch网络模型

Cybill ·
更新时间:2024-09-21
· 597 次阅读

anaconda安装和配置graphviz

Pip安装:

pip install graphviz

pip install tochviz

(或pip install git+https://github.com/szagoruyko/pytorchviz)

Graphviz 是 AT&T 开发的一款开源的图形可视化软件,可以根据dot脚本语言中绘制的无向图(显示了对象间最简单的关系)画出直观的树形图。
Graphviz在Windows中的安装需要下载Release包,并配置环境变量,否则会报错:

下载:

https://graphviz.gitlab.io/_pages/Download/Download_windows.html

安装:建议文件夹--Anaconda下新建graphviz

设置环境变量:path----..\\graphviz\bin

验证:cmd—dot –V

 

使用graphviz

用法:

 

 

源码:

# Rodrigo Caye Daudt # https://rcdaudt.github.io/ # Daudt, R. C., Le Saux, B., & Boulch, A. "Fully convolutional siamese networks for change detection". In 2018 25th IEEE International Conference on Image Processing (ICIP) (pp. 4063-4067). IEEE. import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.modules.padding import ReplicationPad2d class Unet(nn.Module): """EF segmentation network.""" def __init__(self, input_nbr, label_nbr): super(Unet, self).__init__() self.input_nbr = input_nbr self.conv11 = nn.Conv2d(input_nbr, 16, kernel_size=3, padding=1) self.bn11 = nn.BatchNorm2d(16) self.do11 = nn.Dropout2d(p=0.2) self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1) self.bn12 = nn.BatchNorm2d(16) self.do12 = nn.Dropout2d(p=0.2) self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1) self.bn21 = nn.BatchNorm2d(32) self.do21 = nn.Dropout2d(p=0.2) self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1) self.bn22 = nn.BatchNorm2d(32) self.do22 = nn.Dropout2d(p=0.2) self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1) self.bn31 = nn.BatchNorm2d(64) self.do31 = nn.Dropout2d(p=0.2) self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1) self.bn32 = nn.BatchNorm2d(64) self.do32 = nn.Dropout2d(p=0.2) self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1) self.bn33 = nn.BatchNorm2d(64) self.do33 = nn.Dropout2d(p=0.2) self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1) self.bn41 = nn.BatchNorm2d(128) self.do41 = nn.Dropout2d(p=0.2) self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1) self.bn42 = nn.BatchNorm2d(128) self.do42 = nn.Dropout2d(p=0.2) self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1) self.bn43 = nn.BatchNorm2d(128) self.do43 = nn.Dropout2d(p=0.2) self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1) self.conv43d = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1) self.bn43d = nn.BatchNorm2d(128) self.do43d = nn.Dropout2d(p=0.2) self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1) self.bn42d = nn.BatchNorm2d(128) self.do42d = nn.Dropout2d(p=0.2) self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1) self.bn41d = nn.BatchNorm2d(64) self.do41d = nn.Dropout2d(p=0.2) self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1) self.conv33d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1) self.bn33d = nn.BatchNorm2d(64) self.do33d = nn.Dropout2d(p=0.2) self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1) self.bn32d = nn.BatchNorm2d(64) self.do32d = nn.Dropout2d(p=0.2) self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1) self.bn31d = nn.BatchNorm2d(32) self.do31d = nn.Dropout2d(p=0.2) self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1) self.conv22d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1) self.bn22d = nn.BatchNorm2d(32) self.do22d = nn.Dropout2d(p=0.2) self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1) self.bn21d = nn.BatchNorm2d(16) self.do21d = nn.Dropout2d(p=0.2) self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1) self.conv12d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1) self.bn12d = nn.BatchNorm2d(16) self.do12d = nn.Dropout2d(p=0.2) self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1) self.sm = nn.LogSoftmax(dim=1) def forward(self, x1, x2): x = torch.cat((x1, x2), 1) """Forward method.""" # Stage 1 x11 = self.do11(F.relu(self.bn11(self.conv11(x)))) x12 = self.do12(F.relu(self.bn12(self.conv12(x11)))) x1p = F.max_pool2d(x12, kernel_size=2, stride=2) # Stage 2 x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) x22 = self.do22(F.relu(self.bn22(self.conv22(x21)))) x2p = F.max_pool2d(x22, kernel_size=2, stride=2) # Stage 3 x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) x33 = self.do33(F.relu(self.bn33(self.conv33(x32)))) x3p = F.max_pool2d(x33, kernel_size=2, stride=2) # Stage 4 x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) x43 = self.do43(F.relu(self.bn43(self.conv43(x42)))) x4p = F.max_pool2d(x43, kernel_size=2, stride=2) # Stage 4d x4d = self.upconv4(x4p) pad4 = ReplicationPad2d((0, x43.size(3) - x4d.size(3), 0, x43.size(2) - x4d.size(2))) x4d = torch.cat((pad4(x4d), x43), 1) x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d)))) x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d)))) x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d)))) # Stage 3d x3d = self.upconv3(x41d) pad3 = ReplicationPad2d((0, x33.size(3) - x3d.size(3), 0, x33.size(2) - x3d.size(2))) x3d = torch.cat((pad3(x3d), x33), 1) x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d)))) x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d)))) x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d)))) # Stage 2d x2d = self.upconv2(x31d) pad2 = ReplicationPad2d((0, x22.size(3) - x2d.size(3), 0, x22.size(2) - x2d.size(2))) x2d = torch.cat((pad2(x2d), x22), 1) x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d)))) x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d)))) # Stage 1d x1d = self.upconv1(x21d) pad1 = ReplicationPad2d((0, x12.size(3) - x1d.size(3), 0, x12.size(2) - x1d.size(2))) x1d = torch.cat((pad1(x1d), x12), 1) x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d)))) x11d = self.conv11d(x12d) return self.sm(x11d) from torchviz import make_dot from torch.autograd import Variable def main(): #input_nbr, label_nbr x1 = Variable(torch.randn(1,3,256,256)) x2 = Variable(torch.randn(1,3,256,256)) model = Unet(6,2) y = model(x1,x2) vis_graph = make_dot(y.mean(),params = dict(model.named_parameters())) vis_graph.view() if __name__ == '__main__': main()

 


作者:zuo668



graphviz pytorch 模型

需要 登录 后方可回复, 如果你还没有账号请 注册新账号
相关文章