torch_geometric.Index
- class Index(data: Any, *args: Any, dim_size: Optional[int] = None, is_sorted: bool = False, **kwargs: Any)[source]
Bases:
TensorA one-dimensional
indextensor with additional (meta)data attached.Indexis atorch.Tensorthat holds indices of shape[num_indices].While
Indexsub-classes a generaltorch.Tensor, it can hold additional (meta)data, i.e.:dim_size: The size of the underlying sparse vector size, i.e., the size of a dimension that can be indexed viaindex. By default, it is inferred asdim_size=index.max() + 1.is_sorted: Whether indices are sorted in ascending order.
Additionally,
Indexcaches data viaindptrfor fast CSR conversion in case its representation is sorted. Caches are filled based on demand (e.g., when callingIndex.get_indptr()), or when explicitly requested viaIndex.fill_cache_(), and are maintaned and adjusted over its lifespan.This representation ensures optimal computation in GNN message passing schemes, while preserving the ease-of-use of regular COO-based PyG workflows.
from torch_geometric import Index index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True) >>> Index([0, 1, 1, 2], dim_size=3, is_sorted=True) assert index.dim_size == 3 assert index.is_sorted # Flipping order: index.flip(0) >>> Index([[2, 1, 1, 0], dim_size=3) assert not index.is_sorted # Filtering: mask = torch.tensor([True, True, True, False]) index[:, mask] >>> Index([[0, 1, 1], dim_size=3, is_sorted=True) assert index.is_sorted
- validate() Index[source]
Validates the
Indexrepresentation.In particular, it ensures that
it only holds valid indices.
the sort order is correctly set.
- Return type:
- get_dim_size() int[source]
The size of the underlying sparse vector. Automatically computed and cached when not explicitly set.
- Return type:
- dim_resize_(dim_size: Optional[int]) Index[source]
Assigns or re-assigns the size of the underlying sparse vector.
- Return type: