Как загрузить контрольные точки в разных версиях pytorch (1.3.1 и 1.6.x) с помощью ppc64le и x86?
Как я уже отмечал здесь, я застрял в использовании старых версий pytorch и torchvision из-за оборудования, например, с использованием архитектур IBM ppc64le.
По этой причине у меня возникают проблемы при отправке и получении контрольных точек между разными компьютерами, кластерами и моим личным Mac. Интересно, есть ли способ загрузить модели, чтобы избежать этой проблемы? например, возможно сохранение моделей в старом и новом формате при использовании 1.6.x. Конечно, с 1.3.1 на 1.6.x это невозможно, но, по крайней мере, я надеялся, что что-то сработает.
Любой совет? Конечно, мое идеальное решение состоит в том, что мне не нужно об этом беспокоиться, и я всегда могу загрузить и сохранить свои контрольные точки и все, что я обычно обрабатываю, равномерно на всем моем оборудовании.
Первой моей ошибкой была ошибка zip jit:
RuntimeError: /home/miranda9/data/f.pt is a zip archive (did you mean to use torch.jit.load()?)
поэтому я использовал это (и другие библиотеки рассола):
# %%
import torch
from pathlib import Path
def load(path):
import torch
import pickle
import dill
path = str(path)
try:
db = torch.load(path)
f = db['f']
except Exception as e:
db = torch.jit.load(path)
f = db['f']
#with open():
# db = pickle.load(open(path, "r+"))
# db = dill.load(open(path, "r+"))
#raise ValueError(f'FAILED: {e}')
return db, f
p = "~/data/f.pt"
path = Path(p).expanduser()
db, f = load(path)
Din, nb_examples = 1, 5
x = torch.distributions.Normal(loc=0.0, scale=1.0).sample(sample_shape=(nb_examples, Din))
y = f(x)
print(y)
print('Success!\a')
но я получаю жалобы на разные версии pytorch, которые я вынужден использовать:
Traceback (most recent call last):
File "hal_pg.py", line 27, in <module>
db, f = load(path)
File "hal_pg.py", line 16, in load
db = torch.jit.load(path)
File "/home/miranda9/.conda/envs/wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/jit/__init__.py", line 239, in load
cpp_module = torch._C.import_ir_module(cu, f, map_location, _extra_files)
RuntimeError: version_number <= kMaxSupportedFileFormatVersion INTERNAL ASSERT FAILED at /opt/anaconda/conda-bld/pytorch-base_1581395437985/work/caffe2/serialize/inline_container.cc:131, please report a bug to PyTorch. Attempted to read a PyTorch file with version 3, but the maximum supported version for reading is 1. Your PyTorch installation may be too old. (init at /opt/anaconda/conda-bld/pytorch-base_1581395437985/work/caffe2/serialize/inline_container.cc:131)
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xbc (0x7fff7b527b9c in /home/miranda9/.conda/envs/wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: caffe2::serialize::PyTorchStreamReader::init() + 0x1d98 (0x7fff1d293c78 in /home/miranda9/.conda/envs/wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #2: caffe2::serialize::PyTorchStreamReader::PyTorchStreamReader(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x88 (0x7fff1d2950d8 in /home/miranda9/.conda/envs/wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #3: torch::jit::import_ir_module(std::shared_ptr<torch::jit::script::CompilationUnit>, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, c10::optional<c10::Device>, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > > >&) + 0x64 (0x7fff1e624664 in /home/miranda9/.conda/envs/wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #4: <unknown function> + 0x70e210 (0x7fff7c0ae210 in /home/miranda9/.conda/envs/wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #5: <unknown function> + 0x28efc4 (0x7fff7bc2efc4 in /home/miranda9/.conda/envs/wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
<omitting python frames>
frame #26: <unknown function> + 0x25280 (0x7fff84b35280 in /lib64/libc.so.6)
frame #27: __libc_start_main + 0xc4 (0x7fff84b35474 in /lib64/libc.so.6)
есть идеи, как сделать все согласованным в кластерах? Я даже не могу открыть файлы с рассолом.
возможно, это просто невозможно с текущей версией pytorch, которую я вынужден использовать:(
RuntimeError: version_number <= kMaxSupportedFileFormatVersion INTERNAL ASSERT FAILED at /opt/anaconda/conda-bld/pytorch-base_1581395437985/work/caffe2/serialize/inline_container.cc:131, please report a bug to PyTorch. Attempted to read a PyTorch file with version 3, but the maximum supported version for reading is 1. Your PyTorch installation may be too old. (init at /opt/anaconda/conda-bld/pytorch-base_1581395437985/work/caffe2/serialize/inline_container.cc:131)
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xbc (0x7fff83ba7b9c in /home/miranda9/.conda/envs/automl-meta-learning_wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: caffe2::serialize::PyTorchStreamReader::init() + 0x1d98 (0x7fff25993c78 in /home/miranda9/.conda/envs/automl-meta-learning_wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #2: caffe2::serialize::PyTorchStreamReader::PyTorchStreamReader(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x88 (0x7fff259950d8 in /home/miranda9/.conda/envs/automl-meta-learning_wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #3: torch::jit::import_ir_module(std::shared_ptr<torch::jit::script::CompilationUnit>, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, c10::optional<c10::Device>, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > > >&) + 0x64 (0x7fff26d24664 in /home/miranda9/.conda/envs/automl-meta-learning_wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #4: <unknown function> + 0x70e210 (0x7fff8472e210 in /home/miranda9/.conda/envs/automl-meta-learning_wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #5: <unknown function> + 0x28efc4 (0x7fff842aefc4 in /home/miranda9/.conda/envs/automl-meta-learning_wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
<omitting python frames>
frame #23: <unknown function> + 0x25280 (0x7fff8d335280 in /lib64/libc.so.6)
frame #24: __libc_start_main + 0xc4 (0x7fff8d335474 in /lib64/libc.so.6)
используя код:
from pathlib import Path
import torch
path = '/home/miranda9/data/dataset/'
path = Path(path).expanduser() / 'fi_db.pt'
path = str(path)
# db = torch.load(path)
# torch.jit.load(path)
db = torch.jit.load(str(path))
print(db)
Ссылки по теме:
- Как загрузить контрольные точки в разных версиях pytorch (1.3.1 и 1.6.x) с помощью ppc64le и x86?
- https://discuss.pytorch.org/t/how-to-load-checkpoints-across-different-versions-of-pytorch-1-3-1-and-1-6-x-using-ppc64le-and-x86/97829
- связанный gitissue: https://github.com/pytorch/pytorch/issues/43766
- reddit: https://www.reddit.com/r/pytorch/comments/jvza7v/how_to_load_checkpoints_across_different_versions/
4 ответа
Я считаю, что разработчики намерены передать флаг для сохранения в виде рассола. Просто изменение поведения по умолчанию.
Для файлов с ранее установленными контрольными точками перезагрузите zip-файл с сохраненными весами в новом env (с pytorch> = 1.6), а затем снова проверьте контрольную точку как рассол (нет необходимости повторно обучать);
обновите свой код и добавьте флаг со следующего раза
Мы переключили torch.save на использование формата на основе zip-файла по умолчанию, а не на старый формат на основе Pickle. torch.load сохранил возможность загрузки старого формата, но рекомендуется использовать новый формат. Новый формат:
более удобный для проверки и создания инструментов для управления файлами сохранения, исправляет давнюю проблему, когда функции сериализации ( getstate , setstate ) в модулях, зависящие от сериализованных значений Tensor, получали неправильные данные, такие же, как формат сериализации TorchScript, делая сериализацию более сложной. согласован в PyTorch
Использование следующее:
m = MyMod()
torch.save(m.state_dict(), 'mymod.pt') # Saves a zipfile to mymod.pt
Чтобы использовать старый формат, передайте флаг
_use_new_zipfile_serialization=False
m = MyMod()
torch.save(m.state_dict(), 'mymod.pt', _use_new_zipfile_serialization=False) # Saves pickle
Это не идеальное решение, но оно работает для переноса контрольных точек из более новых версий в более старые версии.
Я также использую ppc64le и сталкиваюсь с теми же проблемами. Можно сохранить модель в текстовом формате, который может быть прочитан любой версией PyTorch. У меня установлен PyTorch v1.3.0 на машине ppc64le и v1.7.0 на моем ноутбуке (для которого не требуется видеокарта).
Шаг 1. Сохраните модель в более новой версии PyTorch.
def save_model_txt(model, path):
fout = open(path, 'w')
for k, v in model.state_dict().items():
fout.write(str(k) + '\n')
fout.write(str(v.tolist()) + '\n')
fout.close()
Перед сохранением загружаю модель вот так
checkpoint = torch.load(path, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint, strict=False)
Шаг 2. Перенесите текстовый файл.
Шаг 3. Загрузите текстовый файл в старый PyTorch.
def load_model_txt(model, path):
data_dict = {}
fin = open(path, 'r')
i = 0
odd = 1
prev_key = None
while True:
s = fin.readline().strip()
if not s:
break
if odd:
prev_key = s
else:
print('Iter', i)
val = eval(s)
if type(val) != type([]):
data_dict[prev_key] = torch.FloatTensor([eval(s)])[0]
else:
data_dict[prev_key] = torch.FloatTensor(eval(s))
i += 1
odd = (odd + 1) % 2
# Replace existing values with loaded
print('Loading...')
own_state = model.state_dict()
print('Items:', len(own_state.items()))
for k, v in data_dict.items():
if not k in own_state:
print('Parameter', k, 'not found in own_state!!!')
else:
try:
own_state[k].copy_(v)
except:
print('Key:', k)
print('Old:', own_state[k])
print('New:', v)
sys.exit(0)
print('Model loaded')
Перед загрузкой модель должна быть инициализирована. Пустая модель передается в функцию.
Ограничения
Если ваша модель state_dict содержит что-то еще, кроме значений (str: torch.Tensor), этот метод не будет работать. Вы можете проверить содержимое вашего state_dict с помощью
for k, v in model.state_dict().items():
...
Прочтите это для понимания:
https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_models_for_inference.html
https://discuss.pytorch.org/t/how-to-load-part-of-pre-trained-model/1113
Основываясь на ответе @maxim velikanov, я создал отдельный OrderedDict, где ключи такие же, как и исходное состояние dict для модели, но каждое значение тензора преобразуется в список.
Этот OrderedDict - это они выгружены в файл JSON.
def save_model_json(model, path):
actual_dict = OrderedDict()
for k, v in model.state_dict().items():
actual_dict[k] = v.tolist()
with open(path, 'w') as f:
json.dump(actual_dict, f)
Затем загрузчик может загрузить файл как JSON, и каждый список / целое число будет преобразовано обратно в Tensor, прежде чем его значения будут скопированы в исходное состояние dict.
def load_model_json(model, path):
data_dict = OrderedDict()
with open(path, 'r') as f:
data_dict = json.load(f)
own_state = model.state_dict()
for k, v in data_dict.items():
print('Loading parameter:', k)
if not k in own_state:
print('Parameter', k, 'not found in own_state!!!')
if type(v) == list or type(v) == int:
v = torch.tensor(v)
own_state[k].copy_(v)
model.load_state_dict(own_state)
print('Model loaded')
У меня возникла аналогичная проблема при загрузке обработанных данных. Я ранее сохранял данные в torch 1.8 как «xxx.pt», но загрузил их в torch 1.2. Мне не удалось его загрузить даже с помощью torch.jit.load(). Мое единственное решение - снова сохранить данные в более старой версии.