Parcourir la source

[Feat] Support ReduceOnPlateau (#107)

Lin Manhui il y a 2 ans
Parent
commit
3af03678d4
1 fichiers modifiés avec 6 ajouts et 1 suppressions
  1. 6 1
      paddlers/tasks/base.py

+ 6 - 1
paddlers/tasks/base.py

@@ -374,7 +374,12 @@ class BaseModel(metaclass=ModelMeta):
                 lr = self.optimizer.get_lr()
                 if isinstance(self.optimizer._learning_rate,
                               paddle.optimizer.lr.LRScheduler):
-                    self.optimizer._learning_rate.step()
+                    # If ReduceOnPlateau is used as the scheduler, use the loss value as the metric.
+                    if isinstance(self.optimizer._learning_rate,
+                                  paddle.optimizer.lr.ReduceOnPlateau):
+                        self.optimizer._learning_rate.step(loss.item())
+                    else:
+                        self.optimizer._learning_rate.step()
 
                 train_avg_metrics.update(outputs)
                 outputs['lr'] = lr