Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification

背景消息傳遞模型(Message Passing Model)基于拉普拉斯平滑假設(領居是相似的) , 試圖聚合圖中的鄰居的信息來獲取足夠的依據,以實現更魯棒的半監督節點分類 。
圖神經網絡(Graph Neural Networks, GNN)和標簽傳播算法(Label Propagation, LPA)均為消息傳遞算法,其中GNN主要基于傳播特征來提升預測效果,而LPA基于迭代式的標簽傳播來作預測 。
【Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification】一些工作要么用LPA對GNN預測結果做后處理,要么用LPA對GNN進行正則化 。但是 , 它們仍不能直接將GNN和LPA有效地整合到消息傳遞模型中 。
為解決這個問題,本文提出了統一消息傳遞模型(UNIMP)[1],它可以在訓練和推理時結合特征和標簽傳播 。UniMP基于兩個簡單而有效的想法:

  • 將特征嵌入和標簽嵌入同時作為輸入信息進行傳播
  • 隨機掩碼部分標簽信息,并在訓練時對其進行預測
UniMP在概念上統一了特征傳播和標簽傳播,具有強大的經驗能力 。
Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification

文章插圖
實現關鍵部分
  • 將標簽進行嵌入(原有的C類One-hot標簽,通過線性變換成與原始節點特征相同的維度) 。
  • 然后,將標簽嵌入和節點特征相加作為GNN輸入 。
為避免訓練時使用標簽導致標簽泄露,這里使用了掩碼標簽訓練的策略 。每個Epoch隨機將訓練集中部分節點的標簽置(掩碼)0(視為訓練監督信號),然后利用節點特征 \(\mathbf{X}\) 和 \(\mathbf{A}\)以及剩余的標簽去預測被掩碼的標簽) 。
模型部分UniMP中使用了GraphTransformer(Transformer中的Q、K、V注意力形式,加上邊特征),同時引入了H-GCN的門控殘差機制來緩解過平滑 。
個人實驗將標簽作為輸入,在ArixV數據集節點分類任務上,能在小數點后第2位提升接近2個點 。
在論文BOT[2]中也對標簽作為輸入做了闡述,其作者還發表了相應的論文來論證標簽作為輸入的有效性的原因 。
總結標簽有效的直覺就是,在圖上的節點分類任務中,鄰居標簽也是預測目標節點標簽的關鍵特征(這也和標簽傳播的思想一致)
標簽嵌入和掩碼標簽預測是提升節點分類任務簡單有效的方法 。
參考文獻
[1] Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification[2] Bag of Tricks for Node Classification with Graph Neural Networks
2022-10-29 11:10:13 星期六

    推薦閱讀