custom_model.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import paddle
  2. import paddle.nn as nn
  3. import paddle.nn.functional as F
  4. import paddlers
  5. from paddlers.rs_models.cd import BIT
  6. from attach_tools import Attach
  7. attach = Attach.to(paddlers.rs_models.cd)
  8. @attach
  9. class IterativeBIT(nn.Layer):
  10. def __init__(self, num_iters=1, gamma=0.1, num_classes=2, bit_kwargs=None):
  11. super().__init__()
  12. if num_iters <= 0:
  13. raise ValueError(
  14. f"`num_iters` should have positive value, but got {num_iters}.")
  15. self.num_iters = num_iters
  16. self.gamma = gamma
  17. if bit_kwargs is None:
  18. bit_kwargs = dict()
  19. if 'num_classes' in bit_kwargs:
  20. raise KeyError("'num_classes' should not be set in `bit_kwargs`.")
  21. bit_kwargs['num_classes'] = num_classes
  22. self.bit = BIT(**bit_kwargs)
  23. def forward(self, t1, t2):
  24. rate_map = self._init_rate_map(t1.shape)
  25. for it in range(self.num_iters):
  26. # Construct inputs
  27. x1 = self._constr_iter_input(t1, rate_map)
  28. x2 = self._constr_iter_input(t2, rate_map)
  29. # Get logits
  30. logits_list = self.bit(x1, x2)
  31. # Construct rate map
  32. prob_map = F.softmax(logits_list[0], axis=1)
  33. rate_map = self._constr_rate_map(prob_map)
  34. return logits_list
  35. def _constr_iter_input(self, im, rate_map):
  36. return paddle.concat([im, rate_map], axis=1)
  37. def _init_rate_map(self, im_shape):
  38. b, _, h, w = im_shape
  39. return paddle.zeros((b, 1, h, w))
  40. def _constr_rate_map(self, prob_map):
  41. if prob_map.shape[1] != 2:
  42. raise ValueError(
  43. f"`prob_map.shape[1]` must be 2, but got {prob_map.shape[1]}.")
  44. return (prob_map[:, 1:2] * self.gamma)