Bobholamovic 2 vuotta sitten
vanhempi
commit
0a57bc9b1e

+ 2 - 0
docs/apis/transforms.md

@@ -36,10 +36,12 @@ from paddlers.datasets import CDDataset
 
 
 train_transforms = T.Compose([
+    T.DecodeImg(),
     T.Resize(target_size=512),
     T.RandomHorizontalFlip(),
     T.Normalize(
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+    T.ArrangeChangeDetector('train')
 ])
 
 train_dataset = CDDataset(

+ 3 - 3
paddlers/deploy/predictor.py

@@ -258,9 +258,9 @@ class Predictor(object):
             Args:
                 img_file(list[str | tuple | np.ndarray] | str | tuple | np.ndarray): For scene classification, image restoration, 
                     object detection and semantic segmentation tasks, `img_file` should be either the path of the image to predict
-                    , a decoded image (a `np.ndarray`, which should be consistent with what you get from passing image path to
-                    `paddlers.transforms.decode_image()`), or a list of image paths or decoded images. For change detection tasks,
-                    `img_file` should be a tuple of image paths, a tuple of decoded images, or a list of tuples.
+                    , a decoded image (a np.ndarray, which should be consistent with what you get from passing image path to
+                    paddlers.transforms.decode_image()), or a list of image paths or decoded images. For change detection tasks,
+                    img_file should be a tuple of image paths, a tuple of decoded images, or a list of tuples.
                 topk(int, optional): Top-k values to reserve in a classification result. Defaults to 1.
                 transforms (paddlers.transforms.Compose | None, optional): Pipeline of data preprocessing. If None, load transforms
                     from `model.yml`. Defaults to None.

+ 1 - 1
tests/data/data_utils.py

@@ -325,7 +325,7 @@ class ConstrDetSample(ConstrSample):
 
 def build_input_from_file(file_list, prefix='', task='auto', label_list=None):
     """
-    Construct a list of dictionaries from file. Each dict in the list can be used as the input to `paddlers.transforms.Transform` objects.
+    Construct a list of dictionaries from file. Each dict in the list can be used as the input to paddlers.transforms.Transform objects.
 
     Args:
         file_list (str): Path of file_list.

+ 1 - 1
tests/download_test_data.sh

@@ -4,7 +4,7 @@ function remove_dir_if_exist() {
     local dir="$1"
     if [ -d "${dir}" ]; then
         rm -rf "${dir}"
-        echo "\033[0;31mDirectory ${dir} has been removed.\033[0m"
+        echo -e "\033[0;31mDirectory ${dir} has been removed.\033[0m"
     fi
 }
 

+ 2 - 2
tests/test_tutorials.py

@@ -29,7 +29,7 @@ class TestTutorial(CpuCommonTest):
     @classmethod
     def setUpClass(cls):
         cls._td = tempfile.TemporaryDirectory(dir='./')
-        # Recursively copy the content of `cls.SUBDIR` to td.
+        # Recursively copy the content of cls.SUBDIR to td.
         # This is necessary for running scripts in td.
         cls._TSUBDIR = osp.join(cls._td.name, osp.basename(cls.SUBDIR))
         shutil.copytree(cls.SUBDIR, cls._TSUBDIR)
@@ -47,7 +47,7 @@ class TestTutorial(CpuCommonTest):
 
         def _test_tutorial(script_name):
             def _test_tutorial_impl(self):
-                # Set working directory to `cls._TSUBDIR` such that the 
+                # Set working directory to cls._TSUBDIR such that the 
                 # files generated by the script will be automatically cleaned.
                 run_script(f"python {script_name}", wd=cls._TSUBDIR)
 

+ 1 - 1
tools/raster2geotiff.py

@@ -46,7 +46,7 @@ def convert_data(image_path, geojson_path):
             geo_points = geo["coordinates"][0][0]
         else:
             raise TypeError(
-                "Geometry type must be `Polygon` or `MultiPolygon`, not {}.".
+                "Geometry type must be 'Polygon' or 'MultiPolygon', not {}.".
                 format(geo["type"]))
         xy_points = np.array([
             _gt_convert(point[0], point[1], raster.geot) for point in geo_points

+ 2 - 2
tools/raster2vector.py

@@ -76,7 +76,7 @@ def raster2vector(srcimg_path, mask_path, save_path, ignore_index=255):
     vec_ext = save_path.split(".")[-1].lower()
     if vec_ext not in ["json", "geojson", "shp"]:
         raise ValueError(
-            "The ext of `save_path` must be `json/geojson` or `shp`, not {}.".
+            "The extension of `save_path` must be 'json/geojson' or 'shp', not {}.".
             format(vec_ext))
     ras_ext = srcimg_path.split(".")[-1].lower()
     if osp.exists(srcimg_path) and ras_ext in ["tif", "tiff", "geotiff", "img"]:
@@ -93,7 +93,7 @@ parser = argparse.ArgumentParser()
 parser.add_argument("--mask_path", type=str, required=True, \
                     help="Path of mask data.")
 parser.add_argument("--save_path", type=str, required=True, \
-                    help="Path to save the shape file (the file suffix is `.json/geojson` or `.shp`).")
+                    help="Path to save the shape file (the extension is .json/geojson or .shp).")
 parser.add_argument("--srcimg_path", type=str, default="", \
                     help="Path of original data with geoinfo. Default to empty.")
 parser.add_argument("--ignore_index", type=int, default=255, \

+ 1 - 1
tools/split.py

@@ -75,7 +75,7 @@ parser.add_argument("--mask_path", type=str, default=None, \
 parser.add_argument("--block_size", type=int, default=512, \
                     help="Size of image block. Default value is 512.")
 parser.add_argument("--save_dir", type=str, default="dataset", \
-                    help="Directory to save the results. Default value is `dataset`.")
+                    help="Directory to save the results. Default value is 'dataset'.")
 
 if __name__ == "__main__":
     args = parser.parse_args()

+ 2 - 2
tools/utils/raster.py

@@ -42,7 +42,7 @@ def _get_type(type_name: str) -> int:
     elif type_name == "complex64":
         gdal_type = gdal.GDT_CFloat64
     else:
-        raise TypeError("Non-suported data type `{}`.".format(type_name))
+        raise TypeError("Non-suported data type {}.".format(type_name))
     return gdal_type
 
 
@@ -76,7 +76,7 @@ class Raster:
                         # https://www.osgeo.cn/gdal/drivers/raster/index.html
                         self._src_data = gdal.Open(path)
                     except:
-                        raise TypeError("Unsupported data format: `{}`".format(
+                        raise TypeError("Unsupported data format: {}".format(
                             self.ext_type))
             else:
                 raise ValueError("The path {0} not exists.".format(path))