GRAPH 2

http://geometricdeeplearning.com/

https://github.com/rusty1s/pytorch_geometric

Data handling

A single graph = an instance oftorch_geometric.data.Data

  • data.pos: Node position matrix with shape[num_nodes, num_dimensions]

  • data.x: Node feature matrix with shape[num_nodes, num_node_features]

  • data.edge_attr: Edge feature matrix with shape[num_edges, num_edge_features

  • data.edge_index: Graph connectivity with shape[2, num_edges]

  • data.y: Target to train against (arbitrary shape)

Example

import torch
from torch_geometric.data import Data

edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index)
>>> Data(x=[3, 1], edge_index=[2, 4])

If edge_index is defined as list of index tuples, it should be transpose & contiguous.

edge_index = torch.tensor([[0, 1],
                           [1, 0],
                           [1, 2],
                           [2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index.t().contiguous()) # Attention
>>> Data(x=[3, 1], edge_index=[2, 4])

Own dataset

torch_geometric.data.Dataset

Folder

  • raw_dir = original dataset

  • processed_dir = processed dataset

Transform

  • transform = dynamically transforms before accessing (data augmentation)

  • pre_transform = transform before saving to disk (heavy precomputation, done once)

  • pre_filter = manually filter out data before saving (restriction of data being of a specific class.)

Last updated

Was this helpful?