@@ -30,7 +30,7 @@ def batch_p_dist(x, y, p=2):
"""
x = x.unsqueeze(1)
diff = x - y
- return paddle.norm(diff, p=p, axis=list(range(2, diff.dim())))
+ return paddle.linalg.vector_norm(diff, p=p, axis=list(range(2, diff.dim())))
@register