Найти самые большие n уникальных значений и их частоты в R и Rcpp

У меня есть числовой вектор v (с уже опущенными NA) и я хочу получить n-е наибольшие значения и их соответствующие частоты.

Я нашел http://gallery.rcpp.org/articles/top-elements-from-vectors-using-priority-queue/ довольно быстрым.

// [[Rcpp::export]]
std::vector<int> top_i_pq(NumericVector v, unsigned int n)
{

typedef pair<double, int> Elt;
priority_queue< Elt, vector<Elt>, greater<Elt> > pq;
vector<int> result;

for (int i = 0; i != v.size(); ++i) {
    if (pq.size() < n)
      pq.push(Elt(v[i], i));
    else {
      Elt elt = Elt(v[i], i);
      if (pq.top() < elt) {
        pq.pop();
        pq.push(elt);
      }
    }
  }

  result.reserve(pq.size());
  while (!pq.empty()) {
    result.push_back(pq.top().second + 1);
    pq.pop();
  }

  return result ;

}

Однако связи не будут уважаться. На самом деле мне не нужны индексы, возвращение значений также будет в порядке.

То, что я хотел бы получить, это список, содержащий значения и частоты, скажем что-то вроде:

numv <- c(4.2, 4.2, 4.5, 0.1, 4.4, 2.0, 0.9, 4.4, 3.3, 2.4, 0.1)

top_i_pq(numv, 3)
$lengths
[1] 2 2 1

$values
[1] 4.2 4.4 4.5

Ни получение уникального вектора, ни таблицы, ни (полной) сортировки не является хорошей идеей, поскольку n обычно мало по сравнению с длиной v (которая может легко быть>1e6).

Решения до сих пор:

 library(microbenchmark)
 library(data.table)
 library(DescTools)

 set.seed(1789)
 x <- sample(round(rnorm(1000), 3), 1e5, replace = TRUE)
 n <- 5

 microbenchmark(
   BaseR = tail(table(x), n),
   data.table = data.table(x)[, .N, keyby = x][(.N - n + 1):.N],
   DescTools = Large(x, n, unique=TRUE),
   Coatless = ...
 )

Unit: milliseconds
       expr       min         lq       mean     median        uq       max neval
      BaseR 188.09662 190.830975 193.189422 192.306297 194.02815 253.72304   100
 data.table  11.23986  11.553478  12.294456  11.768114  12.25475  15.68544   100
  DescTools   4.01374   4.174854   5.796414   4.410935   6.70704  64.79134   100

Хм, DescTools все еще самый быстрый, но я уверен, что он может быть значительно улучшен с помощью Rcpp (так как это чистый R)!

3 ответа

Решение

Я хотел бы бросить свою шляпу в кольцо с другим решением на основе Rcpp, которое примерно в 7 раз быстрее, чем DescTools подход и ~13x быстрее, чем data.table подход, используя длину 1e5 x а также n = 5 Пример данных выше. Реализация немного длинная, поэтому я приведу пример:

fn.dt <- function(v, n) {
    data.table(v = v)[
      ,.N, keyby = v
      ][(.N - n + 1):.N]
}

microbenchmark(
    "DescTools" = Large(x, n, unique=TRUE),
    "top_n" = top_n(x, 5),
    "data.table" = fn.dt(x, n),
    times = 500L
)
# Unit: microseconds
#        expr      min       lq      mean   median       uq       max neval
#   DescTools 3330.527 3790.035 4832.7819 4070.573 5323.155 54921.615   500
#       top_n  566.207  587.590  633.3096  593.577  640.832  3568.299   500
#  data.table 6920.636 7380.786 8072.2733 7764.601 8585.472 14443.401   500

Обновить

Если ваш компилятор поддерживает C++11, вы можете воспользоваться std::priority_queue::emplace для (удивительно) заметного прироста производительности (по сравнению с версией C++98 ниже). Я не буду публиковать эту версию, поскольку она в основном идентична, за исключением нескольких звонков std::move а также emplace, но вот ссылка на него.

Тестирование по сравнению с предыдущими тремя функциями и использование data.table 1.9.7 (что немного быстрее, чем 1.9.6) дает

print(res2, order = "median", signif = 3)
# Unit: relative
#              expr  min    lq      mean median    uq   max neval  cld
#            top_n2  1.0  1.00  1.000000   1.00  1.00  1.00  1000    a   
#             top_n  1.6  1.58  1.666523   1.58  1.75  2.75  1000    b  
#         DescTools 10.4 10.10  8.512887   9.68  7.19 12.30  1000    c 
#  data.table-1.9.7 16.9 16.80 14.164139  15.50 10.50 43.70  1000    d 

где top_n2 это версия C++ 11.


top_n Функция реализована следующим образом:

#include <Rcpp.h>
#include <utility>
#include <queue>

class histogram {
private:
    struct paired {
        typedef std::pair<double, unsigned int> pair_t;

        pair_t pair;
        unsigned int is_set;

        paired() 
            : pair(pair_t()),
              is_set(0)
        {}

        paired(double x)
            : pair(std::make_pair(x, 1)),
              is_set(1)
        {}

        bool operator==(const paired& other) const {
            return pair.first == other.pair.first;
        }

        bool operator==(double other) const {
            return is_set && (pair.first == other);
        }

        bool operator>(double other) const {
            return is_set && (pair.first > other);
        }

        bool operator<(double other) const {
            return is_set && (pair.first < other);
        }

        paired& operator++() {
            ++pair.second;
            return *this;
        }

        paired operator++(int) {
            paired tmp(*this);
            ++(*this);
            return tmp;
        }
    };

    struct greater {
        bool operator()(const paired& lhs, const paired& rhs) const {
            if (!lhs.is_set) return false;
            if (!rhs.is_set) return true;
            return lhs.pair.first > rhs.pair.first;
        }
    };  

    typedef std::priority_queue<
        paired,
        std::vector<paired>,
        greater
    > queue_t;

    unsigned int sz;
    queue_t queue;

    void insert(double x) {
        if (queue.empty()) {
            queue.push(paired(x));
            return;
        }

        if (queue.top() > x && queue.size() >= sz) return;

        queue_t qtmp;
        bool matched = false;

        while (queue.size()) {
            paired elem = queue.top();
            if (elem == x) {
                qtmp.push(++elem);
                matched = true;
            } else {
                qtmp.push(elem);
            }
            queue.pop();
        }

        if (!matched) {
            if (qtmp.size() >= sz) qtmp.pop();
            qtmp.push(paired(x));
        }

        std::swap(queue, qtmp);
    }

public:
    histogram(unsigned int sz_) 
        : sz(sz_), 
          queue(queue_t())
    {}

    template <typename InputIt>
    void insert(InputIt first, InputIt last) {
        for ( ; first != last; ++first) {
            insert(*first);
        }
    }

    Rcpp::List get() const {
        Rcpp::NumericVector values(sz);
        Rcpp::IntegerVector freq(sz);
        R_xlen_t i = 0;

        queue_t tmp(queue);
        while (tmp.size()) {
            values[i] = tmp.top().pair.first;
            freq[i] = tmp.top().pair.second;
            ++i;
            tmp.pop();
        }

        return Rcpp::List::create(
            Rcpp::Named("value") = values,
            Rcpp::Named("frequency") = freq);
    }
};


// [[Rcpp::export]]
Rcpp::List top_n(Rcpp::NumericVector x, int n = 5) {
    histogram h(n);
    h.insert(x.begin(), x.end());
    return h.get();
} 

Там много всего происходит в histogram класс выше, но просто коснемся некоторых из ключевых моментов:

  • paired Тип по сути является классом-оберткой вокруг std::pair<double, unsigned int>, который связывает значение с количеством, предоставляя некоторые удобные функции, такие как operator++() / operator++(int) для прямого пре-/ постинкрементного подсчета и модифицированных операторов сравнения.
  • histogram класс обертывает своего рода "управляемую" очередь приоритетов, в том смысле, что размер std::priority_queue ограничен определенным значением sz,
  • Вместо использования по умолчанию std::less заказ std::priority_queueЯ использую компаратор больше, чем можно проверить значения кандидатов std::priority_queue::top() чтобы быстро определить, должны ли они (а) быть отброшены, (б) заменить текущее минимальное значение в очереди или (в) обновить счетчик одного из существующих значений в очереди. Это возможно только потому, что размер очереди ограничен <= sz,

Я держал пари data.table конкурентоспособен:

library(data.table)

data <- data.table(v)

data[ , .N, keyby = v][(.N - n + 1):.N]

где n это номер, который вы хотите получить

Примечание: предыдущая версия реплицировала функциональность для table() а не цель. Эта версия была удалена и будет доступна за пределами сайта.

Карта плана атаки

Ниже приведено решение с использованием map,

С ++98

Прежде всего, нам нужно найти "уникальные" значения для вектора чисел.

Для этого мы решили сохранить число, которое считается key в пределах std::map и увеличить value каждый раз мы наблюдаем это число.

Используя структуру заказа std::map мы знаем что верх n цифры находятся в конце std::map, Таким образом, мы используем итератор для извлечения этих элементов и их экспорта в массив.

C++11

Если у вас есть доступ к компилятору C++11, альтернативой является использование std::unordered_map, который имеет большой O O(1) для вставки и поиска (O(n) если плохие хеши) против std::map который имеет большой O O(log(n)),

Чтобы получить правильный топ n, тогда можно было бы использовать std::partial_sort() сделать это.

Реализация

С ++98

#include <Rcpp.h>

// [[Rcpp::export]]
Rcpp::List top_n_map(const Rcpp::NumericVector & v, int n)
{

  // Initialize a map
  std::map<double, int> Elt;

  Elt.clear();

  // Count each element
  for (int i = 0; i != v.size(); ++i) {
    Elt[ v[i] ] += 1;
  }

  // Find out how many unique elements exist... 
  int n_obs = Elt.size();

  // If the top number, n, is greater than the number of observations,
  // then drop it.  
  if(n > n_obs ) { n = n_obs; }

  // Pop the last n elements as they are already sorted. 

  // Make an iterator to access map info
  std::map<double,int>::iterator itb = Elt.end();

  // Advance the end of the iterator up to 5.
  std::advance(itb, -n);

  // Recast for R
  Rcpp::NumericVector result_vals(n);

  Rcpp::NumericVector result_keys(n);

  unsigned int count = 0;

  // Start at the nth element and move to the last element in the map.
  for( std::map<double,int>::iterator it = itb; it != Elt.end(); ++it )
  {
    // Move them into split vectors
    result_keys(count) = it->first;
    result_vals(count) = it->second;

    count++;
  }

  return Rcpp::List::create(Rcpp::Named("lengths") = result_vals,
                            Rcpp::Named("values") = result_keys);
}

Краткий тест

Давайте проверим, что это работает, запустив некоторые данные:

# Set seed for reproducibility
set.seed(1789)
x <- sample(round(rnorm(1000), 3), 1e5, replace = TRUE)
n <- 5

И теперь мы стремимся получить информацию о происшествии:

# Call our function
top_n_map(a)

Дает нам:

$lengths
[1] 101 104 101 103 103

$values
[1] 2.468 2.638 2.819 3.099 3.509

Ориентиры

Unit: microseconds
       expr        min          lq        mean      median         uq        max neval
      BaseR 112750.403 115946.7175 119493.4501 117676.2840 120712.595 166067.530   100
 data.table   6583.851   6994.3665   8311.8631   7260.9385   7972.548  47482.559   100
  DescTools   3291.626   3503.5620   5047.5074   3885.4090   5057.666  43597.451   100
   Coatless   6097.237   6240.1295   6421.1313   6365.7605   6528.315   7543.271   100
nrussel_c98    513.932    540.6495    571.5362    560.0115    584.628    797.315   100
nrussel_c11    489.616    512.2810    549.6581    533.2950    553.107    961.221   100

Как мы видим, эта реализация выбивает data.table, но становится жертвой попыток DescTools и @nrussel.

Другие вопросы по тегам