GRAPH 2
http://geometricdeeplearning.com/
https://github.com/rusty1s/pytorch_geometric
Data handling
A single graph = an instance oftorch_geometric.data.Data
data.pos
: Node position matrix with shape[num_nodes, num_dimensions]
data.x
: Node feature matrix with shape[num_nodes, num_node_features]
data.edge_attr
: Edge feature matrix with shape[num_edges, num_edge_features
data.edge_index
: Graph connectivity with shape[2, num_edges]
data.y
: Target to train against (arbitrary shape)
Example
import torch
from torch_geometric.data import Data
edge_index = torch.tensor([[0, 1, 1, 2],
[1, 0, 2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index)
>>> Data(x=[3, 1], edge_index=[2, 4])
If edge_index
is defined as list of index tuples, it should be transpose & contiguous.
edge_index = torch.tensor([[0, 1],
[1, 0],
[1, 2],
[2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index.t().contiguous()) # Attention
>>> Data(x=[3, 1], edge_index=[2, 4])
Own dataset
torch_geometric.data.Dataset
Folder
raw_dir
= original datasetprocessed_dir
= processed dataset
Transform
transform
= dynamically transforms before accessing (data augmentation)pre_transform
= transform before saving to disk (heavy precomputation, done once)pre_filter
= manually filter out data before saving (restriction of data being of a specific class.)
Last updated
Was this helpful?