from PIL import Image import torch import torchvision.transforms as transforms import numpy as np import matplotlib.pyplot as plt import cv2 def draw_CAM(model,img_path,save_path,resize=227,isSave=False,isShow=False): # 图像加载&预处理 img=Image.open(img_path).convert('RGB') loader = transforms.Compose([transforms.Resize(size=(resize,resize)),transforms.ToTensor()]) img = loader(img).unsqueeze(0) # unsqueeze(0)在第0维增加一个维度 # 获取模型输出的feature/score model.eval() # 测试模式,不启用BatchNormalization和Dropout feature=model.features(img) output=model.classifier(feature.view(1,-1)) # 预测得分最高的那一类对应的输出score pred = torch.argmax(output).item() pred_class = output[:, pred] # 记录梯度值 def hook_grad(grad): global feature_grad feature_grad=grad feature.register_hook(hook_grad) # 计算梯度 pred_class.backward() grads=feature_grad # 获取梯度 pooled_grads = torch.nn.functional.adaptive_avg_pool2d(grads, (1, 1)) # adaptive_avg_pool2d自适应平均池化函数,输出大小都为(1,1) # 此处batch size默认为1,所以去掉了第0维(batch size维) pooled_grads = pooled_grads[0] # shape为[batch,通道,size,size],此处batch为1,所以直接取[0]即取第一个batch的元素,就取到了每个batch内的所有元素 features = feature[0] # 取【0】原因同上 ########################## 导数(权重)乘以相应元素 # 512是最后一层feature的通道数 for i in range(len(features)): features[i, ...] *= pooled_grads[i, ...] # features[i, ...]与features[i]效果好像是一样的,都是第i个元素 ########################## # 绘制热力图 heatmap = features.detach().numpy() heatmap = np.mean(heatmap, axis=0) # axis=0,对各列求均值,返回1*n heatmap = np.maximum(heatmap, 0) heatmap /= np.max(heatmap) # 可视化原始热力图 if isShow: plt.matshow(heatmap) plt.show() img = cv2.imread(img_path) # 用cv2加载原始图像 heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0])) # 将热力图的大小调整为与原始图像相同 heatmap = np.uint8(255 * heatmap) # 将热力图转换为RGB格式 heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) # 将热力图应用于原始图像 superimposed_img = heatmap * 0.4 + img # 这里的0.4是热力图强度因子 # 将图像保存到硬盘 if isSave: cv2.imwrite(save_path, superimposed_img) # 展示图像 if isShow: superimposed_img/=255 plt.imshow(superimposed_img)
python
运行
import torchvision.models as models model=models.alexnet() # draw_CAM(model,'/Users/liuyanzhe/Study/陶瓷研究相关/陶瓷数据/ceramic_data/训练/1牡丹/mudan15.png','/Users/liuyanzhe/Downloads/热力图1.png',isSave=True,isShow=True) draw_CAM(model,'/Users/liuyanzhe/Downloads/2.jpg','/Users/liuyanzhe/Downloads/热力图1.png',isSave=True,isShow=True)
python
运行
import torchvision.models as models model=models.vgg16() # draw_CAM(model,'/Users/liuyanzhe/Study/陶瓷研究相关/陶瓷数据/ceramic_data/训练/1牡丹/mudan15.png','/Users/liuyanzhe/Downloads/热力图2.png',isSave=True,isShow=True) draw_CAM(model,'/Users/liuyanzhe/Downloads/2.jpg','/Users/liuyanzhe/Downloads/热力图2.png',isSave=True,isShow=True)
python
运行
相关知识
Pytorch可视化神经网络热力图
基于卷积神经网络的棉花病虫害识别与防治系统,resnet50,mobilenet模型【pytorch框架+python源码】
Pytorch神经网络【手写数字识别】
神经网络分类鸢尾花数据与可视化展示
PPT中如何制作热力图教程(ppt怎么制作热力图)
鸢尾花数据集可视化分析:柱状图、热力图、散点图与更多
excel热力图怎么花
excel表格怎么制作热力图
pytorch 花朵的分类识别
神经网络与深度学习
网址: Pytorch可视化神经网络热力图 https://m.huajiangbk.com/newsview2215551.html
上一篇: 黄振兴 |
下一篇: 描写梅花的优美句子23句 |