Нахождение числа xor с элементами последовательности для получения заданной суммы

Недавно я столкнулся со следующей проблемой: нам дана целочисленная последовательность x_i (x_i < 2^60) из n (n < 10^5) целые и целые числа S (S < 2^60) найти наименьшее целое число a так что выполняется следующее:

формула,

Например:

x = [1, 2, 5, 10, 50, 100]
S = 242

Возможные решения для a 21, 23, 37, 39, но самый маленький 21.

(1^21) + (2^21) + (5^21) + (10^21) + (50^21) + (100^21)
= 20 + 23 + 16 + 31 + 39 + 113 
= 242

1 ответ

Решение

Можно понемногу строить результат снизу. Начиная с младшего бита, попробуйте 0 и 1 как младший бит aи посмотрите, соответствует ли младший бит суммы xor соответствующему биту S. Затем попробуйте следующий младший бит, передавая любой перенос с предыдущего шага.

Следуя этому алгоритму, может быть 0, 1 или 2 варианта для каждого бита aТаким образом, в худшем случае нам может понадобиться изучить разные ветви и выбрать ту, которая дает наименьший результат. Чтобы избежать экспоненциального поведения, мы кэшируем ранее просмотренные результаты для переноса с определенным битом. Это дает сложность O(kn) в худшем случае, где k - максимальное количество битов в результате, а n - максимальное значение переноса, учитывая, что входной список имеет длину n.

Вот некоторый код Python, который реализует это:

max_shift = 80

def xor_sum0(xs, S, shift, carry, cache, sums):
    if shift >= max_shift:
        return 1e100 if carry else 0
    key = shift, carry
    if key in cache:
        return cache[key]
    best = 1e100
    for i in xrange(2):
        ss = sums[i][shift] + carry
        if ss & 1 == (S >> shift) & 1:
            best = min(best, i + 2 * xor_sum0(xs, S, shift + 1, ss >> 1, cache, sums))
    cache[key] = best
    return cache[key]

def xor_sum(xs, S):
    sums = [
        [sum(((x >> sh) ^ i) & 1 for x in xs) for sh in xrange(max_shift)]
        for i in xrange(2)]
    return xor_sum0(xs, S, 0, 0, dict(), sums)

В случае, если решения не существует, код возвращает большое (>=1e100) число с плавающей запятой.

И вот тест, который выбирает случайные значения в диапазонах, которые вы дали, выбирает случайный a и вычисляет S, а затем решает. Обратите внимание, что иногда код находит меньший a чем тот, который был использован для вычисления S, так как значения a не всегда уникальны.

import random
xs = [random.randrange(0, 1 << 61) for _ in xrange(random.randrange(10 ** 5))]
a_original = random.randrange(1 << 61)
S = sum(x ^ a_original for x in xs)
print S
print xs

a = xor_sum(xs, S)
assert a < 1e100
print 'a:', a
print 'original a:', a_original

assert a <= a_original

print 'S', S
print 'SUM', sum(x^a for x in xs)

assert sum(x^a for x in xs) == S
Другие вопросы по тегам