Procházet zdrojové kódy

[Feat] Support ReduceOnPlateau (#107)

Lin Manhui před 2 roky
rodič
revize
3af03678d4
1 změnil soubory, kde provedl 6 přidání a 1 odebrání
  1. 6 1
      paddlers/tasks/base.py

+ 6 - 1
paddlers/tasks/base.py

@@ -374,7 +374,12 @@ class BaseModel(metaclass=ModelMeta):
                 lr = self.optimizer.get_lr()
                 lr = self.optimizer.get_lr()
                 if isinstance(self.optimizer._learning_rate,
                 if isinstance(self.optimizer._learning_rate,
                               paddle.optimizer.lr.LRScheduler):
                               paddle.optimizer.lr.LRScheduler):
-                    self.optimizer._learning_rate.step()
+                    # If ReduceOnPlateau is used as the scheduler, use the loss value as the metric.
+                    if isinstance(self.optimizer._learning_rate,
+                                  paddle.optimizer.lr.ReduceOnPlateau):
+                        self.optimizer._learning_rate.step(loss.item())
+                    else:
+                        self.optimizer._learning_rate.step()
 
 
                 train_avg_metrics.update(outputs)
                 train_avg_metrics.update(outputs)
                 outputs['lr'] = lr
                 outputs['lr'] = lr