|
@@ -21,6 +21,7 @@ import paddle
|
|
|
import paddle.nn.functional as F
|
|
|
from paddle.static import InputSpec
|
|
|
import paddlers.models.ppseg as paddleseg
|
|
|
+import paddlers.models.seg as seg
|
|
|
import paddlers
|
|
|
from paddlers.transforms import arrange_transforms
|
|
|
from paddlers.utils import get_single_card_bs, DisablePrint
|
|
@@ -44,7 +45,7 @@ class BaseSegmenter(BaseModel):
|
|
|
del self.init_params['with_net']
|
|
|
super(BaseSegmenter, self).__init__('segmenter')
|
|
|
if not hasattr(paddleseg.models, model_name) and \
|
|
|
- not hasattr(paddleseg.rs_models, model_name):
|
|
|
+ not hasattr(seg.models, model_name):
|
|
|
raise Exception("ERROR: There's no model named {}.".format(
|
|
|
model_name))
|
|
|
self.model_name = model_name
|
|
@@ -60,7 +61,7 @@ class BaseSegmenter(BaseModel):
|
|
|
def build_net(self, **params):
|
|
|
# TODO: when using paddle.utils.unique_name.guard,
|
|
|
# DeepLabv3p and HRNet will raise a error
|
|
|
- net = dict(paddleseg.models.__dict__, **paddleseg.rs_models.__dict__)[self.model_name](
|
|
|
+ net = dict(paddleseg.models.__dict__, **seg.models.__dict__)[self.model_name](
|
|
|
num_classes=self.num_classes, **params)
|
|
|
return net
|
|
|
|