Python: как создать легенду на примере

Это из главы 2 в книге Machine Learning In Action и я пытаюсь сделать сюжет, изображенный здесь:

сюжет

Автор разместил здесь код сюжета, который, я думаю, может быть немного хакерским (он также упоминает, что этот код небрежный, поскольку выходит за рамки книги).

Вот моя попытка воссоздать сюжет:

Во-первых, файл.txt, содержащий данные, выглядит следующим образом (источник: "DatingTestSet2.txt" в гл.2 здесь):

40920   8.326976    0.953952    largeDoses
14488   7.153469    1.673904    smallDoses
26052   1.441871    0.805124    didntLike
75136   13.147394   0.428964    didntLike
38344   1.669788    0.134296    didntLike
...

Предполагать datingDataMat это numpy.ndarray формы `(1000L, 2L), где столбец 0 -" Частые мили летчика в год ", столбец 1 -" Время игры в видеоиграх ", а столбец 2 -" литр мороженого, потребляемого в неделю ", как показано в примере выше.

Предполагать datingLabels это list целые числа 1, 2 или 3 означают "Не понравилось", "Понравилось в малых дозах" и "Понравилось в больших дозах" соответственно - связаны с колонкой 3 выше.

Вот код, который я должен создать сюжет (полная информация для file2matrix в конце):

datingDataMat,datingLabels = file2matrix("datingTestSet2.txt")
import matplotlib.pyplot as plt

fig = plt.figure()
ax = fig.add_subplot (111)
plt.xlabel("Freq flier miles")
plt.ylabel("% time video games")
# Not sure how to finish this: plt.legend([1, 2, 3], ["did not like", "small doses", "large doses"])
plt.scatter(datingDataMat[:,0], datingDataMat[:,1], 15.0*np.array(datingLabels), 15.0*np.array(datingLabels)) # Change marker color and size 
plt.show()

Выход здесь:

введите описание изображения здесь

Моя главная задача - как создать эту легенду. Есть ли способ сделать это без необходимости прямого управления точками?

Далее, мне любопытно, смогу ли я найти способ переключать цвета в соответствии с цветами графика. Есть ли способ сделать это, не имея какой-то "ручки" для отдельных точек?

Также, если интересно, вот file2matrix реализация:

def file2matrix(filename):
    fr = open(filename)
    numberOfLines = len(fr.readlines())
    returnMat = np.zeros((numberOfLines,3)) #numpy.zeros(shape, dtype=float, order='C') 
    classLabelVector = []
    fr = open(filename)
    index = 0
    for line in fr.readlines():
        line = line.strip()
        listFromLine = line.split('\t')
        returnMat[index,:] = listFromLine[0:3] # FFmiles/yr, % time gaming, L ice cream/wk
        classLabelVector.append(int(listFromLine[-1]))
        index += 1
    return returnMat,classLabelVector

2 ответа

Решение

Вот пример, который имитирует код, который у вас уже есть, который демонстрирует подход, описанный в примере Saullo Castro. Также показано, как установить цвета в примере. Если вы хотите получить дополнительную информацию о доступных цветах, см. Документацию по http://matplotlib.org/api/colors_api.html

Также стоит посмотреть документацию по точечному графику по адресу http://matplotlib.org/1.3.1/api/pyplot_api.html.

from numpy.random import rand, randint
from matplotlib import pyplot as plt
n = 1000
# Generate random data
data = rand(n, 2)
# Make a random array to mimic datingLabels
labels = randint(1, 4, n)
# Separate the data according to the labels
data_1 = data[labels==1]
data_2 = data[labels==2]
data_3 = data[labels==3]
# Plot each set of points separately
# 's' is the size parameter.
# 'c' is the color parameter.
# I have chosen the colors so that they match the plot shown.
# With each set of points, input the desired label for the legend.
plt.scatter(data_1[:,0], data_1[:,1], s=15, c='r', label="label 1")
plt.scatter(data_2[:,0], data_2[:,1], s=30, c='g', label="label 2")
plt.scatter(data_3[:,0], data_3[:,1], s=45, c='b', label="label 3")
# Put labels on the axes
plt.ylabel("ylabel")
plt.xlabel("xlabel")
# Place the Legend in the plot.
plt.gca().legend(loc="upper left")
# Display it.
plt.show()

Серые границы должны стать белыми, если вы используете plt.savefig сохранить рисунок в файл вместо его отображения. Не забудьте бежать plt.clf() или же plt.cla() после сохранения в файл, чтобы очистить оси, чтобы вы не в конечном итоге снова и снова выкладывали одни и те же данные поверх себя.

Чтобы создать легенду, вы должны:

  • дать метки каждой кривой

  • позвонить legend() метод из текущего AxesSubplot объект, который можно получить с помощью plt.gca(), например.

Смотрите пример ниже:

plt.scatter(datingDataMat[:,0], datingDataMat[:,1],
            15.0*np.array(datingLabels), 15.0*np.array(datingLabels),
            label='Label for this data')
plt.gca().legend(loc='upper left')
Другие вопросы по тегам