|
@@ -180,22 +180,28 @@ class DecodeImg(Transform):
|
|
|
decode_sar (bool, optional): If True, automatically interpret a two-channel
|
|
|
geo image (e.g. geotiff images) as a SAR image, set this argument to
|
|
|
True. Defaults to True.
|
|
|
+ read_geo_info (bool, optional): If True, read geographical information from
|
|
|
+ the image. Deafults to False.
|
|
|
"""
|
|
|
|
|
|
def __init__(self,
|
|
|
to_rgb=True,
|
|
|
to_uint8=True,
|
|
|
decode_bgr=True,
|
|
|
- decode_sar=True):
|
|
|
+ decode_sar=True,
|
|
|
+ read_geo_info=False):
|
|
|
super(DecodeImg, self).__init__()
|
|
|
self.to_rgb = to_rgb
|
|
|
self.to_uint8 = to_uint8
|
|
|
self.decode_bgr = decode_bgr
|
|
|
self.decode_sar = decode_sar
|
|
|
+ self.read_geo_info = False
|
|
|
|
|
|
def read_img(self, img_path):
|
|
|
img_format = imghdr.what(img_path)
|
|
|
name, ext = os.path.splitext(img_path)
|
|
|
+ geo_trans, geo_proj = None, None
|
|
|
+
|
|
|
if img_format == 'tiff' or ext == '.img':
|
|
|
try:
|
|
|
import gdal
|
|
@@ -209,7 +215,7 @@ class DecodeImg(Transform):
|
|
|
|
|
|
dataset = gdal.Open(img_path)
|
|
|
if dataset == None:
|
|
|
- raise IOError('Can not open', img_path)
|
|
|
+ raise IOError('Cannot open', img_path)
|
|
|
im_data = dataset.ReadAsArray()
|
|
|
if im_data.ndim == 2 and self.decode_sar:
|
|
|
im_data = to_intensity(im_data) # is read SAR
|
|
@@ -217,26 +223,38 @@ class DecodeImg(Transform):
|
|
|
else:
|
|
|
if im_data.ndim == 3:
|
|
|
im_data = im_data.transpose((1, 2, 0))
|
|
|
- return im_data
|
|
|
+ if self.read_geo_info:
|
|
|
+ geo_trans = dataset.GetGeoTransform()
|
|
|
+ geo_proj = dataset.GetGeoProjection()
|
|
|
elif img_format in ['jpeg', 'bmp', 'png', 'jpg']:
|
|
|
if self.decode_bgr:
|
|
|
- return cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
|
|
|
- cv2.IMREAD_ANYCOLOR | cv2.IMREAD_COLOR)
|
|
|
+ im_data = cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
|
|
|
+ cv2.IMREAD_ANYCOLOR | cv2.IMREAD_COLOR)
|
|
|
else:
|
|
|
- return cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
|
|
|
- cv2.IMREAD_ANYCOLOR)
|
|
|
+ im_data = cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
|
|
|
+ cv2.IMREAD_ANYCOLOR)
|
|
|
elif ext == '.npy':
|
|
|
- return np.load(img_path)
|
|
|
+ im_data = np.load(img_path)
|
|
|
else:
|
|
|
raise TypeError("Image format {} is not supported!".format(ext))
|
|
|
|
|
|
+ if self.read_geo_info:
|
|
|
+ return im_data, geo_trans, geo_proj
|
|
|
+ else:
|
|
|
+ return im_data
|
|
|
+
|
|
|
def apply_im(self, im_path):
|
|
|
if isinstance(im_path, str):
|
|
|
try:
|
|
|
- image = self.read_img(im_path)
|
|
|
+ data = self.read_img(im_path)
|
|
|
except:
|
|
|
raise ValueError("Cannot read the image file {}!".format(
|
|
|
im_path))
|
|
|
+ if self.read_geo_info:
|
|
|
+ image, geo_trans, geo_proj = data
|
|
|
+ geo_info_dict = {'geo_trans': geo_trans, 'geo_proj': geo_proj}
|
|
|
+ else:
|
|
|
+ image = data
|
|
|
else:
|
|
|
image = im_path
|
|
|
|
|
@@ -246,7 +264,10 @@ class DecodeImg(Transform):
|
|
|
if self.to_uint8:
|
|
|
image = to_uint8(image)
|
|
|
|
|
|
- return image
|
|
|
+ if self.read_geo_info:
|
|
|
+ return image, geo_info_dict
|
|
|
+ else:
|
|
|
+ return image
|
|
|
|
|
|
def apply_mask(self, mask):
|
|
|
try:
|
|
@@ -269,15 +290,37 @@ class DecodeImg(Transform):
|
|
|
"""
|
|
|
|
|
|
if 'image' in sample:
|
|
|
- sample['image_ori'] = copy.deepcopy(sample['image'])
|
|
|
- sample['image'] = self.apply_im(sample['image'])
|
|
|
+ if self.read_geo_info:
|
|
|
+ image, geo_info_dict = self.apply_im(sample['image'])
|
|
|
+ sample['image'] = image
|
|
|
+ sample['geo_info_dict'] = geo_info_dict
|
|
|
+ else:
|
|
|
+ sample['image'] = self.apply_im(sample['image'])
|
|
|
+
|
|
|
if 'image2' in sample:
|
|
|
- sample['image2'] = self.apply_im(sample['image2'])
|
|
|
+ if self.read_geo_info:
|
|
|
+ image2, geo_info_dict2 = self.apply_im(sample['image2'])
|
|
|
+ sample['image2'] = image2
|
|
|
+ sample['geo_info_dict2'] = geo_info_dict2
|
|
|
+ else:
|
|
|
+ sample['image2'] = self.apply_im(sample['image2'])
|
|
|
+
|
|
|
if 'image_t1' in sample and not 'image' in sample:
|
|
|
if not ('image_t2' in sample and 'image2' not in sample):
|
|
|
raise ValueError
|
|
|
- sample['image'] = self.apply_im(sample['image_t1'])
|
|
|
- sample['image2'] = self.apply_im(sample['image_t2'])
|
|
|
+ if self.read_geo_info:
|
|
|
+ image, geo_info_dict = self.apply_im(sample['image_t1'])
|
|
|
+ sample['image'] = image
|
|
|
+ sample['geo_info_dict'] = geo_info_dict
|
|
|
+ else:
|
|
|
+ sample['image'] = self.apply_im(sample['image_t1'])
|
|
|
+ if self.read_geo_info:
|
|
|
+ image2, geo_info_dict2 = self.apply_im(sample['image_t2'])
|
|
|
+ sample['image2'] = image2
|
|
|
+ sample['geo_info_dict2'] = geo_info_dict2
|
|
|
+ else:
|
|
|
+ sample['image2'] = self.apply_im(sample['image_t2'])
|
|
|
+
|
|
|
if 'mask' in sample:
|
|
|
sample['mask_ori'] = copy.deepcopy(sample['mask'])
|
|
|
sample['mask'] = self.apply_mask(sample['mask'])
|
|
@@ -286,6 +329,7 @@ class DecodeImg(Transform):
|
|
|
if im_height != se_height or im_width != se_width:
|
|
|
raise ValueError(
|
|
|
"The height or width of the image is not same as the mask.")
|
|
|
+
|
|
|
if 'aux_masks' in sample:
|
|
|
sample['aux_masks_ori'] = copy.deepcopy(sample['aux_masks'])
|
|
|
sample['aux_masks'] = list(
|
|
@@ -295,6 +339,7 @@ class DecodeImg(Transform):
|
|
|
sample['im_shape'] = np.array(
|
|
|
sample['image'].shape[:2], dtype=np.float32)
|
|
|
sample['scale_factor'] = np.array([1., 1.], dtype=np.float32)
|
|
|
+
|
|
|
return sample
|
|
|
|
|
|
|