|
@@ -20,6 +20,7 @@ import math
|
|
|
import json
|
|
|
from functools import partial, wraps
|
|
|
from inspect import signature
|
|
|
+from typing import Optional
|
|
|
|
|
|
import yaml
|
|
|
import paddle
|
|
@@ -29,6 +30,7 @@ from paddleslim.analysis import flops
|
|
|
from paddleslim import L1NormFilterPruner, FPGMFilterPruner
|
|
|
|
|
|
import paddlers
|
|
|
+from paddlers.transforms.operators import Compose, DecodeImg, Arrange
|
|
|
import paddlers.utils.logging as logging
|
|
|
from paddlers.utils import (
|
|
|
seconds_to_hms, get_single_card_bs, dict2str, get_pretrain_weights,
|
|
@@ -61,7 +63,7 @@ class ModelMeta(type):
|
|
|
|
|
|
|
|
|
class BaseModel(metaclass=ModelMeta):
|
|
|
-
|
|
|
+ _arrange: Optional[Arrange] = None
|
|
|
find_unused_parameters = False
|
|
|
|
|
|
def __init__(self, model_type):
|
|
@@ -204,13 +206,11 @@ class BaseModel(metaclass=ModelMeta):
|
|
|
else:
|
|
|
attr = op.__dict__
|
|
|
info['Transforms'].append({name: attr})
|
|
|
- arrange = self.test_transforms.arrange
|
|
|
- if arrange is not None:
|
|
|
- info['Transforms'].append({
|
|
|
- arrange.__class__.__name__: {
|
|
|
- 'mode': 'test'
|
|
|
- }
|
|
|
- })
|
|
|
+ info['Transforms'].append({
|
|
|
+ self._arrange.__name__: {
|
|
|
+ 'mode': 'test'
|
|
|
+ }
|
|
|
+ })
|
|
|
info['completed_epochs'] = self.completed_epochs
|
|
|
return info
|
|
|
|
|
@@ -267,11 +267,28 @@ class BaseModel(metaclass=ModelMeta):
|
|
|
open(osp.join(save_dir, '.success'), 'w').close()
|
|
|
logging.info("Model saved in {}.".format(save_dir))
|
|
|
|
|
|
+ def _build_transforms(self, trans, mode):
|
|
|
+ if isinstance(trans, list):
|
|
|
+ trans = Compose(trans)
|
|
|
+ if not isinstance(trans.transforms[0], DecodeImg):
|
|
|
+ trans.transforms.insert(0, DecodeImg())
|
|
|
+ if self._arrange is Arrange or not issubclass(self._arrange, Arrange):
|
|
|
+ raise ValueError(
|
|
|
+ "`self._arrange` must be set to a concrete Arrange type.")
|
|
|
+ if trans.arrange is None:
|
|
|
+ # For backward compatibility, we only set `trans.arrange`
|
|
|
+ # when it is not set by user.
|
|
|
+ trans.arrange = self._arrange(mode)
|
|
|
+ return trans
|
|
|
+
|
|
|
def build_data_loader(self,
|
|
|
dataset,
|
|
|
batch_size,
|
|
|
mode='train',
|
|
|
collate_fn=None):
|
|
|
+ # NOTE: Append `Arrange` to transforms
|
|
|
+ dataset.transforms = self._build_transforms(dataset.transforms, mode)
|
|
|
+
|
|
|
if dataset.num_samples < batch_size:
|
|
|
raise ValueError(
|
|
|
'The volume of dataset({}) must be larger than batch size({}).'
|
|
@@ -315,7 +332,7 @@ class BaseModel(metaclass=ModelMeta):
|
|
|
early_stop=False,
|
|
|
early_stop_patience=5,
|
|
|
use_vdl=True):
|
|
|
- self._check_transforms(train_dataset.transforms, 'train')
|
|
|
+ self._check_transforms(train_dataset.transforms)
|
|
|
|
|
|
# XXX: Hard-coding
|
|
|
if self.model_type == 'detector' and 'RCNN' in self.__class__.__name__ and train_dataset.pos_num < len(
|
|
@@ -351,6 +368,7 @@ class BaseModel(metaclass=ModelMeta):
|
|
|
|
|
|
self.train_data_loader = self.build_data_loader(
|
|
|
train_dataset, batch_size=train_batch_size, mode='train')
|
|
|
+ self._check_arrange(self.train_data_loader.dataset.transforms, 'train')
|
|
|
|
|
|
if eval_dataset is not None:
|
|
|
self.test_transforms = copy.deepcopy(eval_dataset.transforms)
|
|
@@ -493,7 +511,7 @@ class BaseModel(metaclass=ModelMeta):
|
|
|
|
|
|
assert criterion in {'l1_norm', 'fpgm'}, \
|
|
|
"Pruning criterion {} is not supported. Please choose from {'l1_norm', 'fpgm'}."
|
|
|
- self._check_transforms(dataset.transforms, 'eval')
|
|
|
+ self._check_transforms(dataset.transforms)
|
|
|
# XXX: Hard-coding
|
|
|
if self.model_type == 'detector':
|
|
|
self.net.eval()
|
|
@@ -681,16 +699,19 @@ class BaseModel(metaclass=ModelMeta):
|
|
|
|
|
|
return outputs
|
|
|
|
|
|
- def _check_transforms(self, transforms, mode):
|
|
|
- # NOTE: Check transforms and transforms.arrange and give user-friendly error messages.
|
|
|
- if not isinstance(transforms, paddlers.transforms.Compose):
|
|
|
- raise TypeError("`transforms` must be paddlers.transforms.Compose.")
|
|
|
+ def _check_transforms(self, transforms):
|
|
|
+ # NOTE: Check transforms
|
|
|
+ if not isinstance(transforms, Compose):
|
|
|
+ raise TypeError(
|
|
|
+ "`transforms` must be `paddlers.transforms.Compose`.")
|
|
|
+
|
|
|
+ def _check_arrange(self, transforms, mode):
|
|
|
arrange_obj = transforms.arrange
|
|
|
- if not isinstance(arrange_obj, paddlers.transforms.operators.Arrange):
|
|
|
+ if not isinstance(arrange_obj, Arrange):
|
|
|
raise TypeError("`transforms.arrange` must be an Arrange object.")
|
|
|
if arrange_obj.mode != mode:
|
|
|
raise ValueError(
|
|
|
- f"Incorrect arrange mode! Expected {mode} but got {arrange_obj.mode}."
|
|
|
+ f"Incorrect arrange mode! Expected {repr(mode)} but got {repr(arrange_obj.mode)}."
|
|
|
)
|
|
|
|
|
|
def run(self, net, inputs, mode):
|