浏览代码

[Fix] Fix mixed loss and rename changedetector (#47)

Lin Manhui 3 年之前
父节点
当前提交
d0f7293481

+ 2 - 2
paddlers/tasks/__init__.py

@@ -14,7 +14,7 @@
 
 from .object_detector import *
 from .segmenter import *
-from .changedetector import *
+from .change_detector import *
 from .classifier import *
 from .load_model import load_model
-from .imagerestorer import *
+from .image_restorer import *

+ 3 - 3
paddlers/tasks/changedetector.py → paddlers/tasks/change_detector.py

@@ -828,7 +828,7 @@ class DSIFN(BaseChangeDetector):
                 'coef': [1.0] * 5
             }
         else:
-            return super().default_loss()
+            raise ValueError(f"Currently `use_mixed_loss` must be set to False for {self.__class__}")
 
 
 class DSAMNet(BaseChangeDetector):
@@ -860,7 +860,7 @@ class DSAMNet(BaseChangeDetector):
                 'coef': [1.0, 0.05, 0.05]
             }
         else:
-            return super().default_loss()
+            raise ValueError(f"Currently `use_mixed_loss` must be set to False for {self.__class__}")
 
 
 class ChangeStar(BaseChangeDetector):
@@ -892,4 +892,4 @@ class ChangeStar(BaseChangeDetector):
                 'coef': [1.0] * 4
             }
         else:
-            return super().default_loss()
+            raise ValueError(f"Currently `use_mixed_loss` must be set to False for {self.__class__}")

+ 0 - 0
paddlers/tasks/imagerestorer.py → paddlers/tasks/image_restorer.py


+ 10 - 2
paddlers/tasks/utils/seg_metrics.py

@@ -15,13 +15,18 @@
 import numpy as np
 import paddle
 
+import paddlers.models.ppseg as paddleseg
+
 
 def loss_computation(logits_list, labels, losses):
     loss_list = []
     for i in range(len(logits_list)):
         logits = logits_list[i]
         loss_i = losses['types'][i]
-        loss_list.append(losses['coef'][i] * loss_i(logits, labels))
+        if isinstance(loss_i, paddleseg.models.MixedLoss):
+            loss_list.append(losses['coef'][i] * sum(loss_i(logits, labels)))
+        else:
+            loss_list.append(losses['coef'][i] * loss_i(logits, labels))
 
     return loss_list
 
@@ -32,7 +37,10 @@ def multitask_loss_computation(logits_list, labels_list, losses):
         logits = logits_list[i]
         labels = labels_list[i]
         loss_i = losses['types'][i]
-        loss_list.append(losses['coef'][i] * loss_i(logits, labels))
+        if isinstance(loss_i, paddleseg.models.MixedLoss):
+            loss_list.append(losses['coef'][i] * sum(loss_i(logits, labels)))
+        else:
+            loss_list.append(losses['coef'][i] * loss_i(logits, labels))
 
     return loss_list
 

+ 1 - 1
tutorials/train/change_detection/bit.py

@@ -67,7 +67,7 @@ eval_dataset = pdrs.datasets.CDDataset(
 
 # 使用默认参数构建BIT模型
 # 目前已支持的模型请参考:https://github.com/PaddleCV-SIG/PaddleRS/blob/develop/docs/apis/model_zoo.md
-# 模型输入参数请参考:https://github.com/PaddleCV-SIG/PaddleRS/blob/develop/paddlers/tasks/changedetector.py
+# 模型输入参数请参考:https://github.com/PaddleCV-SIG/PaddleRS/blob/develop/paddlers/tasks/change_detector.py
 model = pdrs.tasks.BIT()
 
 # 执行模型训练

+ 1 - 1
tutorials/train/change_detection/cdnet.py

@@ -67,7 +67,7 @@ eval_dataset = pdrs.datasets.CDDataset(
 
 # 使用默认参数构建CDNet模型
 # 目前已支持的模型请参考:https://github.com/PaddleCV-SIG/PaddleRS/blob/develop/docs/apis/model_zoo.md
-# 模型输入参数请参考:https://github.com/PaddleCV-SIG/PaddleRS/blob/develop/paddlers/tasks/changedetector.py
+# 模型输入参数请参考:https://github.com/PaddleCV-SIG/PaddleRS/blob/develop/paddlers/tasks/change_detector.py
 model = pdrs.tasks.CDNet()
 
 # 执行模型训练

+ 1 - 1
tutorials/train/change_detection/dsamnet.py

@@ -67,7 +67,7 @@ eval_dataset = pdrs.datasets.CDDataset(
 
 # 使用默认参数构建DSAMNet模型
 # 目前已支持的模型请参考:https://github.com/PaddleCV-SIG/PaddleRS/blob/develop/docs/apis/model_zoo.md
-# 模型输入参数请参考:https://github.com/PaddleCV-SIG/PaddleRS/blob/develop/paddlers/tasks/changedetector.py
+# 模型输入参数请参考:https://github.com/PaddleCV-SIG/PaddleRS/blob/develop/paddlers/tasks/change_detector.py
 model = pdrs.tasks.DSAMNet()
 
 # 执行模型训练

+ 1 - 1
tutorials/train/change_detection/dsifn.py

@@ -67,7 +67,7 @@ eval_dataset = pdrs.datasets.CDDataset(
 
 # 使用默认参数构建DSIFN模型
 # 目前已支持的模型请参考:https://github.com/PaddleCV-SIG/PaddleRS/blob/develop/docs/apis/model_zoo.md
-# 模型输入参数请参考:https://github.com/PaddleCV-SIG/PaddleRS/blob/develop/paddlers/tasks/changedetector.py
+# 模型输入参数请参考:https://github.com/PaddleCV-SIG/PaddleRS/blob/develop/paddlers/tasks/change_detector.py
 model = pdrs.tasks.DSIFN()
 
 # 执行模型训练

+ 1 - 1
tutorials/train/change_detection/fc_ef.py

@@ -67,7 +67,7 @@ eval_dataset = pdrs.datasets.CDDataset(
 
 # 使用默认参数构建FC-EF模型
 # 目前已支持的模型请参考:https://github.com/PaddleCV-SIG/PaddleRS/blob/develop/docs/apis/model_zoo.md
-# 模型输入参数请参考:https://github.com/PaddleCV-SIG/PaddleRS/blob/develop/paddlers/tasks/changedetector.py
+# 模型输入参数请参考:https://github.com/PaddleCV-SIG/PaddleRS/blob/develop/paddlers/tasks/change_detector.py
 model = pdrs.tasks.FCEarlyFusion()
 
 # 执行模型训练

+ 1 - 1
tutorials/train/change_detection/fc_siam_conc.py

@@ -67,7 +67,7 @@ eval_dataset = pdrs.datasets.CDDataset(
 
 # 使用默认参数构建FC-Siam-conc模型
 # 目前已支持的模型请参考:https://github.com/PaddleCV-SIG/PaddleRS/blob/develop/docs/apis/model_zoo.md
-# 模型输入参数请参考:https://github.com/PaddleCV-SIG/PaddleRS/blob/develop/paddlers/tasks/changedetector.py
+# 模型输入参数请参考:https://github.com/PaddleCV-SIG/PaddleRS/blob/develop/paddlers/tasks/change_detector.py
 model = pdrs.tasks.FCSiamConc()
 
 # 执行模型训练

+ 1 - 1
tutorials/train/change_detection/fc_siam_diff.py

@@ -67,7 +67,7 @@ eval_dataset = pdrs.datasets.CDDataset(
 
 # 使用默认参数构建FC-Siam-diff模型
 # 目前已支持的模型请参考:https://github.com/PaddleCV-SIG/PaddleRS/blob/develop/docs/apis/model_zoo.md
-# 模型输入参数请参考:https://github.com/PaddleCV-SIG/PaddleRS/blob/develop/paddlers/tasks/changedetector.py
+# 模型输入参数请参考:https://github.com/PaddleCV-SIG/PaddleRS/blob/develop/paddlers/tasks/change_detector.py
 model = pdrs.tasks.FCSiamDiff()
 
 # 执行模型训练

+ 1 - 1
tutorials/train/change_detection/snunet.py

@@ -67,7 +67,7 @@ eval_dataset = pdrs.datasets.CDDataset(
 
 # 使用默认参数构建SNUNet模型
 # 目前已支持的模型请参考:https://github.com/PaddleCV-SIG/PaddleRS/blob/develop/docs/apis/model_zoo.md
-# 模型输入参数请参考:https://github.com/PaddleCV-SIG/PaddleRS/blob/develop/paddlers/tasks/changedetector.py
+# 模型输入参数请参考:https://github.com/PaddleCV-SIG/PaddleRS/blob/develop/paddlers/tasks/change_detector.py
 model = pdrs.tasks.SNUNet()
 
 # 执行模型训练

+ 1 - 1
tutorials/train/change_detection/stanet.py

@@ -67,7 +67,7 @@ eval_dataset = pdrs.datasets.CDDataset(
 
 # 使用默认参数构建STANet模型
 # 目前已支持的模型请参考:https://github.com/PaddleCV-SIG/PaddleRS/blob/develop/docs/apis/model_zoo.md
-# 模型输入参数请参考:https://github.com/PaddleCV-SIG/PaddleRS/blob/develop/paddlers/tasks/changedetector.py
+# 模型输入参数请参考:https://github.com/PaddleCV-SIG/PaddleRS/blob/develop/paddlers/tasks/change_detector.py
 model = pdrs.tasks.STANet()
 
 # 执行模型训练