Browse Source

Merge branch 'develop' into add_fft

Bobholamovic 2 years ago
parent
commit
1278ce7226

+ 13 - 14
docs/intro/indices.md

@@ -58,17 +58,16 @@
 
 ## 波段名称与描述
 
-|    波段名称    |     描述    |
-|---------------|-------------|
-|     `'b'`     | Blue        |
-|     `'g'`     | Green       |
-|     `'r'`     | Red         |
-|    `'re1'`    | Red Edge 1  |
-|    `'re2'`    | Red Edge 2  |
-|    `'re3'`    | Red Edge 3  |
-|     `'n'`     | NIR         |
-|    `'n2'`     | NIR 2       |
-|    `'s1'`     | SWIR 1      |
-|    `'s2'`     | SWIR 2      |
-|    `'t1'`     | Thermal 1   |
-|    `'t2'`     | Thermal 2   |
+|    波段名称    |     描述    | 参考波长范围 (μm) |  \*参考波长来源  |
+|---------------|-------------|-------------|-------------|
+|     `'b'`     | Blue        | *0.450 - 0.515* | *Landsat8* |
+|     `'g'`     | Green       | *0.525 - 0.600* | *Landsat8* |
+|     `'r'`     | Red         | *0.630 - 0.680* | *Landsat8* |
+|    `'re1'`    | Red Edge 1  | *0.698 - 0.713* | *Sentinel2* |
+|    `'re2'`    | Red Edge 2  | *0.733 - 0.748* | *Sentinel2* |
+|    `'re3'`    | Red Edge 3  | *0.773 - 0.793* | *Sentinel2* |
+|     `'n'`     | NIR         | *0.845 - 0.885* | *Landsat8* |
+|    `'s1'`     | SWIR 1      | *1.560 - 1.660* | *Landsat8* |
+|    `'s2'`     | SWIR 2      | *2.100 - 2.300* | *Landsat8* |
+|    `'t1'`     | Thermal 1   | *10.60 - 11.19* | *Landsat8* |
+|    `'t2'`     | Thermal 2   | *11.50 - 12.51* | *Landsat8* |

+ 45 - 20
paddlers/transforms/indices.py

@@ -27,27 +27,31 @@ __all__ = [
 ]
 
 EPS = 1e-32
-
-# | Band name | Description |
-# |-----------|-------------|
-# |     b     | Blue        |
-# |     g     | Green       |
-# |     r     | Red         |
-# |    re1    | Red Edge 1  |
-# |    re2    | Red Edge 2  |
-# |    re3    | Red Edge 3  |
-# |     n     | NIR         |
-# |    n2     | NIR 2       |
-# |    s1     | SWIR 1      |
-# |    s2     | SWIR 2      |
-# |    t1     | Thermal 1   |
-# |    t2     | Thermal 2   |
+BAND_NAMES = ["b", "g", "r", "re1", "re2", "re3", "n", "s1", "s2", "t1", "t2"]
+
+# | Band name | Description | Wavelength (μm) | Satellite |
+# |-----------|-------------|-----------------|-----------|
+# |     b     | Blue        |   0.450-0.515   | Landsat8  |
+# |     g     | Green       |   0.525-0.600   | Landsat8  |
+# |     r     | Red         |   0.630-0.680   | Landsat8  |
+# |    re1    | Red Edge 1  |   0.698-0.713   | Sentinel2 |
+# |    re2    | Red Edge 2  |   0.733-0.748   | Sentinel2 |
+# |    re3    | Red Edge 3  |   0.773-0.793   | Sentinel2 |
+# |     n     | NIR         |   0.845-0.885   | Landsat8  |
+# |    s1     | SWIR 1      |   1.560-1.660   | Landsat8  |
+# |    s2     | SWIR 2      |   2.100-2.300   | Landsat8  |
+# |    t1     | Thermal 1   |   10.60-11.19   | Landsat8  |
+# |    t2     | Thermal 2   |   11.50-12.51   | Landsat8  |
 
 
 class RSIndex(metaclass=abc.ABCMeta):
     def __init__(self, band_indices):
         super(RSIndex, self).__init__()
         self.band_indices = band_indices
+        self.required_band_names = iintersection(
+            self._compute.__code__.co_varnames[1:],  # strip self 
+            BAND_NAMES  # only save band names
+        )
 
     @abc.abstractmethod
     def _compute(self, *args, **kwargs):
@@ -55,19 +59,40 @@ class RSIndex(metaclass=abc.ABCMeta):
 
     def __call__(self, image):
         bands = self.select_bands(image)
+        now_band_names = tuple(bands.keys())
+        if not iequal(now_band_names, self.required_band_names):
+            raise LackBandError("Lack of bands: {}.".format(
+                isubtraction(self.required_band_names, now_band_names)))
         return self._compute(**bands)
 
     def select_bands(self, image, to_float32=True):
         bands = {}
         for name, idx in self.band_indices.items():
-            if idx == 0:
-                raise ValueError("Band index starts from 1.")
-            bands[name] = image[..., idx - 1]
-            if to_float32:
-                bands[name] = bands[name].astype('float32')
+            if name in self.required_band_names:
+                if idx == 0:
+                    raise ValueError("Band index starts from 1.")
+                bands[name] = image[..., idx - 1]
+                if to_float32:
+                    bands[name] = bands[name].astype('float32')
         return bands
 
 
+class LackBandError(Exception):
+    pass
+
+
+def iintersection(iter1, iter2):
+    return tuple(set(iter1) & set(iter2))
+
+
+def isubtraction(iter1, iter2):
+    return tuple(set(iter1) - set(iter2))
+
+
+def iequal(iter1, iter2):
+    return set(iter1) == set(iter2)
+
+
 def compute_normalized_difference_index(band1, band2):
     return (band1 - band2) / (band1 + band2 + EPS)
 

+ 19 - 4
paddlers/transforms/operators.py

@@ -29,6 +29,7 @@ from joblib import load
 import paddlers
 import paddlers.transforms.functions as F
 import paddlers.transforms.indices as indices
+import paddlers.transforms.satellites as satellites
 
 __all__ = [
     "Compose",
@@ -1953,15 +1954,29 @@ class AppendIndex(Transform):
         index_type (str): Type of remote sensinng index. See supported 
             index types in 
             https://github.com/PaddlePaddle/PaddleRS/tree/develop/paddlers/transforms/indices.py .
-        band_indices (dict): Mapping of band names to band indices 
+        band_indices (dict, optional): Mapping of band names to band indices 
             (starting from 1). See band names in 
-            https://github.com/PaddlePaddle/PaddleRS/tree/develop/paddlers/transforms/indices.py . 
+            https://github.com/PaddlePaddle/PaddleRS/tree/develop/paddlers/transforms/indices.py .
+            Default: None.
+        satellite (str, optional): Type of satellite. If set, 
+            band indices will be automatically determined accordingly. See supported satellites in 
+            https://github.com/PaddlePaddle/PaddleRS/tree/develop/paddlers/transforms/satellites.py .
+            Default: None.
     """
 
-    def __init__(self, index_type, band_indices, **kwargs):
+    def __init__(self, index_type, band_indices=None, satellite=None, **kwargs):
         super(AppendIndex, self).__init__()
         cls = getattr(indices, index_type)
-        self._compute_index = cls(band_indices, **kwargs)
+        if satellite is not None:
+            satellite_bands = getattr(satellites, satellite)
+            self._compute_index = cls(satellite_bands, **kwargs)
+        else:
+            if band_indices is None:
+                raise ValueError(
+                    "At least one of `band_indices` and `satellite` must not be None."
+                )
+            else:
+                self._compute_index = cls(band_indices, **kwargs)
 
     def apply_im(self, image):
         index = self._compute_index(image)

+ 157 - 0
paddlers/transforms/satellites.py

@@ -0,0 +1,157 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+__all__ = [
+    "Sentinel_2", "Landsat_457", "Landsat_89", "MODIS", "SPOT_15", "SPOT_67",
+    "Quickbird", "WorldView_23", "WorldView_4", "IKONOS", "GF_1_WFV",
+    "GF_6_WFV", "GF_16_PMS", "GF_24", "ZY_3", "CBERS_4", "SJ_9A"
+]
+
+# The rules for names
+# eg. GF_1_WFV: [satellite name]_[model]_[sensor name]
+
+Sentinel_2 = {
+    "b": 2,
+    "g": 3,
+    "r": 4,
+    "re1": 5,
+    "re2": 6,
+    "re3": 7,
+    "n": 8,
+    "s1": 12,  # 11 + 1 (due to 8A)
+    "s2": 13,  # 12 + 1 (due to 8A)
+}
+
+Landsat_457 = {
+    "b": 1,
+    "g": 2,
+    "r": 3,
+    "n": 4,
+    "s1": 5,
+    "s2": 7,
+    "t1": 6,
+}
+
+Landsat_89 = {
+    "b": 2,
+    "g": 3,
+    "r": 4,
+    "n": 5,
+    "s1": 6,
+    "s2": 7,
+    "t1": 10,
+    "t2": 11,
+}
+
+MODIS = {
+    "b": 3,
+    "g": 4,
+    "r": 1,
+    "n": 2,
+    "s1": 6,
+    "s2": 7,
+}
+
+SPOT_15 = {
+    "g": 1,
+    "r": 2,
+    "n": 3,
+}
+
+SPOT_67 = {
+    "b": 1,
+    "g": 2,
+    "r": 3,
+    "n": 4,
+}
+
+Quickbird = {
+    "b": 2,
+    "g": 3,
+    "r": 4,
+    "n": 5,
+}
+
+WorldView_23 = {
+    "b": 2,
+    "g": 3,
+    "r": 4,
+    "n": 5,
+    "re1": 7,
+}
+
+WorldView_4 = {
+    "b": 2,
+    "g": 3,
+    "r": 4,
+    "n": 5,
+}
+
+IKONOS = {
+    "b": 1,
+    "g": 2,
+    "r": 3,
+    "n": 4,
+}
+
+GF_1_WFV = {
+    "b": 1,
+    "g": 2,
+    "r": 3,
+    "n": 4,
+}
+
+GF_6_WFV = {
+    "b": 1,
+    "g": 2,
+    "r": 3,
+    "n": 4,
+    "re1": 5,
+    "re2": 6,
+}
+
+GF_16_PMS = {
+    "b": 2,
+    "g": 3,
+    "r": 4,
+    "n": 5,
+}
+
+GF_24 = {
+    "b": 2,
+    "g": 3,
+    "r": 4,
+    "n": 5,
+}
+
+ZY_3 = {
+    "b": 1,
+    "g": 2,
+    "r": 3,
+    "n": 4,
+}
+
+CBERS_4 = {
+    "b": 1,
+    "g": 2,
+    "r": 3,
+    "n": 4,
+}
+
+SJ_9A = {
+    "b": 2,
+    "g": 3,
+    "r": 4,
+    "n": 5,
+}

+ 0 - 1
tests/transforms/test_indices.py

@@ -28,7 +28,6 @@ NAME_MAPPING = {
     're2': 'RE2',
     're3': 'RE3',
     'n': 'N',
-    'n2': 'N2',
     's1': 'S1',
     's2': 'S2',
     't1': 'T1',

+ 10 - 0
tests/transforms/test_operators.py

@@ -382,6 +382,16 @@ class TestTransform(CpuCommonTest):
             c3=1.5,
             _filter=_filter_only_multispectral)
         test_evi(self)
+        test_evi_from_satellite = make_test_func(
+            T.AppendIndex,
+            'EVI',
+            satellite='Landsat_89',
+            c0=1.0,
+            c1=0.5,
+            c2=1.0,
+            c3=1.5,
+            _filter=_filter_only_multispectral)
+        test_evi_from_satellite(self)
 
     def test_MatchRadiance(self):
         test_hist = make_test_func(