test_indices.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import inspect
  15. import numpy as np
  16. import paddlers.transforms as T
  17. from testing_utils import CpuCommonTest
  18. __all__ = ['TestIndex']
  19. NAME_MAPPING = {
  20. 'b': 'B',
  21. 'g': 'G',
  22. 'r': 'R',
  23. 're1': 'RE1',
  24. 're2': 'RE2',
  25. 're3': 'RE3',
  26. 'n': 'N',
  27. 's1': 'S1',
  28. 's2': 'S2',
  29. 't1': 'T1',
  30. 't2': 'T2'
  31. }
  32. def add_index_tests(cls):
  33. """
  34. Automatically patch testing functions for remote sensing indices.
  35. """
  36. def _make_test_func(index_name, index_class):
  37. def __test_func(self):
  38. bands = {}
  39. cnt = 0
  40. for key in inspect.signature(index_class._compute).parameters:
  41. if key == 'self':
  42. continue
  43. elif key.startswith('c'):
  44. # key 'c*' stands for a constant
  45. raise RuntimeError(
  46. f"Cannot automatically process key '{key}'!")
  47. else:
  48. cnt += 1
  49. bands[key] = cnt
  50. dummy = constr_dummy_image(cnt)
  51. index1 = index_class(bands)(dummy)
  52. params = constr_spyndex_params(dummy, bands)
  53. index2 = compute_spyndex_index(index_name, params)
  54. self.check_output(index1, index2)
  55. return __test_func
  56. for index_name in T.indices.__all__:
  57. index_class = getattr(T.indices, index_name)
  58. attr_name = 'test_' + index_name
  59. if hasattr(cls, attr_name):
  60. continue
  61. setattr(cls, attr_name, _make_test_func(index_name, index_class))
  62. return cls
  63. def constr_spyndex_params(image, bands, consts=None):
  64. params = {}
  65. for k, v in bands.items():
  66. k = NAME_MAPPING[k]
  67. v = image[..., v - 1]
  68. params[k] = v
  69. if consts is not None:
  70. params.update(consts)
  71. return params
  72. def compute_spyndex_index(name, params):
  73. import spyndex
  74. index = spyndex.computeIndex(index=[name], params=params)
  75. return index
  76. def constr_dummy_image(c):
  77. return np.random.uniform(0, 65536, size=(256, 256, c))
  78. @add_index_tests
  79. class TestIndex(CpuCommonTest):
  80. def check_output(self, result, expected_result):
  81. mask = np.isfinite(expected_result)
  82. diff = np.abs(result[mask] - expected_result[mask])
  83. cnt = (diff > (1.e-2 * diff + 0.1)).sum()
  84. self.assertLess(cnt / diff.size, 0.005)
  85. def test_ARVI(self):
  86. dummy = constr_dummy_image(3)
  87. bands = {'b': 1, 'r': 2, 'n': 3}
  88. gamma = 0.1
  89. arvi = T.indices.ARVI(bands, gamma)
  90. index1 = arvi(dummy)
  91. index2 = compute_spyndex_index(
  92. 'ARVI', constr_spyndex_params(dummy, bands, {'gamma': gamma}))
  93. self.check_output(index1, index2)
  94. def test_BWDRVI(self):
  95. dummy = constr_dummy_image(2)
  96. bands = {'b': 1, 'n': 2}
  97. alpha = 0.1
  98. bwdrvi = T.indices.BWDRVI(bands, alpha)
  99. index1 = bwdrvi(dummy)
  100. index2 = compute_spyndex_index(
  101. 'BWDRVI', constr_spyndex_params(dummy, bands, {'alpha': alpha}))
  102. self.check_output(index1, index2)
  103. def test_EVI(self):
  104. dummy = constr_dummy_image(3)
  105. bands = {'b': 1, 'r': 2, 'n': 3}
  106. g = 2.5
  107. C1 = 6.0
  108. C2 = 7.5
  109. L = 1.0
  110. evi = T.indices.EVI(bands, g, C1, C2, L)
  111. index1 = evi(dummy)
  112. index2 = compute_spyndex_index(
  113. 'EVI',
  114. constr_spyndex_params(dummy, bands,
  115. {'g': g,
  116. 'C1': C1,
  117. 'C2': C2,
  118. 'L': L}))
  119. self.check_output(index1, index2)
  120. def test_EVI2(self):
  121. dummy = constr_dummy_image(2)
  122. bands = {'r': 1, 'n': 2}
  123. g = 2.5
  124. L = 1.0
  125. evi2 = T.indices.EVI2(bands, g, L)
  126. index1 = evi2(dummy)
  127. index2 = compute_spyndex_index('EVI2',
  128. constr_spyndex_params(dummy, bands,
  129. {'g': g,
  130. 'L': L}))
  131. self.check_output(index1, index2)
  132. def test_MNLI(self):
  133. dummy = constr_dummy_image(2)
  134. bands = {'r': 1, 'n': 2}
  135. L = 1.0
  136. mnli = T.indices.MNLI(bands, L)
  137. index1 = mnli(dummy)
  138. index2 = compute_spyndex_index(
  139. 'MNLI', constr_spyndex_params(dummy, bands, {'L': L}))
  140. self.check_output(index1, index2)
  141. def test_SAVI(self):
  142. dummy = constr_dummy_image(2)
  143. bands = {'r': 1, 'n': 2}
  144. L = 1.0
  145. savi = T.indices.SAVI(bands, L)
  146. index1 = savi(dummy)
  147. index2 = compute_spyndex_index(
  148. 'SAVI', constr_spyndex_params(dummy, bands, {'L': L}))
  149. self.check_output(index1, index2)