数据翻转,需要把bbox相应的坐标值也进行交换
代码:
import random
from torchvision.transforms import functional as Fclass Compose(object):"""组合多个transform函数"""def __init__(self, transforms):self.transforms = transformsdef __call__(self, image, target):for t in self.transforms:image, target = t(image, target)return image, targetclass ToTensor(object):"""将PIL图像转为Tensor"""def __call__(self, image, target):image = F.to_tensor(image)return image, targetclass RandomHorizontalFlip(object):"""随机水平翻转图像以及bboxes"""def __init__(self, prob=0.5):self.prob = probdef __call__(self, image, target):if random.random() < self.prob:height, width = image.shape[-2:]image = image.flip(-1) # 水平翻转图片bbox = target["boxes"]# bbox: xmin, ymin, xmax, ymaxbbox[:, [0, 2]] = width - bbox[:, [2, 0]] # 翻转对应bbox坐标信息target["boxes"] = bboxreturn image, target
对图像及其对应的标注文件(XML格式)进行数据增强,并将增强后的图像和标注文件保存到指定的目录中
root
:XML文件所在的目录路径。image_id
:XML文件的名称(不包含扩展名)。
代码:
import xml.etree.ElementTree as ET
import pickle
import os
from os import getcwd
import numpy as np
from PIL import Image
import shutil
import matplotlib.pyplot as pltimport imgaug as ia
from imgaug import augmenters as iaaia.seed(1)def read_xml_annotation(root, image_id):in_file = open(os.path.join(root, image_id))tree = ET.parse(in_file)root = tree.getroot()bndboxlist = []for object in root.findall('object'): # 找到root节点下的所有country节点bndbox = object.find('bndbox') # 子节点下节点rank的值xmin = int(bndbox.find('xmin').text)xmax = int(bndbox.find('xmax').text)ymin = int(bndbox.find('ymin').text)ymax = int(bndbox.find('ymax').text)# print(xmin,ymin,xmax,ymax)bndboxlist.append([xmin, ymin, xmax, ymax])# print(bndboxlist)bndbox = root.find('object').find('bndbox')return bndboxlist# (506.0000, 330.0000, 528.0000, 348.0000) -> (520.4747, 381.5080, 540.5596, 398.6603)
def change_xml_annotation(root, image_id, new_target):new_xmin = new_target[0]new_ymin = new_target[1]new_xmax = new_target[2]new_ymax = new_target[3]in_file = open(os.path.join(root, str(image_id) + '.xml')) # 这里root分别由两个意思tree = ET.parse(in_file)xmlroot = tree.getroot()object = xmlroot.find('object')bndbox = object.find('bndbox')xmin = bndbox.find('xmin')xmin.text = str(new_xmin)ymin = bndbox.find('ymin')ymin.text = str(new_ymin)xmax = bndbox.find('xmax')xmax.text = str(new_xmax)ymax = bndbox.find('ymax')ymax.text = str(new_ymax)tree.write(os.path.join(root, str("%06d" % (str(id) + '.xml'))))def change_xml_list_annotation(root, image_id, new_target, saveroot, id):in_file = open(os.path.join(root, str(image_id) + '.xml')) # 这里root分别由两个意思tree = ET.parse(in_file)elem = tree.find('filename')elem.text = (id + '.jpg')xmlroot = tree.getroot()index = 0for object in xmlroot.findall('object'): # 找到root节点下的所有country节点bndbox = object.find('bndbox') # 子节点下节点rank的值# xmin = int(bndbox.find('xmin').text)# xmax = int(bndbox.find('xmax').text)# ymin = int(bndbox.find('ymin').text)# ymax = int(bndbox.find('ymax').text)new_xmin = new_target[index][0]new_ymin = new_target[index][1]new_xmax = new_target[index][2]new_ymax = new_target[index][3]xmin = bndbox.find('xmin')xmin.text = str(new_xmin)ymin = bndbox.find('ymin')ymin.text = str(new_ymin)xmax = bndbox.find('xmax')xmax.text = str(new_xmax)ymax = bndbox.find('ymax')ymax.text = str(new_ymax)index = index + 1tree.write(os.path.join(saveroot, id + '.xml'))def mkdir(path):# 去除首位空格path = path.strip()# 去除尾部 \ 符号path = path.rstrip("\\")# 判断路径是否存在# 存在 True# 不存在 FalseisExists = os.path.exists(path)# 判断结果if not isExists:# 如果不存在则创建目录# 创建目录操作函数os.makedirs(path)print(path + ' 创建成功')return Trueelse:# 如果目录存在则不创建,并提示目录已存在print(path + ' 目录已存在')return Falseif __name__ == "__main__":IMG_DIR = "VOCdevkit/VOC2007/JPEGImages3"XML_DIR = "VOCdevkit/VOC2007/Annotations3"AUG_XML_DIR = "VOCdevkit/VOC2007/Annotations" # 存储增强后的XML文件夹路径try:shutil.rmtree(AUG_XML_DIR)except FileNotFoundError as e:a = 1mkdir(AUG_XML_DIR)AUG_IMG_DIR = "VOCdevkit/VOC2007/JPEGImages" # 存储增强后的影像文件夹路径try:shutil.rmtree(AUG_IMG_DIR)except FileNotFoundError as e:a = 1mkdir(AUG_IMG_DIR)AUGLOOP = 8 # 每张影像增强的数量boxes_img_aug_list = []new_bndbox = []new_bndbox_list = []# 影像增强seq = iaa.Sequential([iaa.Flipud(0.5), # vertically flip 20% of all imagesiaa.Fliplr(0.5), # 镜像iaa.Multiply((1.2, 1.5)), # change brightness, doesn't affect BBsiaa.GaussianBlur(sigma=(0, 3.0)), # iaa.GaussianBlur(0.5),iaa.Affine(translate_px={"x": 15, "y": 15},scale=(0.8, 0.95),rotate=(-30, 30)) # translate by 40/60px on x/y axis, and scale to 50-70%, affects BBs])for root, sub_folders, files in os.walk(XML_DIR):for name in files:bndbox = read_xml_annotation(XML_DIR, name)shutil.copy(os.path.join(XML_DIR, name), AUG_XML_DIR)shutil.copy(os.path.join(IMG_DIR, name[:-4] + '.jpg'), AUG_IMG_DIR)for epoch in range(AUGLOOP):seq_det = seq.to_deterministic() # 保持坐标和图像同步改变,而不是随机# 读取图片img = Image.open(os.path.join(IMG_DIR, name[:-4] + '.jpg'))# sp = img.sizeimg = np.asarray(img)# bndbox 坐标增强for i in range(len(bndbox)):bbs = ia.BoundingBoxesOnImage([ia.BoundingBox(x1=bndbox[i][0], y1=bndbox[i][1], x2=bndbox[i][2], y2=bndbox[i][3]),], shape=img.shape)bbs_aug = seq_det.augment_bounding_boxes([bbs])[0]boxes_img_aug_list.append(bbs_aug)# new_bndbox_list:[[x1,y1,x2,y2],...[],[]]n_x1 = int(max(1, min(img.shape[1], bbs_aug.bounding_boxes[0].x1)))n_y1 = int(max(1, min(img.shape[0], bbs_aug.bounding_boxes[0].y1)))n_x2 = int(max(1, min(img.shape[1], bbs_aug.bounding_boxes[0].x2)))n_y2 = int(max(1, min(img.shape[0], bbs_aug.bounding_boxes[0].y2)))if n_x1 == 1 and n_x1 == n_x2:n_x2 += 1if n_y1 == 1 and n_y2 == n_y1:n_y2 += 1if n_x1 >= n_x2 or n_y1 >= n_y2:print('error', name)new_bndbox_list.append([n_x1, n_y1, n_x2, n_y2])# 存储变化后的图片image_aug = seq_det.augment_images([img])[0]path = os.path.join(AUG_IMG_DIR,str("%06d" % (len(files)*epoch))+ name[:-4] + '.jpg')image_auged = bbs.draw_on_image(image_aug, thickness=0)Image.fromarray(image_auged).save(path)# 存储变化后的XMLchange_xml_list_annotation(XML_DIR, name[:-4], new_bndbox_list, AUG_XML_DIR,str("%06d" % (len(files)*epoch))+ name[:-4])print(str("%06d" % (len(files)*epoch))+ name[:-4] + '.jpg')new_bndbox_list = []
代码结构解读:
1. 导入模块
import xml.etree.ElementTree as ET
import pickle
import os
from os import getcwd
import numpy as np
from PIL import Image
import shutil
import matplotlib.pyplot as pltimport imgaug as ia
from imgaug import augmenters as iaa
xml.etree.ElementTree
:用于解析和操作XML文件。numpy
和PIL
:用于图像处理。imgaug
:用于图像增强。其他模块用于文件操作和路径管理。
2. 数据增强的随机种子
设置随机种子,确保每次运行代码时增强操作的一致性。
ia.seed(1)
3. 读取XML标注文件
def read_xml_annotation(root, image_id):in_file = open(os.path.join(root, image_id))tree = ET.parse(in_file)root = tree.getroot()bndboxlist = []for object in root.findall('object'):bndbox = object.find('bndbox')xmin = int(bndbox.find('xmin').text)xmax = int(bndbox.find('xmax').text)ymin = int(bndbox.find('ymin').text)ymax = int(bndbox.find('ymax').text)bndboxlist.append([xmin, ymin, xmax, ymax])return bndboxlist
输入:XML文件所在的目录和文件名。
功能:解析XML文件,提取所有目标对象的边界框坐标。
输出:边界框列表,每个边界框用
[xmin, ymin, xmax, ymax]
表示。
4. 更新单个XML标注文件
def change_xml_annotation(root, image_id, new_target):new_xmin, new_ymin, new_xmax, new_ymax = new_targetin_file = open(os.path.join(root, str(image_id) + '.xml'))tree = ET.parse(in_file)xmlroot = tree.getroot()object = xmlroot.find('object')bndbox = object.find('bndbox')xmin = bndbox.find('xmin')xmin.text = str(new_xmin)ymin = bndbox.find('ymin')ymin.text = str(new_ymin)xmax = bndbox.find('xmax')xmax.text = str(new_xmax)ymax = bndbox.find('ymax')ymax.text = str(new_ymax)tree.write(os.path.join(root, str("%06d" % (str(id) + '.xml'))))
输入:XML文件所在的目录、文件名和新的边界框坐标。
功能:更新XML文件中第一个目标对象的边界框坐标。
输出:保存更新后的XML文件。
5. 更新多个XML标注文件
def change_xml_list_annotation(root, image_id, new_target, saveroot, id):in_file = open(os.path.join(root, str(image_id) + '.xml'))tree = ET.parse(in_file)elem = tree.find('filename')elem.text = (id + '.jpg')xmlroot = tree.getroot()index = 0for object in xmlroot.findall('object'):bndbox = object.find('bndbox')new_xmin = new_target[index][0]new_ymin = new_target[index][1]new_xmax = new_target[index][2]new_ymax = new_target[index][3]xmin = bndbox.find('xmin')xmin.text = str(new_xmin)ymin = bndbox.find('ymin')ymin.text = str(new_ymin)xmax = bndbox.find('xmax')xmax.text = str(new_xmax)ymax = bndbox.find('ymax')ymax.text = str(new_ymax)index += 1tree.write(os.path.join(saveroot, id + '.xml'))
输入:原始XML目录、文件名、新的边界框列表、保存目录和新的文件名。
功能:更新XML文件中所有目标对象的边界框坐标。
输出:保存更新后的XML文件。
6. 创建目录
def mkdir(path):path = path.strip()path = path.rstrip("\\")isExists = os.path.exists(path)if not isExists:os.makedirs(path)print(path + ' 创建成功')return Trueelse:print(path + ' 目录已存在')return False
输入:目标目录路径。
功能:创建目录,如果目录已存在,则提示。
7. 主程序
if __name__ == "__main__":IMG_DIR = "VOCdevkit/VOC2007/JPEGImages3"XML_DIR = "VOCdevkit/VOC2007/Annotations3"AUG_XML_DIR = "VOCdevkit/VOC2007/Annotations"try:shutil.rmtree(AUG_XML_DIR)except FileNotFoundError as e:passmkdir(AUG_XML_DIR)AUG_IMG_DIR = "VOCdevkit/VOC2007/JPEGImages"try:shutil.rmtree(AUG_IMG_DIR)except FileNotFoundError as e:passmkdir(AUG_IMG_DIR)AUGLOOP = 8 # 每张影像增强的数量seq = iaa.Sequential([iaa.Flipud(0.5), # 垂直翻转iaa.Fliplr(0.5), # 水平翻转iaa.Multiply((1.2, 1.5)), # 调整亮度iaa.GaussianBlur(sigma=(0, 3.0)), # 高斯模糊iaa.Affine(translate_px={"x": 15, "y": 15},scale=(0.8, 0.95),rotate=(-30, 30)) # 平移、缩放、旋转])for root, sub_folders, files in os.walk(XML_DIR):for name in files:bndbox = read_xml_annotation(XML_DIR, name)shutil.copy(os.path.join(XML_DIR, name), AUG_XML_DIR)shutil.copy(os.path.join(IMG_DIR, name[:-4] + '.jpg'), AUG_IMG_DIR)for epoch in range(AUGLOOP):seq_det = seq.to_deterministic()img = Image.open(os.path.join(IMG_DIR, name[:-4] + '.jpg'))img = np.asarray(img)for i in range(len(bndbox)):bbs = ia.BoundingBoxesOnImage([ia.BoundingBox(x1=bndbox[i][0], y1=bndbox[i][1], x2=bndbox[i][2], y2=bndbox[i][3]),], shape=img.shape)bbs_aug = seq_det.augment_bounding_boxes([bbs])[0]n_x1 = int(max(1, min(img.shape[1], bbs_aug.bounding_boxes[0].x1)))n_y1 = int(max(1, min(img.shape[0], bbs_aug.bounding_boxes[0].y1)))n_x2 = int(max(1, min(img.shape[1], bbs_aug.bounding_boxes[0].x2)))n_y2 = int(max(1, min(img.shape[0], bbs_aug.bounding_boxes[0].y2)))if n_x1 == 1 and n_x1 == n_x2:n_x2 += 1if n_y1 == 1 and n_y2 == n_y1:n_y2 += 1if n_x1 >= n_x2 or n_y1 >= n_y2:print('error', name)new_bndbox_list.append([n_x1, n_y1, n_x2, n_y2])image_aug = seq_det.augment_images([img])[0]path = os.path.join(AUG_IMG_DIR, str("%06d" % (len(files) * epoch)) + name