Можно ли использовать продолжения для рекурсивного хвоста foldRight?

Следующая статья блога показывает, как в F# foldBack можно сделать хвост рекурсивным, используя стиль передачи продолжения.

В Scala это будет означать, что:

def foldBack[T,U](l: List[T], acc: U)(f: (T, U) => U): U = {
  l match {
    case x :: xs => f(x, foldBack(xs, acc)(f))
    case Nil => acc
  }
} 

можно сделать хвост рекурсивным, выполнив это:

def foldCont[T,U](list: List[T], acc: U)(f: (T, U) => U): U = {
  @annotation.tailrec
  def loop(l: List[T], k: (U) => U): U = {
    l match {
      case x :: xs => loop(xs, (racc => k(f(x, racc))))
      case Nil => k(acc)
    }
  }
  loop(list, u => u)
} 

К сожалению, я все еще получаю переполнение стека для длинных списков. Цикл является хвостовым рекурсивным и оптимизированным, но я думаю, что накопление стека просто перемещено в вызовы продолжения.

Почему это не проблема с F#? И есть ли способ обойти это со Scala?

Изменить: здесь некоторый код, который показывает глубину стека:

def showDepth(s: Any) {
  println(s.toString + ": " + (new Exception).getStackTrace.size)
}

def foldCont[T,U](list: List[T], acc: U)(f: (T, U) => U): U = {
  @annotation.tailrec
  def loop(l: List[T], k: (U) => U): U = {
    showDepth("loop")
    l match {
      case x :: xs => loop(xs, (racc => { showDepth("k"); k(f(x, racc)) }))
      case Nil => k(acc)
    }
  }
  loop(list, u => u)
} 

foldCont(List.fill(10)(1), 0)(_ + _)

Это печатает:

loop: 50
loop: 50
loop: 50
loop: 50
loop: 50
loop: 50
loop: 50
loop: 50
loop: 50
loop: 50
loop: 50
k: 51
k: 52
k: 53
k: 54
k: 55
k: 56
k: 57
k: 58
k: 59
k: 60
res2: Int = 10

4 ответа

Решение

Проблема в продолжении функции (racc => k(f(x, racc))) сам. Он должен быть оптимизирован для работы всего бизнеса, но это не так.

Scala не может выполнять оптимизацию tailcall для произвольных вызовов tail, только для тех, которые она может преобразовать в циклы (то есть, когда функция вызывает себя, а не какую-то другую функцию).

Джон, нм, спасибо за ваши ответы. Основываясь на ваших комментариях, я решил попробовать батут. Небольшое исследование показывает, что Scala имеет библиотечную поддержку для батутов в TailCalls, Вот то, что я придумал после небольшого возни:

def foldContTC[T,U](list: List[T], acc: U)(f: (T, U) => U): U = {
  import scala.util.control.TailCalls._
  @annotation.tailrec
  def loop(l: List[T], k: (U) => TailRec[U]): TailRec[U] = {
    l match {
      case x :: xs => loop(xs, (racc => tailcall(k(f(x, racc)))))
      case Nil => k(acc)
    }
  }
  loop(list, u => done(u)).result
} 

Мне было интересно посмотреть, как это сравнивается с решением без батута, а также по умолчанию foldLeft а также foldRight Реализации. Вот код теста и некоторые результаты:

val size = 1000
val list = List.fill(size)(1)
val warm = 10
val n = 1000
bench("foldContTC", warm, lots(n, foldContTC(list, 0)(_ + _)))
bench("foldCont", warm, lots(n, foldCont(list, 0)(_ + _)))
bench("foldRight", warm, lots(n, list.foldRight(0)(_ + _)))
bench("foldLeft", warm, lots(n, list.foldLeft(0)(_ + _)))
bench("foldLeft.reverse", warm, lots(n, list.reverse.foldLeft(0)(_ + _)))

Сроки:

foldContTC: warming...
Elapsed: 0.094
foldCont: warming...
Elapsed: 0.060
foldRight: warming...
Elapsed: 0.160
foldLeft: warming...
Elapsed: 0.076
foldLeft.reverse: warming...
Elapsed: 0.155

Исходя из этого, может показаться, что прыжки на батуте на самом деле дают довольно хорошую производительность. Я подозреваю, что наказание за бокс / распаковку относительно неплохо.

Изменить: как предложено в комментариях Джона, здесь приведены временные параметры для элементов 1М, которые подтверждают, что производительность снижается с большими списками. Также я узнал, что реализация библиотеки List.foldLeft не переопределяется, поэтому я рассчитал следующее: foldLeft2:

def foldLeft2[T,U](list: List[T], acc: U)(f: (T, U) => U): U = {
  list match {
    case x :: xs => foldLeft2(xs, f(x, acc))(f)
    case Nil => acc
  }
} 

val size = 1000000
val list = List.fill(size)(1)
val warm = 10
val n = 2
bench("foldContTC", warm, lots(n, foldContTC(list, 0)(_ + _)))
bench("foldLeft", warm, lots(n, list.foldLeft(0)(_ + _)))
bench("foldLeft2", warm, lots(n, foldLeft2(list, 0)(_ + _)))
bench("foldLeft.reverse", warm, lots(n, list.reverse.foldLeft(0)(_ + _)))
bench("foldLeft2.reverse", warm, lots(n, foldLeft2(list.reverse, 0)(_ + _)))

выходы:

foldContTC: warming...
Elapsed: 0.801
foldLeft: warming...
Elapsed: 0.156
foldLeft2: warming...
Elapsed: 0.054
foldLeft.reverse: warming...
Elapsed: 0.808
foldLeft2.reverse: warming...
Elapsed: 0.221

Так что foldLeft2.reverse является победителем...

Почему это не проблема с F#?

F# оптимизировал все хвостовые вызовы.

И есть ли способ обойти это со Scala?

Вы можете использовать TCO, используя другие приемы, такие как батуты, но теряете взаимодействие, потому что оно меняет соглашение о вызовах и работает в ~10 раз медленнее. Это одна из трех причин, по которым я не использую Scala.

РЕДАКТИРОВАТЬ

Результаты вашего теста показывают, что батуты Scala работают намного быстрее, чем в прошлый раз, когда я их тестировал. Также интересно добавить эквивалентные тесты, используя F# и для больших списков (потому что нет смысла делать CPS для маленьких списков!).

За 1000х в списке из 1000 элементов на моем нетбуке с процессором Intel Atom 1,65 ГГц N570 я получаю:

List.fold     0.022s
List.rev+fold 0.116s
List.foldBack 0.047s
foldContTC    0.334s

Для 1x 1 000 000 элементов списка я получаю:

List.fold     0.024s
List.rev+fold 0.188s
List.foldBack 0.054s
foldContTC    0.570s

Возможно, вас также заинтересуют старые обсуждения этого в списке caml в контексте замены функций нерекурсивного списка OCaml оптимизированными хвостовыми рекурсивными.

Я опоздал на этот вопрос, но я хотел показать, как вы можете написать хвостовой рекурсивный FoldRight без использования полного батута; накапливая список продолжений (вместо того, чтобы они вызывали друг друга, когда это было сделано, что приводит к переполнению стека) и складывая их в конце, что-то вроде хранения стека, но в куче:

object FoldRight {

  def apply[A, B](list: Seq[A])(init: B)(f: (A, B) => B): B = {
    @scala.annotation.tailrec
    def step(current: Seq[A], conts: List[B => B]): B = current match {
      case Seq(last) => conts.foldLeft(f(last, init)) { (acc, next) => next(acc) }
      case Seq(x, xs @ _*) => step(xs, { acc: B => f(x, acc) } +: conts)
      case Nil => init
    }
    step(list, Nil)
  }

}

Сгиб, который происходит в конце, сам по себе является хвост-рекурсивным. Попробуйте это в ScalaFiddle

С точки зрения производительности, он работает немного хуже, чем версия с хвостовым вызовом.

[info] Benchmark            (length)  Mode  Cnt   Score    Error  Units
[info] FoldRight.conts           100  avgt   30   0.003 ±  0.001  ms/op
[info] FoldRight.conts         10000  avgt   30   0.197 ±  0.004  ms/op
[info] FoldRight.conts       1000000  avgt   30  77.292 ±  9.327  ms/op
[info] FoldRight.standard        100  avgt   30   0.002 ±  0.001  ms/op
[info] FoldRight.standard      10000  avgt   30   0.154 ±  0.036  ms/op
[info] FoldRight.standard    1000000  avgt   30  18.796 ±  0.551  ms/op
[info] FoldRight.tailCalls       100  avgt   30   0.002 ±  0.001  ms/op
[info] FoldRight.tailCalls     10000  avgt   30   0.176 ±  0.004  ms/op
[info] FoldRight.tailCalls   1000000  avgt   30  33.525 ±  1.041  ms/op
Другие вопросы по тегам