|
@@ -49,7 +49,12 @@ class TrainingStats(object):
|
|
|
for k in stats.keys()
|
|
|
}
|
|
|
for k, v in self.meters.items():
|
|
|
- v.update(stats[k].numpy())
|
|
|
+ stat = stats[k]
|
|
|
+ if stat.ndim == 0:
|
|
|
+ stat = float(stat)
|
|
|
+ else:
|
|
|
+ stat = stat.numpy()
|
|
|
+ v.update(stat)
|
|
|
|
|
|
def get(self, extras=None):
|
|
|
stats = collections.OrderedDict()
|