predict_cd.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. #!/usr/bin/env python
  2. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import argparse
  16. import os
  17. import os.path as osp
  18. import cv2
  19. import paddle
  20. import paddlers
  21. from tqdm import tqdm
  22. import bootstrap
  23. def read_file_list(file_list, sep=' '):
  24. with open(file_list, 'r') as f:
  25. for line in f:
  26. line = line.strip()
  27. parts = line.split(sep)
  28. yield parts
  29. def parse_args():
  30. parser = argparse.ArgumentParser()
  31. parser.add_argument(
  32. "--model_dir", default=None, type=str, help="Path of saved model.")
  33. parser.add_argument("--data_dir", type=str, help="Path of input dataset.")
  34. parser.add_argument("--file_list", type=str, help="Path of file list.")
  35. parser.add_argument(
  36. "--save_dir",
  37. default='./exp/predict',
  38. type=str,
  39. help="Path of directory to save prediction results.")
  40. parser.add_argument(
  41. "--ext",
  42. default='.png',
  43. type=str,
  44. help="Extension name of the saved image file.")
  45. return parser.parse_args()
  46. if __name__ == '__main__':
  47. args = parse_args()
  48. model = paddlers.tasks.load_model(args.model_dir)
  49. if not osp.exists(args.save_dir):
  50. os.makedirs(args.save_dir)
  51. with paddle.no_grad():
  52. for parts in tqdm(read_file_list(args.file_list)):
  53. im1_path = osp.join(args.data_dir, parts[0])
  54. im2_path = osp.join(args.data_dir, parts[1])
  55. pred = model.predict((im1_path, im2_path))
  56. cm = pred['label_map']
  57. # {0,1} -> {0,255}
  58. cm[cm > 0] = 255
  59. cm = cm.astype('uint8')
  60. if len(parts) > 2:
  61. name = osp.basename(parts[2])
  62. else:
  63. name = osp.basename(im1_path)
  64. name = osp.splitext(name)[0] + args.ext
  65. out_path = osp.join(args.save_dir, name)
  66. cv2.imwrite(out_path, cm)