Segmented prime sieve¶
Count primes in a large range by splitting it into segments and sieving each on a worker. Pure standard library; each task returns just a count.
[ ]:
# Connection settings -- edit these to point at your running scheduler.
SCHEDULER_ADDRESS = "ws://127.0.0.1:2345" # supports tcp:// or ws://; only ws:// works from JupyterLite (browser)
OBJECT_STORAGE_ADDRESS = None # leave None to use whatever the scheduler advertises
# Default sweeps the first 2e9 integers across 128 segments; on 16 workers
# expect roughly a minute of wall-clock time.
SEARCH_UPPER = 2_000_000_000
N_SEGMENTS = 128
[ ]:
import time
from scaler import Client
def _small_primes(upper: int) -> list[int]:
"""Plain Eratosthenes sieve up to (but not including) `upper`."""
if upper < 3:
return []
flags = bytearray([1]) * upper
flags[0] = flags[1] = 0
for i in range(2, int(upper ** 0.5) + 1):
if flags[i]:
flags[i * i :: i] = bytearray(len(flags[i * i :: i]))
return [i for i, flag in enumerate(flags) if flag]
def count_primes_in_segment(low: int, high: int, base_primes: list[int]) -> int:
"""Worker-side: segmented Eratosthenes sieve over [low, high)."""
if high <= 2:
return 0
low = max(low, 2)
size = high - low
flags = bytearray([1]) * size
for p in base_primes:
if p * p >= high:
break
# first multiple of p in [low, high)
start = max(p * p, ((low + p - 1) // p) * p)
flags[start - low :: p] = bytearray(len(flags[start - low :: p]))
return sum(flags)
segment_size = SEARCH_UPPER // N_SEGMENTS
segments = [(i * segment_size, (i + 1) * segment_size if i < N_SEGMENTS - 1 else SEARCH_UPPER) for i in range(N_SEGMENTS)]
base_primes = _small_primes(int(SEARCH_UPPER ** 0.5) + 1)
with Client(address=SCHEDULER_ADDRESS, object_storage_address=OBJECT_STORAGE_ADDRESS) as client:
base_ref = client.send_object(base_primes, name="base-primes")
started = time.perf_counter()
futures = [client.submit(count_primes_in_segment, low, high, base_ref) for low, high in segments]
total_primes = sum(f.result() for f in futures)
elapsed = time.perf_counter() - started
print(f"counted {total_primes:,} primes below {SEARCH_UPPER:,} across {N_SEGMENTS} segments in {elapsed:.2f}s")