classifier.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import os.path as osp
  16. from collections import OrderedDict
  17. from operator import itemgetter
  18. import numpy as np
  19. import paddle
  20. import paddle.nn.functional as F
  21. from paddle.static import InputSpec
  22. import paddlers
  23. import paddlers.models.ppcls as ppcls
  24. import paddlers.rs_models.clas as cmcls
  25. import paddlers.utils.logging as logging
  26. from paddlers.utils import get_single_card_bs, DisablePrint
  27. from paddlers.models.ppcls.metric import build_metrics
  28. from paddlers.models import clas_losses
  29. from paddlers.models.ppcls.data.postprocess import build_postprocess
  30. from paddlers.utils.checkpoint import cls_pretrain_weights_dict
  31. from paddlers.transforms import Resize, decode_image
  32. from .base import BaseModel
  33. __all__ = ["ResNet50_vd", "MobileNetV3", "HRNet", "CondenseNetV2"]
  34. class BaseClassifier(BaseModel):
  35. def __init__(self,
  36. model_name,
  37. in_channels=3,
  38. num_classes=2,
  39. use_mixed_loss=False,
  40. losses=None,
  41. **params):
  42. self.init_params = locals()
  43. if 'with_net' in self.init_params:
  44. del self.init_params['with_net']
  45. super(BaseClassifier, self).__init__('classifier')
  46. if not hasattr(ppcls.arch.backbone, model_name) and \
  47. not hasattr(cmcls, model_name):
  48. raise ValueError("ERROR: There is no model named {}.".format(
  49. model_name))
  50. self.model_name = model_name
  51. self.in_channels = in_channels
  52. self.num_classes = num_classes
  53. self.use_mixed_loss = use_mixed_loss
  54. self.metrics = None
  55. self.losses = losses
  56. self.labels = None
  57. self.postprocess = None
  58. if params.get('with_net', True):
  59. params.pop('with_net', None)
  60. self.net = self.build_net(**params)
  61. self.find_unused_parameters = True
  62. def build_net(self, **params):
  63. with paddle.utils.unique_name.guard():
  64. model = dict(ppcls.arch.backbone.__dict__,
  65. **cmcls.__dict__)[self.model_name]
  66. # TODO: Determine whether there is in_channels
  67. try:
  68. net = model(
  69. class_num=self.num_classes,
  70. in_channels=self.in_channels,
  71. **params)
  72. except:
  73. net = model(class_num=self.num_classes, **params)
  74. self.in_channels = 3
  75. return net
  76. def _build_inference_net(self):
  77. infer_net = self.net
  78. infer_net.eval()
  79. return infer_net
  80. def _fix_transforms_shape(self, image_shape):
  81. if hasattr(self, 'test_transforms'):
  82. if self.test_transforms is not None:
  83. has_resize_op = False
  84. resize_op_idx = -1
  85. normalize_op_idx = len(self.test_transforms.transforms)
  86. for idx, op in enumerate(self.test_transforms.transforms):
  87. name = op.__class__.__name__
  88. if name == 'Normalize':
  89. normalize_op_idx = idx
  90. if 'Resize' in name:
  91. has_resize_op = True
  92. resize_op_idx = idx
  93. if not has_resize_op:
  94. self.test_transforms.transforms.insert(
  95. normalize_op_idx, Resize(target_size=image_shape))
  96. else:
  97. self.test_transforms.transforms[resize_op_idx] = Resize(
  98. target_size=image_shape)
  99. def _get_test_inputs(self, image_shape):
  100. if image_shape is not None:
  101. if len(image_shape) == 2:
  102. image_shape = [1, 3] + image_shape
  103. self._fix_transforms_shape(image_shape[-2:])
  104. else:
  105. image_shape = [None, 3, -1, -1]
  106. self.fixed_input_shape = image_shape
  107. input_spec = [
  108. InputSpec(
  109. shape=image_shape, name='image', dtype='float32')
  110. ]
  111. return input_spec
  112. def run(self, net, inputs, mode):
  113. net_out = net(inputs[0])
  114. if mode == 'test':
  115. return self.postprocess(net_out)
  116. outputs = OrderedDict()
  117. label = paddle.to_tensor(inputs[1], dtype="int64")
  118. if mode == 'eval':
  119. label = paddle.unsqueeze(label, axis=-1)
  120. metric_dict = self.metrics(net_out, label)
  121. outputs['top1'] = metric_dict["top1"]
  122. outputs['top5'] = metric_dict["top5"]
  123. if mode == 'train':
  124. loss_list = self.losses(net_out, label)
  125. outputs['loss'] = loss_list['loss']
  126. return outputs
  127. def default_metric(self):
  128. default_config = [{"TopkAcc": {"topk": [1, 5]}}]
  129. return build_metrics(default_config)
  130. def default_loss(self):
  131. # TODO: use mixed loss and other loss
  132. default_config = [{"CELoss": {"weight": 1.0}}]
  133. return clas_losses.build_loss(default_config)
  134. def default_optimizer(self,
  135. parameters,
  136. learning_rate,
  137. num_epochs,
  138. num_steps_each_epoch,
  139. last_epoch=-1,
  140. L2_coeff=0.00007):
  141. decay_step = num_epochs * num_steps_each_epoch
  142. lr_scheduler = paddle.optimizer.lr.CosineAnnealingDecay(
  143. learning_rate, T_max=decay_step, eta_min=0, last_epoch=last_epoch)
  144. optimizer = paddle.optimizer.Momentum(
  145. learning_rate=lr_scheduler,
  146. parameters=parameters,
  147. momentum=0.9,
  148. weight_decay=paddle.regularizer.L2Decay(L2_coeff))
  149. return optimizer
  150. def default_postprocess(self, class_id_map_file):
  151. default_config = {
  152. "name": "Topk",
  153. "topk": 1,
  154. "class_id_map_file": class_id_map_file
  155. }
  156. return build_postprocess(default_config)
  157. def build_postprocess_from_labels(self, topk=1):
  158. label_dict = dict()
  159. for i, label in enumerate(self.labels):
  160. label_dict[i] = label
  161. self.postprocess = build_postprocess({
  162. "name": "Topk",
  163. "topk": topk,
  164. "class_id_map_file": None
  165. })
  166. # Add class_id_map from model.yml
  167. self.postprocess.class_id_map = label_dict
  168. def train(self,
  169. num_epochs,
  170. train_dataset,
  171. train_batch_size=2,
  172. eval_dataset=None,
  173. optimizer=None,
  174. save_interval_epochs=1,
  175. log_interval_steps=2,
  176. save_dir='output',
  177. pretrain_weights='IMAGENET',
  178. learning_rate=0.1,
  179. lr_decay_power=0.9,
  180. early_stop=False,
  181. early_stop_patience=5,
  182. use_vdl=True,
  183. resume_checkpoint=None):
  184. """
  185. Train the model.
  186. Args:
  187. num_epochs (int): Number of epochs.
  188. train_dataset (paddlers.datasets.ClasDataset): Training dataset.
  189. train_batch_size (int, optional): Total batch size among all cards used in
  190. training. Defaults to 2.
  191. eval_dataset (paddlers.datasets.ClasDataset|None, optional): Evaluation dataset.
  192. If None, the model will not be evaluated during training process.
  193. Defaults to None.
  194. optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used in
  195. training. If None, a default optimizer will be used. Defaults to None.
  196. save_interval_epochs (int, optional): Epoch interval for saving the model.
  197. Defaults to 1.
  198. log_interval_steps (int, optional): Step interval for printing training
  199. information. Defaults to 2.
  200. save_dir (str, optional): Directory to save the model. Defaults to 'output'.
  201. pretrain_weights (str|None, optional): None or name/path of pretrained
  202. weights. If None, no pretrained weights will be loaded.
  203. Defaults to 'IMAGENET'.
  204. learning_rate (float, optional): Learning rate for training.
  205. Defaults to .1.
  206. lr_decay_power (float, optional): Learning decay power. Defaults to .9.
  207. early_stop (bool, optional): Whether to adopt early stop strategy.
  208. Defaults to False.
  209. early_stop_patience (int, optional): Early stop patience. Defaults to 5.
  210. use_vdl (bool, optional): Whether to use VisualDL to monitor the training
  211. process. Defaults to True.
  212. resume_checkpoint (str|None, optional): Path of the checkpoint to resume
  213. training from. If None, no training checkpoint will be resumed. At most
  214. Aone of `resume_checkpoint` and `pretrain_weights` can be set simultaneously.
  215. Defaults to None.
  216. """
  217. if self.status == 'Infer':
  218. logging.error(
  219. "Exported inference model does not support training.",
  220. exit=True)
  221. if pretrain_weights is not None and resume_checkpoint is not None:
  222. logging.error(
  223. "`pretrain_weights` and `resume_checkpoint` cannot be set simultaneously.",
  224. exit=True)
  225. self.labels = train_dataset.labels
  226. if self.losses is None:
  227. self.losses = self.default_loss()
  228. self.metrics = self.default_metric()
  229. self.postprocess = self.default_postprocess(train_dataset.label_list)
  230. if optimizer is None:
  231. num_steps_each_epoch = train_dataset.num_samples // train_batch_size
  232. self.optimizer = self.default_optimizer(
  233. self.net.parameters(), learning_rate, num_epochs,
  234. num_steps_each_epoch, lr_decay_power)
  235. else:
  236. self.optimizer = optimizer
  237. if pretrain_weights is not None:
  238. if not osp.exists(pretrain_weights):
  239. if self.model_name not in cls_pretrain_weights_dict:
  240. logging.warning(
  241. "Path of `pretrain_weights` ('{}') does not exist!".
  242. format(pretrain_weights))
  243. pretrain_weights = None
  244. elif pretrain_weights not in cls_pretrain_weights_dict[
  245. self.model_name]:
  246. logging.warning(
  247. "Path of `pretrain_weights` ('{}') does not exist!".
  248. format(pretrain_weights))
  249. pretrain_weights = cls_pretrain_weights_dict[
  250. self.model_name][0]
  251. logging.warning(
  252. "`pretrain_weights` is forcibly set to '{}'. "
  253. "If you don't want to use pretrained weights, "
  254. "set `pretrain_weights` to None.".format(
  255. pretrain_weights))
  256. else:
  257. if osp.splitext(pretrain_weights)[-1] != '.pdparams':
  258. logging.error(
  259. "Invalid pretrained weights. Please specify a .pdparams file.",
  260. exit=True)
  261. pretrained_dir = osp.join(save_dir, 'pretrain')
  262. is_backbone_weights = False
  263. self.net_initialize(
  264. pretrain_weights=pretrain_weights,
  265. save_dir=pretrained_dir,
  266. resume_checkpoint=resume_checkpoint,
  267. is_backbone_weights=is_backbone_weights)
  268. self.train_loop(
  269. num_epochs=num_epochs,
  270. train_dataset=train_dataset,
  271. train_batch_size=train_batch_size,
  272. eval_dataset=eval_dataset,
  273. save_interval_epochs=save_interval_epochs,
  274. log_interval_steps=log_interval_steps,
  275. save_dir=save_dir,
  276. early_stop=early_stop,
  277. early_stop_patience=early_stop_patience,
  278. use_vdl=use_vdl)
  279. def quant_aware_train(self,
  280. num_epochs,
  281. train_dataset,
  282. train_batch_size=2,
  283. eval_dataset=None,
  284. optimizer=None,
  285. save_interval_epochs=1,
  286. log_interval_steps=2,
  287. save_dir='output',
  288. learning_rate=0.0001,
  289. lr_decay_power=0.9,
  290. early_stop=False,
  291. early_stop_patience=5,
  292. use_vdl=True,
  293. resume_checkpoint=None,
  294. quant_config=None):
  295. """
  296. Quantization-aware training.
  297. Args:
  298. num_epochs (int): Number of epochs.
  299. train_dataset (paddlers.datasets.ClasDataset): Training dataset.
  300. train_batch_size (int, optional): Total batch size among all cards used in
  301. training. Defaults to 2.
  302. eval_dataset (paddlers.datasets.ClasDataset|None, optional): Evaluation dataset.
  303. If None, the model will not be evaluated during training process.
  304. Defaults to None.
  305. optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used in
  306. training. If None, a default optimizer will be used. Defaults to None.
  307. save_interval_epochs (int, optional): Epoch interval for saving the model.
  308. Defaults to 1.
  309. log_interval_steps (int, optional): Step interval for printing training
  310. information. Defaults to 2.
  311. save_dir (str, optional): Directory to save the model. Defaults to 'output'.
  312. learning_rate (float, optional): Learning rate for training.
  313. Defaults to .0001.
  314. lr_decay_power (float, optional): Learning decay power. Defaults to .9.
  315. early_stop (bool, optional): Whether to adopt early stop strategy.
  316. Defaults to False.
  317. early_stop_patience (int, optional): Early stop patience. Defaults to 5.
  318. use_vdl (bool, optional): Whether to use VisualDL to monitor the training
  319. process. Defaults to True.
  320. quant_config (dict|None, optional): Quantization configuration. If None,
  321. a default rule of thumb configuration will be used. Defaults to None.
  322. resume_checkpoint (str|None, optional): Path of the checkpoint to resume
  323. quantization-aware training from. If None, no training checkpoint will
  324. be resumed. Defaults to None.
  325. """
  326. self._prepare_qat(quant_config)
  327. self.train(
  328. num_epochs=num_epochs,
  329. train_dataset=train_dataset,
  330. train_batch_size=train_batch_size,
  331. eval_dataset=eval_dataset,
  332. optimizer=optimizer,
  333. save_interval_epochs=save_interval_epochs,
  334. log_interval_steps=log_interval_steps,
  335. save_dir=save_dir,
  336. pretrain_weights=None,
  337. learning_rate=learning_rate,
  338. lr_decay_power=lr_decay_power,
  339. early_stop=early_stop,
  340. early_stop_patience=early_stop_patience,
  341. use_vdl=use_vdl,
  342. resume_checkpoint=resume_checkpoint)
  343. def evaluate(self, eval_dataset, batch_size=1, return_details=False):
  344. """
  345. Evaluate the model.
  346. Args:
  347. eval_dataset (paddlers.datasets.ClasDataset): Evaluation dataset.
  348. batch_size (int, optional): Total batch size among all cards used for
  349. evaluation. Defaults to 1.
  350. return_details (bool, optional): Whether to return evaluation details.
  351. Defaults to False.
  352. Returns:
  353. If `return_details` is False, return collections.OrderedDict with
  354. key-value pairs:
  355. {"top1": acc of top1,
  356. "top5": acc of top5}.
  357. """
  358. self._check_transforms(eval_dataset.transforms, 'eval')
  359. self.net.eval()
  360. nranks = paddle.distributed.get_world_size()
  361. local_rank = paddle.distributed.get_rank()
  362. if nranks > 1:
  363. # Initialize parallel environment if not done.
  364. if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
  365. ):
  366. paddle.distributed.init_parallel_env()
  367. if batch_size > 1:
  368. logging.warning(
  369. "Classifier only supports single card evaluation with batch_size=1 "
  370. "during evaluation, so batch_size is forcibly set to 1.")
  371. batch_size = 1
  372. if nranks < 2 or local_rank == 0:
  373. self.eval_data_loader = self.build_data_loader(
  374. eval_dataset, batch_size=batch_size, mode='eval')
  375. logging.info(
  376. "Start to evaluate(total_samples={}, total_steps={})...".format(
  377. eval_dataset.num_samples, eval_dataset.num_samples))
  378. top1s = []
  379. top5s = []
  380. with paddle.no_grad():
  381. for step, data in enumerate(self.eval_data_loader):
  382. data.append(eval_dataset.transforms.transforms)
  383. outputs = self.run(self.net, data, 'eval')
  384. top1s.append(outputs["top1"])
  385. top5s.append(outputs["top5"])
  386. top1 = np.mean(top1s)
  387. top5 = np.mean(top5s)
  388. eval_metrics = OrderedDict(zip(['top1', 'top5'], [top1, top5]))
  389. if return_details:
  390. # TODO: Add details
  391. return eval_metrics, None
  392. return eval_metrics
  393. def predict(self, img_file, transforms=None):
  394. """
  395. Do inference.
  396. Args:
  397. img_file (list[np.ndarray|str] | str | np.ndarray): Image path or decoded
  398. image data, which also could constitute a list, meaning all images to be
  399. predicted as a mini-batch.
  400. transforms (paddlers.transforms.Compose|None, optional): Transforms for
  401. inputs. If None, the transforms for evaluation process will be used.
  402. Defaults to None.
  403. Returns:
  404. If `img_file` is a string or np.array, the result is a dict with the
  405. following key-value pairs:
  406. class_ids_map (np.ndarray): IDs of predicted classes.
  407. scores_map (np.ndarray): Scores of predicted classes.
  408. label_names_map (np.ndarray): Names of predicted classes.
  409. If `img_file` is a list, the result is a list composed of dicts with the
  410. above keys.
  411. """
  412. if transforms is None and not hasattr(self, 'test_transforms'):
  413. raise ValueError("transforms need to be defined, now is None.")
  414. if transforms is None:
  415. transforms = self.test_transforms
  416. if isinstance(img_file, (str, np.ndarray)):
  417. images = [img_file]
  418. else:
  419. images = img_file
  420. batch_im, batch_origin_shape = self.preprocess(images, transforms,
  421. self.model_type)
  422. self.net.eval()
  423. data = (batch_im, batch_origin_shape, transforms.transforms)
  424. if self.postprocess is None:
  425. self.build_postprocess_from_labels()
  426. outputs = self.run(self.net, data, 'test')
  427. class_ids = map(itemgetter('class_ids'), outputs)
  428. scores = map(itemgetter('scores'), outputs)
  429. label_names = map(itemgetter('label_names'), outputs)
  430. if isinstance(img_file, list):
  431. prediction = [{
  432. 'class_ids_map': l,
  433. 'scores_map': s,
  434. 'label_names_map': n,
  435. } for l, s, n in zip(class_ids, scores, label_names)]
  436. else:
  437. prediction = {
  438. 'class_ids_map': next(class_ids),
  439. 'scores_map': next(scores),
  440. 'label_names_map': next(label_names)
  441. }
  442. return prediction
  443. def preprocess(self, images, transforms, to_tensor=True):
  444. self._check_transforms(transforms, 'test')
  445. batch_im = list()
  446. batch_ori_shape = list()
  447. for im in images:
  448. if isinstance(im, str):
  449. im = decode_image(im, to_rgb=False)
  450. ori_shape = im.shape[:2]
  451. sample = {'image': im}
  452. im = transforms(sample)
  453. batch_im.append(im)
  454. batch_ori_shape.append(ori_shape)
  455. if to_tensor:
  456. batch_im = paddle.to_tensor(batch_im)
  457. else:
  458. batch_im = np.asarray(batch_im)
  459. return batch_im, batch_ori_shape
  460. @staticmethod
  461. def get_transforms_shape_info(batch_ori_shape, transforms):
  462. batch_restore_list = list()
  463. for ori_shape in batch_ori_shape:
  464. restore_list = list()
  465. h, w = ori_shape[0], ori_shape[1]
  466. for op in transforms:
  467. if op.__class__.__name__ == 'Resize':
  468. restore_list.append(('resize', (h, w)))
  469. h, w = op.target_size
  470. elif op.__class__.__name__ == 'ResizeByShort':
  471. restore_list.append(('resize', (h, w)))
  472. im_short_size = min(h, w)
  473. im_long_size = max(h, w)
  474. scale = float(op.short_size) / float(im_short_size)
  475. if 0 < op.max_size < np.round(scale * im_long_size):
  476. scale = float(op.max_size) / float(im_long_size)
  477. h = int(round(h * scale))
  478. w = int(round(w * scale))
  479. elif op.__class__.__name__ == 'ResizeByLong':
  480. restore_list.append(('resize', (h, w)))
  481. im_long_size = max(h, w)
  482. scale = float(op.long_size) / float(im_long_size)
  483. h = int(round(h * scale))
  484. w = int(round(w * scale))
  485. elif op.__class__.__name__ == 'Pad':
  486. if op.target_size:
  487. target_h, target_w = op.target_size
  488. else:
  489. target_h = int(
  490. (np.ceil(h / op.size_divisor) * op.size_divisor))
  491. target_w = int(
  492. (np.ceil(w / op.size_divisor) * op.size_divisor))
  493. if op.pad_mode == -1:
  494. offsets = op.offsets
  495. elif op.pad_mode == 0:
  496. offsets = [0, 0]
  497. elif op.pad_mode == 1:
  498. offsets = [(target_h - h) // 2, (target_w - w) // 2]
  499. else:
  500. offsets = [target_h - h, target_w - w]
  501. restore_list.append(('padding', (h, w), offsets))
  502. h, w = target_h, target_w
  503. batch_restore_list.append(restore_list)
  504. return batch_restore_list
  505. def _check_transforms(self, transforms, mode):
  506. super()._check_transforms(transforms, mode)
  507. if not isinstance(transforms.arrange,
  508. paddlers.transforms.ArrangeClassifier):
  509. raise TypeError(
  510. "`transforms.arrange` must be an ArrangeClassifier object.")
  511. def build_data_loader(self, dataset, batch_size, mode='train'):
  512. if dataset.num_samples < batch_size:
  513. raise ValueError(
  514. 'The volume of dataset({}) must be larger than batch size({}).'
  515. .format(dataset.num_samples, batch_size))
  516. if mode != 'train':
  517. return paddle.io.DataLoader(
  518. dataset,
  519. batch_size=batch_size,
  520. shuffle=dataset.shuffle,
  521. drop_last=False,
  522. collate_fn=dataset.batch_transforms,
  523. num_workers=dataset.num_workers,
  524. return_list=True,
  525. use_shared_memory=False)
  526. else:
  527. return super(BaseClassifier, self).build_data_loader(
  528. dataset, batch_size, mode)
  529. class ResNet50_vd(BaseClassifier):
  530. def __init__(self,
  531. num_classes=2,
  532. use_mixed_loss=False,
  533. losses=None,
  534. **params):
  535. super(ResNet50_vd, self).__init__(
  536. model_name='ResNet50_vd',
  537. num_classes=num_classes,
  538. use_mixed_loss=use_mixed_loss,
  539. losses=losses,
  540. **params)
  541. class MobileNetV3(BaseClassifier):
  542. def __init__(self,
  543. num_classes=2,
  544. use_mixed_loss=False,
  545. losses=None,
  546. **params):
  547. super(MobileNetV3, self).__init__(
  548. model_name='MobileNetV3_small_x1_0',
  549. num_classes=num_classes,
  550. use_mixed_loss=use_mixed_loss,
  551. losses=losses,
  552. **params)
  553. class HRNet(BaseClassifier):
  554. def __init__(self,
  555. num_classes=2,
  556. use_mixed_loss=False,
  557. losses=None,
  558. **params):
  559. super(HRNet, self).__init__(
  560. model_name='HRNet_W18_C',
  561. num_classes=num_classes,
  562. use_mixed_loss=use_mixed_loss,
  563. losses=losses,
  564. **params)
  565. class CondenseNetV2(BaseClassifier):
  566. def __init__(self,
  567. num_classes=2,
  568. use_mixed_loss=False,
  569. losses=None,
  570. in_chnanels=3,
  571. arch='A',
  572. **params):
  573. if arch not in ('A', 'B', 'C'):
  574. raise ValueError("{} is not a supported architecture.".format(arch))
  575. model_name = 'CondenseNetV2_' + arch
  576. super(CondenseNetV2, self).__init__(
  577. model_name=model_name,
  578. num_classes=num_classes,
  579. use_mixed_loss=use_mixed_loss,
  580. losses=losses,
  581. in_channels=in_channels,
  582. **params)