图神经网络拆分学习#

以下代码仅作为示例,请勿在生产环境直接使用。

初始化#

创建alice和bob两个参与方。

[1]:
import secretflow as sf

# In case you got a running secetflow runtime already.
sf.shutdown()

sf.init(parties=['alice', 'bob'], address='local')

alice, bob = sf.PYU('alice'), sf.PYU('bob')

准备数据集#

Cora数据集#

cora 数据集由两个文件组成:cora.citescora.content

  • cora.cites 包含了论文引用记录,共两个列: cited_paper_id(被引用论文)和citing_paper_id(发起引用的论文)。

  • cora.content 包含了论文内容记录,拥有1435个列,分别是paper_id(论文id), subject(课题), 和其他1,433二值化特征。

我们使用隐语内置已经切分好的cora数据集。

  • 训练集包含140个论文id。

  • 测试集包含1000个论文id。

  • 验证集包含500个论文id。

切分数据集#

对数据集进行切分。

  • alice持有前716个特征,bob持有剩余的其他特征。

  • alice持有所有的标签。

  • alice和bob均持有所有的边。

[2]:
import networkx as nx
import numpy as np
import os
import pickle
import scipy
import zipfile
import tempfile
from pathlib import Path
from secretflow.utils.simulation.datasets import dataset
from secretflow.data.ndarray import load


def load_cora():
    dataset_zip = dataset('cora')
    extract_path = str(Path(dataset_zip).parent)
    with zipfile.ZipFile(dataset_zip, 'r') as zip_f:
        zip_f.extractall(extract_path)

    file_names = [
        os.path.join(extract_path, f'ind.cora.{name}')
        for name in ['y', 'tx', 'ty', 'allx', 'ally', 'graph']
    ]

    objects = []
    for name in file_names:
        with open(name, 'rb') as f:
            objects.append(pickle.load(f, encoding='latin1'))

    y, tx, ty, allx, ally, graph = tuple(objects)

    with open(os.path.join(extract_path, f"ind.cora.test.index"), 'r') as f:
        test_idx_reorder = f.readlines()
    test_idx_reorder = list(map(lambda s: int(s.strip()), test_idx_reorder))
    test_idx_range = np.sort(test_idx_reorder)

    nodes = scipy.sparse.vstack((allx, tx)).tolil()
    nodes[test_idx_reorder, :] = nodes[test_idx_range, :]
    edge = nx.adjacency_matrix(nx.from_dict_of_lists(graph))
    edge = edge.toarray() + np.eye(edge.shape[1])

    labels = np.vstack((ally, ty))
    labels[test_idx_reorder, :] = labels[test_idx_range, :]

    idx_test = test_idx_range.tolist()
    idx_train = range(len(y))
    idx_val = range(len(y), len(y) + 500)

    def sample_mask(idx, length):
        mask = np.zeros(length)
        mask[idx] = 1
        return np.array(mask, dtype=bool)

    idx_train = sample_mask(idx_train, labels.shape[0])
    idx_val = sample_mask(idx_val, labels.shape[0])
    idx_test = sample_mask(idx_test, labels.shape[0])

    y_train = np.zeros(labels.shape)
    y_val = np.zeros(labels.shape)
    y_test = np.zeros(labels.shape)
    y_train[idx_train, :] = labels[idx_train, :]
    y_val[idx_val, :] = labels[idx_val, :]
    y_test[idx_test, :] = labels[idx_test, :]

    nodes = nodes.toarray()
    features_split_pos = round(nodes.shape[1] / 2)
    nodes_alice, nodes_bob = nodes[:, :features_split_pos], nodes[:, features_split_pos:]
    temp_dir = tempfile.mkdtemp()
    saved_files = [
        os.path.join(temp_dir, name)
        for name in [
            'edge.npy',
            'x_alice.npy',
            'x_bob.npy',
            'y_train.npy',
            'y_val.npy',
            'y_test.npy',
            'idx_train.npy',
            'idx_val.npy',
            'idx_test.npy',
        ]
    ]
    np.save(saved_files[0], edge)
    np.save(saved_files[1], nodes_alice)
    np.save(saved_files[2], nodes_bob)
    np.save(saved_files[3], y_train)
    np.save(saved_files[4], y_val)
    np.save(saved_files[5], y_test)
    np.save(saved_files[6], idx_train)
    np.save(saved_files[7], idx_val)
    np.save(saved_files[8], idx_test)
    return saved_files

saved_files = load_cora()

edge = load({alice: saved_files[0], bob: saved_files[0]})
features = load({alice: saved_files[1], bob: saved_files[2]})
Y_train = load({alice: saved_files[3]})
Y_val = load({alice: saved_files[4]})
Y_test = load({alice: saved_files[5]})
idx_train = load({alice: saved_files[6]})
idx_val = load({alice: saved_files[7]})
idx_test = load({alice: saved_files[8]})

/tmp/ipykernel_754102/2692670243.py:27: DeprecationWarning: Please use `csr_matrix` from the `scipy.sparse` namespace, the `scipy.sparse.csr` namespace is deprecated.
  objects.append(pickle.load(f, encoding='latin1'))
/tmp/ipykernel_754102/2692670243.py:38: FutureWarning: adjacency_matrix will return a scipy.sparse array instead of a matrix in Networkx 3.0.
  edge = nx.adjacency_matrix(nx.from_dict_of_lists(graph))

事实上,隐语已经内置好了cora数据,所以你只需要跑下面这一行代码即可获得上面代码相同的结果。

[3]:
from secretflow.utils.simulation.datasets import load_cora

(edge, features, Y_train, Y_val, Y_test, idx_train, idx_val, idx_test) = load_cora(
    [alice, bob]
)

构建图神经网络模型#

实现图卷积层#

[4]:
import tensorflow as tf
from tensorflow.keras import activations
from tensorflow.keras import backend as K
from tensorflow.keras import constraints, initializers, regularizers
from tensorflow.keras.layers import Dropout, Layer, LeakyReLU


class GraphAttention(Layer):
    def __init__(
        self,
        F_,
        attn_heads=1,
        attn_heads_reduction='average',  # {'concat', 'average'}
        dropout_rate=0.5,
        activation='relu',
        use_bias=True,
        kernel_initializer='glorot_uniform',
        bias_initializer='zeros',
        attn_kernel_initializer='glorot_uniform',
        kernel_regularizer=None,
        bias_regularizer=None,
        attn_kernel_regularizer=None,
        activity_regularizer=None,
        kernel_constraint=None,
        bias_constraint=None,
        attn_kernel_constraint=None,
        **kwargs,
    ):
        if attn_heads_reduction not in {'concat', 'average'}:
            raise ValueError('Possbile reduction methods: concat, average')

        self.F_ = F_  # Number of output features (F' in the paper)
        self.attn_heads = attn_heads  # Number of attention heads (K in the paper)
        self.attn_heads_reduction = attn_heads_reduction
        self.dropout_rate = dropout_rate  # Internal dropout rate
        self.activation = activations.get(activation)
        self.use_bias = use_bias

        self.kernel_initializer = initializers.get(kernel_initializer)
        self.bias_initializer = initializers.get(bias_initializer)
        self.attn_kernel_initializer = initializers.get(attn_kernel_initializer)

        self.kernel_regularizer = regularizers.get(kernel_regularizer)
        self.bias_regularizer = regularizers.get(bias_regularizer)
        self.attn_kernel_regularizer = regularizers.get(attn_kernel_regularizer)
        self.activity_regularizer = regularizers.get(activity_regularizer)

        self.kernel_constraint = constraints.get(kernel_constraint)
        self.bias_constraint = constraints.get(bias_constraint)
        self.attn_kernel_constraint = constraints.get(attn_kernel_constraint)
        self.supports_masking = False

        # Populated by build()
        self.kernels = []  # Layer kernels for attention heads
        self.biases = []  # Layer biases for attention heads
        self.attn_kernels = []  # Attention kernels for attention heads

        if attn_heads_reduction == 'concat':
            # Output will have shape (..., K * F')
            self.output_dim = self.F_ * self.attn_heads
        else:
            # Output will have shape (..., F')
            self.output_dim = self.F_

        super(GraphAttention, self).__init__(**kwargs)

    def build(self, input_shape):
        assert len(input_shape) >= 2
        F = input_shape[0][-1]

        # Initialize weights for each attention head
        for head in range(self.attn_heads):
            # Layer kernel
            kernel = self.add_weight(
                shape=(F, self.F_),
                initializer=self.kernel_initializer,
                regularizer=self.kernel_regularizer,
                constraint=self.kernel_constraint,
                name='kernel_{}'.format(head),
            )
            self.kernels.append(kernel)

            # # Layer bias
            if self.use_bias:
                bias = self.add_weight(
                    shape=(self.F_,),
                    initializer=self.bias_initializer,
                    regularizer=self.bias_regularizer,
                    constraint=self.bias_constraint,
                    name='bias_{}'.format(head),
                )
                self.biases.append(bias)

            # Attention kernels
            attn_kernel_self = self.add_weight(
                shape=(self.F_, 1),
                initializer=self.attn_kernel_initializer,
                regularizer=self.attn_kernel_regularizer,
                constraint=self.attn_kernel_constraint,
                name='attn_kernel_self_{}'.format(head),
            )
            attn_kernel_neighs = self.add_weight(
                shape=(self.F_, 1),
                initializer=self.attn_kernel_initializer,
                regularizer=self.attn_kernel_regularizer,
                constraint=self.attn_kernel_constraint,
                name='attn_kernel_neigh_{}'.format(head),
            )
            self.attn_kernels.append([attn_kernel_self, attn_kernel_neighs])
        self.built = True

    def call(self, inputs):
        X = inputs[0]  # Node features (N x F)
        A = inputs[1]  # Adjacency matrix (N x N)

        outputs = []
        for head in range(self.attn_heads):
            kernel = self.kernels[head]  # W in the paper (F x F')
            attention_kernel = self.attn_kernels[
                head
            ]  # Attention kernel a in the paper (2F' x 1)

            # Compute inputs to attention network
            features = K.dot(X, kernel)  # (N x F')

            # Compute feature combinations
            # Note: [[a_1], [a_2]]^T [[Wh_i], [Wh_2]] = [a_1]^T [Wh_i] + [a_2]^T [Wh_j]
            attn_for_self = K.dot(
                features, attention_kernel[0]
            )  # (N x 1), [a_1]^T [Wh_i]
            attn_for_neighs = K.dot(
                features, attention_kernel[1]
            )  # (N x 1), [a_2]^T [Wh_j]

            # Attention head a(Wh_i, Wh_j) = a^T [[Wh_i], [Wh_j]]
            dense = attn_for_self + K.transpose(
                attn_for_neighs
            )  # (N x N) via broadcasting

            # Add nonlinearty
            dense = LeakyReLU(alpha=0.2)(dense)

            # Mask values before activation (Vaswani et al., 2017)
            mask = -10e9 * (1.0 - A)
            dense += mask

            # Apply softmax to get attention coefficients
            dense = K.softmax(dense)  # (N x N)

            # Apply dropout to features and attention coefficients
            dropout_attn = Dropout(self.dropout_rate)(dense)  # (N x N)
            dropout_feat = Dropout(self.dropout_rate)(features)  # (N x F')

            # Linear combination with neighbors' features
            node_features = K.dot(dropout_attn, dropout_feat)  # (N x F')

            if self.use_bias:
                node_features = K.bias_add(node_features, self.biases[head])

            # Add output of attention head to final output
            outputs.append(node_features)

        # Aggregate the heads' output according to the reduction method
        if self.attn_heads_reduction == 'concat':
            output = K.concatenate(outputs)  # (N x KF')
        else:
            output = K.mean(K.stack(outputs), axis=0)  # N x F')

        output = self.activation(output)
        return output

    def compute_output_shape(self, input_shape):
        output_shape = input_shape[0][0], self.output_dim
        return output_shape

    def get_config(self):
        config = super().get_config().copy()
        config.update(
            {
                'attn_heads': self.attn_heads,
                'attn_heads_reduction': self.attn_heads_reduction,
                'F_': self.F_,
            }
        )
        return config

实现fuse层#

Fuse模型会在持有label的一方使用,其工作原理如下: 1. 使用多个参与方拼接后的embedding生成最终的embedding。 2. 使用Softmax预测分类。

[5]:

class ServerNet(tf.keras.layers.Layer): def __init__( self, in_channel: int, hidden_size: int, num_layer: int, num_class: int, dropout: float, **kwargs, ): super(ServerNet, self).__init__() self.num_class = num_class self.num_layer = num_layer self.hidden_size = hidden_size self.in_channel = in_channel self.dropout = dropout self.layers = [] super(ServerNet, self).__init__(**kwargs) def build(self, input_shape): self.layers.append( tf.keras.layers.Dense(self.hidden_size, input_shape=(self.in_channel,)) ) for i in range(self.num_layer - 2): self.layers.append( tf.keras.layers.Dense(self.hidden_size, input_shape=(self.hidden_size,)) ) self.layers.append( tf.keras.layers.Dense(self.num_class, input_shape=(self.hidden_size,)) ) super(ServerNet, self).build(input_shape) def call(self, inputs): x = inputs x = Dropout(self.dropout)(x) for i in range(self.num_layer): x = Dropout(self.dropout)(x) x = self.layers[i](x) return K.softmax(x) def compute_output_shape(self, input_shape): output_shape = self.hidden_size, self.output_dim return output_shape def get_config(self): config = super().get_config().copy() config.update( { 'in_channel': self.in_channel, 'hidden_size': self.hidden_size, 'num_layer': self.num_layer, 'num_class': self.num_class, 'dropout': self.dropout, } ) return config

构建基础模型#

每个参与方都会使用基础模型生成embedding。基础模型使用一层图卷积层来生成embedding。

所有参与方的embedding将会传输给持有标签的一方进行后续处理。

[6]:
from tensorflow.keras.models import Model

def create_base_model(input_shape, n_hidden, l2_reg, num_heads, dropout_rate, learning_rate):
    def base_model():
        feature_input = tf.keras.Input(shape=(input_shape[1],))
        graph_input = tf.keras.Input(shape=(input_shape[0],))
        regular = tf.keras.regularizers.l2(l2_reg)
        outputs = GraphAttention(
            F_=n_hidden,
            attn_heads=num_heads,
            attn_heads_reduction='average',  # {'concat', 'average'}
            dropout_rate=dropout_rate,
            activation='relu',
            use_bias=True,
            kernel_initializer='glorot_uniform',
            bias_initializer='zeros',
            attn_kernel_initializer='glorot_uniform',
            kernel_regularizer=regular,
            bias_regularizer=None,
            attn_kernel_regularizer=None,
            activity_regularizer=None,
            kernel_constraint=None,
            bias_constraint=None,
            attn_kernel_constraint=None,
        )([feature_input, graph_input])
        # outputs = tf.keras.layers.Flatten()(outputs)
        model = Model(inputs=[feature_input, graph_input], outputs=outputs)
        model._name = "embed_model"
        # Compile model
        model.summary()
        metrics = ['acc']
        optimizer = tf.keras.optimizers.get(
            {
                'class_name': 'adam',
                'config': {'learning_rate': learning_rate},
            }
        )
        model.compile(
            loss='categorical_crossentropy',
            weighted_metrics=metrics,
            optimizer=optimizer,
        )
        return model

    return base_model

构建fuse模型#

Fuse模型把所有参与方的embedding拼接在一起,该模型仅在持有标签的一方使用。

[7]:
from tensorflow import keras
from tensorflow.keras import layers


def create_fuse_model(hidden_units, hidden_size, n_classes, layer_num, learning_rate):
    def fuse_model():
        inputs = [keras.Input(shape=size) for size in hidden_units]
        x = layers.concatenate(inputs)
        input_shape = x.shape[-1]
        y_pred = ServerNet(
            in_channel=input_shape,
            hidden_size=hidden_size,
            num_layer=layer_num,
            num_class=n_classes,
            dropout=0.0,
        )(x)
        # Create the model.
        model = keras.Model(inputs=inputs, outputs=y_pred, name="fuse_model")
        model.summary()
        metrics = ['acc']
        optimizer = tf.keras.optimizers.get(
            {'class_name': 'adam', 'config': {'learning_rate': learning_rate},}
        )
        model.compile(
            loss='categorical_crossentropy',
            weighted_metrics=metrics,
            optimizer=optimizer,
        )
        return model

    return fuse_model

基于拆分学习训练GNN模型#

构建一个拆分学习模型。

alice作为持有标签的一方,将拥有基础模型和fuse模型,bob仅拥有基础模型。

完整的模型结构如下

split_learning_gnn_model.png

[8]:
from secretflow.ml.nn import SLModel

hidden_size = 256
n_classes = 7
attn_heads = 2
layer_num = 3
learning_rate = 1e-3
dropout_rate = 0.0
l2_reg = 0.1
num_heads = 4
epochs = 10
optimizer = 'adam'

partition_shapes = features.partition_shape()

input_shape_alice = partition_shapes[alice]
input_shape_bob = partition_shapes[bob]

sl_model = SLModel(
    base_model_dict={
        alice: create_base_model(
            input_shape_alice,
            hidden_size,
            l2_reg,
            num_heads,
            dropout_rate,
            learning_rate,
        ),
        bob: create_base_model(
            input_shape_bob,
            hidden_size,
            l2_reg,
            num_heads,
            dropout_rate,
            learning_rate,
        ),
    },
    device_y=alice,
    model_fuse=create_fuse_model(
        [hidden_size, hidden_size], hidden_size, n_classes, layer_num, learning_rate
    ),
)


拟合模型。

[9]:

sl_model.fit( x=[features, edge], y=Y_train, epochs=epochs, batch_size=input_shape_alice[0], sample_weight=idx_train, validation_data=([features, edge], Y_val, idx_val), )
100%|██████████| 1/1 [00:05<00:00,  5.67s/it, epoch: 0/10 -  train_loss:0.10079389065504074  train_acc:0.12857143580913544  val_loss:0.35441863536834717  val_acc:0.3880000114440918 ]
100%|██████████| 1/1 [00:01<00:00,  1.02s/it, epoch: 1/10 -  train_loss:0.2256714105606079  train_acc:0.44843751192092896  val_loss:0.3481321930885315  val_acc:0.5640000104904175 ]
100%|██████████| 1/1 [00:01<00:00,  1.01s/it, epoch: 2/10 -  train_loss:0.22043751180171967  train_acc:0.637499988079071  val_loss:0.34046509861946106  val_acc:0.6320000290870667 ]
100%|██████████| 1/1 [00:01<00:00,  1.09s/it, epoch: 3/10 -  train_loss:0.21422825753688812  train_acc:0.703125  val_loss:0.33100318908691406  val_acc:0.6539999842643738 ]
100%|██████████| 1/1 [00:01<00:00,  1.04s/it, epoch: 4/10 -  train_loss:0.20669467747211456  train_acc:0.7250000238418579  val_loss:0.3193798065185547  val_acc:0.6840000152587891 ]
100%|██████████| 1/1 [00:01<00:00,  1.10s/it, epoch: 5/10 -  train_loss:0.19755281507968903  train_acc:0.7484375238418579  val_loss:0.30531173944473267  val_acc:0.7080000042915344 ]
100%|██████████| 1/1 [00:01<00:00,  1.03s/it, epoch: 6/10 -  train_loss:0.1866169422864914  train_acc:0.7671874761581421  val_loss:0.28866109251976013  val_acc:0.7279999852180481 ]
100%|██████████| 1/1 [00:01<00:00,  1.03s/it, epoch: 7/10 -  train_loss:0.17386208474636078  train_acc:0.7828124761581421  val_loss:0.26949769258499146  val_acc:0.7319999933242798 ]
100%|██████████| 1/1 [00:01<00:00,  1.04s/it, epoch: 8/10 -  train_loss:0.1594790667295456  train_acc:0.785937488079071  val_loss:0.24824994802474976  val_acc:0.7319999933242798 ]
100%|██████████| 1/1 [00:01<00:00,  1.06s/it, epoch: 9/10 -  train_loss:0.1439441591501236  train_acc:0.785937488079071  val_loss:0.2257150411605835  val_acc:0.7400000095367432 ]
[9]:
{'train_loss': [0.10079389,
  0.22567141,
  0.22043751,
  0.21422826,
  0.20669468,
  0.19755282,
  0.18661694,
  0.17386208,
  0.15947907,
  0.14394416],
 'train_acc': [0.12857144,
  0.4484375,
  0.6375,
  0.703125,
  0.725,
  0.7484375,
  0.7671875,
  0.7828125,
  0.7859375,
  0.7859375],
 'val_loss': [0.35441864,
  0.3481322,
  0.3404651,
  0.3310032,
  0.3193798,
  0.30531174,
  0.2886611,
  0.2694977,
  0.24824995,
  0.22571504],
 'val_acc': [0.388,
  0.564,
  0.632,
  0.654,
  0.684,
  0.708,
  0.728,
  0.732,
  0.732,
  0.74]}

查看GNN模型预测效果。

[10]:
sl_model.evaluate(
    x=[features, edge],
    y=Y_test,
    batch_size=input_shape_alice[0],
    sample_weight=idx_test,
)
Evaluate Processing:: 100%|██████████| 1/1 [00:00<00:00,  2.26it/s, loss:0.4411766827106476 acc:0.7720000147819519]
[10]:
{'loss': 0.44117668, 'acc': 0.772}

总结#

本文展示了怎么使用拆分学习训练图神经网络。文中实现了一个基础的GNN示例,未来我们会进行更多探索,比如:

  • 大规模图数据:文中示例每次训练是在全部图数据上进行,会引入较大的计算负担。我们会尝试探索诸如mini batch等方式来降低计算和内存消耗。

  • 支持部分对齐的图:示例中的图数据是假设每方的节点都是对齐一致的,但是实际中可能只有部分对齐。我们后续会探索多方仅有部分节点对齐的场景。