Функция потерь не уменьшается; потеря валидации далеко от потери обучения; f1 смешно

Я пытаюсь использовать простую модель, использующую нейронную сеть графа (GNN) для реализации двоичной классификации. Мои входные данные - это наборы данных, содержащие графики, два набора для обучения (сигнал, фон) и два набора для проверки (сигнал, фон). Цель состоит в том, чтобы сеть могла предсказать, является ли график сигнальным (следовательно, он заслуживает оценки, стремящейся к 1) или фоновым (следовательно, заслуживает оценки, стремящейся к 0). Каждый граф полностью связан и имеет несколько узлов и ребер. Моя модель проста:

      import dgl
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Conv2d,ReLU,MaxPool2d,Linear, BatchNorm1d


node_hidden_size = 25


class EdgeNetwork(nn.Module):
    def __init__(self,inputsize,hidden_layer_size,output_size):
        super().__init__()
    
        self.net = nn.Sequential( 
                        nn.Linear( inputsize, hidden_layer_size*3), 
                        nn.ReLU(), 
                        nn.Linear(hidden_layer_size*3, output_size) 
                        )
        
    def forward(self, x):
        
        input_data = torch.cat((
                        x.dst['features'], 
                        x.dst['node_hidden_rep'],
                        x.src['features'],
                        x.src['node_hidden_rep'],
                        x.data['features']), dim=-1)
        
        
        output = self.net( input_data ) 
        
        return {'edge_hidden_rep': output }

    
class NodeNetwork(nn.Module):
    def __init__(self,inputsize,hidden_layer_size,output_size):
        super().__init__()

        self.net = nn.Sequential( 
                        nn.Linear (inputsize, hidden_layer_size*3 ), 
                        nn.ReLU(), 
                        nn.Linear(hidden_layer_size*3, output_size) 
                        )


    def forward(self, x):
        
        
        message_sum = torch.sum(x.mailbox['edge_hidden_rep'] ,dim=1)
        
        input_data = torch.cat((message_sum, x.data['features'], x.data['node_hidden_rep']),dim=1)
        
        out = self.net( input_data )

        
        return {'node_hidden_rep': out }


class Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        
        
        # a network to init the hidden rep of the nodes
        self.node_init = nn.Sequential(
                                nn.Linear(4,node_hidden_size*3),
                                nn.ReLU(),
                                nn.Linear(node_hidden_size*3,node_hidden_size)
                                ) #4 = dimension of node_features
        
        self.edge_network = EdgeNetwork( 4 + node_hidden_size+ 4 + node_hidden_size + 2, 50, node_hidden_size ) # source features + destination features + edge features
        
        self.node_network = NodeNetwork( node_hidden_size + 4 + node_hidden_size, 50, node_hidden_size )
#        self.edge_classifier = EdgeNetwork( 4*2 + 2*node_hidden_size + 2*1 + 2, 200, 1) 

        
        self.node_classifier = nn.Sequential(  
                                    nn.Linear(node_hidden_size, 50), 
                                    nn.ReLU(),
                    nn.Linear(50, 100),
                    nn.ReLU(),
                    nn.Linear(100,50),
                    nn.ReLU(),
                    BatchNorm1d(50), 
                                    nn.Linear(50, 1) 
                                    )
       
        ### responsible for the prediction
        self.mlp = nn.Sequential( 
                        nn.Linear( node_hidden_size+4 , 50 ) , 
                        nn.ReLU(), nn.Linear(50, 100), nn.ReLU(), nn.Linear(100, 50), nn.ReLU(), BatchNorm1d(50), nn.Linear( 50, 1)
                        )  
 
 
    def forward(self, batched_g):

        batched_g.ndata['node_hidden_rep'] = self.node_init(batched_g.ndata['features'])
        
        
        GN_block_iterations = 2
 
        for i in range( GN_block_iterations ):
                        
            batched_g.update_all(self.edge_network,self.node_network)
            
            

        new_global_mean = dgl.mean_nodes(batched_g,'node_hidden_rep') #new_global_mean must have shape of new_hidden_rep.size = 25,
        
        broadcasted_sum = dgl.broadcast_nodes(batched_g, new_global_mean)

        batched_g.ndata['global_rep'] = torch.cat((broadcasted_sum, batched_g.ndata['features'] ),dim=1)
        global_rep = dgl.mean_nodes(batched_g,'global_rep')
        
        return self.mlp(global_rep)

Я использую обычную логистическую потерю: BCEWithLogitsLoss и не смогли решить две проблемы:

  1. Тренировки почти не проводятся.
  2. Потери при проверке и обучении чрезвычайно далеки друг от друга.

См. Рисунок, чтобы проиллюстрировать это:потери при проверке и обучении

Что я делаю не так? Что мне не хватает?

0 ответов

Другие вопросы по тегам