Невозможно импортировать jax-пакет python в Google TPU
Я работаю над консолью linux, и при вводе python я попадаю в консоль python. Когда я использую следующую команду на машине TPU
import jax
затем он генерирует следующую команду mss и выходит из командной строки python.
paramjeetsingh80@t1v-n-1c883486-w-0:~$ python3
Python 3.8.5 (default, Jan 27 2021, 15:41:15)
[GCC 9.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
2021-07-08 17:41:39.660523: F external/org_tensorflow/tensorflow/core/tpu/tpu_executor_init_fns.inc:110] TpuTransferManager_ReadDynamicShapes not available in this library.
Aborted (core dumped)
paramjeetsingh80@t1v-n-1c883486-w-0:~$
Эта проблема вызывает проблему в моем коде, поэтому я хотел бы выяснить, что это за проблема и как от нее избавиться?
3 ответа
Возможно, в вашей системе не установлена правильная версия libtpu. Попробуйте установить версию, указанную здесь .
Вы должны иметь возможность делать это автоматически с помощью
$ pip install -U jax
$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_releases.html
Выше команда дает некоторую ошибку, но я исследовал, и команда ниже работала для меня. Но ваш ответ дал мне понять, что это проблема пакета.
pip install --upgrade pip
pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Эта проблема решена, и теперь возникает другая ошибка, описанная ниже, и скрипт зависает.
(pid=9454, ip=10.164.0.9) jax runtime initialization starting
2021-07-09 12:11:59,794 ERROR worker.py:78 -- Unhandled error (suppress with RAY_IGNORE_UNHANDLED_ERRORS=1): ray::NetworkRunner.run() (pid=9454, ip=10.164.0.9)
File "python/ray/_raylet.pyx", line 501, in ray._raylet.execute_task
File "python/ray/_raylet.pyx", line 451, in ray._raylet.execute_task.function_executor
File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/ray/_private/function_manager.py", line 563, in actor_method_executor
return method(__ray_actor, *args, **kwargs)
File "/home/paramjeetsingh80/mesh-transformer-jax/mesh_transformer/train_actor.py", line 24, in run
TypeError: __new__() missing 1 required positional argument: 'loops'