Python: Оставьте значения Numpy NaN из тепловой карты matplotlib и ее легенды

У меня есть массив массивов, который мне нужно построить в виде тепловой карты. Массив numpy также будет содержать значения NaN, которые мне нужно исключить из графика. В других постах мне говорили, что numpy автоматически маскирует значения NaN на графике, но это как-то не работает для меня. Вот пример кода

column_labels = list('ABCDEFGH')
row_labels = list('WXYZ')
fig, ax = plt.subplots()
data = np.array([[ 0.96753494,  0.52349944,  0.0254628 ,  0.5104103 ],
         [ 0.07320069,  0.91278731,  0.97094436,  0.70533351],
         [ 0.30162006,  0.49068337,  0.41837729,  0.71139215],
         [ 0.19786101,  0.15882713,  0.59028841,  0.06242765],
         [ 0.51505872,  0.07798389,  0.58790067,  0.44782683],
         [ 0.68975694,  0.53535385,  0.15696023,  0.35641951],
         [ 0.66481995,  0.03576846,  0.9623601 ,  0.96006395],
         [ 0.45865404,  0.50433582,  0.18182575,  0.35126449],])

data[3,:] = np.nan
heatmap = ax.pcolor(data, cmap=plt.cm.seismic)

fig.colorbar(heatmap)
# put the major ticks at the middle of each cell
ax.set_xticks(np.arange(data.shape[1])+0.5, minor=False)
ax.set_yticks(np.arange(data.shape[0])+0.5, minor=False)

# want a more natural, table-like display
ax.invert_yaxis()
ax.xaxis.tick_top()

ax.set_xticklabels(row_labels, minor=False)
ax.set_yticklabels(column_labels, minor=False)
plt.show()

Сюжеты выглядят как

Очевидно, это очень отличается от сюжета без Нэн, который выглядит

Я хочу полностью исключить значения NaN из легенды и желательно пометить их каким-нибудь символом, например, X. Как я могу добиться того же?

1 ответ

Решение

nans мешать pcolor определение диапазона значений, содержащихся в data поскольку

In [72]: data.min(), data.max()
Out[72]: (nan, nan)

Вы можете обойти проблему, объявив диапазон значений самостоятельно, используя np.nanmin а также np.nanmax найти минимальное и максимальное значения, отличные от NaN, в data:

heatmap = ax.pcolor(data, cmap=plt.cm.seismic, 
                    vmin=np.nanmin(data), vmax=np.nanmax(data))

поскольку

In [73]: np.nanmin(data), np.nanmax(data)
Out[73]: (0.025462800000000001, 0.97094435999999995)

import numpy as np
import matplotlib.pyplot as plt

column_labels = list('ABCDEFGH')
row_labels = list('WXYZ')
fig, ax = plt.subplots()
data = np.array([[ 0.96753494,  0.52349944,  0.0254628 ,  0.5104103 ],
         [ 0.07320069,  0.91278731,  0.97094436,  0.70533351],
         [ 0.30162006,  0.49068337,  0.41837729,  0.71139215],
         [ 0.19786101,  0.15882713,  0.59028841,  0.06242765],
         [ 0.51505872,  0.07798389,  0.58790067,  0.44782683],
         [ 0.68975694,  0.53535385,  0.15696023,  0.35641951],
         [ 0.66481995,  0.03576846,  0.9623601 ,  0.96006395],
         [ 0.45865404,  0.50433582,  0.18182575,  0.35126449],])

data[3,:] = np.nan
heatmap = ax.pcolor(data, cmap=plt.cm.seismic, 
                    vmin=np.nanmin(data), vmax=np.nanmax(data))
heatmap.cmap.set_under('black')

bar = fig.colorbar(heatmap, extend='both')

# put the major ticks at the middle of each cell
ax.set_xticks(np.arange(data.shape[1])+0.5, minor=False)
ax.set_yticks(np.arange(data.shape[0])+0.5, minor=False)

# want a more natural, table-like display
ax.invert_yaxis()
ax.xaxis.tick_top()

ax.set_xticklabels(row_labels, minor=False)
ax.set_yticklabels(column_labels, minor=False)
plt.show() 


Другой вариант (основанный на решении Джо Кингтона) - рисовать прямоугольные пятна с штриховками где бы то ни было data является NaN.

Приведенный выше пример показывает, что pcolor цвета в ячейках со значениями NaN, как если бы NaN были очень отрицательными числами. Напротив, если вы передаете pcolor замаскированный массив, pcolor оставляет маскированные области прозрачными. Таким образом, вы можете нарисовать штриховки на осях фонового патча, ax.patch, чтобы показать штриховки на замаскированных областях.

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

column_labels = list('ABCDEFGH')
row_labels = list('WXYZ')
fig, ax = plt.subplots()
data = np.array([[ 0.96753494,  0.52349944,  0.0254628 ,  0.5104103 ],
         [ 0.07320069,  0.91278731,  0.97094436,  0.70533351],
         [ 0.30162006,  0.49068337,  0.41837729,  0.71139215],
         [ 0.19786101,  0.15882713,  0.59028841,  0.06242765],
         [ 0.51505872,  0.07798389,  0.58790067,  0.44782683],
         [ 0.68975694,  0.53535385,  0.15696023,  0.35641951],
         [ 0.66481995,  0.03576846,  0.9623601 ,  0.96006395],
         [ 0.45865404,  0.50433582,  0.18182575,  0.35126449],])

data[3,:] = np.nan
data = np.ma.masked_invalid(data)

heatmap = ax.pcolor(data, cmap=plt.cm.seismic, 
                    vmin=np.nanmin(data), vmax=np.nanmax(data))
# https://stackru.com/a/16125413/190597 (Joe Kington)
ax.patch.set(hatch='x', edgecolor='black')
fig.colorbar(heatmap)

# put the major ticks at the middle of each cell
ax.set_xticks(np.arange(data.shape[1])+0.5, minor=False)
ax.set_yticks(np.arange(data.shape[0])+0.5, minor=False)

# want a more natural, table-like display
ax.invert_yaxis()
ax.xaxis.tick_top()

ax.set_xticklabels(row_labels, minor=False)
ax.set_yticklabels(column_labels, minor=False)
plt.show() 


Если вы хотите использовать более одного типа штриховых меток, скажем, один для NaN, а другой для отрицательных значений, вы можете использовать цикл для добавления заштрихованных прямоугольников:

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

column_labels = list('ABCDEFGH')
row_labels = list('WXYZ')
fig, ax = plt.subplots()
data = np.array([[ 0.96753494,  0.52349944,  0.0254628 ,  0.5104103 ],
         [ 0.07320069,  0.91278731,  0.97094436,  0.70533351],
         [ 0.30162006,  0.49068337,  0.41837729,  0.71139215],
         [ 0.19786101,  0.15882713,  0.59028841,  0.06242765],
         [ 0.51505872,  0.07798389,  0.58790067,  0.44782683],
         [ 0.68975694,  0.53535385,  0.15696023,  0.35641951],
         [ 0.66481995,  0.03576846,  0.9623601 ,  0.96006395],
         [ 0.45865404,  0.50433582,  0.18182575,  0.35126449],])
data -= 0.5
data[3,:] = np.nan
data = np.ma.masked_invalid(data)
heatmap = ax.pcolor(data, cmap=plt.cm.seismic, 
                    vmin=np.nanmin(data), vmax=np.nanmax(data))

# https://stackru.com/a/16125413/190597 (Joe Kington)
ax.patch.set(hatch='x', edgecolor='black')

# draw a hatched rectangle wherever the data is negative
# http://matthiaseisen.com/pp/patterns/p0203/
mask = data < 0
for j, i in np.column_stack(np.where(mask)):
      ax.add_patch(
          mpatches.Rectangle(
              (i, j),     # (x,y)
              1,          # width
              1,          # height
              fill=False, 
              edgecolor='blue',
              snap=False,
              hatch='x' # the more slashes, the denser the hash lines 
          ))

fig.colorbar(heatmap)

# put the major ticks at the middle of each cell
ax.set_xticks(np.arange(data.shape[1])+0.5, minor=False)
ax.set_yticks(np.arange(data.shape[0])+0.5, minor=False)

# want a more natural, table-like display
ax.invert_yaxis()
ax.xaxis.tick_top()

ax.set_xticklabels(row_labels, minor=False)
ax.set_yticklabels(column_labels, minor=False)
plt.show() 

Другие вопросы по тегам