Удаление определенных подмассивов в NumPy
Я работаю с пустым массивом, в котором хранится определенное состояние игры, и я ищу в этом пространстве состояний, "расширяя" текущее состояние и создавая новый массив, в котором хранятся возможные состояния, которые могут быть достигнуты из текущего состояния. Чтобы избежать циклов поиска, мне нужно удалить все массивы, которые возникли из-за недопустимых ходов (если сделано недопустимое движение, результирующее состояние равно текущему состоянию).
Для этого я пытаюсь использовать numpy.where:
invalid_moves = np.where(np.array_equal(state, current_state) for state in successors)
successors = np.delete(successors, invalid_moves, axis=0)
#successors is an (s, n, n) array where s is the number of states possible
#from a given current state and each state is (n, n)
Однако, когда текущее состояние не имеет возможных недопустимых ходов и никакие возможные состояния не должны быть удалены из преемников, первое состояние всегда удаляется. Может кто-нибудь помочь объяснить, почему это или, возможно, где я ошибся?
1 ответ
Проблема может быть проиллюстрирована на фиктивном примере. Здесь s ваши преемники, и мы используем state>100 в качестве текущего состояния.
a = numpy.arange(10)
successors = [a.copy(), a.copy(), a.copy()]
numpy.where(numpy.array_equal(state, state>100) for state in successors)
>>> (array([0], dtype=int64),)
Массив с элементом 0 отвечает за удаление на следующем шаге.
Предполагая, что преемники - это список, мы получаем:
[numpy.array_equal(i, i>100) for i in successor]
>>> [False, False False]
Для которого numpy.where просто возвращает первый элемент.
Чтобы избежать проблемы, проверьте, совпадает ли какое-либо из состояний, и если нет, не выполняйте удаление.
a = [np.array_equal(state, current_state) for state in successors]
if any(a):
invalid_moves = np.where(a)
successors = np.delete(successors, invalid_moves, axis=0)