省メモリ版エラトステネスの篩を Scala で


Squeak Smalltalk の #largePrimesUpTo:do: をほぼそのまま Scala に書き換えてみました。手元の環境では残念ながら 1<<30 ですと OutOfMemoryError になるので少し減らして 1<<29 に収まる全素数(28192750個。最大は 536870909)の生成を試したところ、おおよそ Squeak4.2 CogVM の倍速といった感じでした(1.8GHz Core i7Squeak の 27秒に対し Scala は 14秒)。限られた Scala の知識で機械的に変換したものなので Scala をよく知っている人に手を入れてもらえれば軽く10倍くらいには高速化できるとおもいますので後はよろしくお願いします。w


他の言語への移植もチャレンジしてみたいですね。


object PrimesUpTo extends App {
  def primes(upto: Int, doFun: Int => Unit): Unit = {
    if (upto > 25000) return large_primes(upto, doFun)
    var limit = upto - 1
    var flags = new Array[Boolean](limit)
    for (i <- 0 until limit - 1) {
      if (!flags(i)) {
        var prime = i + 2
        var k = i + prime
        while (k < limit) {
          flags(k) = true
          k += prime
        }
        doFun(prime)
      }
    }
  }

  def large_primes(upto: Int, doFun: Int => Unit): Unit = {
    var limit = upto - 1
    var idx_limit = Math.sqrt(limit).toInt

    var flags = (new Array[Int]((limit + 2309) / 2310 * 60 + 60)).map(_ => 0xFF)

    var buf = new scala.collection.mutable.ArrayBuffer[Int]
    var i: Int = 0
    primes(2310, prime => buf += prime)
    var primes_up_to_2310 = buf.toArray

    var mask_bit_idx = new Array[Int](2310)
    mask_bit_idx(0) = 0
    mask_bit_idx(1) = 1
    var bit_idx = 1

    for (i <- 0 to 4) doFun(primes_up_to_2310(i))

    var idx = 5
    for (n <- 2 to 2309) {
      while (primes_up_to_2310(idx) < n) idx += 1
      if (n == primes_up_to_2310(idx)) {
        bit_idx += 1
        mask_bit_idx(n) = bit_idx
      } else {
      if (n % 2 == 0 || n % 3 == 0 || n % 5 == 0 || n % 7 == 0 || n % 11 == 0) {
        mask_bit_idx(n) = 0
        } else {
        bit_idx += 1
        mask_bit_idx(n) = bit_idx
        }
      }
    }


    for (n <- 13 to limit by 2) {
      var mask_bit = mask_bit_idx(n % 2310)
      if (mask_bit != 0) {
        var byte_idx = n / 2310 * 60 + (mask_bit - 1) / 8
        bit_idx = 1 << (mask_bit & 7)
        if ((flags(byte_idx) & bit_idx) != 0) {
          doFun(n)
          if (n < idx_limit) {
            idx = n * n
            if ((idx & 1) == 0) idx += n
            while (idx <= limit) {
              mask_bit = mask_bit_idx(idx % 2310)
              if (mask_bit != 0) {
                byte_idx = idx / 2310 * 60 + (mask_bit - 1) / 8
                mask_bit = 255 - (1 << (mask_bit & 7))
                flags(byte_idx) = flags(byte_idx) & mask_bit
              }
              idx += 2 * n
            }
          }
        }
      }
    }
  }

  var c = 0
  var q = 0
  var start = new java.util.Date().getTime
  primes(1<<29, p => { q = p; c += 1 })
  println(q, c, new java.util.Date().getTime - start)

  // 1.8GHz Core i7
  // Squeek4.2 CogVM, 27457 ms
  // Scala 2.9.2.rdev-2769, 14223 ms
}


参考: