|
@@ -185,6 +185,9 @@ class Transform(object):
|
|
|
|
|
|
return sample
|
|
|
|
|
|
+ def get_attrs_for_serialization(self):
|
|
|
+ return self.__dict__
|
|
|
+
|
|
|
|
|
|
class DecodeImg(Transform):
|
|
|
"""
|
|
@@ -1970,17 +1973,21 @@ class AppendIndex(Transform):
|
|
|
|
|
|
def __init__(self, index_type, band_indices=None, satellite=None, **kwargs):
|
|
|
super(AppendIndex, self).__init__()
|
|
|
- cls = getattr(indices, index_type)
|
|
|
- if satellite is not None:
|
|
|
- satellite_bands = getattr(satellites, satellite)
|
|
|
- self._compute_index = cls(satellite_bands, **kwargs)
|
|
|
+ self.index_type = index_type
|
|
|
+ self.band_indices = band_indices
|
|
|
+ self.satellite = satellite
|
|
|
+ self._index_args = kwargs
|
|
|
+ cls = getattr(indices, self.index_type)
|
|
|
+ if self.satellite is not None:
|
|
|
+ satellite_bands = getattr(satellites, self.satellite)
|
|
|
+ self._compute_index = cls(satellite_bands, **self._index_args)
|
|
|
else:
|
|
|
- if band_indices is None:
|
|
|
+ if self.band_indices is None:
|
|
|
raise ValueError(
|
|
|
"At least one of `band_indices` and `satellite` must not be None."
|
|
|
)
|
|
|
else:
|
|
|
- self._compute_index = cls(band_indices, **kwargs)
|
|
|
+ self._compute_index = cls(self.band_indices, **self._index_args)
|
|
|
|
|
|
def apply_im(self, image):
|
|
|
index = self._compute_index(image)
|
|
@@ -1993,6 +2000,15 @@ class AppendIndex(Transform):
|
|
|
sample['image2'] = self.apply_im(sample['image2'])
|
|
|
return sample
|
|
|
|
|
|
+ def get_attrs_for_serialization(self):
|
|
|
+ return {
|
|
|
+ 'index_type': self.index_type,
|
|
|
+ 'band_indices': self.band_indices,
|
|
|
+ 'satellite': self.satellite,
|
|
|
+ **
|
|
|
+ self._index_args
|
|
|
+ }
|
|
|
+
|
|
|
|
|
|
class MatchRadiance(Transform):
|
|
|
"""
|