Segmented prime sieve

Count primes in a large range by splitting it into segments and sieving each on a worker. Pure standard library; each task returns just a count.

[ ]:
# Connection settings -- edit these to point at your running scheduler.
SCHEDULER_ADDRESS = "ws://127.0.0.1:2345"  # supports tcp:// or ws://; only ws:// works from JupyterLite (browser)
OBJECT_STORAGE_ADDRESS = None  # leave None to use whatever the scheduler advertises

# Default sweeps the first 2e9 integers across 128 segments; on 16 workers
# expect roughly a minute of wall-clock time.
SEARCH_UPPER = 2_000_000_000
N_SEGMENTS = 128
[ ]:
import time

from scaler import Client


def _small_primes(upper: int) -> list[int]:
    """Plain Eratosthenes sieve up to (but not including) `upper`."""
    if upper < 3:
        return []
    flags = bytearray([1]) * upper
    flags[0] = flags[1] = 0
    for i in range(2, int(upper ** 0.5) + 1):
        if flags[i]:
            flags[i * i :: i] = bytearray(len(flags[i * i :: i]))
    return [i for i, flag in enumerate(flags) if flag]


def count_primes_in_segment(low: int, high: int, base_primes: list[int]) -> int:
    """Worker-side: segmented Eratosthenes sieve over [low, high)."""
    if high <= 2:
        return 0
    low = max(low, 2)
    size = high - low
    flags = bytearray([1]) * size
    for p in base_primes:
        if p * p >= high:
            break
        # first multiple of p in [low, high)
        start = max(p * p, ((low + p - 1) // p) * p)
        flags[start - low :: p] = bytearray(len(flags[start - low :: p]))
    return sum(flags)


segment_size = SEARCH_UPPER // N_SEGMENTS
segments = [(i * segment_size, (i + 1) * segment_size if i < N_SEGMENTS - 1 else SEARCH_UPPER) for i in range(N_SEGMENTS)]
base_primes = _small_primes(int(SEARCH_UPPER ** 0.5) + 1)

with Client(address=SCHEDULER_ADDRESS, object_storage_address=OBJECT_STORAGE_ADDRESS) as client:
    base_ref = client.send_object(base_primes, name="base-primes")
    started = time.perf_counter()
    futures = [client.submit(count_primes_in_segment, low, high, base_ref) for low, high in segments]
    total_primes = sum(f.result() for f in futures)
    elapsed = time.perf_counter() - started

print(f"counted {total_primes:,} primes below {SEARCH_UPPER:,} across {N_SEGMENTS} segments in {elapsed:.2f}s")