12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182 |
- import argparse
- import os
- import os.path as osp
- import cv2
- import paddle
- import paddlers
- from tqdm import tqdm
- import custom_model
- import custom_trainer
- def read_file_list(file_list, sep=' '):
- with open(file_list, 'r') as f:
- for line in f:
- line = line.strip()
- parts = line.split(sep)
- yield parts
- def parse_args():
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--model_dir", default=None, type=str, help="Path of saved model.")
- parser.add_argument("--data_dir", type=str, help="Path of input dataset.")
- parser.add_argument("--file_list", type=str, help="Path of file list.")
- parser.add_argument(
- "--save_dir",
- default='./exp/predict',
- type=str,
- help="Path of directory to save prediction results.")
- parser.add_argument(
- "--ext",
- default='.png',
- type=str,
- help="Extension name of the saved image file.")
- return parser.parse_args()
- if __name__ == '__main__':
- args = parse_args()
- model = paddlers.tasks.load_model(args.model_dir)
- if not osp.exists(args.save_dir):
- os.makedirs(args.save_dir)
- with paddle.no_grad():
- for parts in tqdm(read_file_list(args.file_list)):
- im1_path = osp.join(args.data_dir, parts[0])
- im2_path = osp.join(args.data_dir, parts[1])
- pred = model.predict((im1_path, im2_path))
- cm = pred['label_map']
-
- cm[cm > 0] = 255
- cm = cm.astype('uint8')
- if len(parts) > 2:
- name = osp.basename(parts[2])
- else:
- name = osp.basename(im1_path)
- name = osp.splitext(name)[0] + args.ext
- out_path = osp.join(args.save_dir, name)
- cv2.imwrite(out_path, cm)
|