Как эффективно вычислить взаимную информацию между мини-пакетами изображений в Tensorflow?

Я хочу использовать взаимную информацию (MI) между моими изображениями ввода-вывода как дополнительную потерю для обучения моей нейронной сети в Tensorflow. Как сделать это эффективно, чтобы ускорить обучение?

Я реализовал версию TF, используя подход гистограммы (с tf.while_loop) и это работает. Но это безумно медленно. Основные функции, которые я написал, включают в себя:
- mi_loss (который вычисляет MI для мини-пакета, вызывая следующее для каждого элемента мини-пакета),
- tf_mi (который вычисляет MI между двумя изображениями, вызывая следующее)
- tf_hist2d (который вычисляет объединенную гистограмму между двумя изображениями).

def tf_hist2d(A, B, nbins=256):
    a = (A - tf.reduce_min(A))/(tf.reduce_max(A)-tf.reduce_min(A)) * (nbins-1.)
    a = tf.cast( tf.round(a), dtype=tf.int32 )
    b = (B - tf.reduce_min(B))/(tf.reduce_max(B)-tf.reduce_min(B)) * (nbins-1.)
    b = tf.cast( tf.round(b), dtype=tf.int32 )
    i0 = tf.constant(0)
    hist = tf.zeros((1, nbins), dtype=tf.int32)
    cond = lambda i, his: i < nbins
    def body(i, his):
        idx = tf.where(tf.equal(a, i), tf.ones(tf.shape(a), dtype=tf.bool),
                                        tf.zeros(tf.shape(a), dtype=tf.bool))
        ab = tf.boolean_mask(b, idx)
        h = tf.histogram_fixed_width(tf.reshape(ab, [-1]), [0, nbins], nbins=nbins)
        his = tf.concat((his, h[None,:]), axis=0)
        return [i+1, his]
    res = tf.while_loop(cond, body, [i0, hist],
                shape_invariants=[i0.get_shape(), tf.TensorShape([None, nbins])],
                parallel_iterations=50, swap_memory=True, name='MI_hist_loop')
    hist = res[1][1:,:]
    return hist    


def tf_mi(A, B, nbins=256):
    hist = tf.cast(tf_hist2d(A, B, nbins), dtype=tf.float32)
    pab = hist / tf.reduce_sum(hist)
    pa = tf.reduce_sum(pab, axis=1)
    pb = tf.reduce_sum(pab, axis=0)
    papb = pa[...,None]*pb[None,...]
    idx = tf.where(tf.greater(pab,0), tf.ones(tf.shape(pab), dtype=tf.bool),
                                        tf.zeros(tf.shape(pab), dtype=tf.bool))
    idx = tf.where(tf.greater(papb,0), idx,
                                        tf.zeros(tf.shape(pab), dtype=tf.bool))
    pab = tf.boolean_mask(pab, idx)
    papb= tf.boolean_mask(papb, idx)
    mi = tf.reduce_sum( pab * tf.log(pab/papb) )
    return mi 


def mi_loss(Abat, Bbat, nbins=256):
    """
    Abat: input batch - image RGB
    Bbat: feature batch - output features 1-chan
    """
    #gray-scale then rescale image
    Abat = tf.image.rgb_to_grayscale(Abat)
    Abat = tf.image.resize_images(Abat, (tf.shape(Bbat)[1], tf.shape(Bbat)[2]))
    #loop over minibatch elements
    mi = tf.constant(0.)
    Nb = tf.shape(Abat)[0]
    it = tf.constant(0)
    cond = lambda i, m, A, B: tf.less(i, Nb)
    def body(i, m, A, B):
        m += tf_mi(A[i,...,0], B[i,...,0], nbins=nbins)
        return [i+1, m, A, B]
    res = tf.while_loop(cond, body, [it, mi, Abat, Bbat],
                parallel_iterations=50, swap_memory=True, name='MI_batch_loop')
    return res[1]/tf.cast(Nb, dtype=tf.float32)

Я подозреваю, что вложенный tf.while_loopS замедляют код. Всякий раз, когда я включаю потери MI, время вычислений увеличивается, и я вижу, что использование моего графического процессора падает до ~6%. Если я не включаю потери MI, время вычислений почти вдвое меньше, а использование GPU составляет около 90%.

Есть ли способ сделать это более эффективным? Я пытался думать о:
а) ускорение tf.while_loop с некоторыми вариантами, но я не мог найти обходной путь
б) вычисление MI в пакетном режиме без зацикливания элементов минибата, но я не уверен, как: можно tf.histogram_fixed_width вычислить гистограммы вдоль оси (сомневаетесь)?

Спасибо

0 ответов

Другие вопросы по тегам