Free ~> Trampoline: рекурсивные сбои программы с OutOfMemoryError

Предположим, что я пытаюсь реализовать очень простой предметно-ориентированный язык только с одной операцией:

printLine(line)

Затем я хочу написать программу, которая принимает целое число n в качестве ввода печатает что-то, если n делится на 10k, а затем вызывает себя с n + 1, до тех пор n достигает некоторого максимального значения N,

Опуская все синтаксические шумы, вызванные для понимания, я хочу:

@annotation.tailrec def p(n: Int): Unit = {
  if (n % 10000 == 0) printLine("line")
  if (n > N) () else p(n + 1)
}

По сути, это будет своего рода "шипучий шум".

Вот несколько попыток реализовать это с помощью Free monad из Scalaz 7.3.0-M7:

import scalaz._

object Demo1 {

  // define operations of a little domain specific language
  sealed trait Lang[X]
  case class PrintLine(line: String) extends Lang[Unit]

  // define the domain specific language as the free monad of operations
  type Prog[X] = Free[Lang, X]

  import Free.{liftF, pure}

  // lift operations into the free monad
  def printLine(l: String): Prog[Unit] = liftF(PrintLine(l))
  def ret: Prog[Unit] = Free.pure(())

  // write a program that is just a loop that prints current index 
  // after every few iteration steps
  val mod =  100000
  val N =   1000000

  // straightforward syntax: deadly slow, exits with OutOfMemoryError
  def p0(i: Int): Prog[Unit] = for {
    _ <- (if (i % mod == 0) printLine("i = " + i) else ret)
    _ <- (if (i > N) ret else p0(i + 1))
  } yield ()

  // Same as above, but written out without `for`
  def p1(i: Int): Prog[Unit] = 
    (if (i % mod == 0) printLine("i = " + i) else ret).flatMap{
      ignore1 =>
      (if (i > N) ret else p1(i + 1)).map{ ignore2 => () }
    }

  // Same as above, with `map` attached to recursive call
  def p2(i: Int): Prog[Unit] = 
    (if (i % mod == 0) printLine("i = " + i) else ret).flatMap{
      ignore1 =>
      (if (i > N) ret else p2(i + 1).map{ ignore2 => () })
    }

  // Same as above, but without the `map`; performs ok.
  def p3(i: Int): Prog[Unit] = {
    (if (i % mod == 0) printLine("i = " + i) else ret).flatMap{ 
      ignore1 =>
      if (i > N) ret else p3(i + 1)
    }
  }

  // Variation of the above; Ok.
  def p4(i: Int): Prog[Unit] = (for {
    _ <- (if (i % mod == 0) printLine("i = " + i) else ret)
  } yield ()).flatMap{ ignored2 => 
    if (i > N) ret else p4(i + 1) 
  }

  // try to use the variable returned by the last generator after yield,
  // hope that the final `map` is optimized away (it's not optimized away...)
  def p5(i: Int): Prog[Unit] = for {
    _ <- (if (i % mod == 0) printLine("i = " + i) else ret)
    stopHere <- (if (i > N) ret else p5(i + 1))
  } yield stopHere

  // define an interpreter that translates the programs into Trampoline
  import scalaz.Trampoline
  type Exec[X] = Free.Trampoline[X]  
  val interpreter = new (Lang ~> Exec) {
    def apply[A](cmd: Lang[A]): Exec[A] = cmd match {
      case PrintLine(l) => Trampoline.delay(println(l))
    }
  }

  // try it out
  def main(args: Array[String]): Unit = {
    println("\n p0")
    p0(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
    println("\n p1")
    p1(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
    println("\n p2")
    p2(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
    println("\n p3")
    p3(0).foldMap(interpreter).run // ok 
    println("\n p4")
    p4(0).foldMap(interpreter).run // ok
    println("\n p5")
    p5(0).foldMap(interpreter).run // OutOfMemory
  }
}

К сожалению, простой перевод (p0), кажется, работает с некоторой нагрузкой O(N^2) и завершается с ошибкой OutOfMemoryError. Проблема в том, что for-понимание добавляет map{x => ()} после рекурсивного вызова p0, который заставляет Free Монаду, чтобы заполнить всю память напоминаниями "закончить" p0 ", а затем ничего не делать". Если я вручную "разверну" for понимание и выпиши последний flatMap явно (как в p3 а также p4), то проблема уходит, и все идет гладко. Это, однако, чрезвычайно хрупкий обходной путь: поведение программы резко меняется, если мы просто добавляем map(id) к этому, и это map(id) даже не виден в коде, потому что он генерируется автоматически for-comprehension.

В этом старом посте здесь: https://apocalisp.wordpress.com/2011/10/26/tail-call-elimination-in-scala-monads/ неоднократно советовали заключать рекурсивные вызовы в suspend, Вот попытка с Applicative экземпляр и suspend:

import scalaz._

// Essentially same as in `Demo1`, but this time with 
// an `Applicative` and an explicit `Suspend` in the 
// `for`-comprehension
object Demo2 {

  sealed trait Lang[H]

  case class Const[H](h: H) extends Lang[H]
  case class PrintLine[H](line: String) extends Lang[H]

  implicit object Lang extends Applicative[Lang] {
    def point[A](a: => A): Lang[A] = Const(a)
    def ap[A, B](a: => Lang[A])(f: => Lang[A => B]): Lang[B] = a match {
      case Const(x) => {
        f match {
          case Const(ab) => Const(ab(x))
          case _ => throw new Error
        }
      }
      case PrintLine(l) => PrintLine(l)
    }
  }

  type Prog[X] = Free[Lang, X]

  import Free.{liftF, pure}
  def printLine(l: String): Prog[Unit] = liftF(PrintLine(l))
  def ret: Prog[Unit] = Free.pure(())

  val mod = 100000
  val N = 2000000

  // try to suspend the entire second generator
  def p7(i: Int): Prog[Unit] = for {
    _ <- (if (i % mod == 0) printLine("i = " + i) else ret)
    _ <- Free.suspend(if (i > N) ret else p7(i + 1))
  } yield ()

  // try to suspend the recursive call
  def p8(i: Int): Prog[Unit] = for {
    _ <- (if (i % mod == 0) printLine("i = " + i) else ret)
    _ <- if (i > N) ret else Free.suspend(p8(i + 1))
  } yield ()

  import scalaz.Trampoline
  type Exec[X] = Free.Trampoline[X]

  val interpreter = new (Lang ~> Exec) {
    def apply[A](cmd: Lang[A]): Exec[A] = cmd match {
      case Const(x) => Trampoline.done(x)
      case PrintLine(l) => 
        (Trampoline.delay(println(l))).asInstanceOf[Exec[A]]
    }
  }

  def main(args: Array[String]): Unit = {
    p7(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
    p8(0).foldMap(interpreter).run // same...
  }
}

Вставка suspend не очень помогло: все равно медленно и вылетает с OutOfMemoryErrors.

Должен ли я использовать suspend как-то иначе?

Может быть, есть какое-то чисто синтаксическое средство, которое позволяет использовать для-понимания без генерации map в конце?

Я был бы очень признателен, если бы кто-то мог указать, что я здесь делаю неправильно и как это исправить.

1 ответ

Решение

Это лишнее map добавленный компилятором Scala перемещает рекурсию из хвостовой позиции в не хвостовую позицию. Свободная монада все еще делает этот стек безопасным, но сложность пространства становится O(N) вместо O (1). (В частности, это все еще не O (N2).)

Можно ли сделать scalac оптимизировать это map прочь делает для отдельного вопроса (на который я не знаю ответа).

Я попытаюсь проиллюстрировать, что происходит при интерпретации p1 против p3, (Я буду игнорировать перевод Trampoline, который является избыточным (см. ниже).)

p3 (т.е. без доплаты map)

Позвольте мне использовать следующую стенографию:

def cont(i: Int): Unit => Prg[Unit] =
  ignore1 => if (i > N) ret else p3(i + 1)

Сейчас p3(0) интерпретируется следующим образом

p3(0)
printLine("i = " + 0) flatMap cont(0)
// side-effect: println("i = 0")
cont(0)
p3(1)
ret flatMap cont(1)
cont(1)
p3(2)
ret flatMap cont(2)
cont(2)

и так далее... Вы видите, что объем памяти, необходимый в любой точке, не превышает некоторой постоянной верхней границы.

p1 (т.е. с дополнительным map)

Я буду использовать следующие сокращения:

def cont(i: Int): Unit => Prg[Unit] =
  ignore1 => (if (i > N) ret else p1(i + 1)).map{ ignore2 => () }

def cpu: Unit => Prg[Unit] = // constant pure unit
  ignore => Free.pure(())

Сейчас p1(0) интерпретируется следующим образом:

p1(0)
printLine("i = " + 0) flatMap cont(0)
// side-effect: println("i = 0")
cont(0)
p1(1) map { ignore2 => () }
// Free.map is implemented via flatMap
p1(1) flatMap cpu
(ret flatMap cont(1)) flatMap cpu
cont(1) flatMap cpu
(p1(2) map { ignore2 => () }) flatMap cpu
(p1(2) flatMap cpu) flatMap cpu
((ret flatMap cont(2)) flatMap cpu) flatMap cpu
(cont(2) flatMap cpu) flatMap cpu
((p1(3) map { ignore2 => () }) flatMap cpu) flatMap cpu
((p1(3) flatMap cpu) flatMap cpu) flatMap cpu
(((ret flatMap cont(3)) flatMap cpu) flatMap cpu) flatMap cpu

и так далее... Вы видите, что потребление памяти линейно зависит от N, Мы просто перенесли оценку из стека в кучу.

Забрать: сохранить Free память дружественная, держать рекурсию в "хвостовой позиции", то есть на правой стороне flatMap (или же map).

В сторону: перевод на Trampoline не нужно, так как Free уже батут. Вы могли бы интерпретировать непосредственно Id и использовать foldMapRec для стекового безопасного перевода:

val idInterpreter = new (Lang ~> Id) {
  def apply[A](cmd: Lang[A]): Id[A] = cmd match {
    case PrintLine(l) => println(l)
  }
}

p0(0).foldMapRec(idInterpreter)

Это вернет вам часть памяти (но не устранит проблему).

Другие вопросы по тегам