Использование API "DGLGraph.apply_edges" и "DGLGraph.send_and_recv" (для вычисления сообщений) в качестве замены "DGLGraph.send" и "DGLGraph.recv"
Я использую DGL (пакет Python, предназначенный для глубокого обучения на графиках) для обучения определению графа, определению графовой сверточной сети (GCN) и обучению.
Я столкнулся с проблемой, с которой я имею дело в течение двух недель. Я разработал свой код GCN на основе приведенной ниже ссылки:
Я столкнулся с ошибкой для этой части вышеупомянутого кода:
class GCNLayer(nn.Module): def init(self, in_feats, out_feats): super(GCNLayer, self).init() self.linear = nn.Linear(in_feats, out_feats)
def forward(self, g, inputs):
# g is the graph and the inputs is the input node features
# first set the node features
g.ndata['h'] = inputs
# trigger message passing on all edges
g.send(g.edges(), gcn_message)
# trigger aggregation at all nodes
g.recv(g.nodes(), gcn_reduce)
# get the result node features
h = g.ndata.pop('h')
# perform linear transformation
return self.linear(h)
Я получаю сообщение об ошибке ниже:
dgl._ffi.base.DGLError: DGLGraph.send is deprecated. As a replacement, use DGLGraph.apply_edges API to compute messages as edge data. Then use DGLGraph.send_and_recv and set the message function as dgl.function.copy_e to conduct message aggregation*
Как указано в ошибке, мне интересно узнать, как я могу использовать DGLGraph.apply_edges вместо DGLGraph.send?
В команде «DGLGraph.send» у нас есть 2 аргумента «g.edges()» и «gcn_message» .
Как эти аргументы можно преобразовать в аргументы, необходимые для «DGLGraph.apply_edges» , которые являются (func, edge='ALL', etype=None, inplace=False) (Согласно этой ссылке ?
Кроме того, тот же вопрос для "DGLGraph.send_and_recv" .
В "DGLGraph.recv" у нас было 2 аргумента "g.nodes()" и "gcn_reduce" .
Как эти аргументы можно преобразовать в аргументы, необходимые для «DGLGraph.send_and_recv» , которые являются «(edges, message_func, reduce_func, apply_node_func = None, etype = None, inplace = False)» (Согласно этой ссылке )?
Я был бы очень признателен, если бы вы могли помочь мне с этой большой проблемой.
Спасибо
2 ответа
попробуйте код ниже, он может решить вашу проблему
def forward(self, g, inputs):
g.ndata['h'] = inputs
g.send_and_recv(g.edges(), gcn_message, gcn_reduce)
h = g.ndata.pop('h')
return self.linear(h)
DGLGraph.apply_edges(func, edge='ALL', etype=None, inplace=False) используется для обновления граничных объектов с помощью функции func на всех гранях в 'ребрах'.
DGLGraph.send_and_recv(edges, message_func, reduce_func, apply_node_func=None, etype=None, inplace=False) используется для передачи сообщений, сокращения сообщений и обновления функций узла для всех ребер в «ребрах».
Чтобы заставить ваш метод пересылки работать, вы можете обновить свой код, как показано ниже.
def forward(self, g, inputs):
g.ndata['h'] = inputs
g.send_and_recv(g.edges(), fn.copy_src("h", "m"), fn.sum("m", "h"))
h = g.ndata.pop("h")
return self.linear(h)
Вы можете использовать свои собственные message_func (генерация сообщений) и reduce_func (агрегация сообщений) в соответствии с вашими целями.