فهرست منبع

[Fix] Fix tar weight loading error when using mp

Bobholamovic 2 سال پیش
والد
کامیت
b250d8bb00
2فایلهای تغییر یافته به همراه9 افزوده شده و 7 حذف شده
  1. 1 7
      .gitignore
  2. 8 0
      paddlers/utils/download.py

+ 1 - 7
.gitignore

@@ -132,11 +132,5 @@ dmypy.json
 # Pyre type checker
 .pyre/
 
-# test data
-tutorials/train/change_detection/DataSet/
-tutorials/train/classification/DataSet/
-optic_disc_seg.tar
-optic_disc_seg/
-output/
-
+/tutorials/train/**/output/
 /log

+ 8 - 0
paddlers/utils/download.py

@@ -202,10 +202,13 @@ def download_and_decompress(url, path='.'):
     local_rank = paddle.distributed.get_rank()
     fname = osp.split(url)[-1]
     fullname = osp.join(path, fname)
+    pth_path = fullname + '.path'
 
     if nranks <= 1:
         dst_dir = url2dir(url, path)
         if dst_dir is not None:
+            with open(pth_path, 'w') as f:
+                f.write(dst_dir)
             fullname = dst_dir
     else:
         lock_path = fullname + '.lock'
@@ -215,9 +218,14 @@ def download_and_decompress(url, path='.'):
             if local_rank == 0:
                 dst_dir = url2dir(url, path)
                 if dst_dir is not None:
+                    with open(pth_path, 'w') as f:
+                        f.write(dst_dir)
                     fullname = dst_dir
                 os.remove(lock_path)
             else:
                 while os.path.exists(lock_path):
                     time.sleep(1)
+        if os.path.exists(pth_path):
+            with open(pth_path, 'r') as f:
+                fullname = next(f)
     return fullname