|
@@ -13,6 +13,7 @@
|
|
# limitations under the License.
|
|
# limitations under the License.
|
|
|
|
|
|
import math
|
|
import math
|
|
|
|
+import os
|
|
import os.path as osp
|
|
import os.path as osp
|
|
from collections import OrderedDict
|
|
from collections import OrderedDict
|
|
from operator import attrgetter
|
|
from operator import attrgetter
|
|
@@ -545,6 +546,88 @@ class BaseChangeDetector(BaseModel):
|
|
}
|
|
}
|
|
return prediction
|
|
return prediction
|
|
|
|
|
|
|
|
+ def slider_predict(self, img_file, save_dir, block_size, overlap=36, transforms=None):
|
|
|
|
+ """
|
|
|
|
+ Do inference.
|
|
|
|
+ Args:
|
|
|
|
+ Args:
|
|
|
|
+ img_file(List[str]):
|
|
|
|
+ List of image paths.
|
|
|
|
+ save_dir(str):
|
|
|
|
+ Directory that contains saved geotiff file.
|
|
|
|
+ block_size(List[int] or Tuple[int], int):
|
|
|
|
+ The size of block.
|
|
|
|
+ overlap(List[int] or Tuple[int], int):
|
|
|
|
+ The overlap between two blocks. Defaults to 36.
|
|
|
|
+ transforms(paddlers.transforms.Compose or None, optional):
|
|
|
|
+ Transforms for inputs. If None, the transforms for evaluation process will be used. Defaults to None.
|
|
|
|
+ """
|
|
|
|
+ try:
|
|
|
|
+ from osgeo import gdal
|
|
|
|
+ except:
|
|
|
|
+ import gdal
|
|
|
|
+
|
|
|
|
+ if len(img_file) != 2:
|
|
|
|
+ raise ValueError("`img_file` must be a list of length 2.")
|
|
|
|
+ if isinstance(block_size, int):
|
|
|
|
+ block_size = (block_size, block_size)
|
|
|
|
+ elif isinstance(block_size, (tuple, list)) and len(block_size) == 2:
|
|
|
|
+ block_size = tuple(block_size)
|
|
|
|
+ else:
|
|
|
|
+ raise ValueError("`block_size` must be a tuple/list of length 2 or an integer.")
|
|
|
|
+ if isinstance(overlap, int):
|
|
|
|
+ overlap = (overlap, overlap)
|
|
|
|
+ elif isinstance(overlap, (tuple, list)) and len(overlap) == 2:
|
|
|
|
+ overlap = tuple(overlap)
|
|
|
|
+ else:
|
|
|
|
+ raise ValueError("`overlap` must be a tuple/list of length 2 or an integer.")
|
|
|
|
+
|
|
|
|
+ src1_data = gdal.Open(img_file[0])
|
|
|
|
+ src2_data = gdal.Open(img_file[1])
|
|
|
|
+ width = src1_data.RasterXSize
|
|
|
|
+ height = src1_data.RasterYSize
|
|
|
|
+ bands = src1_data.RasterCount
|
|
|
|
+
|
|
|
|
+ driver = gdal.GetDriverByName("GTiff")
|
|
|
|
+ file_name = osp.splitext(osp.normpath(img_file[0]).split(os.sep)[-1])[0] + ".tif"
|
|
|
|
+ if not osp.exists(save_dir):
|
|
|
|
+ os.makedirs(save_dir)
|
|
|
|
+ save_file = osp.join(save_dir, file_name)
|
|
|
|
+ dst_data = driver.Create(save_file, width, height, 1, gdal.GDT_Byte)
|
|
|
|
+ dst_data.SetGeoTransform(src1_data.GetGeoTransform())
|
|
|
|
+ dst_data.SetProjection(src1_data.GetProjection())
|
|
|
|
+ band = dst_data.GetRasterBand(1)
|
|
|
|
+ band.WriteArray(255 * np.ones((height, width), dtype="uint8"))
|
|
|
|
+
|
|
|
|
+ step = np.array(block_size) - np.array(overlap)
|
|
|
|
+ for yoff in range(0, height, step[1]):
|
|
|
|
+ for xoff in range(0, width, step[0]):
|
|
|
|
+ xsize, ysize = block_size
|
|
|
|
+ if xoff + xsize > width:
|
|
|
|
+ xsize = int(width - xoff)
|
|
|
|
+ if yoff + ysize > height:
|
|
|
|
+ ysize = int(height - yoff)
|
|
|
|
+ im1 = src1_data.ReadAsArray(int(xoff), int(yoff), xsize, ysize).transpose((1, 2, 0))
|
|
|
|
+ im2 = src2_data.ReadAsArray(int(xoff), int(yoff), xsize, ysize).transpose((1, 2, 0))
|
|
|
|
+ # fill
|
|
|
|
+ h, w = im1.shape[:2]
|
|
|
|
+ im1_fill = np.zeros((block_size[1], block_size[0], bands), dtype=im1.dtype)
|
|
|
|
+ im2_fill = im1_fill.copy()
|
|
|
|
+ im1_fill[:h, :w, :] = im1
|
|
|
|
+ im2_fill[:h, :w, :] = im2
|
|
|
|
+ im_fill = (im1_fill, im2_fill)
|
|
|
|
+ # predict
|
|
|
|
+ pred = self.predict(im_fill, transforms)["label_map"].astype("uint8")
|
|
|
|
+ # overlap
|
|
|
|
+ rd_block = band.ReadAsArray(int(xoff), int(yoff), xsize, ysize)
|
|
|
|
+ mask = (rd_block == pred[:h, :w]) | (rd_block == 255)
|
|
|
|
+ temp = pred[:h, :w].copy()
|
|
|
|
+ temp[mask == False] = 0
|
|
|
|
+ band.WriteArray(temp, int(xoff), int(yoff))
|
|
|
|
+ dst_data.FlushCache()
|
|
|
|
+ dst_data = None
|
|
|
|
+ print("GeoTiff saved in {}.".format(save_file))
|
|
|
|
+
|
|
def _preprocess(self, images, transforms, to_tensor=True):
|
|
def _preprocess(self, images, transforms, to_tensor=True):
|
|
arrange_transforms(
|
|
arrange_transforms(
|
|
model_type=self.model_type, transforms=transforms, mode='test')
|
|
model_type=self.model_type, transforms=transforms, mode='test')
|