Как эффективно вычислить взаимную информацию между мини-пакетами изображений в Tensorflow?
Я хочу использовать взаимную информацию (MI) между моими изображениями ввода-вывода как дополнительную потерю для обучения моей нейронной сети в Tensorflow. Как сделать это эффективно, чтобы ускорить обучение?
Я реализовал версию TF, используя подход гистограммы (с tf.while_loop
) и это работает. Но это безумно медленно. Основные функции, которые я написал, включают в себя:
- mi_loss
(который вычисляет MI для мини-пакета, вызывая следующее для каждого элемента мини-пакета),
- tf_mi
(который вычисляет MI между двумя изображениями, вызывая следующее)
- tf_hist2d
(который вычисляет объединенную гистограмму между двумя изображениями).
def tf_hist2d(A, B, nbins=256):
a = (A - tf.reduce_min(A))/(tf.reduce_max(A)-tf.reduce_min(A)) * (nbins-1.)
a = tf.cast( tf.round(a), dtype=tf.int32 )
b = (B - tf.reduce_min(B))/(tf.reduce_max(B)-tf.reduce_min(B)) * (nbins-1.)
b = tf.cast( tf.round(b), dtype=tf.int32 )
i0 = tf.constant(0)
hist = tf.zeros((1, nbins), dtype=tf.int32)
cond = lambda i, his: i < nbins
def body(i, his):
idx = tf.where(tf.equal(a, i), tf.ones(tf.shape(a), dtype=tf.bool),
tf.zeros(tf.shape(a), dtype=tf.bool))
ab = tf.boolean_mask(b, idx)
h = tf.histogram_fixed_width(tf.reshape(ab, [-1]), [0, nbins], nbins=nbins)
his = tf.concat((his, h[None,:]), axis=0)
return [i+1, his]
res = tf.while_loop(cond, body, [i0, hist],
shape_invariants=[i0.get_shape(), tf.TensorShape([None, nbins])],
parallel_iterations=50, swap_memory=True, name='MI_hist_loop')
hist = res[1][1:,:]
return hist
def tf_mi(A, B, nbins=256):
hist = tf.cast(tf_hist2d(A, B, nbins), dtype=tf.float32)
pab = hist / tf.reduce_sum(hist)
pa = tf.reduce_sum(pab, axis=1)
pb = tf.reduce_sum(pab, axis=0)
papb = pa[...,None]*pb[None,...]
idx = tf.where(tf.greater(pab,0), tf.ones(tf.shape(pab), dtype=tf.bool),
tf.zeros(tf.shape(pab), dtype=tf.bool))
idx = tf.where(tf.greater(papb,0), idx,
tf.zeros(tf.shape(pab), dtype=tf.bool))
pab = tf.boolean_mask(pab, idx)
papb= tf.boolean_mask(papb, idx)
mi = tf.reduce_sum( pab * tf.log(pab/papb) )
return mi
def mi_loss(Abat, Bbat, nbins=256):
"""
Abat: input batch - image RGB
Bbat: feature batch - output features 1-chan
"""
#gray-scale then rescale image
Abat = tf.image.rgb_to_grayscale(Abat)
Abat = tf.image.resize_images(Abat, (tf.shape(Bbat)[1], tf.shape(Bbat)[2]))
#loop over minibatch elements
mi = tf.constant(0.)
Nb = tf.shape(Abat)[0]
it = tf.constant(0)
cond = lambda i, m, A, B: tf.less(i, Nb)
def body(i, m, A, B):
m += tf_mi(A[i,...,0], B[i,...,0], nbins=nbins)
return [i+1, m, A, B]
res = tf.while_loop(cond, body, [it, mi, Abat, Bbat],
parallel_iterations=50, swap_memory=True, name='MI_batch_loop')
return res[1]/tf.cast(Nb, dtype=tf.float32)
Я подозреваю, что вложенный tf.while_loop
S замедляют код. Всякий раз, когда я включаю потери MI, время вычислений увеличивается, и я вижу, что использование моего графического процессора падает до ~6%. Если я не включаю потери MI, время вычислений почти вдвое меньше, а использование GPU составляет около 90%.
Есть ли способ сделать это более эффективным? Я пытался думать о:
а) ускорение tf.while_loop
с некоторыми вариантами, но я не мог найти обходной путь
б) вычисление MI в пакетном режиме без зацикливания элементов минибата, но я не уверен, как: можно tf.histogram_fixed_width
вычислить гистограммы вдоль оси (сомневаетесь)?
Спасибо