广告
返回顶部
首页 > 资讯 > 后端开发 > Python >TransUnet官方代码测试自己的数据集(已训练完毕)
  • 925
分享到

TransUnet官方代码测试自己的数据集(已训练完毕)

深度学习pytorchpython人工智能卷积神经网络 2023-09-04 13:09:42 925人浏览 薄情痞子

Python 官方文档:入门教程 => 点击学习

摘要

*************************************************** 码字不易,收藏之余,别忘了给我点个赞吧! ***************************

***************************************************

码字不易,收藏之余,别忘了给我点个赞吧!

***************************************************

---------Start

首先参考上一篇的训练过程,这是测试过程,需要用到训练过程的权重。

1. TransUnet训练完毕之后,会生成权重文件(默认保存位置如下),snapshot_path为保存权重的路径。

在这里插入图片描述
权重文件
在这里插入图片描述

2. 修改test.py文件

调整数据集路径。
在这里插入图片描述
训练和测试时的图像设置相同大小,并设置主干模型的名称同训练时一致。
在这里插入图片描述

配置数据集相关信息。
在这里插入图片描述
手动添加权重。
在这里插入图片描述

3. 设置DataLoader

设置DataLoader中参数num_workers=0。
在这里插入图片描述

4. 修改utils.py文件

替换utils.py中的test_single_volume函数,原网络输出的是0,1,2,3,4像素的图片,分别代表5个类别,直接显示均呈黑色。对此,我们通过像素调整,使每个类别呈现不同的颜色。

def test_single_volume(image, label, net, classes, patch_size=[256, 256], test_save_path=None, case=None, z_spacing=1):    image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy()    _,x, y = image.shape    if x != patch_size[0] or y != patch_size[1]:        #缩放图像符合网络输入        image = zoom(image, (1,patch_size[0] / x, patch_size[1] / y), order=3)    input = torch.from_numpy(image).unsqueeze(0).float().cuda()    net.eval()    with torch.no_grad():        out = torch.argmax(torch.softmax(net(input), dim=1), dim=1).squeeze(0)        out = out.cpu().detach().numpy()        if x != patch_size[0] or y != patch_size[1]:            #缩放图像至原始大小            prediction = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0)        else:            prediction = out    metric_list = []    for i in range(1, classes):        metric_list.append(calculate_metric_percase(prediction == i, label == i))    if test_save_path is not None:        a1 = copy.deepcopy(prediction)        a2 = copy.deepcopy(prediction)        a3 = copy.deepcopy(prediction)        a1[a1 == 1] = 255        a1[a1 == 2] = 0        a1[a1 == 3] = 255        a1[a1 == 4] = 20        a2[a2 == 1] = 255        a2[a2 == 2] = 255        a2[a2 == 3] = 0        a2[a2 == 4] = 10        a3[a3 == 1] = 255        a3[a3 == 2] = 77        a3[a3 == 3] = 0        a3[a3 == 4] = 120        a1 = Image.fromarray(np.uint8(a1)).convert('L')        a2 = Image.fromarray(np.uint8(a2)).convert('L')        a3 = Image.fromarray(np.uint8(a3)).convert('L')        prediction = Image.merge('RGB', [a1, a2, a3])        prediction.save(test_save_path+'/'+case+'.png')    return metric_list

**方便小伙伴理解这部分代码,特意做了个图,a1,a2,a3分别代表RGB三个通道,开始它们的值通过deepcopy函数直接赋值,故三者的值都是一样的。
这里拿类别1举例:a1[a12]=0代表R通道中输出结果为2的赋值0,
a2[a2
2]=255代表G通道中输出结果为2的赋值255,
a3[a3==2]=77代表B通道中输出结果为2的赋值77,(0,255,77)对应就是绿色,类别2就是绿色(轮子)。
然后通过Image.merge(‘RGB’, [a1, a2, a3])函数合并三个通道,此时prediction就成了三通道彩色图。

在这里插入图片描述
在这里插入图片描述

至此,设置完毕,右键run运行。

5. 测试结束

测试结束后,会在根目录下生成predictions文件夹,文件夹的内容如下。
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

来源地址:https://blog.csdn.net/qq_37652891/article/details/123470578

--结束END--

本文标题: TransUnet官方代码测试自己的数据集(已训练完毕)

本文链接: https://www.lsjlt.com/news/393629.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

本篇文章演示代码以及资料文档资料下载

下载Word文档到电脑,方便收藏和打印~

下载Word文档
猜你喜欢
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作