Диаграмма рассеяния Python Matplotlib: Укажите цветовые точки в зависимости от условий

У меня есть два массива, x и y, по 7000 элементов в каждом. Я хочу составить точечный график, на котором каждая точка будет иметь разный цвет в зависимости от следующих условий:

-BLACK if x[i]<10.

-RED if x[i]>=10 and y[i]<=-0.5

-BLUE if x[i]>=10 and y[i]>-0.5 

Я попытался создать список такой же длины, что и данные, с цветом, который я хочу назначить каждой точке, а затем нанести на график данные с помощью цикла, но мне потребовалось много времени для его запуска. Вот мой код:

import numpy as np
import matplotlib.pyplot as plt

#color list with same length as the data
col=[]
for i in range(0,len(x)):
    if x[i]<10:
        col.append('k') 
    elif x[i]>=10 and y[i]<=-0.5:
        col.append('r') 
    else:
        col.append('b') 

#scatter plot
for i in range(len(x)):
    plt.scatter(x[i],y[i],c=col[i],s=5, linewidth=0)

#add horizontal line and invert y-axis
plt.gca().invert_yaxis()
plt.axhline(y=-0.5,linewidth=2,c='k')

До этого я пытался создать тот же список цветов таким же образом, но отображая данные без цикла:

#scatter plot
plt.scatter(x,y,c=col,s=5, linewidth=0)

Несмотря на то, что данные наносятся на график намного, намного быстрее, чем при использовании цикла for, некоторые рассеянные точки отображаются с неправильным цветом. Почему не использование цикла для отображения данных приводит к неправильному цвету некоторых точек?

Я также попытался определить три набора данных, по одному для каждого цвета, и добавить их к графику отдельно. Но это не то, что я ищу.

Есть ли способ указать в аргументах диаграмм рассеяния список цветов, которые я хочу использовать для каждой точки, чтобы не использовать цикл for?

PS: это сюжет, который я получаю, когда не использую цикл for (неправильный):

И этот, когда я использую цикл for (правильно):

1 ответ

Это можно сделать с помощью numpy.where, Так как я не использую ваши точные значения x и y, мне придется использовать некоторые поддельные данные:

import numpy as np
import matplotlib.pyplot as plt

#generate some fake data
x = np.random.random(10000)*10
y = np.random.random(10000)*10

col = np.where(x<1,'k',np.where(y<5,'b','r'))

plt.scatter(x, y, c=col, s=5, linewidth=0)
plt.show()

Это дает сюжет ниже:

Линия col = np.where(x<1,'k',np.where(y<5,'b','r')) это важный. Это создает список того же размера, что и x и y. Это заполняет этот список 'k','b' или же 'r' в зависимости от условия, которое написано до него. Так что, если х меньше 1, 'k' будет добавлен в список, иначе, если у меньше 5 'b' будет добавлено, и если ни одно из этих условий не будет выполнено, 'r' будет добавлен в список. Таким образом, вам не нужно использовать цикл для построения графика.

Для ваших конкретных данных вам придется изменить значения в условиях np.where,

Другие вопросы по тегам