Как улучшить производительность рекурсивного метода?
Я изучаю структуры данных и алгоритмы, и вот вопрос, который я застрял.
Я должен улучшить производительность рекурсивного вызова, сохраняя значение в памяти.
Но проблема в том, что не улучшенная версия кажется быстрее, чем эта.
Кто-нибудь может мне помочь?
Числа Сиракузы - это последовательность натуральных чисел, определяемая следующими правилами:
Сира (1) ≡ 1
syra (n) ≡ n + syra (n/ 2), если n mod 2 == 0
syra (n) ≡ n + syra ((n* 3) +1), в противном случае
import java.util.HashMap;
import java.util.Map;
public class SyraLengthsEfficient {
int counter = 0;
public int syraLength(long n) {
if (n < 1) {
throw new IllegalArgumentException();
}
if (n < 500 && map.containsKey(n)) {
counter += map.get(n);
return map.get(n);
} else if (n == 1) {
counter++;
return 1;
} else if (n % 2 == 0) {
counter++;
return syraLength(n / 2);
} else {
counter++;
return syraLength(n * 3 + 1);
}
}
Map<Integer, Integer> map = new HashMap<Integer, Integer>();
public int lengths(int n) {
if (n < 1) {
throw new IllegalArgumentException();
}
for (int i = 1; i <= n; i++) {
syraLength(i);
if (i < 500 && !map.containsKey(i)) {
map.put(i, counter);
}
}
return counter;
}
public static void main(String[] args) {
System.out.println(new SyraLengthsEfficient().lengths(5000000));
}
}
Вот нормальная версия, которую я написал:
public class SyraLengths{
int total=1;
public int syraLength(long n) {
if (n < 1)
throw new IllegalArgumentException();
if (n == 1) {
int temp=total;
total=1;
return temp;
}
else if (n % 2 == 0) {
total++;
return syraLength(n / 2);
}
else {
total++;
return syraLength(n * 3 + 1);
}
}
public int lengths(int n){
if(n<1){
throw new IllegalArgumentException();
}
int total=0;
for(int i=1;i<=n;i++){
total+=syraLength(i);
}
return total;
}
public static void main(String[] args){
System.out.println(new SyraLengths().lengths(5000000));
}
}
РЕДАКТИРОВАТЬ
Это медленнее, чем не расширенная версия.
import java.util.HashMap;
import java.util.Map;
public class SyraLengthsEfficient {
private Map<Long, Long> map = new HashMap<Long, Long>();
public long syraLength(long n, long count) {
if (n < 1)
throw new IllegalArgumentException();
if (!map.containsKey(n)) {
if (n == 1) {
count++;
map.put(n, count);
} else if (n % 2 == 0) {
count++;
map.put(n, count + syraLength(n / 2, 0));
} else {
count++;
map.put(n, count + syraLength(3 * n + 1, 0));
}
}
return map.get(n);
}
public int lengths(int n) {
if (n < 1) {
throw new IllegalArgumentException();
}
int total = 0;
for (int i = 1; i <= n; i++) {
// long temp = syraLength(i, 0);
// System.out.println(i + " : " + temp);
total += syraLength(i, 0);
}
return total;
}
public static void main(String[] args) {
System.out.println(new SyraLengthsEfficient().lengths(50000000));
}
}
ЗАКЛЮЧИТЕЛЬНОЕ РЕШЕНИЕ (пометить как правильное по школьной системе автоматической пометки)
public class SyraLengthsEfficient {
private int[] values = new int[10 * 1024 * 1024];
public int syraLength(long n, int count) {
if (n <= values.length && values[(int) (n - 1)] != 0) {
return count + values[(int) (n - 1)];
} else if (n == 1) {
count++;
values[(int) (n - 1)] = 1;
return count;
} else if (n % 2 == 0) {
count++;
if (n <= values.length) {
values[(int) (n - 1)] = count + syraLength(n / 2, 0);
return values[(int) (n - 1)];
} else {
return count + syraLength(n / 2, 0);
}
} else {
count++;
if (n <= values.length) {
values[(int) (n - 1)] = count + syraLength(n * 3 + 1, 0);
return values[(int) (n - 1)];
} else {
return count + syraLength(n * 3 + 1, 0);
}
}
}
public int lengths(int n) {
if (n < 1) {
throw new IllegalArgumentException();
}
int total = 0;
for (int i = 1; i <= n; i++) {
total += syraLength(i, 0);
}
return total;
}
public static void main(String[] args) {
SyraLengthsEfficient s = new SyraLengthsEfficient();
System.out.println(s.lengths(50000000));
}
}
3 ответа
Забудьте об ответах, которые говорят, что ваш код неэффективен из-за использования Map
это не причина, почему это идет медленно - это тот факт, что вы ограничиваете кеш вычисляемых чисел n < 500
, Как только вы уберете это ограничение, все начнет работать довольно быстро; Вот подтверждение концепции для вас, чтобы заполнить детали:
private Map<Long, Long> map = new HashMap<Long, Long>();
public long syraLength(long n) {
if (!map.containsKey(n)) {
if (n == 1)
map.put(n, 1L);
else if (n % 2 == 0)
map.put(n, n + syraLength(n/2));
else
map.put(n, n + syraLength(3*n+1));
}
return map.get(n);
}
Если вы хотите узнать больше о том, что происходит в программе и почему так быстро, взгляните на эту статью в Википедии о мемоизации.
Кроме того, я думаю, что вы злоупотребляете counter
переменная, вы увеличиваете ее (++
) когда значение рассчитывается в первый раз, но вы накапливаете его (+=
) когда значение найдено на карте. Это не кажется мне правильным, и я сомневаюсь, что это дает ожидаемый результат.
Не используйте карту. сохраните временный результат в поле (это называется аккумулятором) и выполняйте итерацию в цикле до тех пор, пока n = 1. после каждого цикла ваш аккумулятор будет расти на n. и в каждом цикле ваш n будет расти в 3 раза + 1 или будет уменьшаться в 2 раза. надеюсь, что это поможет вам решить домашнее задание
Конечно, это не так хорошо, вы добавляете много накладных расходов в вызовах map.put и map.get (хеширование, создание сегмента и т. Д.). Плюс вы автобокс, который добавляет беспорядок создания объектов. Я предполагаю, что накладные расходы карты намного перевешивают выгоду.
Попробуйте вместо этого использовать два массива. один для хранения значений и для хранения флагов, которые сообщают вам, установлено ли значение или нет.
int [] syr = new int[Integer.MAX_VALUE];
boolean [] syrcomputed = new boolean[Integer.MAX_VALUE];
и используйте их вместо карты:
if (syrcomputed[n]) {
return syr[n];
}
else {
syrcomputed[n] = true;
syr[n] = ....;
}
Кроме того, я думаю, что вы можете столкнуться с некоторым переполнением здесь с большими числами (когда syr приближается к MAX_INT/3, вы определенно увидите это, если он не делится на 2).
Поэтому вам, вероятно, следует использовать длинные типы для всех ваших вычислений.
PS: если ваша цель действительно понять рекурсию, вы не должны хранить значения как переменную экземпляра, а должны передавать ее как аккумулятор:
public int syr(int n) {
return syr(n, new int[Integer.MAX_VALUE], new boolean[Integer.MAX_VALUE]);
}
private int syr(int n, int[] syr, boolean[] syrcomputed) {
if (syrcomputed[n]) {
return syr[n];
}
else {
s = [ block for recursive computation ]
syrcomputed[n] = true;
syr = s;
}
}
В некоторых функциональных языках (схема, эрланг и т. Д.) Это фактически разворачивается как хвостовой вызов (что позволяет избежать создания стека). Хотя горячая точка jvm этого не делает (по крайней мере, насколько мне известно), это все еще важная концепция.