Почему выход сверточной сети настолько экспоненциально велик?

Я пытаюсь воспроизвести результат unet в наборе данных Carvana с помощью Ternausnet в PyTorch с использованием Lightning.

Я использую для этого DiceLoss с функцией активации сигмоида. Я думаю, что столкнулся с проблемой исчезающего градиента, потому что все градиенты весов равны 0, и я вижу выход сети с минимальным значением порядка 10^8.

В чем может быть проблема? Как я могу решить проблему исчезающего градиента? Кроме того, если я использую другой критерий, я вижу проблему потери отрицательных значений без остановки (например, для BCE с логитами).

Вот код моего проигрыша в кости:

class DiceLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, logits, targets, eps=0, threshold=None):

        # comment out if your model contains a sigmoid or
        # equivalent activation layer
        proba = torch.sigmoid(logits)
        proba = proba.view(proba.shape[0], 1, -1)
        targets = targets.view(targets.shape[0], 1, -1)
        if threshold:
            proba = (proba > threshold).float()
        # flatten label and prediction tensors

        intersection = torch.sum(proba * targets, dim=1)
        summation = torch.sum(proba, dim=1) + torch.sum(targets, dim=1)
        dice = (2.0 * intersection + eps) / (summation + eps)
        # print(intersection, summation, dice)
        return (1 - dice).mean()

0 ответов

Другие вопросы по тегам