Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
4297f49
init
yliu2-sc Feb 23, 2026
a6f758d
Fix AttributeError in add_node_attr when node type has no x attribute
yliu2-sc Feb 23, 2026
dc45b7e
updates
yliu2-sc Feb 24, 2026
cc90c6f
updates
yliu2-sc Feb 24, 2026
d4e97a1
rm claude
yliu2-sc Feb 24, 2026
8d366a3
update to sparse operations
yliu2-sc Feb 25, 2026
00667ee
update hop distance
yliu2-sc Feb 26, 2026
c34d0d4
simplify
yliu2-sc Feb 26, 2026
d34e700
transform
yliu2-sc Feb 27, 2026
88c9ca6
todo
yliu2-sc Feb 27, 2026
26688f6
some optimization
yliu2-sc Feb 27, 2026
b29f8d6
optim 2
yliu2-sc Feb 27, 2026
7171098
hop dist memory optim
yliu2-sc Mar 5, 2026
e017e6c
Add GraphTransformerEncoder adapted from RelGT's LocalModule
yliu2-sc Mar 6, 2026
5cb0450
comments
yliu2-sc Mar 6, 2026
7b0cb07
update with anchor based hop distance
yliu2-sc Mar 7, 2026
5d3c434
Merge branch 'yliu2/heterodata_to_seq' of github.com:Snapchat/GiGL in…
yliu2-sc Mar 7, 2026
506fe7d
update sequence to only seed
yliu2-sc Mar 7, 2026
7a83200
update
yliu2-sc Mar 7, 2026
57e8265
update anchor computation
yliu2-sc Mar 12, 2026
8e0e219
add pe, attention bias
yliu2-sc Mar 17, 2026
cd2e97d
rename variables, add comments, format
yliu2-sc Mar 18, 2026
21d2ba0
merge main
yliu2-sc Mar 20, 2026
81c5fea
address comments
yliu2-sc Mar 20, 2026
4bb14b0
address comments, type check, format
yliu2-sc Mar 23, 2026
4fb7703
Merge branch 'main' into gt_encoder
yliu2-sc Mar 23, 2026
2495a2e
Merge branch 'main' of github.com:Snapchat/GiGL into gt_encoder
yliu2-sc Mar 24, 2026
a338524
Merge branch 'gt_encoder' of github.com:Snapchat/GiGL into gt_encoder
yliu2-sc Mar 24, 2026
f4e3ea6
update tests
yliu2-sc Mar 24, 2026
2b3177a
type check
yliu2-sc Mar 24, 2026
b05c310
fix unit
yliu2-sc Mar 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
829 changes: 829 additions & 0 deletions gigl/src/common/models/graph_transformer/graph_transformer.py

Large diffs are not rendered by default.

25 changes: 14 additions & 11 deletions gigl/transforms/add_positional_encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def __repr__(self) -> str:

@functional_transform("add_hetero_hop_distance_encoding")
class AddHeteroHopDistanceEncoding(BaseTransform):
r"""Adds hop distance positional encoding as relative encoding (sparse).
r"""Adds hop distance positional encoding as relative encoding (sparse CSR).

For each pair of nodes (vi, vj), computes the shortest path distance p(vi, vj).
This captures structural proximity and can be used with a learnable embedding
Expand All @@ -236,12 +236,12 @@ class AddHeteroHopDistanceEncoding(BaseTransform):
Based on the approach from `"Do Transformers Really Perform Bad for Graph
Representation?" <https://arxiv.org/abs/2106.05234>`_ (Graphormer).

The output is a **sparse matrix** where:
The output is a **sparse CSR matrix** where:
- Reachable pairs (i, j) within h_max hops have value = hop distance (1 to h_max)
- Unreachable pairs have value = 0 (not stored in sparse tensor)
- Self-loops (diagonal) are not stored (distance to self is implicitly 0)

This sparse representation avoids GPU memory blowup for large graphs.
CSR format is used for efficient row-based lookups during sequence building.

Args:
h_max (int): Maximum hop distance to consider. Distances > h_max
Expand Down Expand Up @@ -278,12 +278,13 @@ def forward(self, data: HeteroData) -> HeteroData:
num_edges = edge_index.size(1)

if num_nodes == 0 or num_edges == 0:
# Handle empty graph case - return empty sparse tensor
empty_sparse = torch.sparse_coo_tensor(
torch.zeros((2, 0), dtype=torch.long),
# Handle empty graph case - return empty sparse CSR tensor
empty_sparse = torch.sparse_csr_tensor(
torch.zeros(num_nodes + 1, dtype=torch.long),
torch.zeros(0, dtype=torch.long),
torch.zeros(0, dtype=torch.float),
size=(num_nodes, num_nodes),
).coalesce()
)
data[self.attr_name] = empty_sparse
return data

Expand Down Expand Up @@ -420,19 +421,21 @@ def forward(self, data: HeteroData) -> HeteroData:
dist_cols = torch.zeros(0, dtype=torch.long, device=device)
dist_vals = torch.zeros(0, dtype=torch.float, device=device)

# Create sparse distance matrix
# Create sparse distance matrix in CSR format directly
# CSR is more efficient for row-based lookups in _lookup_csr_values
# Unreachable pairs have value 0 (not stored)
# Reachable pairs have value = hop distance (1 to h_max)
dist_sparse = torch.sparse_coo_tensor(
dist_coo = torch.sparse_coo_tensor(
torch.stack([dist_rows, dist_cols]),
dist_vals,
size=(num_nodes, num_nodes),
).coalesce()
dist_sparse = dist_coo.to_sparse_csr()
del dist_coo

# Store sparse pairwise distance matrix as graph-level attribute
# Access via: data.hop_distance or data['hop_distance']
# Usage in attention: dist = data.hop_distance.to_dense() for small graphs,
# or use sparse indexing for memory efficiency
# Usage in attention: use sparse indexing for memory efficiency
# Note: Node ordering follows data.to_homogeneous() order (by node_type alphabetically)
data[self.attr_name] = dist_sparse

Expand Down
Loading