广告
返回顶部
首页 > 资讯 > 后端开发 > Python >Python实现图像去雾效果的示例代码
  • 103
分享到

Python实现图像去雾效果的示例代码

2024-04-02 19:04:59 103人浏览 八月长安

Python 官方文档:入门教程 => 点击学习

摘要

目录修改部分训练测试数据集下载地址修改部分 我利用该代码进行了去雾任务,并对原始代码进行了增删,去掉了人脸提取并对提取人脸美化的部分,如下图 增改了一些数据处理代码,Create_

修改部分

我利用该代码进行了去雾任务,并对原始代码进行了增删,去掉了人脸提取并对提取人脸美化的部分,如下图

增改了一些数据处理代码,Create_Bigfile2.py和Load_Bigfilev2为特定任务需要加的代码,这里数据处理用的是原始方法,即将训练数据打包成一个文件,一次性载入,可能会内存爆炸。去雾的如下

另外,为了节省内存,可以不使用原始方法,我改写了online_dataset_for_odl_photos.py文件

用于我的加雾论文,此时可以不使用原始的Create_Bigfile和Load_bigfile代码如下

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
 
import os.path
import io
import zipfile
from data.base_dataset import BaseDataset, get_params, get_transfORM, normalize
from data.image_folder import make_dataset
from data.Load_Bigfile import BigFileMemoryLoader
import torchvision.transforms as tfs
from torchvision.transforms import functional as FF
 
 
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
 
import random
import cv2
from io import BytesIO
 
#图片转矩阵
def pil_to_np(img_PIL):
    '''Converts image in PIL format to np.array.
    From W x H x C [0...255] to C x W x H [0..1]
    '''
    ar = np.array(img_PIL)
 
    if len(ar.shape) == 3:
        ar = ar.transpose(2, 0, 1)
    else:
        ar = ar[None, ...]
 
    return ar.astype(np.float32) / 255.
 
#矩阵转图片
def np_to_pil(img_np):
    '''Converts image in np.array format to PIL image.
    From C x W x H [0..1] to  W x H x C [0...255]
    '''
    ar = np.clip(img_np * 255, 0, 255).astype(np.uint8)
 
    if img_np.shape[0] == 1:
        ar = ar[0]
    else:
        ar = ar.transpose(1, 2, 0)
 
    return Image.fromarray(ar)
##
#以下合成噪声图片
##
def synthesize_salt_pepper(image,amount,salt_vs_pepper):
 
    ## Give PIL, return the noisy PIL
 
    img_pil=pil_to_np(image)
 
    out = img_pil.copy()
    p = amount
    q = salt_vs_pepper
    flipped = np.random.choice([True, False], size=img_pil.shape,
                               p=[p, 1 - p])
    salted = np.random.choice([True, False], size=img_pil.shape,
                              p=[q, 1 - q])
    peppered = ~salted
    out[flipped & salted] = 1
    out[flipped & peppered] = 0.
    noisy = np.clip(out, 0, 1).astype(np.float32)
 
 
    return np_to_pil(noisy)
 
def synthesize_gaussian(image,std_l,std_r):
 
    ## Give PIL, return the noisy PIL
 
    img_pil=pil_to_np(image)
 
    mean=0
    std=random.uniform(std_l/255.,std_r/255.)
    gauss=np.random.normal(loc=mean,scale=std,size=img_pil.shape)
    noisy=img_pil+gauss
    noisy=np.clip(noisy,0,1).astype(np.float32)
 
    return np_to_pil(noisy)
 
def synthesize_speckle(image,std_l,std_r):
 
    ## Give PIL, return the noisy PIL
 
    img_pil=pil_to_np(image)
 
    mean=0
    std=random.uniform(std_l/255.,std_r/255.)
    gauss=np.random.normal(loc=mean,scale=std,size=img_pil.shape)
    noisy=img_pil+gauss*img_pil
    noisy=np.clip(noisy,0,1).astype(np.float32)
 
    return np_to_pil(noisy)
 
#图片缩小
def synthesize_low_resolution(img):
    w,h=img.size
 
    new_w=random.randint(int(w/2),w)
    new_h=random.randint(int(h/2),h)
 
    img=img.resize((new_w,new_h),Image.BICUBIC)
 
    if random.uniform(0,1)<0.5:
        img=img.resize((w,h),Image.NEAREST)
    else:
        img = img.resize((w, h), Image.BILINEAR)
 
    return img
 
#处理图片
def convertToJpeg(im,quality):
    #在内存中读写bytes
    with BytesIO() as f:
        im.save(f, format='JPEG',quality=quality)
        f.seek(0)
        #使用Image.open读出图像,然后转换为RGB通道,去掉透明通道A
        return Image.open(f).convert('RGB')
 
#由(高斯)噪声生成图片
def blur_image_v2(img):
 
 
    x=np.array(img)
    kernel_size_candidate=[(3,3),(5,5),(7,7)]
    kernel_size=random.sample(kernel_size_candidate,1)[0]
    std=random.uniform(1.,5.)
 
    #print("The gaussian kernel size: (%d,%d) std: %.2f"%(kernel_size[0],kernel_size[1],std))
    blur=cv2.GaussianBlur(x,kernel_size,std)
 
    return Image.fromarray(blur.astype(np.uint8))
#由以上噪声函数随机生成含有噪声的图片
def online_add_degradation_v2(img):
 
    task_id=np.random.permutation(4)
 
    for x in task_id:
        if x==0 and random.uniform(0,1)<0.7:
            img = blur_image_v2(img)
        if x==1 and random.uniform(0,1)<0.7:
            flag = random.choice([1, 2, 3])
            if flag == 1:
                img = synthesize_gaussian(img, 5, 50)
            if flag == 2:
                img = synthesize_speckle(img, 5, 50)
            if flag == 3:
                img = synthesize_salt_pepper(img, random.uniform(0, 0.01), random.uniform(0.3, 0.8))
        if x==2 and random.uniform(0,1)<0.7:
            img=synthesize_low_resolution(img)
 
        if x==3 and random.uniform(0,1)<0.7:
            img=convertToJpeg(img,random.randint(40,100))
 
    return img
 
#根据mask生成带有折痕的图片
#原论文中对于一些复杂的折痕会出现处理不佳的情况,在此进行改进,而不是简单进行加mask,
def irregular_hole_synthesize(img,mask):
 
    img_np=np.array(img).astype('uint8')
    mask_np=np.array(mask).astype('uint8')
    mask_np=mask_np/255
    img_new=img_np*(1-mask_np)+mask_np*255
 
 
    hole_img=Image.fromarray(img_new.astype('uint8')).convert("RGB")
    #L为灰度图像
    return hole_img,mask.convert("L")
#生成全黑三通道图像mask
def zero_mask(size):
    x=np.zeros((size,size,3)).astype('uint8')
    mask=Image.fromarray(x).convert("RGB")
    return mask
#########################################  my  ################################
 
class UnPairOldPhotos_SRv2(BaseDataset):  ## Synthetic + Real Old
    def initialize(self, opt):
        self.opt = opt
        self.isImage = 'domainA' in opt.name
        self.task = 'old_photo_restoration_training_vae'
        self.dir_AB = opt.dataroot
        # 载入VOC以及真实灰度、彩色图
        #dominA
        if self.isImage:
            path_clear = r'/home/vip/shy/ots/clear_images/' ##self.opt.path_clear
            path_old = r'/home/vip/shy/Bringing-Old-Photos-Back-to-Life_v1/voc2007/Real_RGB_old' ##self.opt.path_old
            path_haze = r'/home/vip/shy/ots/hazy/' ##self.opt.path_haze
            #self.load_img_dir_L_old=os.path.join(self.dir_AB,"Real_L_old.bigfile")
            self.load_img_dir_RGB_old=path_old
            self.load_img_dir_clean=path_clear
            self.load_img_dir_Synhaze=path_haze
 
            self.img_dir_Synhaze = os.listdir(self.load_img_dir_Synhaze)
            self.loaded_imgs_Synhaze=[os.path.join(self.load_img_dir_Synhaze,img) for img in self.img_dir_Synhaze]
            self.img_dir_RGB_old = os.listdir(self.load_img_dir_RGB_old)
            self.loaded_imgs_RGB_old = [os.path.join(self.load_img_dir_RGB_old,img) for img in self.img_dir_RGB_old]
            self.loaded_imgs_clean = []
            for path_i in self.loaded_imgs_Synhaze:
                    p,n = os.path.split(path_i)
                    pre,ex = os.path.splitext(n)
                    clear_pre = pre.split('_')[0]
                    clear_path = os.path.join(path_clear,clear_pre+ex)
                    self.loaded_imgs_clean.append(clear_path)
            print('________________filter whose size <256')
            self.filtered_imgs_clean = []
            self.filtered_imgs_Synhaze = []
            self.filtered_imgs_old = []
            print('________________now filter syn and clean size <256')
            for i in range(len(self.loaded_imgs_Synhaze)):
                img_name_syn = self.loaded_imgs_Synhaze[i]
                img = Image.open(img_name_syn)
                h, w = img.size
                img_name_clear = self.loaded_imgs_clean[i]
                if h < 256 or w < 256:
                    continue
                self.filtered_imgs_clean.append(img_name_clear)
                self.filtered_imgs_Synhaze.append(img_name_syn)
            print('________________now filter old size <256')
            for i in range(len(self.loaded_imgs_RGB_old)):
                img_name_old = self.loaded_imgs_RGB_old[i]
                img = Image.open(img_name_old)
                h, w = img.size
                if h < 256 or w < 256:
                    continue
                self.filtered_imgs_old.append(img_name_old)
 
        #dominB: if dominA not in experiment's name ,load VOC defultly
        else:
            path_clear = r'/home/vip/shy/ots/clear_images/' ##self.opt.path_clear
            self.load_img_dir_clean=path_clear
            self.loaded_imgs_clean = []
            self.img_dir_clean = os.listdir(self.load_img_dir_clean)
            self.loaded_imgs_clean = [os.path.join(self.load_img_dir_clean, img) for img in self.img_dir_clean]
            print('________________now filter old size <256')
            self.filtered_imgs_clean = []
            for i in range(len(self.loaded_imgs_clean)):
                img_name_clean = self.loaded_imgs_clean[i]
                img = Image.open(img_name_clean)
                h, w = img.size
                if h < 256 or w < 256:
                    continue
                self.filtered_imgs_clean.append(img_name_clean)
        ####
        print("-------------Filter the imgs whose size <256 finished -------------")
 
        self.pid = os.getpid()
 
    def __getitem__(self, index):
 
 
        is_real_old=0
 
        sampled_dataset=None
        degradation=None
        #随机抽取一张图片(从合成的老照片 和 真实老照片 中)
        if self.isImage: ## domain A , contains 2 kinds of data: synthetic + real_old
            P=random.uniform(0,2)
            if P>=0 and P<1:
                sampled_dataset=self.filtered_imgs_old
                self.load_img_dir=self.load_img_dir_RGB_old
                self.Num = len(sampled_dataset)
                is_real_old=1
            if P>=1 and P<2:
                sampled_dataset=self.filtered_imgs_Synhaze
                self.load_img_dir=self.load_img_dir_Synhaze
                self.Num = len(sampled_dataset)
                degradation=1
        #domin B
        else:
            #载入过滤后小于256大小的图
            sampled_dataset=self.filtered_imgs_clean
            self.load_img_dir=self.load_img_dir_clean
            self.Num = len(sampled_dataset)
 
        index=random.randint(0,self.Num-1)
        img_name = sampled_dataset[index]
        A = Image.open(img_name)
        path = img_name
        #########################################################################
        # i, j, h, w = tfs.RandomCrop.get_params(A, output_size=(256, 256))
        # A = FF.crop(A, i, j, h, w)
        # A = A.convert("RGB")
        # A_tensor = #tfs.ToTensor()(A)
        #########################################################################
        transform_params = get_params(self.opt, A.size)
        A_transform = get_transform(self.opt, transform_params)
        A_tensor = A_transform(A.convert("RGB"))
 
        B_tensor = inst_tensor = feat_tensor = 0
        input_dict = {'label': A_tensor, 'inst': is_real_old, 'image': A_tensor,
                        'feat': feat_tensor, 'path': path}
        return input_dict
 
    def __len__(self):
        return  len(self.filtered_imgs_clean)## actually, this is useless, since the selected index is just a random number
                                        #control the epoch through the iters =len(loaded_imgs_clean)
    def name(self):
        return 'UnPairOldPhotos_SR'
 
 
 
 
# ###################################################################################3
# #非成对的老照片图像载入器(合成的老的和真实的老的照片,他们并非对应的,合成的老的照片由VOC数据集经处理生成)
# class UnPairOldPhotos_SR(BaseDataset):  ## Synthetic + Real Old
#     def initialize(self, opt):
#         self.opt = opt
#         self.isImage = 'domainA' in opt.name
#         self.task = 'old_photo_restoration_training_vae'
#         self.dir_AB = opt.dataroot
#         # 载入VOC以及真实灰度、彩色图
#         #dominA
#         if self.isImage:
#
#             #self.load_img_dir_L_old=os.path.join(self.dir_AB,"Real_L_old.bigfile")
#             self.load_img_dir_RGB_old=os.path.join(self.dir_AB,"Real_RGB_old.bigfile")
#             self.load_img_dir_clean=os.path.join(self.dir_AB,"VOC_RGB_JPEGImages.bigfile")
#             self.load_img_dir_Synhaze=os.path.join(self.dir_AB,"VOC_RGB_Synhaze.bigfile")
#
#             #self.loaded_imgs_L_old=BigFileMemoryLoader(self.load_img_dir_L_old)
#             self.loaded_imgs_RGB_old=BigFileMemoryLoader(self.load_img_dir_RGB_old)
#             self.loaded_imgs_clean=BigFileMemoryLoader(self.load_img_dir_clean)
#             self.loaded_imgs_Synhaze=BigFileMemoryLoader(self.load_img_dir_Synhaze)
#
#         #dominB: if dominA not in experiment's name ,load VOC defultly
#         else:
#             # self.load_img_dir_clean=os.path.join(self.dir_AB,self.opt.test_dataset)
#             self.load_img_dir_clean=os.path.join(self.dir_AB,"VOC_RGB_JPEGImages.bigfile")
#             self.loaded_imgs_clean=BigFileMemoryLoader(self.load_img_dir_clean)
#             self.load_img_dir_Synhaze=os.path.join(self.dir_AB,"VOC_RGB_Synhaze.bigfile")
#             self.loaded_imgs_Synhaze=BigFileMemoryLoader(self.load_img_dir_Synhaze)
#
#         ####
#         print("-------------Filter the imgs whose size <256 in VOC-------------")
#         self.filtered_imgs_clean=[]
#         self.filtered_imgs_Synhaze=[]
#
#         # 过滤出VOC中小于256的图片
#         for i in range(len(self.loaded_imgs_clean)):
#             img_name,img=self.loaded_imgs_clean[i]
#             synimg_name,synimg=self.loaded_imgs_Synhaze[i]
#
#             h,w=img.size
#             if h<256 or w<256:
#                 continue
#             self.filtered_imgs_clean.append((img_name,img))
#             self.filtered_imgs_Synhaze.append((synimg_name,synimg))
#
#
#         print("--------Origin image num is [%d], filtered result is [%d]--------" % (
#         len(self.loaded_imgs_clean), len(self.filtered_imgs_clean)))
#         ## Filter these images whose size is less than 256
#
#         # self.img_list=os.listdir(load_img_dir)
#         self.pid = os.getpid()
#
#     def __getitem__(self, index):
#
#
#         is_real_old=0
#
#         sampled_dataset=None
#         degradation=None
#         #随机抽取一张图片(从合成的老照片 和 真实老照片 中)
#         if self.isImage: ## domain A , contains 2 kinds of data: synthetic + real_old
#             P=random.uniform(0,2)
#             if P>=0 and P<1:
#                 if random.uniform(0,1)<0.5:
#                     # sampled_dataset=self.loaded_imgs_L_old
#                     # self.load_img_dir=self.load_img_dir_L_old
#
#                     sampled_dataset=self.loaded_imgs_RGB_old
#                     self.load_img_dir=self.load_img_dir_RGB_old
#                 else:
#                     sampled_dataset=self.loaded_imgs_RGB_old
#                     self.load_img_dir=self.load_img_dir_RGB_old
#                 is_real_old=1
#             if P>=1 and P<2:
#                 sampled_dataset=self.filtered_imgs_Synhaze
#                 self.load_img_dir=self.load_img_dir_Synhaze
#
#                 degradation=1
#         #domin B
#         else:
#             #载入过滤后小于256大小的图
#             sampled_dataset=self.filtered_imgs_clean
#             self.load_img_dir=self.load_img_dir_clean
#
#         sampled_dataset_len=len(sampled_dataset)
#
#         index=random.randint(0,sampled_dataset_len-1)
#
#         img_name,img = sampled_dataset[index]
#
#         #already old
#         #if degradation is not None:
#         #    #对图片进行降质做旧处理
#         #    img=online_add_degradation_v2(img)
#
#         path=os.path.join(self.load_img_dir,img_name)
#
#         # AB = Image.open(path).convert('RGB')
#         # split AB image into A and B
#
#         # apply the same transform to both A and B
#         #随机对图片转换为灰度图
#         if random.uniform(0,1) <0.1:
#             img=img.convert("L")
#             img=img.convert("RGB")
#             ## Give a probability P, we convert the RGB image into L
#
#         #调整大小
#         A=img
#         w,h=A.size
#         if w<256 or h<256:
#             A=transforms.Scale(256,Image.BICUBIC)(A)
#         # 将图片裁剪为256*256,对于一些小于256的老照片,先进行调整大小
#         ## Since we want to only crop the images (256*256), for those old photos whose size is smaller than 256, we first resize them.
#         transform_params = get_params(self.opt, A.size)
#         A_transform = get_transform(self.opt, transform_params)
#
#         B_tensor = inst_tensor = feat_tensor = 0
#         A_tensor = A_transform(A)
#
#         #存入字典
#         #A_tensor  :     old or Syn imgtensor;
#         #is_real_old:     1:old ; 0:Syn
#         #feat       :     0
#         input_dict = {'label': A_tensor, 'inst': is_real_old, 'image': A_tensor,
#                         'feat': feat_tensor, 'path': path}
#         return input_dict
#
#     def __len__(self):
#         return len(self.loaded_imgs_clean) ## actually, this is useless, since the selected index is just a random number
#
#     def name(self):
#         return 'UnPairOldPhotos_SR'
#################################    my   ####################        if self.isImage:
#成对图像载入器(原始图及其合成旧图)
# mapping
class PairOldPhotosv2(BaseDataset):
    def initialize(self, opt):
        self.opt = opt
        self.isImage = 'imagan' in opt.name #actually ,useless ;
        self.task = 'old_photo_restoration_training_mapping'
        self.dir_AB = opt.dataroot
        #训练模式,载入
        if opt.isTrain:
            path_clear = r'/home/vip/shy/ots/clear_images/'
            path_haze = r'/home/vip/shy/ots/hazy/'
            self.load_img_dir_clean=path_clear
            self.load_img_dir_Synhaze=path_haze
 
            self.img_dir_Synhaze = os.listdir(self.load_img_dir_Synhaze)
            self.loaded_imgs_Synhaze=[os.path.join(self.load_img_dir_Synhaze,img) for img in self.img_dir_Synhaze]
            self.loaded_imgs_clean = []
            for path_i in self.loaded_imgs_Synhaze:
                    p,n = os.path.split(path_i)
                    pre,ex = os.path.splitext(n)
                    clear_pre = pre.split('_')[0]
                    clear_path = os.path.join(path_clear,clear_pre+ex)
                    self.loaded_imgs_clean.append(clear_path)
            print('________________filter whose size <256')
            self.filtered_imgs_clean = []
            self.filtered_imgs_Synhaze = []
            print('________________now filter syn and clean size <256')
            for i in range(len(self.loaded_imgs_Synhaze)):
                img_name_syn = self.loaded_imgs_Synhaze[i]
                img = Image.open(img_name_syn)
                h, w = img.size
                img_name_clear = self.loaded_imgs_clean[i]
                if h < 256 or w < 256:
                    continue
                self.filtered_imgs_clean.append(img_name_clear)
                self.filtered_imgs_Synhaze.append(img_name_syn)
 
            print("--------Origin image num is [%d], filtered result is [%d]--------" % (
            len(self.loaded_imgs_clean), len(self.filtered_imgs_clean)))
        #测试模式时,仅载入测试集
        else:
            if self.opt.test_on_synthetic:
                ############valset#########
                path_val_clear = r'/home/vip/shy/SOTS/outdoor/gt' ######none###############self.opt.path_clear
                path_val_haze = r'/home/vip/shy/SOTS/outdoor/hazy' #########none#############self.opt.path_haze
                self.load_img_dir_clean = path_val_clear
                self.load_img_dir_Synhaze = path_val_haze
 
                self.img_dir_Synhaze = os.listdir(self.load_img_dir_Synhaze)
                self.loaded_imgs_Synhaze = [os.path.join(self.load_img_dir_Synhaze, img) for img in
                                            self.img_dir_Synhaze]
                self.loaded_imgs_clean = []
                for path_i in self.loaded_imgs_Synhaze:
                    p, n = os.path.split(path_i)
                    pre, ex = os.path.splitext(n)
                    clear_pre = pre.split('_')[0]
                    clear_path = os.path.join(self.load_img_dir_clean, clear_pre + ex)
                    self.loaded_imgs_clean.append(clear_path)
                print('________________filter whose size <256')
                self.filtered_val_imgs_clean = []
                self.filtered_val_imgs_Synhaze = []
                print('________________now filter val syn and clean size <256')
                for i in range(len(self.loaded_imgs_Synhaze)):
                    img_name_syn = self.loaded_imgs_Synhaze[i]
                    img = Image.open(img_name_syn)
                    h, w = img.size
                    img_name_clear = self.loaded_imgs_clean[i]
                    if h < 256 or w < 256:
                        continue
                    self.filtered_val_imgs_clean.append(img_name_clear)
                    self.filtered_val_imgs_Synhaze.append(img_name_syn)
                print('________________finished filter val syn and clean ')
 
            else:
                ############testset#########
                path_test_clear = r'/home/vip/shy/SOTS/outdoor/gt' ##################self.opt.path_test_clear
                path_test_haze = r'/home/vip/shy/SOTS/outdoor/hazy' ###################self.opt.path_test_haze
                self.load_img_dir_clean=path_test_clear
                self.load_img_dir_Synhaze=path_test_haze
 
                self.img_dir_Synhaze = os.listdir(self.load_img_dir_Synhaze)
                self.loaded_imgs_Synhaze=[os.path.join(self.load_img_dir_Synhaze,img) for img in self.img_dir_Synhaze]
                self.loaded_imgs_clean = []
                for path_i in self.loaded_imgs_Synhaze:
                        p,n = os.path.split(path_i)
                        pre,ex = os.path.splitext(n)
                        clear_pre = pre.split('_')[0]
                        clear_path = os.path.join(self.load_img_dir_clean,clear_pre+ex)
                        self.loaded_imgs_clean.append(clear_path)
                print('________________filter whose size <256')
                self.filtered_test_imgs_clean = []
                self.filtered_test_imgs_Synhaze = []
                print('________________now filter testset syn and clean size <256')
                for i in range(len(self.loaded_imgs_Synhaze)):
                    img_name_syn = self.loaded_imgs_Synhaze[i]
                    img = Image.open(img_name_syn)
                    h, w = img.size
                    img_name_clear = self.loaded_imgs_clean[i]
                    if h < 256 or w < 256:
                        continue
                    self.filtered_test_imgs_clean.append(img_name_clear)
                    self.filtered_test_imgs_Synhaze.append(img_name_syn)
                print('________________finished filter testset syn and clean ')
 
            print("--------Origin image num is [%d], filtered result is [%d]--------" % (
            len(self.loaded_imgs_Synhaze), len(self.filtered_test_imgs_Synhaze)))
 
 
        self.pid = os.getpid()
 
    def __getitem__(self, index):
 
 
        #训练模式
        if self.opt.isTrain:
            #(B为清晰VOC数据集)
            img_name_clean = self.filtered_imgs_clean[index]
            B = Image.open(img_name_clean)
            img_name_synhaze = self.filtered_imgs_Synhaze[index]
            S = Image.open(img_name_synhaze)
            path = os.path.join(img_name_clean)
            #生成成对图像(B为清晰VOC数据集,A对应的含噪声的图像)
            A=S
 
        ### Remind: A is the input and B is corresponding GT
        #ceshi daima wei xiugai #####################################################
        else:
            #测试模式
            #(B为清晰VOC数据集,A对应的含噪声的图像)
 
            if self.opt.test_on_synthetic:
                #valset
                img_name_B = self.filtered_test_imgs_clean[index]
                B = Image.open(img_name_B)
                img_name_A=self.filtered_test_imgs_Synhaze[index]
                A = Image.open(img_name_A)
                path = os.path.join(img_name_A)
            else:
                #testset
                img_name_B = self.filtered_val_imgs_clean[index]
                B = Image.open(img_name_B)
                img_name_A=self.filtered_val_imgs_Synhaze[index]
                A = Image.open(img_name_A)
                path = os.path.join(img_name_A)
 
        #去掉透明通道
        # if random.uniform(0,1)<0.1 and self.opt.isTrain:
        #     A=A.convert("L")
        #     B=B.convert("L")
        A=A.convert("RGB")
        B=B.convert("RGB")
 
        # apply the same transform to both A and B
        #获取变换相关参数test_dataset
        transform_params = get_params(self.opt, A.size)
        #变换数据,数据增强
        A_transform = get_transform(self.opt, transform_params)
        B_transform = get_transform(self.opt, transform_params)
 
        B_tensor = inst_tensor = feat_tensor = 0
        A_tensor = A_transform(A)
        B_tensor = B_transform(B)
 
        # input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor,
        #             'feat': feat_tensor, 'path': path}
        input_dict = {'label': B_tensor, 'inst': inst_tensor, 'image': A_tensor,
                    'feat': feat_tensor, 'path': path}
 
        return input_dict
 
    def __len__(self):
 
        if self.opt.isTrain:
            return len(self.filtered_imgs_clean)
        else:
            return len(self.filtered_test_imgs_clean)
 
    def name(self):
        return 'PairOldPhotos'
#
 
 
#
#
# #成对图像载入器(原始图及其合成旧图)
# # mapping
# class PairOldPhotos(BaseDataset):
#     def initialize(self, opt):
#         self.opt = opt
#         self.isImage = 'imagan' in opt.name #actually ,useless ;
#         self.task = 'old_photo_restoration_training_mapping'
#         self.dir_AB = opt.dataroot
#         #训练模式,载入VOC
#         if opt.isTrain:
#             self.load_img_dir_clean= os.path.join(self.dir_AB, "VOC_RGB_JPEGImages.bigfile")
#             self.loaded_imgs_clean = BigFileMemoryLoader(self.load_img_dir_clean)
#
#             self.load_img_dir_Synhaze= os.path.join(self.dir_AB, "VOC_RGB_Synhaze.bigfile")
#             self.loaded_imgs_Synhaze = BigFileMemoryLoader(self.load_img_dir_Synhaze)
#
#             print("-------------Filter the imgs whose size <256 in VOC-------------")
#             #过滤出VOC中小于256的图片
#             self.filtered_imgs_clean = []
#             self.filtered_imgs_Synhaze = []
#
#             for i in range(len(self.loaded_imgs_clean)):
#                 img_name, img = self.loaded_imgs_clean[i]
#                 synhazeimg_name, synhazeimg = self.loaded_imgs_clean[i]
#
#                 h, w = img.size
#                 if h < 256 or w < 256:
#                     continue
#                 self.filtered_imgs_clean.append((img_name, img))
#                 self.filtered_imgs_Synhaze.append((synhazeimg_name, synhazeimg))
#
#             print("--------Origin image num is [%d], filtered result is [%d]--------" % (
#             len(self.loaded_imgs_clean), len(self.filtered_imgs_clean)))
#         #测试模式时,仅载入测试集
#         else:
#             self.load_img_dir=os.path.join(self.dir_AB,opt.test_dataset)
#             self.loaded_imgs=BigFileMemoryLoader(self.load_img_dir)
#
#         self.pid = os.getpid()
#
#     def __getitem__(self, index):
#
#
#         #训练模式
#         if self.opt.isTrain:
#             #(B为清晰VOC数据集)
#             img_name_clean,B = self.filtered_imgs_clean[index]
#             img_name_synhaze,S = self.filtered_imgs_Synhaze[index]
#
#             path = os.path.join(self.load_img_dir_clean, img_name_clean)
#             #生成成对图像(B为清晰VOC数据集,A对应的含噪声的图像)
#             if self.opt.use_v2_degradation:
#                 A=S
#             ### Remind: A is the input and B is corresponding GT
#         #ceshi daima wei xiugai #####################################################
#         else:
#             #测试模式
#             #(B为清晰VOC数据集,A对应的含噪声的图像)
#             if self.opt.test_on_synthetic:
#
#                 img_name_B,B=self.loaded_imgs[index]
#                 A=online_add_degradation_v2(B)
#                 img_name_A=img_name_B
#                 path = os.path.join(self.load_img_dir, img_name_A)
#             else:
#                 img_name_A,A=self.loaded_imgs[index]
#                 img_name_B,B=self.loaded_imgs[index]
#                 path = os.path.join(self.load_img_dir, img_name_A)
#
#         #去掉透明通道
#         if random.uniform(0,1)<0.1 and self.opt.isTrain:
#             A=A.convert("L")
#             B=B.convert("L")
#             A=A.convert("RGB")
#             B=B.convert("RGB")
#         ## In P, we convert the RGB into L
#
#
#         ##test on L
#
#         # split AB image into A and B
#         # w, h = img.size
#         # w2 = int(w / 2)
#         # A = img.crop((0, 0, w2, h))
#         # B = img.crop((w2, 0, w, h))
#         w,h=A.size
#         if w<256 or h<256:
#             A=transforms.Scale(256,Image.BICUBIC)(A)
#             B=transforms.Scale(256, Image.BICUBIC)(B)
#
#         # apply the same transform to both A and B
#         #获取变换相关参数
#         transform_params = get_params(self.opt, A.size)
#         #变换数据,数据增强
#         A_transform = get_transform(self.opt, transform_params)
#         B_transform = get_transform(self.opt, transform_params)
#
#         B_tensor = inst_tensor = feat_tensor = 0
#         A_tensor = A_transform(A)
#         B_tensor = B_transform(B)
#
#         input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor,
#                     'feat': feat_tensor, 'path': path}
#         return input_dict
#
#     def __len__(self):
#
#         if self.opt.isTrain:
#             return len(self.filtered_imgs_clean)
#         else:
#             return len(self.loaded_imgs)
#
#     def name(self):
#         return 'PairOldPhotos'
# #####################################################################
# #成对带折痕图像载入器
# class PairOldPhotos_with_hole(BaseDataset):
#     def initialize(self, opt):
#         self.opt = opt
#         self.isImage = 'imagegan' in opt.name
#         self.task = 'old_photo_restoration_training_mapping'
#         self.dir_AB = opt.dataroot
#         #训练模式下,载入成对的带有裂痕的合成图片
#         if opt.isTrain:
#             self.load_img_dir_clean= os.path.join(self.dir_AB, "VOC_RGB_JPEGImages.bigfile")
#             self.loaded_imgs_clean = BigFileMemoryLoader(self.load_img_dir_clean)
#
#             print("-------------Filter the imgs whose size <256 in VOC-------------")
#             #过滤出大小小于256的图片
#             self.filtered_imgs_clean = []
#             for i in range(len(self.loaded_imgs_clean)):
#                 img_name, img = self.loaded_imgs_clean[i]
#                 h, w = img.size
#                 if h < 256 or w < 256:
#                     continue
#                 self.filtered_imgs_clean.append((img_name, img))
#
#             print("--------Origin image num is [%d], filtered result is [%d]--------" % (
#             len(self.loaded_imgs_clean), len(self.filtered_imgs_clean)))
#
#         else:
#             self.load_img_dir=os.path.join(self.dir_AB,opt.test_dataset)
#             self.loaded_imgs=BigFileMemoryLoader(self.load_img_dir)
#         #载入不规则mask
#         self.loaded_masks = BigFileMemoryLoader(opt.irregular_mask)
#
#         self.pid = os.getpid()
#
#     def __getitem__(self, index):
#
#
#
#         if self.opt.isTrain:
#             img_name_clean,B = self.filtered_imgs_clean[index]
#             path = os.path.join(self.load_img_dir_clean, img_name_clean)
#
#
#             B=transforms.RandomCrop(256)(B)
#             A=online_add_degradation_v2(B)
#             ### Remind: A is the input and B is corresponding GT
#
#         else:
#             img_name_A,A=self.loaded_imgs[index]
#             img_name_B,B=self.loaded_imgs[index]
#             path = os.path.join(self.load_img_dir, img_name_A)
#
#             #A=A.resize((256,256))
#             A=transforms.CenterCrop(256)(A)
#             B=A
#
#         if random.uniform(0,1)<0.1 and self.opt.isTrain:
#             A=A.convert("L")
#             B=B.convert("L")
#             A=A.convert("RGB")
#             B=B.convert("RGB")
#         ## In P, we convert the RGB into L
#
#         if self.opt.isTrain:
#             #载入mask
#             mask_name,mask=self.loaded_masks[random.randint(0,len(self.loaded_masks)-1)]
#         else:
#             # 载入mask
#             mask_name, mask = self.loaded_masks[index%100]
#         #调整mask大小
#         mask = mask.resize((self.opt.loadSize, self.opt.loadSize), Image.NEAREST)
#
#         if self.opt.random_hole and random.uniform(0,1)>0.5 and self.opt.isTrain:
#             mask=zero_mask(256)
#
#         if self.opt.no_hole:
#             mask=zero_mask(256)
#
#         #由mask合成带有折痕的图片
#         A,_=irregular_hole_synthesize(A,mask)
#
#         if not self.opt.isTrain and self.opt.hole_image_no_mask:
#             mask=zero_mask(256)
#         #获取做旧变换参数
#         transform_params = get_params(self.opt, A.size)
#         A_transform = get_transform(self.opt, transform_params)
#         B_transform = get_transform(self.opt, transform_params)
#         #对mask进行相同的左右翻转
#         if transform_params['flip'] and self.opt.isTrain:
#             mask=mask.transpose(Image.FLIP_LEFT_RIGHT)
#         #归一化
#         mask_tensor = transforms.ToTensor()(mask)
#
#
#         B_tensor = inst_tensor = feat_tensor = 0
#         A_tensor = A_transform(A)
#         B_tensor = B_transform(B)
#
#         input_dict = {'label': A_tensor, 'inst': mask_tensor[:1], 'image': B_tensor,
#                     'feat': feat_tensor, 'path': path}
#         return input_dict
#
#     def __len__(self):
#
#         if self.opt.isTrain:
#             return len(self.filtered_imgs_clean)
#
#         else:
#             return len(self.loaded_imgs)
#
#     def name(self):
#         return 'PairOldPhotos_with_hole'

用于去雾时,我改写得代码如下,增加了利用清晰图像和对应的深度图生成雾图的代码,合并至源代码中的online_dataset_for_odl_photos.py中。如下

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
 
import os.path
import io
import zipfile
from data.base_dataset import BaseDataset, get_params, get_transform, normalize
from data.image_folder import make_dataset
import torchvision.transforms as transforms
from data.Load_Bigfile import BigFileMemoryLoader
from data.Load_Bigfilev2 import BigFileMemoryLoaderv2
 
from io import BytesIO
import os
import glob
import cv2, math
import random
import numpy as np
import h5py
import os
from PIL import Image
import scipy.io
 
def pil_to_np(img_PIL):
    '''Converts image in PIL format to np.array.
    From W x H x C [0...255] to C x W x H [0..1]
    '''
    ar = np.array(img_PIL)
 
    if len(ar.shape) == 3:
        ar = ar.transpose(2, 0, 1)
    else:
        ar = ar[None, ...]
 
    return ar.astype(np.float32) / 255.
 
 
def np_to_pil(img_np):
    '''Converts image in np.array format to PIL image.
    From C x W x H [0..1] to  W x H x C [0...255]
    '''
    ar = np.clip(img_np * 255, 0, 255).astype(np.uint8)
 
    if img_np.shape[0] == 1:
        ar = ar[0]
    else:
        ar = ar.transpose(1, 2, 0)
 
    return Image.fromarray(ar)
 
def synthesize_salt_pepper(image,amount,salt_vs_pepper):
 
    ## Give PIL, return the noisy PIL
 
    img_pil=pil_to_np(image)
 
    out = img_pil.copy()
    p = amount
    q = salt_vs_pepper
    flipped = np.random.choice([True, False], size=img_pil.shape,
                               p=[p, 1 - p])
    salted = np.random.choice([True, False], size=img_pil.shape,
                              p=[q, 1 - q])
    peppered = ~salted
    out[flipped & salted] = 1
    out[flipped & peppered] = 0.
    noisy = np.clip(out, 0, 1).astype(np.float32)
 
 
    return np_to_pil(noisy)
 
def synthesize_gaussian(image,std_l,std_r):
 
    ## Give PIL, return the noisy PIL
 
    img_pil=pil_to_np(image)
 
    mean=0
    std=random.uniform(std_l/255.,std_r/255.)
    gauss=np.random.normal(loc=mean,scale=std,size=img_pil.shape)
    noisy=img_pil+gauss
    noisy=np.clip(noisy,0,1).astype(np.float32)
 
    return np_to_pil(noisy)
 
def synthesize_speckle(image,std_l,std_r):
 
    ## Give PIL, return the noisy PIL
 
    img_pil=pil_to_np(image)
 
    mean=0
    std=random.uniform(std_l/255.,std_r/255.)
    gauss=np.random.normal(loc=mean,scale=std,size=img_pil.shape)
    noisy=img_pil+gauss*img_pil
    noisy=np.clip(noisy,0,1).astype(np.float32)
 
    return np_to_pil(noisy)
 
 
def synthesize_low_resolution(img):
    w,h=img.size
 
    new_w=random.randint(int(w/2),w)
    new_h=random.randint(int(h/2),h)
 
    img=img.resize((new_w,new_h),Image.BICUBIC)
 
    if random.uniform(0,1)<0.5:
        img=img.resize((w,h),Image.NEAREST)
    else:
        img = img.resize((w, h), Image.BILINEAR)
 
    return img
 
 
 
 
def convertToJpeg(im,quality):
    with BytesIO() as f:
        im.save(f, format='JPEG',quality=quality)
        f.seek(0)
        return Image.open(f).convert('RGB')
 
 
def blur_image_v2(img):
 
 
    x=np.array(img)
    kernel_size_candidate=[(3,3),(5,5),(7,7)]
    kernel_size=random.sample(kernel_size_candidate,1)[0]
    std=random.uniform(1.,5.)
 
    #print("The gaussian kernel size: (%d,%d) std: %.2f"%(kernel_size[0],kernel_size[1],std))
    blur=cv2.GaussianBlur(x,kernel_size,std)
 
    return Image.fromarray(blur.astype(np.uint8))
def perlin_noise(im,varargin):
    """
        This is the function for adding perlin noise to the depth map. It is a
    simplified implementation of the paper:
    an image sunthesizer
    Ken Perlin, SIGGRAPH, Jul. 1985
    The bicubic interpolation is used, compared to the original version.
    Reference:
    HAZERD: an outdoor scene dataset and benchmark for single image dehazing
    IEEE International Conference on Image Processing, Sep 2017
    The paper and additional information on the project are available at:
    https://labsites.rochester.edu/gsharma/research/computer-vision/hazerd/
    If you use this code, please cite our paper.
    Input:
      im: depth map
      varargin{1}: decay term
    Output:
      im: result of transmission with perlin noise added
    Authors:
      Yanfu Zhang: yzh185@ur.rochester.edu
      Li Ding: l.ding@rochester.edu
      Gaurav Sharma: gaurav.sharma@rochester.edu
    Last update: May 2017
    :return:
    """
    # (h, w, c) = im.shape
    # i = 1
    # if nargin == 1:
    #     decay = 2
    # else:
    #     decay = varargin{1}
    # l_bound = min(h,w)
    # while i <= l_bound:
    #     d = imresize(randn(i, i)*decay, im.shape, 'bicubic')
    #     im = im+d
    #     i = i*2
    # im = c(im);
    # return im
    pass
 
def srgb2lrgb(I0):
    gamma = ((I0 + 0.055) / 1.055)**2.4
    scale = I0 / 12.92
    return np.where (I0 > 0.04045, gamma, scale)
 
def lrgb2srgb(I1):
    gamma =  1.055*I1**(1/2.4)-0.055
    scale = I1 * 12.92
    return np.where (I1 > 0.0031308, gamma, scale)
 
#return : depth matrix
def get_depth(depth_or_trans_name):
    #depth_or_trans_name为mat类型文件或者img类型文件地址
    data = scipy.io.loadmat(depth_or_trans_name)
    depths = data['imDepth'] #深度变量
    #print(data.keys())  #打印mat文件中所有变量
    depths = np.array(depths)
    return depths
 
 
 
 
 
def irregular_hole_synthesize(img,mask):
 
    img_np=np.array(img).astype('uint8')
    mask_np=np.array(mask).astype('uint8')
    mask_np=mask_np/255
    img_new=img_np*(1-mask_np)+mask_np*255
 
    hole_img=Image.fromarray(img_new.astype('uint8')).convert("RGB")
 
    return hole_img,mask.convert("L")
 
def zero_mask(size):
    x=np.zeros((size,size,3)).astype('uint8')
    mask=Image.fromarray(x).convert("RGB")
    return mask
 
def hazy_simu(img_name,depth_or_trans_name,airlight=0.76,is_imdepth=1): ##for outdoor
    """
    This is the function for haze simulation with the parameters given by
    the paper:
    HAZERD: an outdoor scene dataset and benchmark for single image dehazing
    IEEE Internation Conference on Image Processing, Sep 2017
    The paper and additional information on the project are available at:
    Https://labsites.rochester.edu/gsharma/research/computer-vision/hazerd/
    If you use this code, please cite our paper.
    IMPORTANT NOTE: The code uses the convention that pixel locations with a
    depth value of 0 correspond to objects that are very far and for the
    simulation of haze these are placed a distance of 2 times the visual
    range.
    Authors:
    Yanfu Zhang: yzh185@ur.rochester.edu
    Li Ding: l.ding@rochester.edu
    Gaurav Sharma: gaurav.sharma@rochester.edu
    Last update: May 2017
    python version update : Aug 2021
    Authors :
    Haoying Sun : 1913434222@qq.com
    parse inputs and set default values
    Set default parameter values. Some of these are used only if they are not
    passed in
    :param img_name: the directory and name of a haze-free RGB image, the name
                     should be in the format of ..._RGB.jpg
    :param depth_name: the corresponding directory and name of the depth map, in
                     .mat file, the name should be in the format of ..._depth.mat
    :param save_dir: the directory to save the simulated images
    :param pert_perlin: 1 for adding perlin noise, default 0
    :param airlight:  3*1 matrix in the range [0,1]
    :param visual_range: a vector of any size
    :return: image name of hazy image
    """
    # if random.uniform(0, 1) < 0.5:
    visual_range = [0.05, 0.1, 0.2, 0.5, 1]  #  visual range in km #可自行调整,或者使用range函数设置区间,此时需要修改beta_param,尚未研究
    beta_param = 3.912     #Default beta parameter corresponding to visual range of 1000m
 
    A = airlight
    #print('Simulating hazy image for:{}'.format(img_name))
    VR = random.choice(visual_range)
 
    #print('Viusal value: {} km'.format(VR) )
    #im1 = cv2.imread(img_name)
    img_pil = pil_to_np(img_name)
 
    #convert sRGB to linear RGB
    I = srgb2lrgb(img_pil)
 
    if is_imdepth:
        depths = depth_or_trans_name
 
        d = depths/1000   # convert meter to kilometer
        if depths.max()==0:
            d = np.where(d == 0,0.01, d) ####
        else:
            d = np.where(d==0,2*VR,d)
        #Set regions where depth value is set to 0 to indicate no valid depth to
        #a distance of two times the visual range. These regions typically
        #correspond to sky areas
 
        #convert depth map to transmission
        beta = beta_param / VR
        beta_return = beta
        beta = np.ones(d.shape) * beta
        transmission = np.exp((-beta*d))
        transmission_3 = np.array([transmission,transmission,transmission])
 
        #Obtain simulated linear RGB hazy image.Eq. 3 in the HazeRD paper
        Ic = transmission_3 * I + (1 - transmission_3) * A
    else:
        Ic = pil_to_np(depth_or_trans_name) * I + (1 - pil_to_np(depth_or_trans_name)) * A
 
    # convert linear RGB to sRGB
    I2 = lrgb2srgb(Ic)
    haze_img = np_to_pil(I2)
    # haze_img = np.asarray(haze_img)
    # haze_img = cv2.cvtColor(haze_img, cv2.COLOR_RGB2BGR)
    # haze_img = Image.fromarray(haze_img)
    return haze_img,airlight,beta_return
 
def hazy_reside_training(img_name,depth_or_trans_name,is_imdepth=1):
    """
    RESIDE的 training中:A :(0.7, 1.0) ,   beta:(0.6, 1.8)
    :param img_name:
    :param depth_or_trans_name:
    :param pert_perlin:
    :param is_imdepth:
    :return:
    """
    beta = random.uniform(0.6, 1.8)
    beta_return = beta
    airlight = random.uniform(0.7, 1.0)
 
    A = airlight
 
    #print('Viusal value: {} km'.format(VR) )
    #im1 = cv2.imread(img_name)
    img_pil = pil_to_np(img_name)
 
    #convert sRGB to linear RGB
    I = srgb2lrgb(img_pil)
 
    if is_imdepth:
        depths = depth_or_trans_name
 
        #convert depth map to transmission
        if depths.max()==0:
            d = np.where(depths == 0,1, depths)
        else:
            d = depths / depths.max()
            d = np.where(d == 0, 1, d)
 
        beta = np.ones(d.shape) * beta
        transmission = np.exp((-beta*d))
        transmission_3 = np.array([transmission,transmission,transmission])
 
        #Obtain simulated linear RGB hazy image.Eq. 3 in the HazeRD paper
        Ic = transmission_3 * I + (1 - transmission_3) * A
 
    else:
        Ic = pil_to_np(depth_or_trans_name) * I + (1 - pil_to_np(depth_or_trans_name)) * A
 
    # convert linear RGB to sRGB
    I2 = lrgb2srgb(Ic)
    #I2 = cv2.cvtColor(I2, cv2.COLOR_BGR2RGB)
 
    haze_img = np_to_pil(I2)
    # haze_img = np.asarray(haze_img)
    # haze_img = cv2.cvtColor(haze_img, cv2.COLOR_RGB2BGR)
    # haze_img = Image.fromarray(haze_img)
    return haze_img,airlight,beta_return
 
def hazy_reside_OTS(img_name,depth_or_trans_name,is_imdepth=1):
    """
    RESIDE的 OTS中:A [0.8, 0.85, 0.9, 0.95, 1] ,   beta:[0.04, 0.06, 0.08, 0.1, 0.12, 0.16, 0.2]
    :param img_name:
    :param depth_or_trans_name:
    :param pert_perlin:
    :param is_imdepth:
    :return:
    """
    beta = random.choice([0.04, 0.06, 0.08, 0.1, 0.12, 0.16, 0.2])
    beta_return = beta
    airlight = random.choice([0.8, 0.85, 0.9, 0.95, 1])
    #print(beta)
    #print(airlight)
    A = airlight
 
    #print('Viusal value: {} km'.format(VR) )
    #im1 = cv2.imread(img_name)
 
    #img = cv2.cvtColor(np.asarray(img_name), cv2.COLOR_RGB2BGR)
    img_pil = pil_to_np(img_name)
 
    #convert sRGB to linear RGB
    I = srgb2lrgb(img_pil)
 
    if is_imdepth:
        depths = depth_or_trans_name
        #convert depth map to transmission
        if depths.max()==0:
                d = np.where(depths == 0, 1, depths)
        else:
            d = depths/(depths.max())
            d = np.where(d == 0, 1, d)
 
        beta = np.ones(d.shape) * beta
        transmission = np.exp((-beta*d))
        transmission_3 = np.array([transmission,transmission,transmission])
 
        #Obtain simulated linear RGB hazy image.Eq. 3 in the HazeRD paper
        Ic = transmission_3 * I + (1 - transmission_3) * A
 
    else:
        Ic = pil_to_np(depth_or_trans_name) * I + (1 - pil_to_np(depth_or_trans_name)) * A
 
    # convert linear RGB to sRGB
    I2 = lrgb2srgb(Ic)
    haze_img = np_to_pil(I2)
 
    #haze_img = np.asarray(haze_img)
    #haze_img = cv2.cvtColor(haze_img, cv2.COLOR_RGB2BGR)
    #haze_img = Image.fromarray(haze_img)
    return haze_img,airlight,beta_return
def online_add_degradation_v2(img,depth_or_trans):
    noise = 0
    task_id=np.random.permutation(4)
    if random.uniform(0,1)<0.3:
        noise = 1
        #print('noise')
        for x in task_id:
            #为增加更多变化,随机进行30%的丢弃,即<0.7
            if x==0 and random.uniform(0,1)<0.7:
                img = blur_image_v2(img)
            if x==1 and random.uniform(0,1)<0.7:
                flag = random.choice([1, 2, 3])
                if flag == 1:
                    img = synthesize_gaussian(img, 5, 50) # Gaussian white noise with σ ∈ [5,50]
                if flag == 2:
                    img = synthesize_speckle(img, 5, 50)
                if flag == 3:
                    img = synthesize_salt_pepper(img, random.uniform(0, 0.01), random.uniform(0.3, 0.8))
            if x==2 and random.uniform(0,1)<0.7:
                img=synthesize_low_resolution(img)
 
            if x==3 and random.uniform(0,1)<0.7:
                img=convertToJpeg(img,random.randint(40,100))
                #JPEG compression whose level is in the range of [40,100]
    add_haze = random.choice([1,2,3])
    if add_haze == 1:
        img, airlight, beta  = hazy_reside_OTS(img, depth_or_trans)
    elif add_haze  == 2:
        img, airlight, beta  = hazy_simu(img, depth_or_trans)
    else:
        img, airlight, beta  = hazy_reside_training(img, depth_or_trans)
    # else:
    #     if add_haze < 0.1:
    #         img = hazy_reside_OTS(img, depth_or_trans)
    #     elif add_haze > 0.1 and add_haze < 0.2:
    #         img = hazy_simu(img, depth_or_trans)
    #     else:
    #         img = hazy_reside_training(img, depth_or_trans)
    return img#,noise,airlight,beta
 
 
class UnPairOldPhotos_SR(BaseDataset):  ## Synthetic + Real Old
    def initialize(self, opt):
        self.opt = opt
        self.isImage = 'domainA' in opt.name
        self.task = 'old_photo_restoration_training_vae'
        self.dir_AB = opt.dataroot
        if self.isImage:
 
            self.load_npy_dir_depth=os.path.join(self.dir_AB,"VOC_RGB_Depthnpy.bigfile")
            self.load_img_dir_RGB_old=os.path.join(self.dir_AB,"Real_RGB_old.bigfile")
            self.load_img_dir_clean=os.path.join(self.dir_AB,"VOC_RGB_JPEGImages.bigfile")
 
            self.loaded_npys_depth=BigFileMemoryLoaderv2(self.load_npy_dir_depth)
            self.loaded_imgs_RGB_old=BigFileMemoryLoader(self.load_img_dir_RGB_old)
            self.loaded_imgs_clean=BigFileMemoryLoader(self.load_img_dir_clean)
 
        else:
            # self.load_img_dir_clean=os.path.join(self.dir_AB,self.opt.test_dataset)
            self.load_img_dir_clean=os.path.join(self.dir_AB,"VOC_RGB_JPEGImages.bigfile")
            self.loaded_imgs_clean=BigFileMemoryLoader(self.load_img_dir_clean)
 
            self.load_npy_dir_depth=os.path.join(self.dir_AB,"VOC_RGB_Depthnpy.bigfile")
            self.loaded_npys_depth=BigFileMemoryLoaderv2(self.load_npy_dir_depth)
 
        ####
        print("-------------Filter the imgs whose size <256 in VOC-------------")
        self.filtered_imgs_clean=[]
        self.filtered_npys_depth = []
        for i in range(len(self.loaded_imgs_clean)):
            img_name,img=self.loaded_imgs_clean[i]
            npy_name, npy = self.loaded_npys_depth[i]
            h,w=img.size
            if h<256 or w<256:
                continue
            self.filtered_imgs_clean.append((img_name,img))
            self.filtered_npys_depth.append((npy_name, npy))
        print("--------Origin image num is [%d], filtered result is [%d]--------" % (
        len(self.loaded_imgs_clean), len(self.filtered_imgs_clean)))
        ## Filter these images whose size is less than 256
 
        # self.img_list=os.listdir(load_img_dir)
        self.pid = os.getpid()
 
    def __getitem__(self, index):
 
 
        is_real_old=0
 
        sampled_dataset=None
        sampled_depthdataset = None
        degradation=None
        if self.isImage: ## domain A , contains 2 kinds of data: synthetic + real_old
            P=random.uniform(0,2)
            if P>=0 and P<1:
                #if random.uniform(0,1)<0.5:
                    #  buyao  huidutu
                    #sampled_dataset=self.loaded_imgs_L_old
                    #self.load_img_dir=self.load_img_dir_L_old
                    sampled_dataset = self.loaded_imgs_RGB_old
                    self.load_img_dir = self.load_img_dir_RGB_old
 
 
                # else:
                #     sampled_dataset=self.loaded_imgs_RGB_old
                #     self.load_img_dir=self.load_img_dir_RGB_old
 
 
                    is_real_old=1
            if P>=1 and P<2:
                sampled_dataset=self.filtered_imgs_clean
                self.load_img_dir=self.load_img_dir_clean
 
                sampled_depthdataset=self.filtered_npys_depth
                self.load_npy_dir=self.load_npy_dir_depth
 
                degradation=1
        else:
 
            sampled_dataset=self.filtered_imgs_clean
            self.load_img_dir=self.load_img_dir_clean
 
            sampled_depthdataset = self.filtered_npys_depth
            self.load_npy_dir = self.load_npy_dir_depth
 
        sampled_dataset_len=len(sampled_dataset)
        #print('sampled_dataset_len::::',sampled_dataset_len)
        index=random.randint(0,sampled_dataset_len-1)
 
        img_name,img = sampled_dataset[index]
        # print(img_name)
        # print(img)
        # print(index)
 
        #print(npy_name)
        #print(npy)
        if degradation is not None:
            npy_name, npy = sampled_depthdataset[index]
            img=online_add_degradation_v2(img,npy)
        path=os.path.join(self.load_img_dir,img_name)
 
        # AB = Image.open(path).convert('RGB')
        # split AB image into A and B
 
        # apply the same transform to both A and B
 
        # if random.uniform(0,1) <0.1:
        #     img=img.convert("L")
        #     img=img.convert("RGB")
        #     ## Give a probability P, we convert the RGB image into L
 
 
        A=img
        w,h=A.size
        if w<256 or h<256:
            A=transforms.Scale(256,Image.BICUBIC)(A)
        ## Since we want to only crop the images (256*256), for those old photos whose size is smaller than 256, we first resize them.
 
        transform_params = get_params(self.opt, A.size)
        A_transform = get_transform(self.opt, transform_params)
 
        B_tensor = inst_tensor = feat_tensor = 0
        A_tensor = A_transform(A)
 
 
        input_dict = {'label': A_tensor, 'inst': is_real_old, 'image': A_tensor,
                        'feat': feat_tensor, 'path': path}
        return input_dict
 
    def __len__(self):
        return len(self.loaded_imgs_clean) ## actually, this is useless, since the selected index is just a random number
 
    def name(self):
        return 'UnPairOldPhotos_SR'
 
 
class PairOldPhotos(BaseDataset):
    def initialize(self, opt):
        self.opt = opt
        self.isImage = 'imagegan' in opt.name
        self.task = 'old_photo_restoration_training_mapping'
        self.dir_AB = opt.dataroot
        if opt.isTrain:
            self.load_img_dir_clean= os.path.join(self.dir_AB, "VOC_RGB_JPEGImages.bigfile")
            self.loaded_imgs_clean = BigFileMemoryLoader(self.load_img_dir_clean)
 
            self.load_npy_dir_depth= os.path.join(self.dir_AB, "VOC_RGB_Depthnpy.bigfile")
            self.loaded_npys_depth = BigFileMemoryLoaderv2(self.load_npy_dir_depth)
 
            print("-------------Filter the imgs whose size <256 in VOC-------------")
            self.filtered_imgs_clean = []
            self.filtered_npys_depth = []
            for i in range(len(self.loaded_imgs_clean)):
                img_name, img = self.loaded_imgs_clean[i]
                npy_name, npy = self.loaded_npys_depth[i]
                h, w = img.size
                if h < 256 or w < 256:
                    continue
                self.filtered_imgs_clean.append((img_name, img))
                self.filtered_npys_depth.append((npy_name, npy))
            print("--------Origin image num is [%d], filtered result is [%d]--------" % (
            len(self.loaded_imgs_clean), len(self.filtered_imgs_clean)))
 
        else:
            self.load_img_dir=os.path.join(self.dir_AB,opt.test_dataset)
            self.loaded_imgs=BigFileMemoryLoader(self.load_img_dir)
 
            self.load_depth_dir = os.path.join(self.dir_AB, opt.test_depthdataset)
            self.loaded_npys = BigFileMemoryLoaderv2(self.load_depth_dir)
 
 
        self.pid = os.getpid()
 
    def __getitem__(self, index):
 
 
 
        if self.opt.isTrain:
            img_name_clean,B = self.filtered_imgs_clean[index]
            npy_name_depth,D = self.filtered_npys_depth[index]
            path = os.path.join(self.load_img_dir_clean, img_name_clean)
            if self.opt.use_v2_degradation:
                A=online_add_degradation_v2(B,D)
            ### Remind: A is the input and B is corresponding GT
        else:
 
            if self.opt.test_on_synthetic:
 
                img_name_B,B=self.loaded_imgs[index]
                npy_name_D,D=self.loaded_npys[index]
                A=online_add_degradation_v2(B,D)
                A.save('../mybig_data/' + index + '.jpg')
 
                img_name_A=img_name_B
                path = os.path.join(self.load_img_dir, img_name_A)
            else:
                img_name_A,A=self.loaded_imgs[index]
                img_name_B,B=self.loaded_imgs[index]
                path = os.path.join(self.load_img_dir, img_name_A)
 
 
        # if random.uniform(0,1)<0.1 and self.opt.isTrain:
        #     A=A.convert("L")
        #     B=B.convert("L")
        #     A=A.convert("RGB")
        #     B=B.convert("RGB")
        # ## In P, we convert the RGB into L
 
 
        ##test on L
 
        # split AB image into A and B
        # w, h = img.size
        # w2 = int(w / 2)
        # A = img.crop((0, 0, w2, h))
        # B = img.crop((w2, 0, w, h))
        w,h=A.size
        if w<256 or h<256:
            A=transforms.Scale(256,Image.BICUBIC)(A)
            B=transforms.Scale(256, Image.BICUBIC)(B)
 
        # apply the same transform to both A and B
        transform_params = get_params(self.opt, A.size)
        A_transform = get_transform(self.opt, transform_params)
        B_transform = get_transform(self.opt, transform_params)
 
        B_tensor = inst_tensor = feat_tensor = 0
        A_tensor = A_transform(A)
        B_tensor = B_transform(B)
 
        input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor,
                    'feat': feat_tensor, 'path': path}
        return input_dict
 
    def __len__(self):
 
        if self.opt.isTrain:
            return len(self.filtered_imgs_clean)
        else:
            return len(self.loaded_imgs)
 
    def name(self):
        return 'PairOldPhotos'
 
#del
class PairOldPhotos_with_hole(BaseDataset):
    def initialize(self, opt):
        self.opt = opt
        self.isImage = 'imagegan' in opt.name
        self.task = 'old_photo_restoration_training_mapping'
        self.dir_AB = opt.dataroot
        if opt.isTrain:
            self.load_img_dir_clean= os.path.join(self.dir_AB, "VOC_RGB_JPEGImages.bigfile")
            self.loaded_imgs_clean = BigFileMemoryLoader(self.load_img_dir_clean)
 
            print("-------------Filter the imgs whose size <256 in VOC-------------")
            self.filtered_imgs_clean = []
            self.filtered_npys_depth = []
            for i in range(len(self.loaded_imgs_clean)):
                img_name, img = self.loaded_imgs_clean[i]
                npy_name, npy = self.loaded_npys_depth[i]
                h, w = img.size
                if h < 256 or w < 256:
                    continue
                self.filtered_imgs_clean.append((img_name, img))
                self.filtered_npys_depth.append((npy_name, npy))
            print("--------Origin image num is [%d], filtered result is [%d]--------" % (
            len(self.loaded_imgs_clean), len(self.filtered_imgs_clean)))
 
        else:
            self.load_img_dir=os.path.join(self.dir_AB,opt.test_dataset)
            self.loaded_imgs=BigFileMemoryLoader(self.load_img_dir)
 
            self.load_depth_dir = os.path.join(self.dir_AB, opt.test_depthdataset)
            self.loaded_npys = BigFileMemoryLoaderv2(self.load_depth_dir)
 
        self.loaded_masks = BigFileMemoryLoader(opt.irregular_mask)
 
        self.pid = os.getpid()
 
    def __getitem__(self, index):
 
 
 
        if self.opt.isTrain:
            img_name_clean,B = self.filtered_imgs_clean[index]
            npy_name_depth, D = self.filtered_npys_depth[index]
 
            path = os.path.join(self.load_img_dir_clean, img_name_clean)
 
            A=online_add_degradation_v2(B,D)
            B=transforms.RandomCrop(256)(B)
 
            ### Remind: A is the input and B is corresponding GT
 
        else:
            img_name_A,A=self.loaded_imgs[index]
            img_name_B,B=self.loaded_imgs[index]
            path = os.path.join(self.load_img_dir, img_name_A)
 
            #A=A.resize((256,256))
            A=transforms.CenterCrop(256)(A)
            B=A
 
        if random.uniform(0,1)<0.1 and self.opt.isTrain:
            A=A.convert("L")
            B=B.convert("L")
            A=A.convert("RGB")
            B=B.convert("RGB")
        ## In P, we convert the RGB into L
 
        if self.opt.isTrain:
            mask_name,mask=self.loaded_masks[random.randint(0,len(self.loaded_masks)-1)]
        else:
            mask_name, mask = self.loaded_masks[index%100]
        mask = mask.resize((self.opt.loadSize, self.opt.loadSize), Image.NEAREST)
 
        if self.opt.random_hole and random.uniform(0,1)>0.5 and self.opt.isTrain:
            mask=zero_mask(256)
 
        if self.opt.no_hole:
            mask=zero_mask(256)
 
 
        A,_=irregular_hole_synthesize(A,mask)
 
        if not self.opt.isTrain and self.opt.hole_image_no_mask:
            mask=zero_mask(256)
 
        transform_params = get_params(self.opt, A.size)
        A_transform = get_transform(self.opt, transform_params)
        B_transform = get_transform(self.opt, transform_params)
 
        if transform_params['flip'] and self.opt.isTrain:
            mask=mask.transpose(Image.FLIP_LEFT_RIGHT)
 
        mask_tensor = transforms.ToTensor()(mask)
 
 
        B_tensor = inst_tensor = feat_tensor = 0
        A_tensor = A_transform(A)
        B_tensor = B_transform(B)
 
        input_dict = {'label': A_tensor, 'inst': mask_tensor[:1], 'image': B_tensor,
                    'feat': feat_tensor, 'path': path}
        return input_dict
 
    def __len__(self):
 
        if self.opt.isTrain:
            return len(self.filtered_imgs_clean)
 
        else:
            return len(self.loaded_imgs)
 
    def name(self):
        return 'PairOldPhotos_with_hole'

把比较重要的改动写了下,以上在之前得博客中有的提到过。

训练测试

run.py里前面有三行存放了训练、测试、数据准备(请查看data文件夹里里的代码,可以不需要此部分)的代码,需要酌情修改。如下

############test############
#Python run.py --input_folder /home/vip/shy/HBDH/haze --output_folder /home/vip/shy/HBDH/l1-feat30-01 --GPU 0  
 
############dataset prepare#################
# python Create_Bigfile.py
 
############train A、B、mapping############
#python train_domain_A.py --use_v2_degradation --continue_train --training_dataset domain_A --name domainA_SR_old_photos --label_nc 0 --loadSize 256 --fineSize 256 --dataroot ../mybig_data/ --no_instance --resize_or_crop crop_only --batchSize 48 --no_html --gpu_ids 0,1 --self_gen --nThreads 4 --n_downsample_global 3 --k_size 4 --use_v2 --mc 64 --start_r 1 --kl 1 --no_cgan --outputs_dir your_output_folder --checkpoints_dir your_ckpt_folder
#python train_domain_B.py --continue_train --training_dataset domain_B --name domainB_old_photos --label_nc 0 --loadSize 256 --fineSize 256 --dataroot ../mybig_data/  --no_instance --resize_or_crop crop_only --batchSize 48 --no_html --gpu_ids 0,1 --self_gen --nThreads 4 --n_downsample_global 3 --k_size 4 --use_v2 --mc 64 --start_r 1 --kl 1 --no_cgan --outputs_dir your_output_folder  --checkpoints_dir your_ckpt_folder
#python train_mapping.py --use_v2_degradation --training_dataset mapping --use_vae_which_epoch latest --continue_train --name mapping_quality --label_nc 0 --loadSize 256 --fineSize 256 --dataroot ../mybig_data/ --no_instance --resize_or_crop crop_only --batchSize 16 --no_html --gpu_ids 0,1 --nThreads 8 --load_pretrainA ./your_ckpt_folder/domainA_SR_old_photos --load_pretrainB ./your_ckpt_folder/domainB_old_photos --l2_feat 60 --n_downsample_global 3 --mc 64 --k_size 4 --start_r 1 --mapping_n_block 6 --map_mc 512 --use_l1_feat --outputs_dir your_output_folder --checkpoints_dir your_ckpt_folder

数据集

注意,以下四个文件为我按照原始论文打包的训练集,其中VOC_RGB_Depthnpy.bigfile文件有两个,1.6G大小的为NYUv2中的深度矩阵(1399张),2.9G的为NYUv2中的深度矩阵和额外我添加的图像对应的深度矩阵(504张裁剪后的HAZERD数据集和85张我收集的天空图像)。VOC_RGB_JPEGImages.bigfile为深度矩阵对应的真实图像。

以上的bigfile文件具体截图如下:

VOC_RGB_JPEGImages.bigfile 1.2G如下

VOC_RGB_JPEGImages.bigfile 1.69G除了上面NYUv2外,还有我处理的(HAZERD和收集的天空)如下:

深度矩阵就是上面图像对应的深度npy文件。

另外还是真实雾图,需要自己下载,可以酌情滤掉过于低质的图像,我利用的是开源数据集reside beta

下载地址

根据个人需要下载,里面有些文件过大。

我把我的去雾的代码打包并发布,地址如下:

获取链接  提取码:Haze

加雾的代码其实很简单,就是把输入和输出反一下,然后让align部分对应的一行代码反一下。

以上就是Python实现图像去雾效果的示例代码的详细内容,更多关于Python图像去雾的资料请关注编程网其它相关文章!

--结束END--

本文标题: Python实现图像去雾效果的示例代码

本文链接: https://www.lsjlt.com/news/138307.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

本篇文章演示代码以及资料文档资料下载

下载Word文档到电脑,方便收藏和打印~

下载Word文档
猜你喜欢
  • Python实现图像去雾效果的示例代码
    目录修改部分训练测试数据集下载地址修改部分 我利用该代码进行了去雾任务,并对原始代码进行了增删,去掉了人脸提取并对提取人脸美化的部分,如下图 增改了一些数据处理代码,Create_...
    99+
    2022-11-13
  • python实现图像识别的示例代码
    一、安装库 首先我们需要安装PIL和pytesseract库。 PIL:(Python Imaging Library)是Python平台上的图像处理标准库,功能非常强大。 pyte...
    99+
    2022-11-11
  • Python实现屏幕代码雨效果的示例代码
    直接上代码 import pygame import random def main(): # 初始化pygame pygame.init() #...
    99+
    2022-11-13
  • JS实现图片翻书效果示例代码
    picture.html 复制代码 代码如下: <html xmlns="http://www.w3.org/1999/xhtml"> <head> <...
    99+
    2022-11-15
    JS 图片翻书
  • Python实现识别图像中人物的示例代码
    目录前言环境部署代码总结前言 接着上一篇:AI识别照片是谁,人脸识别face_recognition开源项目安装使用 根据项目提供的demo代码,调整了一下功能,自己写了一个识别人脸...
    99+
    2022-11-12
  • 基于Python实现烟花效果的示例代码
    python烟花代码 如下 # -*- coding: utf-8 -*- import math, random,time import threading import tki...
    99+
    2022-11-13
  • python实现字母闪烁效果的示例代码
    目录1. 介绍2. 完整代码效果图 1. 介绍 屏幕上随机闪烁的代码块,一定能满足我们对于电影中黑客的一丝丝设想,这次,让我们用简简单单的30行python代码,实现这个效果。 前...
    99+
    2022-11-11
  • openCV实现图像融合的示例代码
    目录1. 概念2. 流程3 代码1. 概念 图像融合: 两幅图片叠加在一起,形成前景背景的效果。 2. 流程 (1)读入要融合的两幅图片。(2)把两幅图片调整到统一大小,方便下一步叠...
    99+
    2022-11-13
  • 图片去摩尔纹简述实现python代码示例
    目录1、前言2、网络结构复现3、数据预处理4、模型训练总结1、前言 当感光元件像素的空间频率与影像中条纹的空间频率接近时,可能产生一种新的波浪形的干扰图案,即所谓的摩尔纹。传感器的...
    99+
    2023-02-24
    python 图片去摩尔纹 python 去摩尔纹
  • 基于Python实现千图成像工具的示例代码
    目录前置GUI制作GUI界面设计逻辑设计图片处理修改底图大小修改组图大小计算图片填充次数组图合成图片合成GUI打包千图成像也就是用N张图片组成一张图片的效果。制作方法有很多的,最常见...
    99+
    2022-11-11
  • Python K-means实现简单图像聚类的示例代码
    这里直接给出第一个版本的直接实现: import os import numpy as np from sklearn.cluster import KMeans import ...
    99+
    2022-11-12
  • AndroidFlutter实现点赞效果的示例代码
    目录前言绘制小手完整源码前言 点赞这个动作不得不说在社交、短视频等App中实在是太常见了,当用户手指按下去的那一刻,给用户一个好的反馈效果也是非常重要的,这样用户点起赞来才会有一种强...
    99+
    2022-11-13
  • Android使用API实现图像扭曲效果示例
    本文实例讲述了Android使用API实现图像扭曲效果。分享给大家供大家参考,具体如下:public class BitmapMesh extends GraphicsActivity { @Override protected void ...
    99+
    2023-05-30
    android api 图像
  • JavaScript实现流星雨效果的示例代码
    目录演示技术栈源码首先建立星星对象让星星闪亮起来创建流星雨对象让流星动起来演示 上一次做了一个雨滴的动画,顺着这种思维正好可以改成流星雨,嘿嘿我真是一个小机灵。 技术栈 还是先建立...
    99+
    2022-11-13
  • Unity实现跑马灯效果的示例代码
    目录一、效果二、需要动画插件DOTween三、脚本1.每个格子上的脚本文件2.管理脚本文件一、效果 二、需要动画插件DOTween 下载地址 三、脚本 1.每个格子上的脚本文件 u...
    99+
    2022-11-13
  • C#实现跑马灯效果的示例代码
    目录文章描述开发环境开发工具实现代码实现效果文章描述 跑马灯效果,功能效果大家应该都知道,就是当我们的文字过长,整个页面放不下的时候(一般用于公告等),可以让它自动实现来回滚动,以让...
    99+
    2022-11-13
    C#实现跑马灯效果 C# 跑马灯
  • jquery实现div阴影效果示例代码
    复制代码 代码如下: <html> <head> <style> .mydiv1 {height:250px;width:250px;border...
    99+
    2022-11-15
    jquery div阴影
  • Python实现为PDF去除水印的示例代码
    目录前言原理特色成果安装依赖代码想法前言 为什么做出这个? 就是有时候从网上下载的资料中的pdf有水印,看着不舒服。 比如说我从网上下载的试卷,然后去打印店打印,打印之后水印看着很不...
    99+
    2022-11-13
  • OpenCV 图像分割实现Kmean聚类的示例代码
    目录1 Kmean图像分割2 流程3 实现1 Kmean图像分割 按照Kmean原理,对图像像素进行聚类。优点:此方法原理简单,效果显著。缺点:实践发现对于前景和背景颜色相近或者颜色...
    99+
    2022-11-13
  • Python实现一键抠图的示例代码
    目录需求来源实现方法需求来源 好友 A:橡皮擦,可否提供网页,上传带人像的图片,然后可以直接抠图,最好直接生成 PNG 图片下载。 橡皮擦:每天需要调用多少次? 好友 A:大概 10...
    99+
    2022-11-11
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作