|  | @@ -20,14 +20,14 @@ import random
 | 
	
		
			
				|  |  |  from collections import OrderedDict
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import numpy as np
 | 
	
		
			
				|  |  | -from paddle.io import Dataset
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -from paddlers.utils import logging, get_num_workers, get_encoding, path_normalization, is_pic
 | 
	
		
			
				|  |  | +from .base import BaseDataset
 | 
	
		
			
				|  |  | +from paddlers.utils import logging, get_encoding, path_normalization, is_pic
 | 
	
		
			
				|  |  |  from paddlers.transforms import DecodeImg, MixupImage
 | 
	
		
			
				|  |  |  from paddlers.tools import YOLOAnchorCluster
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -class COCODetection(Dataset):
 | 
	
		
			
				|  |  | +class COCODetection(BaseDataset):
 | 
	
		
			
				|  |  |      """读取COCO格式的检测数据集,并对样本进行相应的处理。
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      Args:
 | 
	
	
		
			
				|  | @@ -35,7 +35,7 @@ class COCODetection(Dataset):
 | 
	
		
			
				|  |  |          image_dir (str): 描述数据集图片文件路径。
 | 
	
		
			
				|  |  |          anno_path (str): COCO标注文件路径。
 | 
	
		
			
				|  |  |          label_list (str): 描述数据集包含的类别信息文件路径。
 | 
	
		
			
				|  |  | -        transforms (paddlers.det.transforms): 数据集中每个样本的预处理/增强算子。
 | 
	
		
			
				|  |  | +        transforms (paddlers.transforms.Compose): 数据集中每个样本的预处理/增强算子。
 | 
	
		
			
				|  |  |          num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数。默认为'auto'。当设为'auto'时,根据
 | 
	
		
			
				|  |  |              系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8,则`num_workers`为8,否则为CPU核数的
 | 
	
		
			
				|  |  |              一半。
 | 
	
	
		
			
				|  | @@ -60,10 +60,10 @@ class COCODetection(Dataset):
 | 
	
		
			
				|  |  |          import matplotlib
 | 
	
		
			
				|  |  |          matplotlib.use('Agg')
 | 
	
		
			
				|  |  |          from pycocotools.coco import COCO
 | 
	
		
			
				|  |  | -        super(COCODetection, self).__init__()
 | 
	
		
			
				|  |  | -        self.data_dir = data_dir
 | 
	
		
			
				|  |  | +        super(COCODetection, self).__init__(data_dir, label_list, transforms,
 | 
	
		
			
				|  |  | +                                            num_workers, shuffle)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |          self.data_fields = None
 | 
	
		
			
				|  |  | -        self.transforms = copy.deepcopy(transforms)
 | 
	
		
			
				|  |  |          self.num_max_boxes = 50
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          self.use_mix = False
 | 
	
	
		
			
				|  | @@ -76,8 +76,6 @@ class COCODetection(Dataset):
 | 
	
		
			
				|  |  |                      break
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          self.batch_transforms = None
 | 
	
		
			
				|  |  | -        self.num_workers = get_num_workers(num_workers)
 | 
	
		
			
				|  |  | -        self.shuffle = shuffle
 | 
	
		
			
				|  |  |          self.allow_empty = allow_empty
 | 
	
		
			
				|  |  |          self.empty_ratio = empty_ratio
 | 
	
		
			
				|  |  |          self.file_list = list()
 |