From 63acb1001e201804649c7b8cf32285f5b78219ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luis=20M=C3=BCller?= Date: Fri, 5 Jan 2024 10:16:38 +0100 Subject: [PATCH 1/2] [FIX] Stable sorting in graphormer encoder --- graphgps/encoder/graphormer_encoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphgps/encoder/graphormer_encoder.py b/graphgps/encoder/graphormer_encoder.py index d56ce6a9..af4697f8 100644 --- a/graphgps/encoder/graphormer_encoder.py +++ b/graphgps/encoder/graphormer_encoder.py @@ -201,7 +201,7 @@ def add_graph_token(data, token): data.batch = torch.cat( [torch.arange(0, B, device=data.x.device, dtype=torch.long), data.batch] ) - data.batch, sort_idx = torch.sort(data.batch) + data.batch, sort_idx = torch.sort(data.batch, stable=False) data.x = data.x[sort_idx] return data From 107f1c1ad49af7a19c10637c8bc270dd7e719f45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luis=20M=C3=BCller?= Date: Fri, 5 Jan 2024 10:20:10 +0100 Subject: [PATCH 2/2] [FIX] Stable sorting in graphormer encoder activated --- graphgps/encoder/graphormer_encoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphgps/encoder/graphormer_encoder.py b/graphgps/encoder/graphormer_encoder.py index af4697f8..f2af15e8 100644 --- a/graphgps/encoder/graphormer_encoder.py +++ b/graphgps/encoder/graphormer_encoder.py @@ -201,7 +201,7 @@ def add_graph_token(data, token): data.batch = torch.cat( [torch.arange(0, B, device=data.x.device, dtype=torch.long), data.batch] ) - data.batch, sort_idx = torch.sort(data.batch, stable=False) + data.batch, sort_idx = torch.sort(data.batch, stable=True) data.x = data.x[sort_idx] return data