visualize.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import warnings
  15. warnings.filterwarnings("ignore", category=DeprecationWarning)
  16. import os.path as osp
  17. from pathlib import Path
  18. from typing import List, Tuple, Union, Optional
  19. import webbrowser
  20. import numpy as np
  21. from folium import folium, Map, LayerControl
  22. from folium.raster_layers import TileLayer, ImageOverlay
  23. from paddlers.transforms.functions import to_uint8
  24. try:
  25. from osgeo import gdal, osr
  26. except:
  27. import gdal
  28. import osr
  29. CHINA_TILES = (
  30. "GeoQ China Community",
  31. "GeoQ China Street",
  32. "AMAP China",
  33. "TencentMap China",
  34. "BaiduMaps China", )
  35. def map_display(mask_path: str,
  36. img_path: Optional[str]=None,
  37. band_list: Union[List[int], Tuple[int, ...], None]=None,
  38. save_path: Optional[str]=None,
  39. tiles: str="GeoQ China Community") -> folium.Map:
  40. """
  41. Show mask (and original image) on an online map.
  42. Args:
  43. mask_path (str): Path of predicted or ground-truth masks.
  44. img_path (str|None, optional): Path of the original image. Defaults to None.
  45. band_list (list[int]|tuple[int]|None, optional):
  46. Bands to select from the original image for display (the band index starts from 1).
  47. If None, use all bands. Defaults to None.
  48. save_path (str, optional): Path of the .html file to save the visualization results.
  49. In Jupyter Notebook environments,
  50. leave `save_path` as None to display the result immediately in the notebook.
  51. Defaults to None.
  52. tiles (str): Map tileset to use. Chosen from the following list:
  53. - "GeoQ China Community", "GeoQ China Street" (from http://www.geoq.cn/)
  54. - "AMAP China" (from https://www.amap.com/)
  55. - "TencentMap China" (from https://map.qq.com/)
  56. - "BaiduMaps China" (from https://map.baidu.com/)
  57. Defaults to "GeoQ China Community".
  58. * All tilesets have been corrected through public algorithms from the Internet.
  59. * Please read the relevant terms of use carefully:
  60. - GeoQ [GISUNI] (http://geoq.cn/useragreement.html)
  61. - AMap [AutoNavi] (https://wap.amap.com/doc/serviceitem.html)
  62. - Tencent Map (https://ugc.map.qq.com/AppBox/Landlord/serveagreement.html)
  63. - Baidu Map (https://map.baidu.com/zt/client/service/index.html)
  64. Returns:
  65. folium.Map: An example of folium map.
  66. """
  67. if tiles not in CHINA_TILES:
  68. raise ValueError("The `tiles` must in {}, not {}.".format(CHINA_TILES,
  69. tiles))
  70. fmap = Map(
  71. tiles=tiles,
  72. min_zoom=1,
  73. max_zoom=24, )
  74. if img_path is not None:
  75. layer, _ = Raster(img_path, band_list).get_layer()
  76. layer.add_to(fmap)
  77. layer, center = Raster(mask_path).get_layer()
  78. layer.add_to(fmap)
  79. if center is not None:
  80. fmap.location = center
  81. fmap.fit_bounds(layer.bounds)
  82. LayerControl().add_to(fmap)
  83. if save_path:
  84. fmap.save(save_path)
  85. webbrowser.open(save_path)
  86. return fmap
  87. class OpenAsEPSG4326Error(Exception):
  88. pass
  89. class Raster:
  90. def __init__(
  91. self,
  92. path: str,
  93. band_list: Union[List[int], Tuple[int, ...], None]=None) -> None:
  94. self.src_data = Converter.open_as_WGS84(path)
  95. if self.src_data is None:
  96. raise OpenAsEPSG4326Error("Faild to open {} in EPSG:4326.".format(
  97. path))
  98. self.name = Path(path).stem
  99. self.set_bands(band_list)
  100. self._get_info()
  101. def get_layer(self) -> ImageOverlay:
  102. layer = ImageOverlay(self._get_array(), self.wgs_range, name=self.name)
  103. return layer, self.wgs_center
  104. def set_bands(self,
  105. band_list: Union[List[int], Tuple[int, ...], None]) -> None:
  106. self.bands = self.src_data.RasterCount
  107. if band_list is None:
  108. if self.bands == 3:
  109. band_list = [1, 2, 3]
  110. else:
  111. band_list = [1]
  112. band_list_lens = len(band_list)
  113. if band_list_lens not in (1, 3):
  114. raise ValueError("The lenght of band_list must be 1 or 3, not {}.".
  115. format(str(band_list_lens)))
  116. if max(band_list) > self.bands or min(band_list) < 1:
  117. raise ValueError("The range of band_list must within [1, {}].".
  118. format(str(self.bands)))
  119. self.band_list = band_list
  120. def _get_info(self) -> None:
  121. self.width = self.src_data.RasterXSize
  122. self.height = self.src_data.RasterYSize
  123. self.geotf = self.src_data.GetGeoTransform()
  124. self.proj = self.src_data.GetProjection() # WGS84
  125. self.wgs_range = self._get_WGS84_range()
  126. self.wgs_center = self._get_WGS84_center()
  127. def _get_WGS84_range(self) -> List[List[float]]:
  128. converter = Converter(self.proj, self.geotf)
  129. lat1, lon1 = converter.xy2latlon(self.height - 1, 0)
  130. lat2, lon2 = converter.xy2latlon(0, self.width - 1)
  131. return [[lon1, lat1], [lon2, lat2]]
  132. def _get_WGS84_center(self) -> List[float]:
  133. clat = (self.wgs_range[0][0] + self.wgs_range[1][0]) / 2
  134. clon = (self.wgs_range[0][1] + self.wgs_range[1][1]) / 2
  135. return [clat, clon]
  136. def _get_array(self) -> np.ndarray:
  137. band_array = []
  138. for b in self.band_list:
  139. band_i = self.src_data.GetRasterBand(b).ReadAsArray()
  140. band_array.append(band_i)
  141. ima = np.stack(band_array, axis=0)
  142. if self.bands == 1:
  143. # the type is complex means this is a SAR data
  144. if isinstance(type(ima[0, 0]), complex):
  145. ima = abs(ima)
  146. ima = ima.squeeze()
  147. else:
  148. ima = ima.transpose((1, 2, 0))
  149. ima = to_uint8(ima, True)
  150. return ima
  151. class Converter:
  152. def __init__(self, proj: str, geotf: tuple) -> None:
  153. # source data
  154. self.source = osr.SpatialReference()
  155. self.source.ImportFromWkt(proj)
  156. self.geotf = geotf
  157. # target data
  158. self.target = osr.SpatialReference()
  159. self.target.ImportFromEPSG(4326)
  160. @classmethod
  161. def open_as_WGS84(self, path: str) -> gdal.Dataset:
  162. if not osp.exists(path):
  163. raise FileNotFoundError("{} not found.".format(path))
  164. result = gdal.Warp("", path, dstSRS="EPSG:4326", format="VRT")
  165. return result
  166. def xy2latlon(self, row: int, col: int) -> List[float]:
  167. px = self.geotf[0] + col * self.geotf[1] + row * self.geotf[2]
  168. py = self.geotf[3] + col * self.geotf[4] + row * self.geotf[5]
  169. ct = osr.CoordinateTransformation(self.source, self.target)
  170. coords = ct.TransformPoint(px, py)
  171. return coords[:2]