matplotlib FuncAnimation - убедитесь, что легенда и линии графика имеют одинаковые цвета?

Рассмотрим фреймворк Pandas с несколькими столбцами, каждый столбец - это название страны, и несколько строк, в каждой строке - дата. Ячейки - это данные о странах, которые меняются во времени. Это CSV:

https://pastebin.com/bJbDz7ei

Я хочу создать динамический сюжет (анимацию) в Jupyter, который показывает, как данные меняются во времени. Из всех стран мира я хочу показать только 10 лучших стран в любой момент времени. Таким образом, страны, показанные на диаграмме, могут время от времени меняться (потому что первая десятка развивается).

Я также хочу сохранить единообразие цветов. Одновременно отображаются только 10 стран, а некоторые страны появляются и исчезают почти постоянно, но цвет любой страны не должен меняться на протяжении всей анимации. Цвет любой страны должен оставаться неизменным от начала до конца.

Это код, который у меня есть (EDIT: теперь вы можете скопировать / вставить код в Jupyter, и он работает из коробки, поэтому вы можете легко увидеть ошибку, о которой я говорю):

import pandas as pd
import requests
import os
from matplotlib import pyplot as plt
import matplotlib.animation as ani

rel_big_file = 'rel_big.csv'
rel_big_url = 'https://pastebin.com/raw/bJbDz7ei'

if not os.path.exists(rel_big_file):
    r = requests.get(rel_big_url)
    with open(rel_big_file, 'wb') as f:
        f.write(r.content)

rel_big = pd.read_csv(rel_big_file, index_col='Date')

# history of top N countries
champs = []
# frame draw function
def animate_graph(i=int):
    N = 10
    # get current values for each country
    last_index = rel_big.index[i]
    # which countries are top N in last_index?
    topN = rel_big.loc[last_index].sort_values(ascending=False).head(N).index.tolist()
    # if country not already in champs, add it
    for c in topN:
        if c not in champs:
            champs.append(c)
    # pull a standard color map from matplotlib
    cmap = plt.get_cmap("tab20")
    # draw legend
    plt.legend(topN)
    # make a temporary dataframe with only top N countries
    rel_plot = rel_big[topN].copy(deep=True)
    # plot temporary dataframe
    p = plt.plot(rel_plot[:i].index, rel_plot[:i].values)
    # set color for each country based on index in champs
    for i in range(0, N):
        p[i].set_color(cmap(champs.index(topN[i]) % 20))

%matplotlib notebook
fig = plt.figure(figsize=(10, 6))
plt.xticks(rotation=45, ha="right", rotation_mode="anchor")
# x ticks get too crowded, limit their number
plt.gca().xaxis.set_major_locator(plt.MaxNLocator(nbins=10))
animator = ani.FuncAnimation(fig, animate_graph, interval = 333)
plt.show()

Это в некоторой степени выполняет свою работу. Я храню лучшие страны в списке чемпионов и назначаю цвета на основе индекса каждой страны в чемпионатах. Но только цвет нанесенных линий назначается правильно, исходя из индекса в champs.

Цвет в легенде назначается жестко: первая страна в легенде всегда имеет один и тот же цвет, вторая страна в легенде всегда получает определенный цвет и т. Д., И в основном цвет каждой страны в легенде меняется на протяжении всей анимации. когда страны перемещаются вверх и вниз в легенде.

Цвета нанесенных линий соответствуют указателю в полях. Цвета стран в легенде зависят от порядка в легенде. Я не этого хочу.

Как назначить цвет каждой стране в легенде так, чтобы он соответствовал линиям сюжета?

2 ответа

Решение

Вот мое решение:

Я удалил ваш код, который генерирует цвета, и установил новый рабочий:

Сначала я инициализировал каждую страну своим уникальным цветом в словаре:

# initializing fixed color to all countries
colorsCountries = {}
for country in rel_big.columns:
    colorsCountries[country] = random.choice(list(mcd.CSS4_COLORS.keys()))

затем я заменил это:

# plot temporary dataframe
p = plt.plot(rel_plot[:i].index, rel_plot[:i].values)

с этим:

# plot temporary dataframe
for keyIndex in rel_plot[:i].keys() :
    p = plt.plot(rel_plot[:i].index,rel_plot[:i][keyIndex].values,color=colorsCountries[keyIndex])

а затем добавил код, который обновляет метку и цвета легенды matplotlib

leg = plt.legend(topN)
for line, text in zip(leg.get_lines(), leg.get_texts()):
    line.set_color(colorsCountries[text.get_text()])

не забудьте добавить импорт:

import matplotlib._color_data as mcd
import random

Вот полное предлагаемое решение:

import pandas as pd
import requests
import os
from matplotlib import pyplot as plt
import matplotlib.animation as ani
import matplotlib._color_data as mcd
import random

rel_big_file = 'rel_big.csv'
rel_big_url = 'https://pastebin.com/raw/bJbDz7ei'

if not os.path.exists(rel_big_file):
    r = requests.get(rel_big_url)
    with open(rel_big_file, 'wb') as f:
        f.write(r.content)

rel_big = pd.read_csv(rel_big_file, index_col='Date')

# history of top N countries
champs = []
# initializing fixed color to all countries
colorsCountries = {}
for country in rel_big.columns:
    colorsCountries[country] = random.choice(list(mcd.CSS4_COLORS.keys()))
# frame draw function
def animate_graph(i=int):
    N = 10
    # get current values for each country
    last_index = rel_big.index[i]
    # which countries are top N in last_index?
    topN = rel_big.loc[last_index].sort_values(ascending=False).head(N).index.tolist()
    # if country not already in champs, add it
    for c in topN:
        if c not in champs:
            champs.append(c)
    # pull a standard color map from matplotlib
    cmap = plt.get_cmap("tab20")
    # draw legend
    plt.legend(topN)
    # make a temporary dataframe with only top N countries
    rel_plot = rel_big[topN].copy(deep=True)
    # plot temporary dataframe
    #### Removed Code
    #p = plt.plot(rel_plot[:i].index, rel_plot[:i].values)
    #### Removed Code
    for keyIndex in rel_plot[:i].keys() :
        p = plt.plot(rel_plot[:i].index,rel_plot[:i][keyIndex].values,color=colorsCountries[keyIndex])
    # set color for each country based on index in champs
    #### Removed Code
    #for i in range(0, N):
        #p[i].set_color(cmap(champs.index(topN[i]) % 20))
    #### Removed Code
    leg = plt.legend(topN)
    for line, text in zip(leg.get_lines(), leg.get_texts()):
        line.set_color(colorsCountries[text.get_text()])

%matplotlib notebook
fig = plt.figure(figsize=(10, 6))
plt.xticks(rotation=45, ha="right", rotation_mode="anchor")
# x ticks get too crowded, limit their number
plt.gca().xaxis.set_major_locator(plt.MaxNLocator(nbins=10))
animator = ani.FuncAnimation(fig, animate_graph, interval = 333)
plt.show()

Ответ ZINE Mahmoud великолепен. Я изменил только одну вещь - мне нужно было детерминированное распределение цветов при каждом прогоне, поэтому вместо случайного метода я назначил цвета таким странам, как эта:

colorsCountries = {}
colorPalette = list(mcd.CSS4_COLORS.keys())
for country in rel_big.columns:
    colorsCountries[country] = colorPalette[rel_big.columns.tolist().index(country) % len(colorPalette)]
Другие вопросы по тегам