utils.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import sys
  15. import os
  16. import os.path as osp
  17. import time
  18. import math
  19. import imghdr
  20. import chardet
  21. import json
  22. import numpy as np
  23. from . import logging
  24. import platform
  25. import paddlers
  26. def seconds_to_hms(seconds):
  27. h = math.floor(seconds / 3600)
  28. m = math.floor((seconds - h * 3600) / 60)
  29. s = int(seconds - h * 3600 - m * 60)
  30. hms_str = "{}:{}:{}".format(h, m, s)
  31. return hms_str
  32. def get_encoding(path):
  33. f = open(path, 'rb')
  34. data = f.read()
  35. file_encoding = chardet.detect(data).get('encoding')
  36. f.close()
  37. return file_encoding
  38. def get_single_card_bs(batch_size):
  39. card_num = paddlers.env_info['num']
  40. place = paddlers.env_info['place']
  41. if batch_size % card_num == 0:
  42. return int(batch_size // card_num)
  43. elif batch_size == 1:
  44. # Evaluation of detection task only supports single card with batch size 1
  45. return batch_size
  46. else:
  47. raise Exception("Please support correct batch_size, \
  48. which can be divided by available cards({}) in {}"
  49. .format(card_num, place))
  50. def dict2str(dict_input):
  51. out = ''
  52. for k, v in dict_input.items():
  53. try:
  54. v = '{:8.6f}'.format(float(v))
  55. except:
  56. pass
  57. out = out + '{}={}, '.format(k, v)
  58. return out.strip(', ')
  59. def path_normalization(path):
  60. win_sep = "\\"
  61. other_sep = "/"
  62. if platform.system() == "Windows":
  63. path = win_sep.join(path.split(other_sep))
  64. else:
  65. path = other_sep.join(path.split(win_sep))
  66. return path
  67. def is_pic(img_path):
  68. valid_suffix = [
  69. 'JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png', 'npy'
  70. ]
  71. suffix = img_path.split('.')[-1]
  72. if suffix in valid_suffix:
  73. return True
  74. img_format = imghdr.what(img_path)
  75. _, ext = osp.splitext(img_path)
  76. if img_format == 'tiff' or ext == '.img':
  77. return True
  78. return False
  79. class MyEncoder(json.JSONEncoder):
  80. def default(self, obj):
  81. if isinstance(obj, np.integer):
  82. return int(obj)
  83. elif isinstance(obj, np.floating):
  84. return float(obj)
  85. elif isinstance(obj, np.ndarray):
  86. return obj.tolist()
  87. else:
  88. return super(MyEncoder, self).default(obj)
  89. class EarlyStop:
  90. def __init__(self, patience, thresh):
  91. self.patience = patience
  92. self.counter = 0
  93. self.score = None
  94. self.max = 0
  95. self.thresh = thresh
  96. if patience < 1:
  97. raise Exception("Argument patience should be a positive integer.")
  98. def __call__(self, current_score):
  99. if self.score is None:
  100. self.score = current_score
  101. return False
  102. elif current_score > self.max:
  103. self.counter = 0
  104. self.score = current_score
  105. self.max = current_score
  106. return False
  107. else:
  108. if (abs(self.score - current_score) < self.thresh or
  109. current_score < self.score):
  110. self.counter += 1
  111. self.score = current_score
  112. logging.debug("EarlyStopping: %i / %i" %
  113. (self.counter, self.patience))
  114. if self.counter >= self.patience:
  115. logging.info("EarlyStopping: Stop training")
  116. return True
  117. return False
  118. else:
  119. self.counter = 0
  120. self.score = current_score
  121. return False
  122. class DisablePrint(object):
  123. def __enter__(self):
  124. self._original_stdout = sys.stdout
  125. sys.stdout = open(os.devnull, 'w')
  126. def __exit__(self, exc_type, exc_val, exc_tb):
  127. sys.stdout.close()
  128. sys.stdout = self._original_stdout
  129. class Times(object):
  130. def __init__(self):
  131. self.time = 0.
  132. # start time
  133. self.st = 0.
  134. # end time
  135. self.et = 0.
  136. def start(self):
  137. self.st = time.time()
  138. def end(self, iter_num=1, accumulative=True):
  139. self.et = time.time()
  140. if accumulative:
  141. self.time += (self.et - self.st) / iter_num
  142. else:
  143. self.time = (self.et - self.st) / iter_num
  144. def reset(self):
  145. self.time = 0.
  146. self.st = 0.
  147. self.et = 0.
  148. def value(self):
  149. return round(self.time, 4)
  150. class Timer(Times):
  151. def __init__(self):
  152. super(Timer, self).__init__()
  153. self.preprocess_time_s = Times()
  154. self.inference_time_s = Times()
  155. self.postprocess_time_s = Times()
  156. self.img_num = 0
  157. self.repeats = 0
  158. def info(self, average=False):
  159. total_time = self.preprocess_time_s.value(
  160. ) * self.img_num + self.inference_time_s.value(
  161. ) + self.postprocess_time_s.value() * self.img_num
  162. total_time = round(total_time, 4)
  163. print("------------------ Inference Time Info ----------------------")
  164. print("total_time(ms): {}, img_num: {}, batch_size: {}".format(
  165. total_time * 1000, self.img_num, self.img_num))
  166. preprocess_time = round(
  167. self.preprocess_time_s.value() / self.repeats,
  168. 4) if average else self.preprocess_time_s.value()
  169. postprocess_time = round(
  170. self.postprocess_time_s.value() / self.repeats,
  171. 4) if average else self.postprocess_time_s.value()
  172. inference_time = round(self.inference_time_s.value() / self.repeats,
  173. 4) if average else self.inference_time_s.value()
  174. average_latency = total_time / self.repeats
  175. print("average latency time(ms): {:.2f}, QPS: {:2f}".format(
  176. average_latency * 1000, 1 / average_latency))
  177. print("preprocess_time_per_im(ms): {:.2f}, "
  178. "inference_time_per_batch(ms): {:.2f}, "
  179. "postprocess_time_per_im(ms): {:.2f}".format(
  180. preprocess_time * 1000, inference_time * 1000,
  181. postprocess_time * 1000))
  182. def report(self, average=False):
  183. dic = {}
  184. dic['preprocess_time_s'] = round(
  185. self.preprocess_time_s.value() / self.repeats,
  186. 4) if average else self.preprocess_time_s.value()
  187. dic['postprocess_time_s'] = round(
  188. self.postprocess_time_s.value() / self.repeats,
  189. 4) if average else self.postprocess_time_s.value()
  190. dic['inference_time_s'] = round(
  191. self.inference_time_s.value() / self.repeats,
  192. 4) if average else self.inference_time_s.value()
  193. dic['img_num'] = self.img_num
  194. total_time = self.preprocess_time_s.value(
  195. ) + self.inference_time_s.value() + self.postprocess_time_s.value()
  196. dic['total_time_s'] = round(total_time, 4)
  197. dic['batch_size'] = self.img_num / self.repeats
  198. return dic
  199. def reset(self):
  200. self.preprocess_time_s.reset()
  201. self.inference_time_s.reset()
  202. self.postprocess_time_s.reset()
  203. self.img_num = 0
  204. self.repeats = 0