Как уменьшить набор данных xarray с помощью groupby?

Я хотел бы сократить набор данных xarray на основе определенной группы, поэтому я использую groupby выбрать группу, а затем взять 10% образцов в каждой группе. Я использую код ниже, но я получаю IndexError: index 1330 is out of bounds for axis 0 with size 1330 что подсказывает мне, что моя функция возвращает пустой массив, но subset определенно имеет ненулевые размеры.

Я использовал squeeze=True который я думал, что позволит новые измерения в соответствии с документацией GroupBy, но это не помогло, поэтому я изменил его squeeze=False,

Вы знаете, что может происходить? Спасибо!

# Set random seed for reproducibility
np.random.seed(0)

def select_random_cell_subset(x):
    size = int(0.1 * len(x.cell))
    random_cells = sorted(np.random.choice(x.cell, size=size, replace=False))
    print('number of random cells:', len(random_cells))
    print('\tsome random cells:', random_cells[:5])
    subset = x.sel(cell=random_cells)
    print('subset:', subset)
    return subset

# squeeze=False because the final dataset is smaller than the original
ds_subset = ds.groupby('group', squeeze=True).apply(select_random_cell_subset)
ds_subset

Вот ошибка:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-44-39c7803e9e40> in <module>()
     12 
     13 # squeeze=False because the final dataset is smaller than the original
---> 14 ds_subset = ds.groupby('group', squeeze=True).apply(select_random_cell_subset)
     15 ds_subset

~/anaconda3/envs/cshl-sca-2017/lib/python3.6/site-packages/xarray/core/groupby.py in apply(self, func, **kwargs)
    615         kwargs.pop('shortcut', None)  # ignore shortcut if set (for now)
    616         applied = (func(ds, **kwargs) for ds in self._iter_grouped())
--> 617         return self._combine(applied)
    618 
    619     def _combine(self, applied):

~/anaconda3/envs/cshl-sca-2017/lib/python3.6/site-packages/xarray/core/groupby.py in _combine(self, applied)
    622         coord, dim, positions = self._infer_concat_args(applied_example)
    623         combined = concat(applied, dim)
--> 624         combined = _maybe_reorder(combined, dim, positions)
    625         if coord is not None:
    626             combined[coord.name] = coord

~/anaconda3/envs/cshl-sca-2017/lib/python3.6/site-packages/xarray/core/groupby.py in _maybe_reorder(xarray_obj, dim, positions)
    443         return xarray_obj
    444     else:
--> 445         return xarray_obj[{dim: order}]
    446 
    447 

~/anaconda3/envs/cshl-sca-2017/lib/python3.6/site-packages/xarray/core/dataset.py in __getitem__(self, key)
    716         """
    717         if utils.is_dict_like(key):
--> 718             return self.isel(**key)
    719 
    720         if hashable(key):

~/anaconda3/envs/cshl-sca-2017/lib/python3.6/site-packages/xarray/core/dataset.py in isel(self, drop, **indexers)
   1141         for name, var in iteritems(self._variables):
   1142             var_indexers = dict((k, v) for k, v in indexers if k in var.dims)
-> 1143             new_var = var.isel(**var_indexers)
   1144             if not (drop and name in var_indexers):
   1145                 variables[name] = new_var

~/anaconda3/envs/cshl-sca-2017/lib/python3.6/site-packages/xarray/core/variable.py in isel(self, **indexers)
    568             if dim in indexers:
    569                 key[i] = indexers[dim]
--> 570         return self[tuple(key)]
    571 
    572     def squeeze(self, dim=None):

~/anaconda3/envs/cshl-sca-2017/lib/python3.6/site-packages/xarray/core/variable.py in __getitem__(self, key)
    398         dims = tuple(dim for k, dim in zip(key, self.dims)
    399                      if not isinstance(k, integer_types))
--> 400         values = self._indexable_data[key]
    401         # orthogonal indexing should ensure the dimensionality is consistent
    402         if hasattr(values, 'ndim'):

~/anaconda3/envs/cshl-sca-2017/lib/python3.6/site-packages/xarray/core/indexing.py in __getitem__(self, key)
    476     def __getitem__(self, key):
    477         key = self._convert_key(key)
--> 478         return self._ensure_ndarray(self.array[key])
    479 
    480     def __setitem__(self, key, value):

IndexError: index 1330 is out of bounds for axis 0 with size 1330

2 ответа

Решение

Это совершенно разумная вещь, но, к сожалению, она пока не работает. Xarray использует некоторые эвристики, чтобы решить, является ли apply операция имеет reduce или же transform тип, и в этом случае мы неправильно идентифицируем сгруппированную операцию как "преобразование", поскольку выходные данные повторно используют исходное имя измерения. Я только что подал отчет об ошибке, но, к сожалению, исправление для xarray будет несколько связано.

Вероятно, самый простой обходной путь состоит в том, чтобы применяемая функция возвращала логический DataArray вместо этого, указывая позиции для сохранения. Затем вы можете использовать операцию индексации для выбора исходного объекта.

Вот как я это реализовал. Как предложил @shoyer выше, я вернул логическое значение xarray.DataArray для каждой группы, а затем использовал это логическое значение для подмножества моих данных.

# Set random seed for reproducibility
np.random.seed(0)

def select_random_cell_subset(x, threshold=0.1):
    random_bools = xr.DataArray(np.random.uniform(size=len(x.cell)) <= threshold,
                               coords=dict(cell=x.cell)) 
    return random_bools

    subset_bools = ds.groupby('group',).apply(select_random_cell_subset, 
                                                    threshold=0.1)
ds_subset = ds.sel(cell=subset_bools)
Другие вопросы по тегам