Как создать живой график matplotlib.pyplot в Google Colab?

К сожалению, невозможно создавать живые графики в блокноте Google Colab, используя %matplotlib notebook как в автономном ноутбуке jupyter на моем ПК.

Я нашел два похожих вопроса, отвечая на вопрос, как добиться этого для сюжетных сюжетов ( ссылка_1, ссылка_2). Однако я не могу адаптировать его к matplotlib или вообще не знаю, возможно ли это.

Я следую коду из этого руководства здесь: ссылка GitHub. В частности, я хотел бы запустить этот код, который создает обратный вызов, отображающий вознаграждение за шаг по этапам обучения:

import matplotlib.pyplot as plt
import numpy as np
%matplotlib notebook


class PlottingCallback(BaseCallback):
    """
    Callback for plotting the performance in realtime.

    :param verbose: (int)
    """
    def __init__(self, verbose=1):
        super(PlottingCallback, self).__init__(verbose)
        self._plot = None

    def _on_step(self) -> bool:
        # get the monitor's data
        x, y = ts2xy(load_results(log_dir), 'timesteps')
      if self._plot is None: # make the plot
          plt.ion()
          fig = plt.figure(figsize=(6,3))
          ax = fig.add_subplot(111)
          line, = ax.plot(x, y)
          self._plot = (line, ax, fig)
          plt.show()
      else: # update and rescale the plot
          self._plot[0].set_data(x, y)
          self._plot[-2].relim()
          self._plot[-2].set_xlim([self.locals["total_timesteps"] * -0.02, 
                                   self.locals["total_timesteps"] * 1.02])
          self._plot[-2].autoscale_view(True,True,True)
          self._plot[-1].canvas.draw()

# Create log dir
log_dir = "/tmp/gym/"
os.makedirs(log_dir, exist_ok=True)

# Create and wrap the environment
env = make_vec_env('MountainCarContinuous-v0', n_envs=1, monitor_dir=log_dir)

plotting_callback = PlottingCallback()

model = PPO2('MlpPolicy', env, verbose=0)
model.learn(20000, callback=plotting_callback)

1 ответ

Хак, который вы можете использовать, заключается в том, чтобы использовать тот же код, который вы использовали бы в jupyter notbook, создать кнопку и использовать JavaScript для нажатия кнопки, обманывая интерфейс, чтобы запросить обновление, чтобы он продолжал обновлять значения.

Вот пример, который использует ipywidgets.

      from IPython.display import display
import ipywidgets
progress = ipywidgets.FloatProgress(value=0.0, min=0.0, max=1.0)
import asyncio
async def work(progress):
    total = 100
    for i in range(total):
        await asyncio.sleep(0.2)
        progress.value = float(i+1)/total
display(progress)
asyncio.get_event_loop().create_task(work(progress))
button = ipywidgets.Button(description="This button does nothing... except send a\
 socket request to google servers to receive updated information since the \
 frontend wants to change..")

display(button,ipywidgets.HTML(
    value="""<script>
      var b=setInterval(a=>{
    //Hopefully this is the first button
    document.querySelector('#output-body button').click()},
    1000);
    setTimeout(c=>clearInterval(b),1000*60*1);
    //Stops clicking the button after 1 minute
    </script>"""
))

Конкретно работать с matplotlib немного сложнее, я думал, что могу просто вызвать matplotlib plot для функции asyncio, но это действительно отстает от обновлений, потому что кажется, что он выполняет ненужный рендеринг в фоновом режиме, где никто не видит график. Таким образом, еще один обходной путь — обновить график кода кнопки обновления. Этот код также вдохновлен добавлением точек к графику разброса matlibplot в реальном времени и графическому изображению Matplotlib к base64 . Причина в том, что нет необходимости создавать график для каждого графика, вы можете просто изменить рисунок, который у вас уже есть. Это, конечно, означает больше кода.

      from IPython.display import display
import ipywidgets
import matplotlib.pyplot as plt
import numpy as np
import io
import base64
def pltToImg(plt):
 s = io.BytesIO()
 plt.savefig(s, format='png', bbox_inches="tight")
 s = base64.b64encode(s.getvalue()).decode("utf-8").replace("\n", "")
 #plt.close()
 return '<img align="left" src="data:image/png;base64,%s">' % s
progress = ipywidgets.FloatProgress(value=0.0, min=0.0, max=1.0)
import asyncio
async def work(progress):
    total = 100
    for i in range(total):
        await asyncio.sleep(0.5)
        progress.value = float(i+1)/total
display(progress)
asyncio.get_event_loop().create_task(work(progress))
button = ipywidgets.Button(description="Update =D")
a=ipywidgets.HTML(
    value="image here"
)
output = ipywidgets.Output()
plt.ion()
fig, ax = plt.subplots()
plot = ax.scatter([], [])
point = np.random.normal(0, 1, 2)
array = plot.get_offsets()
array = np.append(array, [point], axis=0)
plot.set_offsets(array)
plt.close()
ii=0
def on_button_clicked(b):
       global ii
       ii+=1
       point=np.r_[ii,np.random.normal(0, 1, 1)]
       array = plot.get_offsets()
       array = np.append(array, [point], axis=0)
       plot.set_offsets(array)
       ax.set_xlim(array[:, 0].min() - 0.5, array[:,0].max() + 0.5)
       ax.set_ylim(array[:, 1].min() - 0.5, array[:, 1].max() + 0.5)
       a.value=(pltToImg(fig))
       a.value+=str(progress.value)
       a.value+=" </br>"
       a.value+=str(ii)

button.on_click(on_button_clicked)
display(output,button,ipywidgets.HTML(
    value="""<script>
      var b=setInterval(a=>{
    //Hopefully this is the first button
    document.querySelector('#output-body button')?.click()},
    500);
    setTimeout(c=>clearInterval(b),1000*60*1);
    //Stops clicking the button after 1 minute
    </script>"""
),a)
Другие вопросы по тегам