Node Classification with GATΒΆ
>>> from deepgnn.graph_engine.data.cora import CoraFull
>>> CoraFull("/tmp/cora/")
<deepgnn.graph_engine.data.cora.CoraFull object at ...>
>>> import argparse
>>> import numpy as np
>>> import tensorflow as tf
>>> from dataclasses import dataclass
>>> from typing import Dict, List, Union, Callable, Any, Tuple
>>> from contextlib import closing
>>> from deepgnn import str2list_int, setup_default_logging_config
>>> from deepgnn.graph_engine import Graph, graph_ops
>>> from deepgnn.graph_engine import (
... SamplingStrategy,
... GENodeSampler,
... RangeNodeSampler,
... FileNodeSampler,
... BackendOptions,
... create_backend,
... )
>>> from deepgnn.tf import common
>>> from deepgnn.tf.nn.gat_conv import GATConv
>>> from deepgnn.tf.nn.metrics import masked_accuracy, masked_softmax_cross_entropy
>>> from deepgnn.tf.common.dataset import create_tf_dataset, get_distributed_dataset
>>> from deepgnn.tf.common.trainer_factory import get_trainer
>>> @dataclass
... class GATQueryParameter:
... neighbor_edge_types: np.array
... feature_idx: int
... feature_dim: int
... label_idx: int
... label_dim: int
... feature_type: np.dtype = np.float32
... label_type: np.dtype = np.float32
... num_hops: int = 2
>>> class GATQuery:
... """Graph Query: get sub graph for GAT training"""
...
... def __init__(self, param: GATQueryParameter):
... self.param = param
... self.label_meta = np.array([[param.label_idx, param.label_dim]], np.int32)
... self.feat_meta = np.array([[param.feature_idx, param.feature_dim]], np.int32)
...
... def query_training(
... self, graph: Graph, inputs: np.array, return_shape: bool = False
... ):
... nodes, edges, src_idx = graph_ops.sub_graph(
... graph=graph,
... src_nodes=inputs,
... edge_types=self.param.neighbor_edge_types,
... num_hops=self.param.num_hops,
... self_loop=True,
... undirected=True,
... return_edges=True,
... )
... input_mask = np.zeros(nodes.size, np.bool_)
... input_mask[src_idx] = True
...
... feat = graph.node_features(nodes, self.feat_meta, self.param.feature_type)
... label = graph.node_features(nodes, self.label_meta, self.param.label_type)
... label = label.astype(np.int32)
...
... edges_value = np.ones(edges.shape[0], np.float32)
... adj_shape = np.array([nodes.size, nodes.size], np.int64)
... graph_tensor = (nodes, feat, input_mask, label, edges, edges_value, adj_shape)
... if return_shape:
... # fmt: off
... # N is the number of `nodes`, which is variable because `inputs` nodes are different.
... N = None
... shapes = (
... [N], # Nodes
... [N, self.param.feature_dim], # feat
... [N], # input_mask
... [N, self.param.label_dim], # label
... [None, 2], # edges
... [None], # edges_value
... [2] # adj_shape
... )
... # fmt: on
... return graph_tensor, shapes
...
... return graph_tensor
>>> class GAT(tf.keras.Model):
... """ GAT Model (supervised)"""
...
... def __init__(
... self,
... head_num: List[int] = [8, 1],
... hidden_dim: int = 8,
... num_classes: int = -1,
... ffd_drop: float = 0.0,
... attn_drop: float = 0.0,
... l2_coef: float = 0.0005,
... ):
... super().__init__()
... self.num_classes = num_classes
... self.l2_coef = l2_coef
...
... self.out_dim = num_classes
...
... self.input_layer = GATConv(
... attn_heads=head_num[0],
... out_dim=hidden_dim,
... act=tf.nn.elu,
... in_drop=ffd_drop,
... coef_drop=attn_drop,
... attn_aggregate="concat",
... )
... ## TODO: support hidden layer
... assert len(head_num) == 2
... self.out_layer = GATConv(
... attn_heads=head_num[1],
... out_dim=self.out_dim,
... act=None,
... in_drop=ffd_drop,
... coef_drop=attn_drop,
... attn_aggregate="average",
... )
...
... def forward(self, feat, bias_mat, training):
... h_1 = self.input_layer([feat, bias_mat], training=training)
... out = self.out_layer([h_1, bias_mat], training=training)
... #tf.compat.v1.logging.info("h_1 {}, out shape {}".format(h_1.shape, out.shape))
... return out
...
... def call(self, inputs, training=True):
... # inputs: nodes feat mask labels edges edges_value adj_shape
... # shape: [N] [N, F] [N] [N] [num_e, 2] [num_e] [2]
... nodes, feat, mask, labels, edges, edges_value, adj_shape = inputs
...
... # bias_mat = -1e9 * (1.0 - adj)
... sp_adj = tf.SparseTensor(edges, edges_value, adj_shape)
... logits = self.forward(feat, sp_adj, training)
...
... ## embedding results
... self.src_emb = tf.boolean_mask(logits, mask)
... self.src_nodes = tf.boolean_mask(nodes, mask)
...
... labels = tf.one_hot(labels, self.num_classes)
... logits = tf.reshape(logits, [-1, self.num_classes])
... labels = tf.reshape(labels, [-1, self.num_classes])
... mask = tf.reshape(mask, [-1])
...
... ## loss
... xent_loss = masked_softmax_cross_entropy(logits, labels, mask)
... loss = xent_loss + self.l2_loss()
...
... ## metric
... acc = masked_accuracy(logits, labels, mask)
... return logits, loss, {"accuracy": acc}
...
... def l2_loss(self):
... vs = []
... for v in self.trainable_variables:
... vs.append(tf.nn.l2_loss(v))
... lossL2 = tf.add_n(vs) * self.l2_coef
... return lossL2
...
... def train_step(self, data: dict):
... """override base train_step."""
... with tf.GradientTape() as tape:
... _, loss, metrics = self(data, training=True)
...
... grads = tape.gradient(loss, self.trainable_variables)
... self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
... result = {"loss": loss}
... result.update(metrics)
... return result
...
... def test_step(self, data: dict):
... """override base test_step."""
... _, loss, metrics = self(data, training=False)
... result = {"loss": loss}
... result.update(metrics)
... return result
...
... def predict_step(self, data: dict):
... """override base predict_step."""
... self(data, training=False)
... return [self.src_nodes, self.src_emb]
>>> def build_model(param):
... p = GATQueryParameter(
... neighbor_edge_types=np.array(param.neighbor_edge_types, np.int32),
... feature_idx=param.feature_idx,
... feature_dim=param.feature_dim,
... label_idx=param.label_idx,
... label_dim=param.label_dim,
... num_hops=len(param.head_num),
... )
... query_obj = GATQuery(p)
...
... model = GAT(
... head_num=param.head_num,
... hidden_dim=param.hidden_dim,
... num_classes=param.num_classes,
... ffd_drop=param.ffd_drop,
... attn_drop=param.attn_drop,
... l2_coef=param.l2_coef,
... )
...
... return model, query_obj
>>> def define_param_gat(parser):
... parser.add_argument("--batch_size", type=int, default=16, help="mini-batch size")
... parser.add_argument("--epochs", type=int, default=200, help="num of epochs for training")
... parser.add_argument("--learning_rate", type=float, default=0.005, help="learning rate")
...
... # GAT Model Parameters.
... parser.add_argument("--head_num", type=str2list_int, default="8,1", help="the number of attention headers.")
... parser.add_argument("--hidden_dim", type=int, default=8, help="hidden layer dimension.")
... parser.add_argument("--num_classes", type=int, default=-1, help="number of classes for category")
... parser.add_argument("--ffd_drop", type=float, default=0.0, help="feature dropout rate.")
... parser.add_argument("--attn_drop", type=float, default=0.0, help="attention layer dropout rate.")
... parser.add_argument("--l2_coef", type=float, default=0.0005, help="l2 loss")
...
... ## training node types.
... parser.add_argument("--node_types", type=str2list_int, default="0", help="Graph Node for training.")
... ## evaluate node files.
... parser.add_argument("--evaluate_node_files", type=str, help="evaluate node file list.")
... ## inference node id
... parser.add_argument("--inf_min_id", type=int, default=0, help="inferece min node id.")
... parser.add_argument("--inf_max_id", type=int, default=-1, help="inference max node id.")
...
... parser.add_argument(
... "--distributed_strategy",
... type=str,
... default=None,
... choices=[None, "Mirrored", "MultiWorkerMirrored"],
... help="Distributed strategies to use.",
... )
... def register_gat_query_param(parser):
... group = parser.add_argument_group("GAT Query Parameters")
... group.add_argument("--neighbor_edge_types", type=str2list_int, default="0", help="Graph Edge for attention encoder.",)
... group.add_argument("--feature_idx", type=int, default=0, help="feature index.")
... group.add_argument("--feature_dim", type=int, default=16, help="feature dim.")
... group.add_argument("--label_idx", type=int, default=1, help="label index.")
... group.add_argument("--label_dim", type=int, default=1, help="label dim.")
... register_gat_query_param(parser)
>>> def run_train(param, trainer, query, model, tf1_mode, backend):
... tf_dataset, steps_per_epoch = create_tf_dataset(
... sampler_class=GENodeSampler,
... query_fn=query.query_training,
... backend=backend,
... node_types=np.array(param.node_types, dtype=np.int32),
... batch_size=param.batch_size,
... num_workers=trainer.worker_size,
... worker_index=trainer.task_index,
... strategy=SamplingStrategy.RandomWithoutReplacement,
... )
...
... distributed_dataset = get_distributed_dataset(
... # NOTE: here we flatten all the epochs into 1 to increase performance.
... lambda ctx: tf_dataset.repeat(param.epochs)
... )
...
... # we need to make sure the steps_per_epoch are provided in distributed dataset.
... assert steps_per_epoch is not None or param.steps_per_epoch is not None
... # Since we flatten the dataset to len(dataset) * param.epochs,
... # we alos need to update steps_per_epoch.
... steps_per_epoch = param.epochs * (steps_per_epoch or param.steps_per_epoch)
...
... if tf1_mode:
... opt = tf.compat.v1.train.AdamOptimizer(param.learning_rate * trainer.lr_scaler)
... else:
... opt = tf.keras.optimizers.Adam(
... learning_rate=param.learning_rate * trainer.lr_scaler
... )
...
... trainer.train(
... dataset=distributed_dataset,
... model=model,
... optimizer=opt,
... epochs=1,
... steps_per_epoch=steps_per_epoch,
... )
>>> try:
... define_param_base
... except NameError:
... define_param_base = define_param_gat
>>> MODEL_DIR = f"tmp/gat_{np.random.randint(9999999)}"
>>> arg_list = [
... "--data_dir", "/tmp/cora",
... "--mode", "train",
... # "--trainer", "hvd",
... "--seed", "123",
... "--eager",
... "--log_save_steps", "1",
... "--backend", "snark",
... "--graph_type", "local",
... "--converter", "skip",
... # "--sample_file", "/tmp/cora/train.nodes",
... # "--node_type", "0",
... "--neighbor_edge_types", "0",
... "--feature_idx", "0",
... "--feature_dim", "1433",
... "--label_idx", "1",
... "--label_dim", "1",
... "--num_classes", "7",
... "--batch_size", "140",
... "--epochs", "20",
... "--learning_rate", "0.005",
... "--l2_coef", "0.0005",
... "--attn_drop", "0.6",
... "--ffd_drop", "0.6",
... "--head_num", "8,1",
... "--hidden_dim", "8",
... "--model_dir", MODEL_DIR,
... # "--metric_dir", MODEL_DIR,
... # "--save_path", MODEL_DIR,
... ]
>>> def define_param_wrap(define_param):
... def define_param_new(parser):
... define_param(parser)
... parse_args = parser.parse_args
... parser.parse_args = lambda: parse_args(arg_list)
... return define_param_new
>>> define_param_gat = define_param_wrap(define_param_base)
>>> def _main():
... # setup default logging component.
... setup_default_logging_config(enable_telemetry=True)
...
... parser = argparse.ArgumentParser(
... formatter_class=argparse.ArgumentDefaultsHelpFormatter, allow_abbrev=False
... )
... common.args.import_default_parameters(parser)
... define_param_gat(parser)
...
... param = parser.parse_args()
... common.args.log_all_parameters(param)
...
... trainer = get_trainer(param)
...
... backend = create_backend(BackendOptions(param), is_leader=(trainer.task_index == 0))
...
... def run(tf1_mode=False):
... model, query = build_model(param)
... if param.mode == common.args.TrainMode.TRAIN:
... run_train(param, trainer, query, model, tf1_mode, backend)
... elif param.mode == common.args.TrainMode.EVALUATE:
... run_eval(param, trainer, query, model, backend)
... elif param.mode == common.args.TrainMode.INFERENCE:
... run_inference(param, trainer, query, model, backend)
...
... with closing(backend):
... if param.eager:
... strategy = None
... if param.distributed_strategy == "Default":
... strategy = tf.distribute.get_strategy()
... elif param.distributed_strategy == "Mirrored":
... strategy = tf.distribute.MirroredStrategy()
... elif param.distributed_strategy == "MultiWorkerMirrored":
... strategy = tf.distribute.MultiWorkerMirroredStrategy()
...
... if strategy:
... with strategy.scope():
... run()
... else:
... run()
... else:
... with tf.Graph().as_default():
... trainer.set_random_seed(param.seed)
... with trainer.tf_device():
... run(tf1_mode=True)
>>> _main()