Диаграмма рассеяния Python Matplotlib: Укажите цветовые точки в зависимости от условий
У меня есть два массива, x и y, по 7000 элементов в каждом. Я хочу составить точечный график, на котором каждая точка будет иметь разный цвет в зависимости от следующих условий:
-BLACK if x[i]<10.
-RED if x[i]>=10 and y[i]<=-0.5
-BLUE if x[i]>=10 and y[i]>-0.5
Я попытался создать список такой же длины, что и данные, с цветом, который я хочу назначить каждой точке, а затем нанести на график данные с помощью цикла, но мне потребовалось много времени для его запуска. Вот мой код:
import numpy as np
import matplotlib.pyplot as plt
#color list with same length as the data
col=[]
for i in range(0,len(x)):
if x[i]<10:
col.append('k')
elif x[i]>=10 and y[i]<=-0.5:
col.append('r')
else:
col.append('b')
#scatter plot
for i in range(len(x)):
plt.scatter(x[i],y[i],c=col[i],s=5, linewidth=0)
#add horizontal line and invert y-axis
plt.gca().invert_yaxis()
plt.axhline(y=-0.5,linewidth=2,c='k')
До этого я пытался создать тот же список цветов таким же образом, но отображая данные без цикла:
#scatter plot
plt.scatter(x,y,c=col,s=5, linewidth=0)
Несмотря на то, что данные наносятся на график намного, намного быстрее, чем при использовании цикла for, некоторые рассеянные точки отображаются с неправильным цветом. Почему не использование цикла для отображения данных приводит к неправильному цвету некоторых точек?
Я также попытался определить три набора данных, по одному для каждого цвета, и добавить их к графику отдельно. Но это не то, что я ищу.
Есть ли способ указать в аргументах диаграмм рассеяния список цветов, которые я хочу использовать для каждой точки, чтобы не использовать цикл for?
PS: это сюжет, который я получаю, когда не использую цикл for (неправильный):
И этот, когда я использую цикл for (правильно):
1 ответ
Это можно сделать с помощью numpy.where
, Так как я не использую ваши точные значения x и y, мне придется использовать некоторые поддельные данные:
import numpy as np
import matplotlib.pyplot as plt
#generate some fake data
x = np.random.random(10000)*10
y = np.random.random(10000)*10
col = np.where(x<1,'k',np.where(y<5,'b','r'))
plt.scatter(x, y, c=col, s=5, linewidth=0)
plt.show()
Это дает сюжет ниже:
Линия col = np.where(x<1,'k',np.where(y<5,'b','r'))
это важный. Это создает список того же размера, что и x и y. Это заполняет этот список 'k','b'
или же 'r'
в зависимости от условия, которое написано до него. Так что, если х меньше 1, 'k'
будет добавлен в список, иначе, если у меньше 5 'b'
будет добавлено, и если ни одно из этих условий не будет выполнено, 'r'
будет добавлен в список. Таким образом, вам не нужно использовать цикл для построения графика.
Для ваших конкретных данных вам придется изменить значения в условиях np.where
,