classifier.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import os.path as osp
  16. from collections import OrderedDict
  17. from operator import itemgetter
  18. import numpy as np
  19. import paddle
  20. import paddle.nn.functional as F
  21. from paddle.static import InputSpec
  22. import paddlers
  23. import paddlers.models.ppcls as ppcls
  24. import paddlers.rs_models.clas as cmcls
  25. import paddlers.utils.logging as logging
  26. from paddlers.utils import get_single_card_bs, DisablePrint
  27. from paddlers.models.ppcls.metric import build_metrics
  28. from paddlers.models import clas_losses
  29. from paddlers.models.ppcls.data.postprocess import build_postprocess
  30. from paddlers.utils.checkpoint import cls_pretrain_weights_dict
  31. from paddlers.transforms import Resize, decode_image
  32. from .base import BaseModel
  33. __all__ = [
  34. "ResNet50_vd", "MobileNetV3_small_x1_0", "HRNet_W18_C", "CondenseNetV2_b"
  35. ]
  36. class BaseClassifier(BaseModel):
  37. def __init__(self,
  38. model_name,
  39. in_channels=3,
  40. num_classes=2,
  41. use_mixed_loss=False,
  42. losses=None,
  43. **params):
  44. self.init_params = locals()
  45. if 'with_net' in self.init_params:
  46. del self.init_params['with_net']
  47. super(BaseClassifier, self).__init__('classifier')
  48. if not hasattr(ppcls.arch.backbone, model_name) and \
  49. not hasattr(cmcls, model_name):
  50. raise ValueError("ERROR: There is no model named {}.".format(
  51. model_name))
  52. self.model_name = model_name
  53. self.in_channels = in_channels
  54. self.num_classes = num_classes
  55. self.use_mixed_loss = use_mixed_loss
  56. self.metrics = None
  57. self.losses = losses
  58. self.labels = None
  59. self.postprocess = None
  60. if params.get('with_net', True):
  61. params.pop('with_net', None)
  62. self.net = self.build_net(**params)
  63. self.find_unused_parameters = True
  64. def build_net(self, **params):
  65. with paddle.utils.unique_name.guard():
  66. model = dict(ppcls.arch.backbone.__dict__,
  67. **cmcls.__dict__)[self.model_name]
  68. # TODO: Determine whether there is in_channels
  69. try:
  70. net = model(
  71. class_num=self.num_classes,
  72. in_channels=self.in_channels,
  73. **params)
  74. except:
  75. net = model(class_num=self.num_classes, **params)
  76. self.in_channels = 3
  77. return net
  78. def _build_inference_net(self):
  79. infer_net = self.net
  80. infer_net.eval()
  81. return infer_net
  82. def _fix_transforms_shape(self, image_shape):
  83. if hasattr(self, 'test_transforms'):
  84. if self.test_transforms is not None:
  85. has_resize_op = False
  86. resize_op_idx = -1
  87. normalize_op_idx = len(self.test_transforms.transforms)
  88. for idx, op in enumerate(self.test_transforms.transforms):
  89. name = op.__class__.__name__
  90. if name == 'Normalize':
  91. normalize_op_idx = idx
  92. if 'Resize' in name:
  93. has_resize_op = True
  94. resize_op_idx = idx
  95. if not has_resize_op:
  96. self.test_transforms.transforms.insert(
  97. normalize_op_idx, Resize(target_size=image_shape))
  98. else:
  99. self.test_transforms.transforms[resize_op_idx] = Resize(
  100. target_size=image_shape)
  101. def _get_test_inputs(self, image_shape):
  102. if image_shape is not None:
  103. if len(image_shape) == 2:
  104. image_shape = [1, 3] + image_shape
  105. self._fix_transforms_shape(image_shape[-2:])
  106. else:
  107. image_shape = [None, 3, -1, -1]
  108. self.fixed_input_shape = image_shape
  109. input_spec = [
  110. InputSpec(
  111. shape=image_shape, name='image', dtype='float32')
  112. ]
  113. return input_spec
  114. def run(self, net, inputs, mode):
  115. net_out = net(inputs[0])
  116. if mode == 'test':
  117. return self.postprocess(net_out)
  118. outputs = OrderedDict()
  119. label = paddle.to_tensor(inputs[1], dtype="int64")
  120. if mode == 'eval':
  121. label = paddle.unsqueeze(label, axis=-1)
  122. metric_dict = self.metrics(net_out, label)
  123. outputs['top1'] = metric_dict["top1"]
  124. outputs['top5'] = metric_dict["top5"]
  125. if mode == 'train':
  126. loss_list = self.losses(net_out, label)
  127. outputs['loss'] = loss_list['loss']
  128. return outputs
  129. def default_metric(self):
  130. default_config = [{"TopkAcc": {"topk": [1, 5]}}]
  131. return build_metrics(default_config)
  132. def default_loss(self):
  133. # TODO: use mixed loss and other loss
  134. default_config = [{"CELoss": {"weight": 1.0}}]
  135. return clas_losses.build_loss(default_config)
  136. def default_optimizer(self,
  137. parameters,
  138. learning_rate,
  139. num_epochs,
  140. num_steps_each_epoch,
  141. last_epoch=-1,
  142. L2_coeff=0.00007):
  143. decay_step = num_epochs * num_steps_each_epoch
  144. lr_scheduler = paddle.optimizer.lr.CosineAnnealingDecay(
  145. learning_rate, T_max=decay_step, eta_min=0, last_epoch=last_epoch)
  146. optimizer = paddle.optimizer.Momentum(
  147. learning_rate=lr_scheduler,
  148. parameters=parameters,
  149. momentum=0.9,
  150. weight_decay=paddle.regularizer.L2Decay(L2_coeff))
  151. return optimizer
  152. def default_postprocess(self, class_id_map_file):
  153. default_config = {
  154. "name": "Topk",
  155. "topk": 1,
  156. "class_id_map_file": class_id_map_file
  157. }
  158. return build_postprocess(default_config)
  159. def build_postprocess_from_labels(self, topk=1):
  160. label_dict = dict()
  161. for i, label in enumerate(self.labels):
  162. label_dict[i] = label
  163. self.postprocess = build_postprocess({
  164. "name": "Topk",
  165. "topk": topk,
  166. "class_id_map_file": None
  167. })
  168. # Add class_id_map from model.yml
  169. self.postprocess.class_id_map = label_dict
  170. def train(self,
  171. num_epochs,
  172. train_dataset,
  173. train_batch_size=2,
  174. eval_dataset=None,
  175. optimizer=None,
  176. save_interval_epochs=1,
  177. log_interval_steps=2,
  178. save_dir='output',
  179. pretrain_weights='IMAGENET',
  180. learning_rate=0.1,
  181. lr_decay_power=0.9,
  182. early_stop=False,
  183. early_stop_patience=5,
  184. use_vdl=True,
  185. resume_checkpoint=None):
  186. """
  187. Train the model.
  188. Args:
  189. num_epochs (int): Number of epochs.
  190. train_dataset (paddlers.datasets.ClasDataset): Training dataset.
  191. train_batch_size (int, optional): Total batch size among all cards used in
  192. training. Defaults to 2.
  193. eval_dataset (paddlers.datasets.ClasDataset|None, optional): Evaluation dataset.
  194. If None, the model will not be evaluated during training process.
  195. Defaults to None.
  196. optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used in
  197. training. If None, a default optimizer will be used. Defaults to None.
  198. save_interval_epochs (int, optional): Epoch interval for saving the model.
  199. Defaults to 1.
  200. log_interval_steps (int, optional): Step interval for printing training
  201. information. Defaults to 2.
  202. save_dir (str, optional): Directory to save the model. Defaults to 'output'.
  203. pretrain_weights (str|None, optional): None or name/path of pretrained
  204. weights. If None, no pretrained weights will be loaded.
  205. Defaults to 'IMAGENET'.
  206. learning_rate (float, optional): Learning rate for training.
  207. Defaults to .1.
  208. lr_decay_power (float, optional): Learning decay power. Defaults to .9.
  209. early_stop (bool, optional): Whether to adopt early stop strategy.
  210. Defaults to False.
  211. early_stop_patience (int, optional): Early stop patience. Defaults to 5.
  212. use_vdl (bool, optional): Whether to use VisualDL to monitor the training
  213. process. Defaults to True.
  214. resume_checkpoint (str|None, optional): Path of the checkpoint to resume
  215. training from. If None, no training checkpoint will be resumed. At most
  216. Aone of `resume_checkpoint` and `pretrain_weights` can be set simultaneously.
  217. Defaults to None.
  218. """
  219. if self.status == 'Infer':
  220. logging.error(
  221. "Exported inference model does not support training.",
  222. exit=True)
  223. if pretrain_weights is not None and resume_checkpoint is not None:
  224. logging.error(
  225. "`pretrain_weights` and `resume_checkpoint` cannot be set simultaneously.",
  226. exit=True)
  227. self.labels = train_dataset.labels
  228. if self.losses is None:
  229. self.losses = self.default_loss()
  230. self.metrics = self.default_metric()
  231. self.postprocess = self.default_postprocess(train_dataset.label_list)
  232. if optimizer is None:
  233. num_steps_each_epoch = train_dataset.num_samples // train_batch_size
  234. self.optimizer = self.default_optimizer(
  235. self.net.parameters(), learning_rate, num_epochs,
  236. num_steps_each_epoch, lr_decay_power)
  237. else:
  238. self.optimizer = optimizer
  239. if pretrain_weights is not None:
  240. if not osp.exists(pretrain_weights):
  241. if self.model_name not in cls_pretrain_weights_dict:
  242. logging.warning(
  243. "Path of `pretrain_weights` ('{}') does not exist!".
  244. format(pretrain_weights))
  245. pretrain_weights = None
  246. elif pretrain_weights not in cls_pretrain_weights_dict[
  247. self.model_name]:
  248. logging.warning(
  249. "Path of `pretrain_weights` ('{}') does not exist!".
  250. format(pretrain_weights))
  251. pretrain_weights = cls_pretrain_weights_dict[
  252. self.model_name][0]
  253. logging.warning(
  254. "`pretrain_weights` is forcibly set to '{}'. "
  255. "If you don't want to use pretrained weights, "
  256. "set `pretrain_weights` to None.".format(
  257. pretrain_weights))
  258. else:
  259. if osp.splitext(pretrain_weights)[-1] != '.pdparams':
  260. logging.error(
  261. "Invalid pretrained weights. Please specify a .pdparams file.",
  262. exit=True)
  263. pretrained_dir = osp.join(save_dir, 'pretrain')
  264. is_backbone_weights = False
  265. self.initialize_net(
  266. pretrain_weights=pretrain_weights,
  267. save_dir=pretrained_dir,
  268. resume_checkpoint=resume_checkpoint,
  269. is_backbone_weights=is_backbone_weights)
  270. self.train_loop(
  271. num_epochs=num_epochs,
  272. train_dataset=train_dataset,
  273. train_batch_size=train_batch_size,
  274. eval_dataset=eval_dataset,
  275. save_interval_epochs=save_interval_epochs,
  276. log_interval_steps=log_interval_steps,
  277. save_dir=save_dir,
  278. early_stop=early_stop,
  279. early_stop_patience=early_stop_patience,
  280. use_vdl=use_vdl)
  281. def quant_aware_train(self,
  282. num_epochs,
  283. train_dataset,
  284. train_batch_size=2,
  285. eval_dataset=None,
  286. optimizer=None,
  287. save_interval_epochs=1,
  288. log_interval_steps=2,
  289. save_dir='output',
  290. learning_rate=0.0001,
  291. lr_decay_power=0.9,
  292. early_stop=False,
  293. early_stop_patience=5,
  294. use_vdl=True,
  295. resume_checkpoint=None,
  296. quant_config=None):
  297. """
  298. Quantization-aware training.
  299. Args:
  300. num_epochs (int): Number of epochs.
  301. train_dataset (paddlers.datasets.ClasDataset): Training dataset.
  302. train_batch_size (int, optional): Total batch size among all cards used in
  303. training. Defaults to 2.
  304. eval_dataset (paddlers.datasets.ClasDataset|None, optional): Evaluation dataset.
  305. If None, the model will not be evaluated during training process.
  306. Defaults to None.
  307. optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used in
  308. training. If None, a default optimizer will be used. Defaults to None.
  309. save_interval_epochs (int, optional): Epoch interval for saving the model.
  310. Defaults to 1.
  311. log_interval_steps (int, optional): Step interval for printing training
  312. information. Defaults to 2.
  313. save_dir (str, optional): Directory to save the model. Defaults to 'output'.
  314. learning_rate (float, optional): Learning rate for training.
  315. Defaults to .0001.
  316. lr_decay_power (float, optional): Learning decay power. Defaults to .9.
  317. early_stop (bool, optional): Whether to adopt early stop strategy.
  318. Defaults to False.
  319. early_stop_patience (int, optional): Early stop patience. Defaults to 5.
  320. use_vdl (bool, optional): Whether to use VisualDL to monitor the training
  321. process. Defaults to True.
  322. quant_config (dict|None, optional): Quantization configuration. If None,
  323. a default rule of thumb configuration will be used. Defaults to None.
  324. resume_checkpoint (str|None, optional): Path of the checkpoint to resume
  325. quantization-aware training from. If None, no training checkpoint will
  326. be resumed. Defaults to None.
  327. """
  328. self._prepare_qat(quant_config)
  329. self.train(
  330. num_epochs=num_epochs,
  331. train_dataset=train_dataset,
  332. train_batch_size=train_batch_size,
  333. eval_dataset=eval_dataset,
  334. optimizer=optimizer,
  335. save_interval_epochs=save_interval_epochs,
  336. log_interval_steps=log_interval_steps,
  337. save_dir=save_dir,
  338. pretrain_weights=None,
  339. learning_rate=learning_rate,
  340. lr_decay_power=lr_decay_power,
  341. early_stop=early_stop,
  342. early_stop_patience=early_stop_patience,
  343. use_vdl=use_vdl,
  344. resume_checkpoint=resume_checkpoint)
  345. def evaluate(self, eval_dataset, batch_size=1, return_details=False):
  346. """
  347. Evaluate the model.
  348. Args:
  349. eval_dataset (paddlers.datasets.ClasDataset): Evaluation dataset.
  350. batch_size (int, optional): Total batch size among all cards used for
  351. evaluation. Defaults to 1.
  352. return_details (bool, optional): Whether to return evaluation details.
  353. Defaults to False.
  354. Returns:
  355. If `return_details` is False, return collections.OrderedDict with
  356. key-value pairs:
  357. {"top1": acc of top1,
  358. "top5": acc of top5}.
  359. """
  360. self._check_transforms(eval_dataset.transforms, 'eval')
  361. self.net.eval()
  362. nranks = paddle.distributed.get_world_size()
  363. local_rank = paddle.distributed.get_rank()
  364. if nranks > 1:
  365. # Initialize parallel environment if not done.
  366. if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
  367. ):
  368. paddle.distributed.init_parallel_env()
  369. if batch_size > 1:
  370. logging.warning(
  371. "Classifier only supports single card evaluation with batch_size=1 "
  372. "during evaluation, so batch_size is forcibly set to 1.")
  373. batch_size = 1
  374. if nranks < 2 or local_rank == 0:
  375. self.eval_data_loader = self.build_data_loader(
  376. eval_dataset, batch_size=batch_size, mode='eval')
  377. logging.info(
  378. "Start to evaluate(total_samples={}, total_steps={})...".format(
  379. eval_dataset.num_samples, eval_dataset.num_samples))
  380. top1s = []
  381. top5s = []
  382. with paddle.no_grad():
  383. for step, data in enumerate(self.eval_data_loader):
  384. data.append(eval_dataset.transforms.transforms)
  385. outputs = self.run(self.net, data, 'eval')
  386. top1s.append(outputs["top1"])
  387. top5s.append(outputs["top5"])
  388. top1 = np.mean(top1s)
  389. top5 = np.mean(top5s)
  390. eval_metrics = OrderedDict(zip(['top1', 'top5'], [top1, top5]))
  391. if return_details:
  392. # TODO: Add details
  393. return eval_metrics, None
  394. return eval_metrics
  395. def predict(self, img_file, transforms=None):
  396. """
  397. Do inference.
  398. Args:
  399. img_file (list[np.ndarray|str] | str | np.ndarray): Image path or decoded
  400. image data, which also could constitute a list, meaning all images to be
  401. predicted as a mini-batch.
  402. transforms (paddlers.transforms.Compose|None, optional): Transforms for
  403. inputs. If None, the transforms for evaluation process will be used.
  404. Defaults to None.
  405. Returns:
  406. If `img_file` is a string or np.array, the result is a dict with the
  407. following key-value pairs:
  408. class_ids_map (np.ndarray): IDs of predicted classes.
  409. scores_map (np.ndarray): Scores of predicted classes.
  410. label_names_map (np.ndarray): Names of predicted classes.
  411. If `img_file` is a list, the result is a list composed of dicts with the
  412. above keys.
  413. """
  414. if transforms is None and not hasattr(self, 'test_transforms'):
  415. raise ValueError("transforms need to be defined, now is None.")
  416. if transforms is None:
  417. transforms = self.test_transforms
  418. if isinstance(img_file, (str, np.ndarray)):
  419. images = [img_file]
  420. else:
  421. images = img_file
  422. batch_im, batch_origin_shape = self.preprocess(images, transforms,
  423. self.model_type)
  424. self.net.eval()
  425. data = (batch_im, batch_origin_shape, transforms.transforms)
  426. if self.postprocess is None:
  427. self.build_postprocess_from_labels()
  428. outputs = self.run(self.net, data, 'test')
  429. class_ids = map(itemgetter('class_ids'), outputs)
  430. scores = map(itemgetter('scores'), outputs)
  431. label_names = map(itemgetter('label_names'), outputs)
  432. if isinstance(img_file, list):
  433. prediction = [{
  434. 'class_ids_map': l,
  435. 'scores_map': s,
  436. 'label_names_map': n,
  437. } for l, s, n in zip(class_ids, scores, label_names)]
  438. else:
  439. prediction = {
  440. 'class_ids_map': next(class_ids),
  441. 'scores_map': next(scores),
  442. 'label_names_map': next(label_names)
  443. }
  444. return prediction
  445. def preprocess(self, images, transforms, to_tensor=True):
  446. self._check_transforms(transforms, 'test')
  447. batch_im = list()
  448. batch_ori_shape = list()
  449. for im in images:
  450. if isinstance(im, str):
  451. im = decode_image(im, read_raw=True)
  452. ori_shape = im.shape[:2]
  453. sample = {'image': im}
  454. im = transforms(sample)
  455. batch_im.append(im)
  456. batch_ori_shape.append(ori_shape)
  457. if to_tensor:
  458. batch_im = paddle.to_tensor(batch_im)
  459. else:
  460. batch_im = np.asarray(batch_im)
  461. return batch_im, batch_ori_shape
  462. @staticmethod
  463. def get_transforms_shape_info(batch_ori_shape, transforms):
  464. batch_restore_list = list()
  465. for ori_shape in batch_ori_shape:
  466. restore_list = list()
  467. h, w = ori_shape[0], ori_shape[1]
  468. for op in transforms:
  469. if op.__class__.__name__ == 'Resize':
  470. restore_list.append(('resize', (h, w)))
  471. h, w = op.target_size
  472. elif op.__class__.__name__ == 'ResizeByShort':
  473. restore_list.append(('resize', (h, w)))
  474. im_short_size = min(h, w)
  475. im_long_size = max(h, w)
  476. scale = float(op.short_size) / float(im_short_size)
  477. if 0 < op.max_size < np.round(scale * im_long_size):
  478. scale = float(op.max_size) / float(im_long_size)
  479. h = int(round(h * scale))
  480. w = int(round(w * scale))
  481. elif op.__class__.__name__ == 'ResizeByLong':
  482. restore_list.append(('resize', (h, w)))
  483. im_long_size = max(h, w)
  484. scale = float(op.long_size) / float(im_long_size)
  485. h = int(round(h * scale))
  486. w = int(round(w * scale))
  487. elif op.__class__.__name__ == 'Pad':
  488. if op.target_size:
  489. target_h, target_w = op.target_size
  490. else:
  491. target_h = int(
  492. (np.ceil(h / op.size_divisor) * op.size_divisor))
  493. target_w = int(
  494. (np.ceil(w / op.size_divisor) * op.size_divisor))
  495. if op.pad_mode == -1:
  496. offsets = op.offsets
  497. elif op.pad_mode == 0:
  498. offsets = [0, 0]
  499. elif op.pad_mode == 1:
  500. offsets = [(target_h - h) // 2, (target_w - w) // 2]
  501. else:
  502. offsets = [target_h - h, target_w - w]
  503. restore_list.append(('padding', (h, w), offsets))
  504. h, w = target_h, target_w
  505. batch_restore_list.append(restore_list)
  506. return batch_restore_list
  507. def _check_transforms(self, transforms, mode):
  508. super()._check_transforms(transforms, mode)
  509. if not isinstance(transforms.arrange,
  510. paddlers.transforms.ArrangeClassifier):
  511. raise TypeError(
  512. "`transforms.arrange` must be an ArrangeClassifier object.")
  513. def build_data_loader(self, dataset, batch_size, mode='train'):
  514. if dataset.num_samples < batch_size:
  515. raise ValueError(
  516. 'The volume of dataset({}) must be larger than batch size({}).'
  517. .format(dataset.num_samples, batch_size))
  518. if mode != 'train':
  519. return paddle.io.DataLoader(
  520. dataset,
  521. batch_size=batch_size,
  522. shuffle=dataset.shuffle,
  523. drop_last=False,
  524. collate_fn=dataset.batch_transforms,
  525. num_workers=dataset.num_workers,
  526. return_list=True,
  527. use_shared_memory=False)
  528. else:
  529. return super(BaseClassifier, self).build_data_loader(
  530. dataset, batch_size, mode)
  531. class ResNet50_vd(BaseClassifier):
  532. def __init__(self,
  533. num_classes=2,
  534. use_mixed_loss=False,
  535. losses=None,
  536. **params):
  537. super(ResNet50_vd, self).__init__(
  538. model_name='ResNet50_vd',
  539. num_classes=num_classes,
  540. use_mixed_loss=use_mixed_loss,
  541. losses=losses,
  542. **params)
  543. class MobileNetV3_small_x1_0(BaseClassifier):
  544. def __init__(self,
  545. num_classes=2,
  546. use_mixed_loss=False,
  547. losses=None,
  548. **params):
  549. super(MobileNetV3_small_x1_0, self).__init__(
  550. model_name='MobileNetV3_small_x1_0',
  551. num_classes=num_classes,
  552. use_mixed_loss=use_mixed_loss,
  553. losses=losses,
  554. **params)
  555. class HRNet_W18_C(BaseClassifier):
  556. def __init__(self,
  557. num_classes=2,
  558. use_mixed_loss=False,
  559. losses=None,
  560. **params):
  561. super(HRNet_W18_C, self).__init__(
  562. model_name='HRNet_W18_C',
  563. num_classes=num_classes,
  564. use_mixed_loss=use_mixed_loss,
  565. losses=losses,
  566. **params)
  567. class CondenseNetV2_b(BaseClassifier):
  568. def __init__(self,
  569. num_classes=2,
  570. use_mixed_loss=False,
  571. losses=None,
  572. **params):
  573. super(CondenseNetV2_b, self).__init__(
  574. model_name='CondenseNetV2_b',
  575. num_classes=num_classes,
  576. use_mixed_loss=use_mixed_loss,
  577. losses=losses,
  578. **params)