Использование бинарного поиска для поиска k-го наибольшего числа в таблице умножения n*m
Итак, я пытаюсь решить проблему: http://codeforces.com/contest/448/problem/D
Чемпион Бизон не просто очарователен, он еще и очень умен.
В то время как некоторые из нас изучали таблицу умножения, Чемпион Bizon веселился по-своему. Бизон Чемпион нарисовал таблицу умножения n × m, где элемент на пересечении i-й строки и j-го столбца равен i·j (строки и столбцы таблицы нумеруются, начиная с 1). Затем его спросили: какое число в таблице является k-ым по величине числом? Чемпион Бизон всегда отвечал правильно и сразу. Можете ли вы повторить его успех?
Рассмотрим данную таблицу умножения. Если вы записываете все n·m чисел из таблицы в неубывающем порядке, то k-е число, которое вы записываете, называется k-м наибольшим числом.
Входные данные В единственной строке записаны целые числа n, m и k (1 ≤ n, m ≤ 5·105; 1 ≤ k ≤ n·m).
Выходные данные Выведите k-е наибольшее число в таблице умножения × m.
Что я сделал, я применил бинарный поиск от 1 до n*m
ищу номер, который имеет точно k
элементов меньше, чем это. Для этого я сделал следующий код:
using namespace std;
#define ll long long
#define pb push_back
#define mp make_pair
ll n,m;
int f (int val);
int min (int a, int b);
int main (void)
{
int k;
cin>>n>>m>>k;
int ind = k;
ll low = 1LL;
ll high = n*m;
int ans;
while (low <= high)
{
ll mid = low + (high-low)/2;
if (f(mid) == k)
ans = mid;
else if (f(mid) < k)
low = mid+1;
else
high = mid-1;
}
cout<<ans<<"\n";
return 0;
}
int f (int val)
{
int ret = 0;
for ( int i = 1; i <= n; i++ )
{
ret = ret + min(val/i,m);
}
return ret;
}
int min (int a, int b)
{
if (a < b)
return a;
else
return b;
}
Тем не менее, я не знаю почему, но это дает неправильный ответ на тестовых случаях:
input
2 2 2
output
2
Мой вывод выходит 0
Я изучаю бинарный поиск, но я не знаю, где я ошибаюсь с этой реализацией. Любая помощь будет оценена.
2 ответа
Игнорируя тот факт, что ваш бинарный поиск не самый быстрый метод, вы все равно хотите знать, почему он некорректен.
Во-первых, будьте предельно ясны относительно того, что вы хотите и что возвращает ваш f:
ищем число, которое имеет ровно k элементов меньше его.
Нет! Вы ищете наименьшее число, у которого k элементов меньше или равно ему. И ваша функция f(X)
возвращает количество элементов меньше или равное X.
Поэтому, когда f(X) возвращает слишком маленькое значение, вы знаете, что X должно быть больше как минимум на 1, поэтому low=mid+1
верно. Но когда f(X) возвращает слишком большое значение, X может быть идеальным (это может быть элемент, появляющийся в таблице несколько раз). И наоборот, когда f(X) возвращает точно правильное число, X может все еще быть слишком большим (X может быть значением, которое появляется в таблице ноль раз).
Поэтому, когда f(X) не слишком мало, лучшее, что вы можете сделать, это high=mid
не high=mid-1
while (low < high)
{
ll mid = low + (high-low)/2;
if (f(mid) < k)
low = mid+1;
else
high = mid;
}
Заметьте, что низкий никогда не становится> высоким, поэтому остановитесь, когда они равны, и мы не пытаемся поймать ans
по пути. Вместо этого в конце низкий == высокий == Ответ
В конкурсе указано 1 секунда. На моем компьютере ваш код с этим исправлением решает проблему максимального размера менее чем за секунду. Но я не уверен, что компьютер для судейства такой быстрый.
Редактировать: int
слишком мал для максимального размера проблемы, поэтому вы не можете вернуть int из f:
n, m и i умещаются в 32 бита, но на входе и выходе функции f(), а также k, ret, low и high должны содержаться целые числа до 2.5e11
import java.util.*;
public class op {
static int n,m;
static long k;
public static void main(String args[]){
Scanner s=new Scanner(System.in);
n=s.nextInt();
m=s.nextInt();
k=s.nextLong();
long start=1;
long end=n*m;
long ans=0;
while(end>=start){
long mid=start+end;
mid/=2;
long fmid=f(mid);
long gmid=g(mid);
if(fmid>=k && fmid-gmid<k){
ans=mid;
break;
}
else if(f(mid)>k){
end=mid-1;
}
else{
start=mid+1;
}
}
System.out.println(ans);
}
static long f (long val)
{
long ret = 0;
for ( int i = 1; i <= n; i++ )
{
ret = ret + Math.min(val/i,m);
}
return ret;
}
static long g (long val)
{
long ret = 0;
for ( int i = 1; i <= n; i++ )
{
if(val%i==0 && val/i<=m){
ret++;
}
}
return ret;
}
public static class Pair{
int x,y;
Pair(int a,int b){
x=a;y=b;
}
}
}