test_indices.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import inspect
  15. import numpy as np
  16. import paddlers.transforms as T
  17. from testing_utils import CpuCommonTest
  18. __all__ = ['TestIndex']
  19. NAME_MAPPING = {
  20. 'b': 'B',
  21. 'g': 'G',
  22. 'r': 'R',
  23. 're1': 'RE1',
  24. 're2': 'RE2',
  25. 're3': 'RE3',
  26. 'n': 'N',
  27. 's1': 'S1',
  28. 's2': 'S2',
  29. 't': 'T',
  30. 't1': 'T1',
  31. 't2': 'T2'
  32. }
  33. def add_index_tests(cls):
  34. """
  35. Automatically patch testing functions for remote sensing indices.
  36. """
  37. def _make_test_func(index_name, index_class):
  38. def __test_func(self):
  39. bands = {}
  40. cnt = 0
  41. for key in inspect.signature(index_class._compute).parameters:
  42. if key == 'self':
  43. continue
  44. elif key.startswith('c'):
  45. # key 'c*' stands for a constant
  46. raise RuntimeError(
  47. f"Cannot automatically process key '{key}'!")
  48. else:
  49. cnt += 1
  50. bands[key] = cnt
  51. dummy = constr_dummy_image(cnt)
  52. index1 = index_class(bands)(dummy)
  53. params = constr_spyndex_params(dummy, bands)
  54. index2 = compute_spyndex_index(index_name, params)
  55. self.check_output(index1, index2)
  56. return __test_func
  57. for index_name in T.indices.__all__:
  58. index_class = getattr(T.indices, index_name)
  59. attr_name = 'test_' + index_name
  60. if hasattr(cls, attr_name):
  61. continue
  62. setattr(cls, attr_name, _make_test_func(index_name, index_class))
  63. return cls
  64. def constr_spyndex_params(image, bands, consts=None):
  65. params = {}
  66. for k, v in bands.items():
  67. k = NAME_MAPPING[k]
  68. v = image[..., v - 1]
  69. params[k] = v
  70. if consts is not None:
  71. params.update(consts)
  72. return params
  73. def compute_spyndex_index(name, params):
  74. import spyndex
  75. index = spyndex.computeIndex(index=[name], params=params)
  76. return index
  77. def constr_dummy_image(c):
  78. return np.random.uniform(0, 65536, size=(256, 256, c))
  79. @add_index_tests
  80. class TestIndex(CpuCommonTest):
  81. def check_output(self, result, expected_result):
  82. mask = np.isfinite(expected_result)
  83. diff = np.abs(result[mask] - expected_result[mask])
  84. cnt = (diff > (1.e-2 * diff + 0.1)).sum()
  85. self.assertLess(cnt / diff.size, 0.005)
  86. def test_ARVI(self):
  87. dummy = constr_dummy_image(3)
  88. bands = {'b': 1, 'r': 2, 'n': 3}
  89. gamma = 0.1
  90. arvi = T.indices.ARVI(bands, gamma)
  91. index1 = arvi(dummy)
  92. index2 = compute_spyndex_index(
  93. 'ARVI', constr_spyndex_params(dummy, bands, {'gamma': gamma}))
  94. self.check_output(index1, index2)
  95. def test_BWDRVI(self):
  96. dummy = constr_dummy_image(2)
  97. bands = {'b': 1, 'n': 2}
  98. alpha = 0.1
  99. bwdrvi = T.indices.BWDRVI(bands, alpha)
  100. index1 = bwdrvi(dummy)
  101. index2 = compute_spyndex_index(
  102. 'BWDRVI', constr_spyndex_params(dummy, bands, {'alpha': alpha}))
  103. self.check_output(index1, index2)
  104. def test_EVI(self):
  105. dummy = constr_dummy_image(3)
  106. bands = {'b': 1, 'r': 2, 'n': 3}
  107. g = 2.5
  108. C1 = 6.0
  109. C2 = 7.5
  110. L = 1.0
  111. evi = T.indices.EVI(bands, g, C1, C2, L)
  112. index1 = evi(dummy)
  113. index2 = compute_spyndex_index(
  114. 'EVI',
  115. constr_spyndex_params(dummy, bands,
  116. {'g': g,
  117. 'C1': C1,
  118. 'C2': C2,
  119. 'L': L}))
  120. self.check_output(index1, index2)
  121. def test_EVI2(self):
  122. dummy = constr_dummy_image(2)
  123. bands = {'r': 1, 'n': 2}
  124. g = 2.5
  125. L = 1.0
  126. evi2 = T.indices.EVI2(bands, g, L)
  127. index1 = evi2(dummy)
  128. index2 = compute_spyndex_index('EVI2',
  129. constr_spyndex_params(dummy, bands,
  130. {'g': g,
  131. 'L': L}))
  132. self.check_output(index1, index2)
  133. def test_MNLI(self):
  134. dummy = constr_dummy_image(2)
  135. bands = {'r': 1, 'n': 2}
  136. L = 1.0
  137. mnli = T.indices.MNLI(bands, L)
  138. index1 = mnli(dummy)
  139. index2 = compute_spyndex_index(
  140. 'MNLI', constr_spyndex_params(dummy, bands, {'L': L}))
  141. self.check_output(index1, index2)
  142. def test_SAVI(self):
  143. dummy = constr_dummy_image(2)
  144. bands = {'r': 1, 'n': 2}
  145. L = 1.0
  146. savi = T.indices.SAVI(bands, L)
  147. index1 = savi(dummy)
  148. index2 = compute_spyndex_index(
  149. 'SAVI', constr_spyndex_params(dummy, bands, {'L': L}))
  150. self.check_output(index1, index2)