Использование бинарного поиска для поиска k-го наибольшего числа в таблице умножения n*m

Итак, я пытаюсь решить проблему: http://codeforces.com/contest/448/problem/D

Чемпион Бизон не просто очарователен, он еще и очень умен.

В то время как некоторые из нас изучали таблицу умножения, Чемпион Bizon веселился по-своему. Бизон Чемпион нарисовал таблицу умножения n × m, где элемент на пересечении i-й строки и j-го столбца равен i·j (строки и столбцы таблицы нумеруются, начиная с 1). Затем его спросили: какое число в таблице является k-ым по величине числом? Чемпион Бизон всегда отвечал правильно и сразу. Можете ли вы повторить его успех?

Рассмотрим данную таблицу умножения. Если вы записываете все n·m чисел из таблицы в неубывающем порядке, то k-е число, которое вы записываете, называется k-м наибольшим числом.

Входные данные В единственной строке записаны целые числа n, m и k (1 ≤ n, m ≤ 5·105; 1 ≤ k ≤ n·m).

Выходные данные Выведите k-е наибольшее число в таблице умножения × m.

Что я сделал, я применил бинарный поиск от 1 до n*m ищу номер, который имеет точно k элементов меньше, чем это. Для этого я сделал следующий код:

using namespace std;
#define ll long long
#define pb push_back
#define mp make_pair
ll n,m;
int f (int val);
int min (int a, int b);
int main (void)
{
    int k;
    cin>>n>>m>>k;
    int ind = k;
    ll low = 1LL;
    ll high = n*m;
    int ans;
    while (low <= high)
    {
        ll mid = low + (high-low)/2;
        if (f(mid) == k)
            ans = mid;
        else if (f(mid) < k)
            low = mid+1;
        else
            high = mid-1;
    }
    cout<<ans<<"\n";
    return 0;

}

int f (int val)
{
    int ret = 0;
    for ( int i = 1; i <= n; i++ )
    {
        ret = ret + min(val/i,m);
    }
    return ret;
}

int min (int a, int b)
{
    if (a < b)
        return a;
    else
        return b;
}

Тем не менее, я не знаю почему, но это дает неправильный ответ на тестовых случаях:

input
2 2 2
output
2

Мой вывод выходит 0

Я изучаю бинарный поиск, но я не знаю, где я ошибаюсь с этой реализацией. Любая помощь будет оценена.

2 ответа

Игнорируя тот факт, что ваш бинарный поиск не самый быстрый метод, вы все равно хотите знать, почему он некорректен.

Во-первых, будьте предельно ясны относительно того, что вы хотите и что возвращает ваш f:

ищем число, которое имеет ровно k элементов меньше его.

Нет! Вы ищете наименьшее число, у которого k элементов меньше или равно ему. И ваша функция f(X) возвращает количество элементов меньше или равное X.

Поэтому, когда f(X) возвращает слишком маленькое значение, вы знаете, что X должно быть больше как минимум на 1, поэтому low=mid+1 верно. Но когда f(X) возвращает слишком большое значение, X может быть идеальным (это может быть элемент, появляющийся в таблице несколько раз). И наоборот, когда f(X) возвращает точно правильное число, X может все еще быть слишком большим (X может быть значением, которое появляется в таблице ноль раз).

Поэтому, когда f(X) не слишком мало, лучшее, что вы можете сделать, это high=mid не high=mid-1

while (low < high)
{
    ll mid = low + (high-low)/2;
    if (f(mid) < k)
        low = mid+1;
    else
        high = mid;
}

Заметьте, что низкий никогда не становится> высоким, поэтому остановитесь, когда они равны, и мы не пытаемся поймать ans по пути. Вместо этого в конце низкий == высокий == Ответ

В конкурсе указано 1 секунда. На моем компьютере ваш код с этим исправлением решает проблему максимального размера менее чем за секунду. Но я не уверен, что компьютер для судейства такой быстрый.

Редактировать: int слишком мал для максимального размера проблемы, поэтому вы не можете вернуть int из f:
n, m и i умещаются в 32 бита, но на входе и выходе функции f(), а также k, ret, low и high должны содержаться целые числа до 2.5e11

import java.util.*;
public class op {
    static int n,m;
    static long k;
    public static void main(String args[]){
        Scanner s=new Scanner(System.in);
        n=s.nextInt();
        m=s.nextInt();
        k=s.nextLong();
        long start=1;
        long end=n*m;
        long ans=0;
        while(end>=start){
            long mid=start+end;
            mid/=2;
            long fmid=f(mid);
            long gmid=g(mid);
            if(fmid>=k && fmid-gmid<k){
                ans=mid;
                break;
            }
            else if(f(mid)>k){
                end=mid-1;
            }
            else{
                start=mid+1;
            }
        }
        System.out.println(ans);

    }
    static long f (long val)
    {
        long ret = 0;
        for ( int i = 1; i <= n; i++ )
        {
            ret = ret + Math.min(val/i,m);
        }
        return ret;
    }
    static long g (long val)
    {
        long ret = 0;
        for ( int i = 1; i <= n; i++ )
        {
            if(val%i==0 && val/i<=m){
                ret++;
            }
        }
        return ret;
    }
    public static class Pair{
        int x,y;
        Pair(int a,int b){
            x=a;y=b;
        }
    }
}
Другие вопросы по тегам