Bobholamovic %!s(int64=2) %!d(string=hai) anos
pai
achega
1ef0067ad9

+ 1 - 0
examples/rs_research/config_utils.py

@@ -133,6 +133,7 @@ def parse_args(*args, **kwargs):
     # Global settings
     parser.add_argument('cmd', choices=['train', 'eval'])
     parser.add_argument('task', choices=['cd', 'clas', 'det', 'res', 'seg'])
+    parser.add_argument('--seed', type=int, default=None)
 
     # Data
     parser.add_argument('--datasets', type=dict, default={})

+ 7 - 0
examples/rs_research/run_task.py

@@ -15,7 +15,9 @@
 # limitations under the License.
 
 import os
+import random
 
+import numpy as np
 # Import cv2 and sklearn before paddlers to solve the
 # "ImportError: dlopen: cannot load any more object with static TLS" issue.
 import cv2
@@ -62,6 +64,11 @@ if __name__ == '__main__':
     cfg = parse_args()
     print(format_cfg(cfg))
 
+    if cfg['seed'] is not None:
+        random.seed(cfg['seed'])
+        np.random.seed(cfg['seed'])
+        paddle.seed(cfg['seed'])
+
     # Automatically download data
     if cfg['download_on']:
         paddlers.utils.download_and_decompress(

+ 1 - 0
test_tipc/config_utils.py

@@ -119,6 +119,7 @@ def parse_args(*args, **kwargs):
     # Global settings
     parser.add_argument('cmd', choices=['train', 'eval'])
     parser.add_argument('task', choices=['cd', 'clas', 'det', 'res', 'seg'])
+    parser.add_argument('--seed', type=int, default=None)
 
     # Data
     parser.add_argument('--datasets', type=dict, default={})

+ 7 - 0
test_tipc/run_task.py

@@ -1,7 +1,9 @@
 #!/usr/bin/env python
 
 import os
+import random
 
+import numpy as np
 # Import cv2 and sklearn before paddlers to solve the
 # "ImportError: dlopen: cannot load any more object with static TLS" issue.
 import cv2
@@ -46,6 +48,11 @@ if __name__ == '__main__':
     cfg = parse_args()
     print(format_cfg(cfg))
 
+    if cfg['seed'] is not None:
+        random.seed(cfg['seed'])
+        np.random.seed(cfg['seed'])
+        paddle.seed(cfg['seed'])
+
     # Automatically download data
     if cfg['download_on']:
         paddlers.utils.download_and_decompress(