Графическая нейронная сеть с узлами на входе и краями на выходе в DGL

Я хотел бы адаптировать пример DGL GATLayer таким образом, чтобы вместо изучения представлений узлов сеть могла изучать веса ребер. То есть я хочу построить сеть, которая принимает набор узловых функций в качестве входных и выводит края. Этикетки будут представлять собой набор "границ истинности", которые представляют, какие узлы происходят из общего источника, так что я могу научиться кластеризовать невидимые данные таким же образом.

Я использую в качестве отправной точки код из следующего примера DGL:

https://www.dgl.ai/blog/2019/02/17/gat.html

import torch.nn as nn
import torch.nn.functional as F

class GATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim):
        super(GATLayer, self).__init__()
        self.g = g
        # equation (1)
        self.fc = nn.Linear(in_dim, out_dim, bias=False)
        # equation (2)
        self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)
    
    def edge_attention(self, edges):
        # edge UDF for equation (2)
        z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
        a = self.attn_fc(z2)
        return {'e' : F.leaky_relu(a)}
    
    def message_func(self, edges):
        # message UDF for equation (3) & (4)
        return {'z' : edges.src['z'], 'e' : edges.data['e']}
    
    def reduce_func(self, nodes):
        # reduce UDF for equation (3) & (4)
        # equation (3)
        alpha = F.softmax(nodes.mailbox['e'], dim=1)
        # equation (4)
        h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
        return {'h' : h}
    
    def forward(self, h):
        # equation (1)
        z = self.fc(h)
        self.g.ndata['z'] = z
        # equation (2)
        self.g.apply_edges(self.edge_attention)
        # equation (3) & (4)
        self.g.update_all(self.message_func, self.reduce_func)
        return self.g.ndata.pop('h')

class MultiHeadGATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):
        super(MultiHeadGATLayer, self).__init__()
        self.heads = nn.ModuleList()
        for i in range(num_heads):
            self.heads.append(GATLayer(g, in_dim, out_dim))
        self.merge = merge
    
    def forward(self, h):
        head_outs = [attn_head(h) for attn_head in self.heads]
        if self.merge == 'cat':
            # concat on the output feature dimension (dim=1)
            return torch.cat(head_outs, dim=1)
        else:
            # merge using average
            return torch.mean(torch.stack(head_outs))

class GAT(nn.Module):
    def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):
        super(GAT, self).__init__()
        self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)
        # Be aware that the input dimension is hidden_dim*num_heads since
        #   multiple head outputs are concatenated together. Also, only
        #   one attention head in the output layer.
        self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)
    
    def forward(self, h):
        h = self.layer1(h)
        h = F.elu(h)
        h = self.layer2(h)
        return h

Я надеялся, что смогу адаптировать это, чтобы просто вернуть края вместо узлов, например, заменив строку

return self.g.ndata.pop('h')

с участием

return self.e.ndata.pop('e')

Но, похоже, не все так просто. Мне удалось заставить что-то пробежать, но потеря скакала повсюду, и никакого обучения не происходило.

Я новичок в графических сетях, но не в глубоком обучении в целом. Разумно ли то, что я пытаюсь сделать? Я упустил что-то важное в моем понимании того, как это работает? Мне не удалось найти простых для понимания примеров сетей графов, в которых ребра сами по себе являются целью обучения, поэтому я немного запутался в данный момент. Я ценю любую помощь, которую может оказать каждый!

1 ответ

Я не совсем уверен, потому что это зависит от вашего ввода, но self.g, скорее всего, является графом DGL, поэтому в коде они обращаются к ndata, что означает данные узла, если вы хотите получить доступ к данным края графа, вы должны получить доступ к edata. Поэтому вы должны написать return self.g.edata... хотя я не уверен, какие атрибуты краев, к которым вы пытаетесь получить доступ, будут изменять pop(что бы вы ни пытались получить)

Другие вопросы по тегам