Реализация сита Аткина в Python

Я пытаюсь реализовать алгоритм Сита Аткина, приведенный в ссылке на Википедию, как показано ниже:

Сито Аткин

До сих пор я пробовал реализацию в Python, представленную следующим кодом:

import math
is_prime = list()
limit = 100
for i in range(5,limit):
    is_prime.append(False)

for x in range(1,int(math.sqrt(limit))+1):
    for y in range(1,int(math.sqrt(limit))+1):
        n = 4*x**2 + y**2

        if n<=limit and (n%12==1 or n%12==5):
            # print "1st if"
            is_prime[n] = not is_prime[n]
        n = 3*x**2+y**2
        if n<= limit and n%12==7:
            # print "Second if"
            is_prime[n] = not is_prime[n]
        n = 3*x**2 - y**2
        if x>y and n<=limit and n%12==11:
            # print "third if"
            is_prime[n] = not is_prime[n]

for n in range(5,int(math.sqrt(limit))):
    if is_prime[n]:
        for k in range(n**2,limit+1,n**2):
            is_prime[k] = False
print 2,3
for n in range(5,limit):
    if is_prime[n]: print n

Теперь я получаю ошибку как

is_prime[n] = not is_prime[n]
IndexError: list index out of range

это означает, что я обращаюсь к значению в списке, где индекс больше длины списка. Рассмотрим Условие, когда x,y = 100, тогда, конечно, условие n=4x^2+y^2 даст значение, которое больше длины списка. Я что-то здесь не так делаю? Пожалуйста помоги.

РЕДАКТИРОВАТЬ 1 Как предложил Гейб, используя

is_prime = [False] * (limit + 1)

инстед из:

for i in range(5,limit):
    is_prime.append(False)

действительно решил проблему.

4 ответа

Решение

Вы проблема в том, что ваш лимит составляет 100, но ваш is_prime только в списке limit-5 элементы в нем из-за инициализации с range(5, limit),

Поскольку этот код предполагает, что он может получить доступ к limit Индекс, вам нужно иметь limit+1 элементы в нем: is_prime = [False] * (limit + 1)

Обратите внимание, что это не имеет значения, что 4x^2+y^2 больше, чем limit потому что это всегда проверяет n <= limit,

Вот решение

import math

def sieveOfAtkin(limit):
    P = [2,3]
    sieve=[False]*(limit+1)
    for x in range(1,int(math.sqrt(limit))+1):
        for y in range(1,int(math.sqrt(limit))+1):
            n = 4*x**2 + y**2
            if n<=limit and (n%12==1 or n%12==5) : sieve[n] = not sieve[n]
            n = 3*x**2+y**2
            if n<= limit and n%12==7 : sieve[n] = not sieve[n]
            n = 3*x**2 - y**2
            if x>y and n<=limit and n%12==11 : sieve[n] = not sieve[n]
    for x in range(5,int(math.sqrt(limit))):
        if sieve[x]:
            for y in range(x**2,limit+1,x**2):
                sieve[y] = False
    for p in range(5,limit):
        if sieve[p] : P.append(p)
    return P

print sieveOfAtkin(100)

Спасибо за очень интересный вопрос!

Поскольку ошибки в вашем коде уже исправлены другими ответами, я решил реализовать с нуля свои собственные, очень оптимизированные версии Решета Аткина , а также Решета Эратосфена (для сравнения).

Оказывается, из 4 реализованных мной функций лучшая работает в 193 раза быстрее исходного кода, невероятное ускорение! Если на моем медленном ноутбуке лимит в 10 миллионов в вашем коде занимает 50 секунд, в моей функции тот же лимит занимает всего 0,26 секунды.

Лучшее ускорение в моем коде достигается с помощью пакетов Numba и Numpy . Они используются только для ускорения выполнения кода (с помощью прекомпиляции), но никаких научных функций из этих пакетов я не использовал.

Сначала я покажу тайминги, это вывод на консоль, если вы запустите мой код, расположенный в конце поста.

      Limit 10_000_000
SieveOfAtkin_Mahadeva           : time  50.513 sec, boost    1.00x
SieveOfAtkin_bergee             : time  13.016 sec, boost    3.88x
SieveOfEratosthenes_Arty_Python : time   5.768 sec, boost    8.76x
SieveOfAtkin_Arty_Python        : time   3.632 sec, boost   13.91x
SieveOfEratosthenes_Arty_Numba  : time   0.445 sec, boost  113.51x
SieveOfAtkin_Arty_Numba         : time   0.261 sec, boost  193.54x

Помимо моих 4 функций, я использовал оригинальный код спрашивающего @Mahadeva и лучший (с точки зрения скорости) код ответа .

Я сделал 2 версии функции Аткина (первую на чистом Python, вторую с использованием Numba и Numpy) и 2 версии Решета Эратосфена (те же, одну на чистом Python, другую с Numba/Numpy).

Также я провел следующие оптимизации, вы можете узнать их, прочитав официальную Wiki Atkin (раздел псевдокода):

  1. Вместо вычисления циклов X и Y до выполнения шага 1, согласно Wiki, вы можете выполнить шаг 2 в половине циклов.

  2. Кроме того, если вы выполните домашнее задание по математике, то легко увидеть, что цикл X не нужно запускать до тех пор, покаSqrt(limit), но вместо этого первый цикл X может выполняться до тех пор, покаSqrt(limit / 4), второй цикл X доSqrt(limit / 3), третья петля X доSqrt(limit / 2). Все это можно получить, обратив термин4 * x * xи3 * x * x.

  3. Также в Википедии написано, что не нужно обрабатывать ВСЕ значенияn % 12 == 1 or n % 12 == 5иn % 12 == 7иn % 12 == 11, но только на 30-50% меньшее подмножество напоминаний по модулю 60. Моя функцияSieveOfAtkin_Arty_Numba()(и Wiki тоже) показывает, какие напоминания использовать.

  4. Вместо сохранения массиваis_prime[]логических значений илиbyteзначений (в случае Numpy), достаточно сохранить массив битов. Это уменьшит использование памяти ровно в 8 раз . Это не только ускоряет вычисления за счет использования кэша ЦП, но также позволяет вычислить гораздо больше простых чисел, если у вас ограниченная память. Две версии Numba выполняют битовую арифметику для работы с битовым массивом.

  5. Предварительная компиляция Numba выполняет большую часть работы по оптимизации. Потому что он преобразует код в промежуточное представление LLVM , которое является своего рода ассемблерным кодом, который по скорости аналогичен оптимизированному коду C/C++. По сути, благодаря помощи Numba код становится таким же быстрым, как если бы вы написали его на чистом C/C++, а не на Python. Но это всё равно код Python, но автоматически оптимизированный Numba.

Посмотрите на функциюSieveOfAtkin_Arty_Python()— по сути, эта функция — это то, на что вам нужно обратить внимание, чтобы изучить мой код. Он написан на чистом Python(без Numba), но в 13,9 раз быстрее, чем ваш исходный код, и в 3,58 раза быстрее, чем лучший другой @bergeeответ @bergee .

Если вам не нужна тяжелая Numba в ваших проектах, лучше всего скопировать и вставить код функции SieveOfAtkin_Arty_Python(), это лучший вариант из чистых решений Python.

Перед запуском кода выполните установку пакетов только один раз.python -m pip install numba numpy -U. Если вам не нравится Numba, удалите импорт пакетов Numba и Numpy из моей первой строки кода, а также удалите две функции.SieveOfEratosthenes_Arty_Numba()иSieveOfAtkin_Arty_Numbda().

Попробуйте онлайн!

      import numba as nb, numpy as np, math, time

def SieveOfAtkin_Arty_Python(limit):
    # https://en.wikipedia.org/wiki/Sieve_of_Atkin
    end = limit + 1
    sqrt_end = int(end ** 0.5 + 2.01)
    primes = [2, 3]
    is_prime = [False] * end
    for x in range(1, int((end / 4) ** 0.5 + 2.01)):
        xx4 = 4 * x * x
        for y in range(1, sqrt_end, 2):
            n = xx4 + y * y
            if n >= end:
                break
            if n % 12 == 1 or n % 12 == 5:
                is_prime[n] = not is_prime[n]
    for x in range(1, int((end / 3) ** 0.5 + 2.01), 2):
        xx3 = 3 * x * x
        for y in range(2, sqrt_end, 2):
            n = xx3 + y * y
            if n >= end:
                break
            if n % 12 == 7:
                is_prime[n] = not is_prime[n]
    for x in range(2, int((end / 2) ** 0.5 + 2.01)):
        xx3 = 3 * x * x
        for y in range(x - 1, 0, -2):
            n = xx3 - y * y
            if n >= end:
                break
            if n % 12 == 11:
                is_prime[n] = not is_prime[n]
    for x in range(5, sqrt_end):
        if is_prime[x]:
            for y in range(x * x, end, x * x):
                is_prime[y] = False
    for p in range(5, end, 2):
        if is_prime[p]:
            primes.append(p)
    return primes

def SieveOfEratosthenes_Arty_Python(limit):
    # https://en.wikipedia.org/wiki/Sieve_of_Eratosthenes
    end = limit + 1
    composite = [False] * end
    for i in range(3, int(end ** 0.5 + 2.01)):
        if not composite[i]:
            for j in range(i * i, end, i):
                composite[j] = True
    return [2] + [i for i in range(3, end, 2) if not composite[i]]

@nb.njit(cache = True)
def SieveOfEratosthenes_Arty_Numba(limit):
    # https://en.wikipedia.org/wiki/Sieve_of_Eratosthenes
    end = limit + 1
    composite = np.zeros(((end + 7) // 8,), dtype = np.uint8)
    for i in range(3, int(end ** 0.5 + 2.01)):
        if not (composite[i // 8] & (1 << (i % 8))):
            for j in range(i * i, end, i):
                composite[j // 8] |= 1 << (j % 8)
    return np.array([2] + [i for i in range(3, end, 2)
        if not (composite[i // 8] & (1 << (i % 8)))], dtype = np.uint32)

@nb.njit(cache = True)
def SieveOfAtkin_Arty_Numba(limit):
    # https://en.wikipedia.org/wiki/Sieve_of_Atkin
    # https://github.com/mccricardo/sieve_of_atkin/blob/master/sieve_of_atkin.py
    # https://stackoverflow.com/questions/21783160/
    end = limit + 1
    is_prime = np.zeros(((end + 7) // 8,), dtype = np.uint8)
    # Subset of n % 12 == 1 or n % 12 == 5
    set0 = np.array([int(i in {1, 13, 17, 29, 37, 41, 49, 53}) for i in range(60)], dtype = np.uint8)
    # Subset of n % 12 == 7
    set1 = np.array([int(i in {7, 19, 31, 43}) for i in range(60)], dtype = np.uint8)
    # Subset of n % 12 == 11
    set2 = np.array([int(i in {11, 23, 47, 59}) for i in range(60)], dtype = np.uint8)
    sqrt_end = int(math.sqrt(end) + 1.01)
    
    for x in range(1, int(sqrt_end / math.sqrt(4) + 2.01)):
        xx4 = 4 * x * x
        for y in range(1, sqrt_end, 2):
            n = xx4 + y * y
            if n >= end:
                break
            if set0[n % 60]:
                is_prime[n // 8] ^= 1 << (n % 8)
    for x in range(1, int(sqrt_end / math.sqrt(3) + 2.01), 2):
        xx3 = 3 * x * x
        for y in range(2, sqrt_end, 2):
            n = xx3 + y * y
            if n >= end:
                break
            if set1[n % 60]:
                is_prime[n // 8] ^= 1 << (n % 8)
    for x in range(2, int(sqrt_end / math.sqrt(2) + 2.01)):
        xx3 = 3 * x * x
        for y in range(x - 1, 0, -2):
            n = xx3 - y * y
            if n >= end:
                break
            if set2[n % 60]:
                is_prime[n // 8] ^= 1 << (n % 8)
    
    for n in range(7, sqrt_end):
        if is_prime[n // 8] & (1 << (n % 8)):
            for k in range(n * n, end, n * n):
                is_prime[k // 8] &= ~np.uint8(1 << (k % 8))
    
    return np.array([2, 3, 5] + [n for n in range(7, end, 2)
        if is_prime[n // 8] & (1 << (n % 8))], dtype = np.uint32)
        
def SieveOfAtkin_Mahadeva(limit):
    # https://stackoverflow.com/q/21783160/941531
    
    is_prime = [False] * (limit + 1)
    
    for x in range(1,int(math.sqrt(limit))+1):
        for y in range(1,int(math.sqrt(limit))+1):
            n = 4*x**2 + y**2

            if n<=limit and (n%12==1 or n%12==5):
                # print "1st if"
                is_prime[n] = not is_prime[n]
            n = 3*x**2+y**2
            if n<= limit and n%12==7:
                # print "Second if"
                is_prime[n] = not is_prime[n]
            n = 3*x**2 - y**2
            if x>y and n<=limit and n%12==11:
                # print "third if"
                is_prime[n] = not is_prime[n]

    for n in range(5,int(math.sqrt(limit))):
        if is_prime[n]:
            for k in range(n**2,limit+1,n**2):
                is_prime[k] = False
    return [2,3] + [n for n in range(5,limit) if is_prime[n]]

def SieveOfAtkin_bergee(limit):
    # https://stackoverflow.com/a/71490622/941531
    P = [2,3]
    r = range(1,int(math.sqrt(limit))+1)
    sieve=[False]*(limit+1)
    for x in r:
        for y in r:
            xx=x*x
            yy=y*y
            xx3 = 3*xx
            n = 4*xx + yy
            if n<=limit and (n%12==1 or n%12==5) : sieve[n] = not sieve[n]
            n = xx3 + yy
            if n<=limit and n%12==7 : sieve[n] = not sieve[n]
            n = xx3 - yy
            if x>y and n<=limit and n%12==11 : sieve[n] = not sieve[n]
    for x in range(5,int(math.sqrt(limit))):
        if sieve[x]:
            xx=x*x
            for y in range(xx,limit+1,xx):
                sieve[y] = False
    for p in range(5,limit):
        if sieve[p] : P.append(p)
    return P

def Test():
    limit = 5 * 10 ** 6
    # Do pretty printing of limit
    print(f'Limit', ''.join(reversed(''.join([['', '_'][i > 0 and i % 3 == 0] + c for i, c in enumerate(reversed(str(limit)))]))))
    rtim, rres = None, None
    for f in [
        SieveOfAtkin_Mahadeva,
        SieveOfAtkin_bergee, 
        SieveOfEratosthenes_Arty_Python,
        SieveOfAtkin_Arty_Python,
        SieveOfEratosthenes_Arty_Numba,
        SieveOfAtkin_Arty_Numba,
    ]:
        fname = f.__name__
        print(f'{fname:<31} : ', end = '', flush = True)
        f(1 << 10) # Pre-compute function, Numba needs it for pre-compilation
        tim = time.time()
        res = np.array(f(limit), dtype = np.uint32)
        tim = time.time() - tim
        if rtim is None:
            rtim = tim
        if rres is None:
            rres = res
        else:
            assert np.all(rres == res)
        print(f'time {tim:>7.3f} sec, boost {rtim / tim:>7.2f}x', flush = True)
        
if __name__ == '__main__':
    Test()

Это оптимизированная реализация, предложенная Zsolt KOVACS:

          import math
    import sys
    
    def sieveOfAtkin(limit):
        P = [2,3]
        r = range(1,int(math.sqrt(limit))+1)
        sieve=[False]*(limit+1)
        for x in r:
            for y in r:
                xx=x*x
                yy=y*y
                xx3 = 3*xx
                n = 4*xx + yy
                if n<=limit and (n%12==1 or n%12==5) : sieve[n] = not sieve[n]
                n = xx3 + yy
                if n<=limit and n%12==7 : sieve[n] = not sieve[n]
                n = xx3 - yy
                if x>y and n<=limit and n%12==11 : sieve[n] = not sieve[n]
        for x in range(5,int(math.sqrt(limit))):
            if sieve[x]:
                xx=x*x
                for y in range(xx,limit+1,xx):
                    sieve[y] = False
        for p in range(5,limit):
            if sieve[p] : P.append(p)
        return P
    
    primes = sieveOfAtkin(int(sys.argv[1]))    
    print (primes)

Вы передаете верхний предел в качестве первого аргумента. Эта программа работает примерно за 6 секунд на моей машине по сравнению с оригиналом, который работает за 21 секунду при ограничении в 10 миллионов. Что я сделал:

  • заменил возведение в степень умножением
  • предварительно вычислил некоторые умножения
Другие вопросы по тегам