Bobholamovic 2 gadi atpakaļ
vecāks
revīzija
a373e11835

+ 1 - 1
docs/intro/transforms.md

@@ -12,7 +12,7 @@ PaddleRS对不同遥感任务需要的数据预处理/数据增强(合称为
 | RandomResizeByShort  | 随机调整输入影像大小,保持纵横比不变(根据短边计算缩放系数)。 | 所有任务 | ... |
 | ResizeByLong         | 调整输入影像大小,保持纵横比不变(根据长边计算缩放系数)。 | 所有任务 | ... |
 | RandomHorizontalFlip | 随机水平翻转输入影像。 | 所有任务 | ... |
-| RandomVerticalFlip   | 随机直翻转输入影像。 | 所有任务 | ... |
+| RandomVerticalFlip   | 随机直翻转输入影像。 | 所有任务 | ... |
 | Normalize            | 对输入影像应用标准化。 | 所有任务 | ... |
 | CenterCrop           | 对输入影像进行中心裁剪。 | 所有任务 | ... |
 | RandomCrop           | 对输入影像进行随机中心裁剪。 | 所有任务 | ... |

+ 13 - 2
examples/rs_research/custom_trainer.py

@@ -12,6 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import inspect
+
 import paddle
 import paddlers
 from paddlers.tasks.change_detector import BaseChangeDetector
@@ -27,12 +29,21 @@ def make_trainer(net_type, *args, **kwargs):
                    use_mixed_loss=False,
                    losses=None,
                    **params):
-        super().__init__(
+        sig = inspect.signature(net_type.__init__)
+        net_params = {
+            k: p.default
+            for k, p in sig.parameters.items() if not p.default is p.empty
+        }
+        net_params.pop('self', None)
+        net_params.pop('num_classes', None)
+        net_params.update(params)
+
+        super(trainer_type, self).__init__(
             model_name=net_type.__name__,
             num_classes=num_classes,
             use_mixed_loss=use_mixed_loss,
             losses=losses,
-            **params)
+            **net_params)
 
     if not issubclass(net_type, paddle.nn.Layer):
         raise TypeError("net must be a subclass of paddle.nn.Layer")

+ 2 - 3
examples/rs_research/train_cd.py

@@ -76,8 +76,7 @@ test_dataset = pdrs.datasets.CDDataset(
 
 # 构建自定义模型CustomModel并为其自动生成训练器
 # make_trainer()的首个参数为模型类型,剩余参数为模型构造所需参数
-# 这里使用默认参数构造
-model = make_trainer(CustomModel)
+model = make_trainer(CustomModel, in_channels=3)
 
 # 构建学习率调度器
 # 使用定步长学习率衰减策略
@@ -86,7 +85,7 @@ lr_scheduler = paddle.optimizer.lr.StepDecay(
 
 # 构建优化器
 optimizer = paddle.optimizer.Adam(
-    model.net.parameters(), learning_rate=lr_scheduler)
+    parameters=model.net.parameters(), learning_rate=lr_scheduler)
 
 # 执行模型训练
 model.train(

+ 1 - 3
paddlers/tasks/change_detector.py

@@ -52,9 +52,7 @@ class BaseChangeDetector(BaseModel):
         if 'with_net' in self.init_params:
             del self.init_params['with_net']
         super(BaseChangeDetector, self).__init__('change_detector')
-        if model_name not in __all__:
-            raise ValueError("ERROR: There is no model named {}.".format(
-                model_name))
+
         self.model_name = model_name
         self.num_classes = num_classes
         self.use_mixed_loss = use_mixed_loss