Удаление определенных подмассивов в NumPy

Я работаю с пустым массивом, в котором хранится определенное состояние игры, и я ищу в этом пространстве состояний, "расширяя" текущее состояние и создавая новый массив, в котором хранятся возможные состояния, которые могут быть достигнуты из текущего состояния. Чтобы избежать циклов поиска, мне нужно удалить все массивы, которые возникли из-за недопустимых ходов (если сделано недопустимое движение, результирующее состояние равно текущему состоянию).

Для этого я пытаюсь использовать numpy.where:

invalid_moves = np.where(np.array_equal(state, current_state) for state in successors)
successors = np.delete(successors, invalid_moves, axis=0) 
#successors is an (s, n, n) array where s is the number of states possible 
#from a given current state and each state is (n, n)

Однако, когда текущее состояние не имеет возможных недопустимых ходов и никакие возможные состояния не должны быть удалены из преемников, первое состояние всегда удаляется. Может кто-нибудь помочь объяснить, почему это или, возможно, где я ошибся?

1 ответ

Решение

Проблема может быть проиллюстрирована на фиктивном примере. Здесь s ваши преемники, и мы используем state>100 в качестве текущего состояния.

a = numpy.arange(10)
successors = [a.copy(), a.copy(), a.copy()]

numpy.where(numpy.array_equal(state, state>100) for state in successors)
>>> (array([0], dtype=int64),)

Массив с элементом 0 отвечает за удаление на следующем шаге.

Предполагая, что преемники - это список, мы получаем:

[numpy.array_equal(i, i>100) for i in successor]
>>> [False, False False]

Для которого numpy.where просто возвращает первый элемент.

Чтобы избежать проблемы, проверьте, совпадает ли какое-либо из состояний, и если нет, не выполняйте удаление.

a = [np.array_equal(state, current_state) for state in successors]
if any(a):
    invalid_moves = np.where(a)
    successors = np.delete(successors, invalid_moves, axis=0) 
Другие вопросы по тегам