diff --git a/graphgps/encoder/graphormer_encoder.py b/graphgps/encoder/graphormer_encoder.py index d56ce6a9..f2af15e8 100644 --- a/graphgps/encoder/graphormer_encoder.py +++ b/graphgps/encoder/graphormer_encoder.py @@ -201,7 +201,7 @@ def add_graph_token(data, token): data.batch = torch.cat( [torch.arange(0, B, device=data.x.device, dtype=torch.long), data.batch] ) - data.batch, sort_idx = torch.sort(data.batch) + data.batch, sort_idx = torch.sort(data.batch, stable=True) data.x = data.x[sort_idx] return data