|
@@ -20,14 +20,14 @@ import random
|
|
from collections import OrderedDict
|
|
from collections import OrderedDict
|
|
|
|
|
|
import numpy as np
|
|
import numpy as np
|
|
-from paddle.io import Dataset
|
|
|
|
|
|
|
|
-from paddlers.utils import logging, get_num_workers, get_encoding, path_normalization, is_pic
|
|
|
|
|
|
+from .base import BaseDataset
|
|
|
|
+from paddlers.utils import logging, get_encoding, path_normalization, is_pic
|
|
from paddlers.transforms import DecodeImg, MixupImage
|
|
from paddlers.transforms import DecodeImg, MixupImage
|
|
from paddlers.tools import YOLOAnchorCluster
|
|
from paddlers.tools import YOLOAnchorCluster
|
|
|
|
|
|
|
|
|
|
-class COCODetection(Dataset):
|
|
|
|
|
|
+class COCODetection(BaseDataset):
|
|
"""读取COCO格式的检测数据集,并对样本进行相应的处理。
|
|
"""读取COCO格式的检测数据集,并对样本进行相应的处理。
|
|
|
|
|
|
Args:
|
|
Args:
|
|
@@ -35,7 +35,7 @@ class COCODetection(Dataset):
|
|
image_dir (str): 描述数据集图片文件路径。
|
|
image_dir (str): 描述数据集图片文件路径。
|
|
anno_path (str): COCO标注文件路径。
|
|
anno_path (str): COCO标注文件路径。
|
|
label_list (str): 描述数据集包含的类别信息文件路径。
|
|
label_list (str): 描述数据集包含的类别信息文件路径。
|
|
- transforms (paddlers.det.transforms): 数据集中每个样本的预处理/增强算子。
|
|
|
|
|
|
+ transforms (paddlers.transforms.Compose): 数据集中每个样本的预处理/增强算子。
|
|
num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数。默认为'auto'。当设为'auto'时,根据
|
|
num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数。默认为'auto'。当设为'auto'时,根据
|
|
系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的
|
|
系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的
|
|
一半。
|
|
一半。
|
|
@@ -60,10 +60,10 @@ class COCODetection(Dataset):
|
|
import matplotlib
|
|
import matplotlib
|
|
matplotlib.use('Agg')
|
|
matplotlib.use('Agg')
|
|
from pycocotools.coco import COCO
|
|
from pycocotools.coco import COCO
|
|
- super(COCODetection, self).__init__()
|
|
|
|
- self.data_dir = data_dir
|
|
|
|
|
|
+ super(COCODetection, self).__init__(data_dir, label_list, transforms,
|
|
|
|
+ num_workers, shuffle)
|
|
|
|
+
|
|
self.data_fields = None
|
|
self.data_fields = None
|
|
- self.transforms = copy.deepcopy(transforms)
|
|
|
|
self.num_max_boxes = 50
|
|
self.num_max_boxes = 50
|
|
|
|
|
|
self.use_mix = False
|
|
self.use_mix = False
|
|
@@ -76,8 +76,6 @@ class COCODetection(Dataset):
|
|
break
|
|
break
|
|
|
|
|
|
self.batch_transforms = None
|
|
self.batch_transforms = None
|
|
- self.num_workers = get_num_workers(num_workers)
|
|
|
|
- self.shuffle = shuffle
|
|
|
|
self.allow_empty = allow_empty
|
|
self.allow_empty = allow_empty
|
|
self.empty_ratio = empty_ratio
|
|
self.empty_ratio = empty_ratio
|
|
self.file_list = list()
|
|
self.file_list = list()
|