Python: Оставьте значения Numpy NaN из тепловой карты matplotlib и ее легенды
У меня есть массив массивов, который мне нужно построить в виде тепловой карты. Массив numpy также будет содержать значения NaN, которые мне нужно исключить из графика. В других постах мне говорили, что numpy автоматически маскирует значения NaN на графике, но это как-то не работает для меня. Вот пример кода
column_labels = list('ABCDEFGH')
row_labels = list('WXYZ')
fig, ax = plt.subplots()
data = np.array([[ 0.96753494, 0.52349944, 0.0254628 , 0.5104103 ],
[ 0.07320069, 0.91278731, 0.97094436, 0.70533351],
[ 0.30162006, 0.49068337, 0.41837729, 0.71139215],
[ 0.19786101, 0.15882713, 0.59028841, 0.06242765],
[ 0.51505872, 0.07798389, 0.58790067, 0.44782683],
[ 0.68975694, 0.53535385, 0.15696023, 0.35641951],
[ 0.66481995, 0.03576846, 0.9623601 , 0.96006395],
[ 0.45865404, 0.50433582, 0.18182575, 0.35126449],])
data[3,:] = np.nan
heatmap = ax.pcolor(data, cmap=plt.cm.seismic)
fig.colorbar(heatmap)
# put the major ticks at the middle of each cell
ax.set_xticks(np.arange(data.shape[1])+0.5, minor=False)
ax.set_yticks(np.arange(data.shape[0])+0.5, minor=False)
# want a more natural, table-like display
ax.invert_yaxis()
ax.xaxis.tick_top()
ax.set_xticklabels(row_labels, minor=False)
ax.set_yticklabels(column_labels, minor=False)
plt.show()
Очевидно, это очень отличается от сюжета без Нэн, который выглядит
Я хочу полностью исключить значения NaN из легенды и желательно пометить их каким-нибудь символом, например, X. Как я могу добиться того же?
1 ответ
nans
мешать pcolor
определение диапазона значений, содержащихся в data
поскольку
In [72]: data.min(), data.max()
Out[72]: (nan, nan)
Вы можете обойти проблему, объявив диапазон значений самостоятельно, используя np.nanmin
а также np.nanmax
найти минимальное и максимальное значения, отличные от NaN, в data
:
heatmap = ax.pcolor(data, cmap=plt.cm.seismic,
vmin=np.nanmin(data), vmax=np.nanmax(data))
поскольку
In [73]: np.nanmin(data), np.nanmax(data)
Out[73]: (0.025462800000000001, 0.97094435999999995)
import numpy as np
import matplotlib.pyplot as plt
column_labels = list('ABCDEFGH')
row_labels = list('WXYZ')
fig, ax = plt.subplots()
data = np.array([[ 0.96753494, 0.52349944, 0.0254628 , 0.5104103 ],
[ 0.07320069, 0.91278731, 0.97094436, 0.70533351],
[ 0.30162006, 0.49068337, 0.41837729, 0.71139215],
[ 0.19786101, 0.15882713, 0.59028841, 0.06242765],
[ 0.51505872, 0.07798389, 0.58790067, 0.44782683],
[ 0.68975694, 0.53535385, 0.15696023, 0.35641951],
[ 0.66481995, 0.03576846, 0.9623601 , 0.96006395],
[ 0.45865404, 0.50433582, 0.18182575, 0.35126449],])
data[3,:] = np.nan
heatmap = ax.pcolor(data, cmap=plt.cm.seismic,
vmin=np.nanmin(data), vmax=np.nanmax(data))
heatmap.cmap.set_under('black')
bar = fig.colorbar(heatmap, extend='both')
# put the major ticks at the middle of each cell
ax.set_xticks(np.arange(data.shape[1])+0.5, minor=False)
ax.set_yticks(np.arange(data.shape[0])+0.5, minor=False)
# want a more natural, table-like display
ax.invert_yaxis()
ax.xaxis.tick_top()
ax.set_xticklabels(row_labels, minor=False)
ax.set_yticklabels(column_labels, minor=False)
plt.show()
Другой вариант (основанный на решении Джо Кингтона) - рисовать прямоугольные пятна с штриховками где бы то ни было data
является NaN.
Приведенный выше пример показывает, что pcolor
цвета в ячейках со значениями NaN, как если бы NaN были очень отрицательными числами. Напротив, если вы передаете pcolor
замаскированный массив, pcolor
оставляет маскированные области прозрачными. Таким образом, вы можете нарисовать штриховки на осях фонового патча, ax.patch
, чтобы показать штриховки на замаскированных областях.
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
column_labels = list('ABCDEFGH')
row_labels = list('WXYZ')
fig, ax = plt.subplots()
data = np.array([[ 0.96753494, 0.52349944, 0.0254628 , 0.5104103 ],
[ 0.07320069, 0.91278731, 0.97094436, 0.70533351],
[ 0.30162006, 0.49068337, 0.41837729, 0.71139215],
[ 0.19786101, 0.15882713, 0.59028841, 0.06242765],
[ 0.51505872, 0.07798389, 0.58790067, 0.44782683],
[ 0.68975694, 0.53535385, 0.15696023, 0.35641951],
[ 0.66481995, 0.03576846, 0.9623601 , 0.96006395],
[ 0.45865404, 0.50433582, 0.18182575, 0.35126449],])
data[3,:] = np.nan
data = np.ma.masked_invalid(data)
heatmap = ax.pcolor(data, cmap=plt.cm.seismic,
vmin=np.nanmin(data), vmax=np.nanmax(data))
# https://stackru.com/a/16125413/190597 (Joe Kington)
ax.patch.set(hatch='x', edgecolor='black')
fig.colorbar(heatmap)
# put the major ticks at the middle of each cell
ax.set_xticks(np.arange(data.shape[1])+0.5, minor=False)
ax.set_yticks(np.arange(data.shape[0])+0.5, minor=False)
# want a more natural, table-like display
ax.invert_yaxis()
ax.xaxis.tick_top()
ax.set_xticklabels(row_labels, minor=False)
ax.set_yticklabels(column_labels, minor=False)
plt.show()
Если вы хотите использовать более одного типа штриховых меток, скажем, один для NaN, а другой для отрицательных значений, вы можете использовать цикл для добавления заштрихованных прямоугольников:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
column_labels = list('ABCDEFGH')
row_labels = list('WXYZ')
fig, ax = plt.subplots()
data = np.array([[ 0.96753494, 0.52349944, 0.0254628 , 0.5104103 ],
[ 0.07320069, 0.91278731, 0.97094436, 0.70533351],
[ 0.30162006, 0.49068337, 0.41837729, 0.71139215],
[ 0.19786101, 0.15882713, 0.59028841, 0.06242765],
[ 0.51505872, 0.07798389, 0.58790067, 0.44782683],
[ 0.68975694, 0.53535385, 0.15696023, 0.35641951],
[ 0.66481995, 0.03576846, 0.9623601 , 0.96006395],
[ 0.45865404, 0.50433582, 0.18182575, 0.35126449],])
data -= 0.5
data[3,:] = np.nan
data = np.ma.masked_invalid(data)
heatmap = ax.pcolor(data, cmap=plt.cm.seismic,
vmin=np.nanmin(data), vmax=np.nanmax(data))
# https://stackru.com/a/16125413/190597 (Joe Kington)
ax.patch.set(hatch='x', edgecolor='black')
# draw a hatched rectangle wherever the data is negative
# http://matthiaseisen.com/pp/patterns/p0203/
mask = data < 0
for j, i in np.column_stack(np.where(mask)):
ax.add_patch(
mpatches.Rectangle(
(i, j), # (x,y)
1, # width
1, # height
fill=False,
edgecolor='blue',
snap=False,
hatch='x' # the more slashes, the denser the hash lines
))
fig.colorbar(heatmap)
# put the major ticks at the middle of each cell
ax.set_xticks(np.arange(data.shape[1])+0.5, minor=False)
ax.set_yticks(np.arange(data.shape[0])+0.5, minor=False)
# want a more natural, table-like display
ax.invert_yaxis()
ax.xaxis.tick_top()
ax.set_xticklabels(row_labels, minor=False)
ax.set_yticklabels(column_labels, minor=False)
plt.show()