| 
					
				 | 
			
			
				@@ -180,22 +180,28 @@ class DecodeImg(Transform): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         decode_sar (bool, optional): If True, automatically interpret a two-channel  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             geo image (e.g. geotiff images) as a SAR image, set this argument to  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             True. Defaults to True. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        read_geo_info (bool, optional): If True, read geographical information from  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            the image. Deafults to False. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def __init__(self, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                  to_rgb=True, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                  to_uint8=True, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                  decode_bgr=True, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                 decode_sar=True): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                 decode_sar=True, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                 read_geo_info=False): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         super(DecodeImg, self).__init__() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self.to_rgb = to_rgb 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self.to_uint8 = to_uint8 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self.decode_bgr = decode_bgr 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self.decode_sar = decode_sar 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.read_geo_info = False 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def read_img(self, img_path): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         img_format = imghdr.what(img_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         name, ext = os.path.splitext(img_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        geo_trans, geo_proj = None, None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if img_format == 'tiff' or ext == '.img': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             try: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 import gdal 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -209,7 +215,7 @@ class DecodeImg(Transform): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             dataset = gdal.Open(img_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             if dataset == None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                raise IOError('Can not open', img_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                raise IOError('Cannot open', img_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             im_data = dataset.ReadAsArray() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             if im_data.ndim == 2 and self.decode_sar: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 im_data = to_intensity(im_data)  # is read SAR 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -217,26 +223,38 @@ class DecodeImg(Transform): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 if im_data.ndim == 3: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     im_data = im_data.transpose((1, 2, 0)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            return im_data 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if self.read_geo_info: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                geo_trans = dataset.GetGeoTransform() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                geo_proj = dataset.GetGeoProjection() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         elif img_format in ['jpeg', 'bmp', 'png', 'jpg']: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             if self.decode_bgr: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                return cv2.imread(img_path, cv2.IMREAD_ANYDEPTH | 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                                  cv2.IMREAD_ANYCOLOR | cv2.IMREAD_COLOR) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                im_data = cv2.imread(img_path, cv2.IMREAD_ANYDEPTH | 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                     cv2.IMREAD_ANYCOLOR | cv2.IMREAD_COLOR) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                return cv2.imread(img_path, cv2.IMREAD_ANYDEPTH | 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                                  cv2.IMREAD_ANYCOLOR) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                im_data = cv2.imread(img_path, cv2.IMREAD_ANYDEPTH | 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                     cv2.IMREAD_ANYCOLOR) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         elif ext == '.npy': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            return np.load(img_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            im_data = np.load(img_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             raise TypeError("Image format {} is not supported!".format(ext)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if self.read_geo_info: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            return im_data, geo_trans, geo_proj 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            return im_data 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def apply_im(self, im_path): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if isinstance(im_path, str): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             try: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                image = self.read_img(im_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                data = self.read_img(im_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             except: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 raise ValueError("Cannot read the image file {}!".format( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     im_path)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if self.read_geo_info: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                image, geo_trans, geo_proj = data 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                geo_info_dict = {'geo_trans': geo_trans, 'geo_proj': geo_proj} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                image = data 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             image = im_path 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -246,7 +264,10 @@ class DecodeImg(Transform): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if self.to_uint8: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             image = to_uint8(image) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        return image 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if self.read_geo_info: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            return image, geo_info_dict 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            return image 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def apply_mask(self, mask): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         try: 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -269,15 +290,37 @@ class DecodeImg(Transform): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if 'image' in sample: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            sample['image_ori'] = copy.deepcopy(sample['image']) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            sample['image'] = self.apply_im(sample['image']) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if self.read_geo_info: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                image, geo_info_dict = self.apply_im(sample['image']) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                sample['image'] = image 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                sample['geo_info_dict'] = geo_info_dict 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                sample['image'] = self.apply_im(sample['image']) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if 'image2' in sample: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            sample['image2'] = self.apply_im(sample['image2']) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if self.read_geo_info: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                image2, geo_info_dict2 = self.apply_im(sample['image2']) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                sample['image2'] = image2 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                sample['geo_info_dict2'] = geo_info_dict2 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                sample['image2'] = self.apply_im(sample['image2']) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if 'image_t1' in sample and not 'image' in sample: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             if not ('image_t2' in sample and 'image2' not in sample): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 raise ValueError 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            sample['image'] = self.apply_im(sample['image_t1']) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            sample['image2'] = self.apply_im(sample['image_t2']) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if self.read_geo_info: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                image, geo_info_dict = self.apply_im(sample['image_t1']) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                sample['image'] = image 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                sample['geo_info_dict'] = geo_info_dict 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                sample['image'] = self.apply_im(sample['image_t1']) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if self.read_geo_info: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                image2, geo_info_dict2 = self.apply_im(sample['image_t2']) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                sample['image2'] = image2 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                sample['geo_info_dict2'] = geo_info_dict2 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                sample['image2'] = self.apply_im(sample['image_t2']) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if 'mask' in sample: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             sample['mask_ori'] = copy.deepcopy(sample['mask']) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             sample['mask'] = self.apply_mask(sample['mask']) 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -286,6 +329,7 @@ class DecodeImg(Transform): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             if im_height != se_height or im_width != se_width: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 raise ValueError( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     "The height or width of the image is not same as the mask.") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if 'aux_masks' in sample: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             sample['aux_masks_ori'] = copy.deepcopy(sample['aux_masks']) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             sample['aux_masks'] = list( 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -295,6 +339,7 @@ class DecodeImg(Transform): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         sample['im_shape'] = np.array( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             sample['image'].shape[:2], dtype=np.float32) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         sample['scale_factor'] = np.array([1., 1.], dtype=np.float32) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return sample 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 |