Sfoglia il codice sorgente

Fix transform export attr error

Bobholamovic 2 anni fa
parent
commit
e7f2817c7a
2 ha cambiato i file con 26 aggiunte e 7 eliminazioni
  1. 4 1
      paddlers/tasks/base.py
  2. 22 6
      paddlers/transforms/operators.py

+ 4 - 1
paddlers/tasks/base.py

@@ -195,7 +195,10 @@ class BaseModel(metaclass=ModelMeta):
                 info['Transforms'] = list()
                 for op in self.test_transforms.transforms:
                     name = op.__class__.__name__
-                    attr = op.__dict__
+                    if hasattr(op, 'get_attrs_for_serialization'):
+                        attr = op.get_attrs_for_serialization()
+                    else:
+                        attr = op.__dict__
                     info['Transforms'].append({name: attr})
                 arrange = self.test_transforms.arrange
                 if arrange is not None:

+ 22 - 6
paddlers/transforms/operators.py

@@ -185,6 +185,9 @@ class Transform(object):
 
         return sample
 
+    def get_attrs_for_serialization(self):
+        return self.__dict__
+
 
 class DecodeImg(Transform):
     """
@@ -1970,17 +1973,21 @@ class AppendIndex(Transform):
 
     def __init__(self, index_type, band_indices=None, satellite=None, **kwargs):
         super(AppendIndex, self).__init__()
-        cls = getattr(indices, index_type)
-        if satellite is not None:
-            satellite_bands = getattr(satellites, satellite)
-            self._compute_index = cls(satellite_bands, **kwargs)
+        self.index_type = index_type
+        self.band_indices = band_indices
+        self.satellite = satellite
+        self._index_args = kwargs
+        cls = getattr(indices, self.index_type)
+        if self.satellite is not None:
+            satellite_bands = getattr(satellites, self.satellite)
+            self._compute_index = cls(satellite_bands, **self._index_args)
         else:
-            if band_indices is None:
+            if self.band_indices is None:
                 raise ValueError(
                     "At least one of `band_indices` and `satellite` must not be None."
                 )
             else:
-                self._compute_index = cls(band_indices, **kwargs)
+                self._compute_index = cls(self.band_indices, **self._index_args)
 
     def apply_im(self, image):
         index = self._compute_index(image)
@@ -1993,6 +2000,15 @@ class AppendIndex(Transform):
             sample['image2'] = self.apply_im(sample['image2'])
         return sample
 
+    def get_attrs_for_serialization(self):
+        return {
+            'index_type': self.index_type,
+            'band_indices': self.band_indices,
+            'satellite': self.satellite,
+            **
+            self._index_args
+        }
+
 
 class MatchRadiance(Transform):
     """