Почему выход сверточной сети настолько экспоненциально велик?
Я пытаюсь воспроизвести результат unet в наборе данных Carvana с помощью Ternausnet в PyTorch с использованием Lightning.
Я использую для этого DiceLoss с функцией активации сигмоида. Я думаю, что столкнулся с проблемой исчезающего градиента, потому что все градиенты весов равны 0, и я вижу выход сети с минимальным значением порядка 10^8.
В чем может быть проблема? Как я могу решить проблему исчезающего градиента? Кроме того, если я использую другой критерий, я вижу проблему потери отрицательных значений без остановки (например, для BCE с логитами).
Вот код моего проигрыша в кости:
class DiceLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, logits, targets, eps=0, threshold=None):
# comment out if your model contains a sigmoid or
# equivalent activation layer
proba = torch.sigmoid(logits)
proba = proba.view(proba.shape[0], 1, -1)
targets = targets.view(targets.shape[0], 1, -1)
if threshold:
proba = (proba > threshold).float()
# flatten label and prediction tensors
intersection = torch.sum(proba * targets, dim=1)
summation = torch.sum(proba, dim=1) + torch.sum(targets, dim=1)
dice = (2.0 * intersection + eps) / (summation + eps)
# print(intersection, summation, dice)
return (1 - dice).mean()