广告
返回顶部
首页 > 资讯 > 后端开发 > Python >Python Pytorch图像检索实例分析
  • 740
分享到

Python Pytorch图像检索实例分析

2023-06-29 22:06:24 740人浏览 安东尼

Python 官方文档:入门教程 => 点击学习

摘要

这篇文章主要介绍“python PyTorch图像检索实例分析”,在日常操作中,相信很多人在Python Pytorch图像检索实例分析问题上存在疑惑,小编查阅了各式资料,整理出简单好用的操作方法,希望对大家解答”Pyt

这篇文章主要介绍“python PyTorch图像检索实例分析”,在日常操作中,相信很多人在Python Pytorch图像检索实例分析问题上存在疑惑,小编查阅了各式资料,整理出简单好用的操作方法,希望对大家解答”Python Pytorch图像检索实例分析”的疑惑有所帮助!接下来,请跟着小编一起来学习吧!

背景

图像检索的基本本质是根据查询图像的特征从集合数据库中查找图像。

大多数情况下,这种特征是图像之间简单的视觉相似性。在一个复杂的问题中,这种特征可能是两幅图像在风格上的相似性,甚至是互补性。

由于原始形式的图像不会在基于像素的数据中反映这些特征,因此我们需要将这些像素数据转换为一个潜空间,在该空间中,图像的表示将反映这些特征。

一般来说,在潜空间中,任何两个相似的图像都会相互靠近,而不同的图像则会相隔很远。这是我们用来训练我们的模型的基本管理规则。一旦我们这样做,检索部分只需搜索潜在空间,在给定查询图像表示的潜在空间中拾取最近的图像。大多数情况下,它是在最近邻搜索的帮助下完成的。

因此,我们可以将我们的方法分为两部分:

  • 图像表现

  • 搜索

我们将在Oxford 102 Flowers数据集上解决这两个部分。

图像表现

我们将使用一种叫做暹罗模型的东西,它本身并不是一种全新的模型,而是一种训练模型的技术。大多数情况下,这是与triplet loss一起使用的。这个技术的基本组成部分是三元组。

三元组是3个独立的数据样本,比如A(锚点),B(阳性)和C(阴性);其中A和B相似或具有相似的特征(可能是同一类),而C与A和B都不相似。这三个样本共同构成了训练数据的一个单元——三元组。

注:任何图像检索任务的90%都体现在暹罗网络、triplet loss和三元组的创建中。如果你成功地完成了这些,那么整个努力的成功或多或少是有保证的。

首先,我们将创建管道的这个组件——数据。下面我们将在PyTorch中创建一个自定义数据集和数据加载器,它将从数据集中生成三元组。

class TripletData(Dataset):    def __init__(self, path, transfORMs, split="train"):         self.path = path        self.split = split    # train or valid        self.cats = 102       # number of cateGories        self.transforms = transforms             def __getitem__(self, idx):         # our positive class for the triplet        idx = str(idx%self.cats + 1)         # choosing our pair of positive images (im1, im2)        positives = os.listdir(os.path.join(self.path, idx))        im1, im2 = random.sample(positives, 2)         # choosing a negative class and negative image (im3)        negative_cats = [str(x+1) for x in range(self.cats)]        negative_cats.remove(idx)        negative_cat = str(random.choice(negative_cats))        negatives = os.listdir(os.path.join(self.path, negative_cat))         im3 = random.choice(negatives)         im1,im2,im3 = os.path.join(self.path, idx, im1), os.path.join(self.path, idx, im2), os.path.join(self.path, negative_cat, im3)         im1 = self.transforms(Image.open(im1))         im2 = self.transforms(Image.open(im2))         im3 = self.transforms(Image.open(im3))         return [im1, im2, im3]         # we'll put some value that we want since there can be far too many triplets possible    # multiples of the number of images/ number of categories is a good choice    def __len__(self):        return self.cats*8# Transformstrain_transforms = transforms.Compose([    transforms.Resize((224,224)),    transforms.RandomHorizontalFlip(),    transforms.ToTensor(),    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),])val_transforms = transforms.Compose([    transforms.Resize((224, 224)),    transforms.ToTensor(),    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),])# Datasets and Dataloaderstrain_data = TripletData(PATH_TRAIN, train_transforms)val_data = TripletData(PATH_VALID, val_transforms)train_loader = torch.utils.data.DataLoader(dataset = train_data, batch_size=32, shuffle=True, num_workers=4)val_loader = torch.utils.data.DataLoader(dataset = val_data, batch_size=32, shuffle=False, num_workers=4)

现在我们有了数据,让我们转到暹罗网络。

暹罗网络给人的印象是2个或3个模型,但是它本身是一个单一的模型。所有这些模型共享权重,即只有一个模型。

Python Pytorch图像检索实例分析

如前所述,将整个体系结构结合在一起的关键因素是triplet loss。triplet loss产生了一个目标函数,该函数迫使相似输入对(锚点和正)之间的距离小于不同输入对(锚点和负)之间的距离,并限定一定的阈值。

下面我们来看看triplet loss以及训练管道实现。

class TripletLoss(nn.Module):    def __init__(self, margin=1.0):                super(TripletLoss, self).__init__()        self.margin = margin                    def calc_euclidean(self, x1, x2):        return (x1 - x2).pow(2).sum(1)            # Distances in embedding space is calculated in euclidean    def forward(self, anchor, positive, negative):                distance_positive = self.calc_euclidean(anchor, positive)                distance_negative = self.calc_euclidean(anchor, negative)                losses = torch.relu(distance_positive - distance_negative + self.margin)                return losses.mean()       device = 'cuda' # Our base modelmodel = models.resnet18().cuda()optimizer = optim.Adam(model.parameters(), lr=0.001)triplet_loss = TripletLoss() # Trainingfor epoch in range(epochs):        model.train()    epoch_loss = 0.0        for data in tqdm(train_loader):                optimizer.zero_grad()        x1,x2,x3 = data        e1 = model(x1.to(device))        e2 = model(x2.to(device))        e3 = model(x3.to(device))                 loss = triplet_loss(e1,e2,e3)        epoch_loss += loss        loss.backward()        optimizer.step()            print("Train Loss: {}".format(epoch_loss.item()))         class TripletLoss(nn.Module):    def __init__(self, margin=1.0):                super(TripletLoss, self).__init__()        self.margin = margin                    def calc_euclidean(self, x1, x2):        return (x1 - x2).pow(2).sum(1)            # Distances in embedding space is calculated in euclidean    def forward(self, anchor, positive, negative):                distance_positive = self.calc_euclidean(anchor, positive)                distance_negative = self.calc_euclidean(anchor, negative)                losses = torch.relu(distance_positive - distance_negative + self.margin)                return losses.mean()       device = 'cuda'  # Our base modelmodel = models.resnet18().cuda()optimizer = optim.Adam(model.parameters(), lr=0.001)triplet_loss = TripletLoss()  # Trainingfor epoch in range(epochs):    model.train()    epoch_loss = 0.0    for data in tqdm(train_loader):         optimizer.zero_grad()                x1,x2,x3 = data                e1 = model(x1.to(device))        e2 = model(x2.to(device))        e3 = model(x3.to(device))                 loss = triplet_loss(e1,e2,e3)        epoch_loss += loss        loss.backward()        optimizer.step()            print("Train Loss: {}".format(epoch_loss.item()))

到目前为止,我们的模型已经经过训练,可以将图像转换为一个嵌入空间。接下来,我们进入搜索部分。

搜索

我们可以很容易地使用Scikit Learn提供的最近邻搜索。我们将探索新的更好的东西,而不是走简单的路线。

我们将使用Faiss。这比最近的邻居要快得多,如果我们有大量的图像,这种速度上的差异会变得更加明显。

下面我们将演示如何在给定查询图像时,在存储的图像表示中搜索最近的图像。

#!pip install faiss-gpuimport faiss                            faiss_index = faiss.IndexFlatL2(1000)   # build the index # storing the image representationsim_indices = [] with torch.no_grad():    for f in glob.glob(os.path.join(PATH_TRAIN, '*/*')):                im = Image.open(f)        im = im.resize((224,224))        im = torch.tensor([val_transforms(im).numpy()]).cuda()            preds = model(im)        preds = np.array([preds[0].cpu().numpy()])        faiss_index.add(preds) #add the representation to index        im_indices.append(f)   #store the image name to find it later on         # Retrieval with a query imagewith torch.no_grad():    for f in os.listdir(PATH_TEST):                # query/test image        im = Image.open(os.path.join(PATH_TEST,f))        im = im.resize((224,224))        im = torch.tensor([val_transforms(im).numpy()]).cuda()            test_embed = model(im).cpu().numpy()                _, I = faiss_index.search(test_embed, 5)        print("Retrieved Image: {}".format(im_indices[I[0][0]]))

这涵盖了基于现代深度学习的图像检索,但不会使其变得太复杂。大多数检索问题都可以通过这个基本管道解决。

到此,关于“Python Pytorch图像检索实例分析”的学习就结束了,希望能够解决大家的疑惑。理论与实践的搭配能更好的帮助大家学习,快去试试吧!若想继续学习更多相关知识,请继续关注编程网网站,小编会继续努力为大家带来更多实用的文章!

--结束END--

本文标题: Python Pytorch图像检索实例分析

本文链接: https://www.lsjlt.com/news/326622.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

本篇文章演示代码以及资料文档资料下载

下载Word文档到电脑,方便收藏和打印~

下载Word文档
猜你喜欢
  • Python Pytorch图像检索实例分析
    这篇文章主要介绍“Python Pytorch图像检索实例分析”,在日常操作中,相信很多人在Python Pytorch图像检索实例分析问题上存在疑惑,小编查阅了各式资料,整理出简单好用的操作方法,希望对大家解答”Pyt...
    99+
    2023-06-29
  • Python Pytorch学习之图像检索实践
    目录背景图像表现搜索随着电子商务和在线网站的出现,图像检索在我们的日常生活中的应用一直在增加。 亚马逊、阿里巴巴、Myntra等公司一直在大量利用图像检索技术。当然,只有当通常的信息...
    99+
    2022-11-10
  • python中pytorch图像识别的示例分析
    这篇文章将为大家详细讲解有关python中pytorch图像识别的示例分析,小编觉得挺实用的,因此分享给大家做个参考,希望大家阅读完这篇文章后可以有所收获。一、数据集爬取现在的深度学习对数据集量的需求越来越大了,也有了许多现成的数据集可供大...
    99+
    2023-06-29
  • python OpenCV图像金字塔实例分析
    这篇“python OpenCV图像金字塔实例分析”文章的知识点大部分人都不太理解,所以小编给大家总结了以下内容,内容详细,步骤清晰,具有一定的借鉴价值,希望大家阅读完这篇文章能有所收获,下面我们一起来看看这篇“python&nb...
    99+
    2023-07-02
  • python中的opencv图像梯度实例分析
    本文小编为大家详细介绍“python中的opencv图像梯度实例分析”,内容详细,步骤清晰,细节处理妥当,希望这篇“python中的opencv图像梯度实例分析”文章能帮助大家解决疑惑,下面跟着小编的思路慢慢深入,一起来学习新知识吧。图像梯...
    99+
    2023-06-30
  • matlab图像滤波实例分析
    这篇“matlab图像滤波实例分析”文章的知识点大部分人都不太理解,所以小编给大家总结了以下内容,内容详细,步骤清晰,具有一定的借鉴价值,希望大家阅读完这篇文章能有所收获,下面我们一起来看看这篇“matlab图像滤波实例分析”文章吧。mat...
    99+
    2023-07-05
  • Python-OpenCV实现图像缺陷检测的实例
    目录1.实现代码2.运行结果在Jupyter Notebook上使用Python+opencv实现如下图像缺陷检测。关于opencv库的安装可以参考:Python下opencv库的安...
    99+
    2022-11-12
  • Python OpenCV图像识别的示例分析
    小编给大家分享一下Python OpenCV图像识别的示例分析,希望大家阅读完这篇文章之后都有所收获,下面让我们一起去探讨吧!一、人脸识别主要有以下两种实现方法:哈尔(Haar)级联法:专门解决人脸识别而推出的传统算法;实现步骤:...
    99+
    2023-06-29
  • Matlab图像处理的公路裂缝案例检测分析
    本篇内容主要讲解“Matlab图像处理的公路裂缝案例检测分析”,感兴趣的朋友不妨来看看。本文介绍的方法操作简单快捷,实用性强。下面就让小编来带大家学习“Matlab图像处理的公路裂缝案例检测分析”吧!一、简介1 案例背景随着国家对公路建设的...
    99+
    2023-06-29
  • Python基于Pytorch特征图提取的示例分析
    这篇文章给大家分享的是有关Python基于Pytorch特征图提取的示例分析的内容。小编觉得挺实用的,因此分享给大家做个参考,一起跟随小编过来看看吧。简述为了方便理解卷积神经网络的运行过程,需要对卷积神经网络的运行结果进行可视化的展示。大致...
    99+
    2023-06-29
  • Python深度学习pytorch实现图像分类数据集
    目录读取数据集读取小批量整合所有组件目前广泛使用的图像分类数据集之一是MNIST数据集。如今,MNIST数据集更像是一个健全的检查,而不是一个基准。 为了提高难度,我们将在接下来的章...
    99+
    2022-11-12
  • python实现人脸检测的实例分析
    这篇文章主要介绍“python实现人脸检测的实例分析”的相关知识,小编通过实际案例向大家展示操作过程,操作方法简单快捷,实用性强,希望这篇“python实现人脸检测的实例分析”文章能帮助大家解决问题。OpenCVOpenCV 是计算机视觉领...
    99+
    2023-06-29
  • Python图算法实例分析
    本文实例讲述了Python图算法。分享给大家供大家参考,具体如下: #encoding=utf-8 import networkx,heapq,sys from matplotlib import py...
    99+
    2022-06-04
    算法 实例 Python
  • Python Matplotlib绘图实例分析
    这篇文章主要介绍“Python Matplotlib绘图实例分析”,在日常操作中,相信很多人在Python Matplotlib绘图实例分析问题上存在疑惑,小编查阅了各式资料,整理出简单好用的操作方法,希望对大家解答”P...
    99+
    2023-07-02
  • Python中图像量化处理的示例分析
    小编给大家分享一下Python中图像量化处理的示例分析,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下面让我们一起去了解一下吧!一.图像量化处理原理量化(Quantization)旨在将图像...
    99+
    2023-06-29
  • opencv 案例05-基于二值图像分析(简单缺陷检测)
    缺陷检测,分为两个部分,一个部分是提取指定的轮廓,第二个部分通过对比实现划痕检测与缺角检测。本次主要搞定第一部分,学会观察图像与提取图像ROI对象轮廓外接矩形与轮廓。 下面是基于二值图像分析的大致流程 读取图像将图像转换为灰度图,并对其进行...
    99+
    2023-08-30
    opencv 人工智能 计算机视觉 目标检测
  • Python+OpenCV图像处理之直方图统计的示例分析
    这篇文章主要为大家展示了“Python+OpenCV图像处理之直方图统计的示例分析”,内容简而易懂,条理清晰,希望能够帮助大家解决疑惑,下面让小编带领大家一起研究并学习一下“Python+OpenCV图像处理之直方图统计的示例分析”这篇文章...
    99+
    2023-06-22
  • python opencv图像处理基本操作的示例分析
    本篇文章给大家分享的是有关python opencv图像处理基本操作的示例分析,小编觉得挺实用的,因此分享给大家学习,希望大家阅读完这篇文章后可以有所收获,话不多说,跟着小编一起来看看吧。①读取图像②显示图像该函数中,name是显示窗口的名...
    99+
    2023-06-25
  • python数据分析绘图可视化实例分析
    本篇内容介绍了“python数据分析绘图可视化实例分析”的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!前言:数据分析初始阶段,通常都要进行可视...
    99+
    2023-07-02
  • Python中图像算术与逻辑运算的示例分析
    小编给大家分享一下Python中图像算术与逻辑运算的示例分析,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下面让我们一起去了解一下吧!一.图像加法运算图像加法运算主要有两种方法。第一种是调用...
    99+
    2023-06-29
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作