Почему scala не выполняет оптимизацию хвостовых вызовов?

Просто играю с продолжениями. Цель состоит в том, чтобы создать функцию, которая получит другую функцию в качестве параметра и количества выполнения - и функцию возврата, которая будет применять параметр заданное количество раз.

Реализация выглядит довольно очевидно

def n_times[T](func:T=>T,count:Int):T=>T = {
  @tailrec
  def n_times_cont(cnt:Int, continuation:T=>T):T=>T= cnt match {
        case _ if cnt < 1 => throw new IllegalArgumentException(s"count was wrong $count")
        case 1 => continuation
        case _ => n_times_cont(cnt-1,i=>continuation(func(i)))
      }
  n_times_cont(count, func)
}

def inc (x:Int) = x+1

    val res1 = n_times(inc,1000)(1)  // Works OK, returns 1001

val res = n_times(inc,10000000)(1) // FAILS

Но проблем нет - этот код завершается с ошибкой Stackru. Почему здесь нет оптимизации хвостового вызова?

Я запускаю его в Eclipse, используя плагин Scala, и он возвращает исключение в потоке "main" java.lang.StackruError в scala.runtime.BoxesRunTime.boxToInteger(неизвестный источник) в Task_Mult$$anonfun$1.apply(Task_Mult.scala:25) в Task_Mult$$anonfun$n_times_cont$1$1.применить (Task_Mult.scala:18)

п.с.

Код F#, который является почти прямым переводом, работает без проблем

let n_times_cnt func count = 
    let rec n_times_impl count' continuation = 
        match count' with
        | _ when count'<1 -> failwith "wrong count"
        | 1 -> continuation
        | _ -> n_times_impl (count'-1) (func >> continuation) 
    n_times_impl count func

let inc x = x+1
let res = (n_times_cnt inc 10000000) 1

printfn "%o" res

2 ответа

Решение

Стандартная библиотека Scala имеет реализацию батутов в scala.util.control.TailCalls, Итак, вернемся к вашей реализации... Когда вы создаете вложенные вызовы с continuation(func(t))это хвостовые вызовы, просто не оптимизированные компилятором. Итак, давайте создадим T => TailRec[T]где кадры стека будут заменены объектами в куче. Затем верните функцию, которая примет аргумент и передаст его этой батутной функции:

import util.control.TailCalls._
def n_times_trampolined[T](func: T => T, count: Int): T => T = {
  @annotation.tailrec
  def n_times_cont(cnt: Int, continuation: T => TailRec[T]): T => TailRec[T] = cnt match {
    case _ if cnt < 1 => throw new IllegalArgumentException(s"count was wrong $count")
    case 1 => continuation
    case _ => n_times_cont(cnt - 1, t => tailcall(continuation(func(t))))
  }
  val lifted : T => TailRec[T] = t => done(func(t))
  t => n_times_cont(count, lifted)(t).result
}

Я могу ошибаться, но я подозреваю, что n_times_cont внутренняя функция должным образом преобразована для использования хвостовой рекурсии; виновник не там.

Стек взорван собранным continuation закрытие (то есть i=>continuation(func(i))), которые делают 10000000 вложенных звонков на ваш inc метод, как только вы примените результат основной функции.

на самом деле вы можете попробовать

scala> val rs = n_times(inc, 1000000)
rs: Int => Int = <function1> //<- we're happy here

scala> rs(1) //<- this blows up the stack!

Как в стороне, вы можете переписать

i=>continuation(func(i))

как

continuation compose func

ради большей читабельности

Другие вопросы по тегам