Идентичное случайное кадрирование на двух изображениях, преобразованных Pytorch

Я пытаюсь передать два изображения в сеть и хочу выполнить идентичное преобразование между этими двумя изображениями. transforms.Compose()берет одно изображение за раз и производит вывод независимо друг от друга, но мне нужно одно и то же преобразование. Я сам написал код дляhflip()теперь мне интересно получить случайный урожай. Есть ли способ сделать это без написания пользовательских функций?

2 ответа

Решение

Я бы использовал такой обходной путь - сделайте свой собственный класс урожая, унаследованный от RandomCrop, переопределив вызов с помощью

…
    if self.call_is_even :
        self.ijhw = self.get_params(img, self.size)
    i, j, h, w = self.ijhw
    self.call_is_even = not self.call_is_even

вместо того

i, j, h, w = self.get_params(img, self.size)

Идея состоит в том, чтобы подавить рандомизатор при нечетных вызовах

PtrBlck рекомендует использовать функциональный API pytorch для создания собственного преобразования, которое будет делать то, что вы хотите [1] , однако я думаю, что в большинстве случаев есть более чистый способ:

Если изображение представляет собой тензор факела, ожидается, что оно будет иметь форму […, H, W], где … означает произвольное количество ведущих измерений
torchvision.RandomCrop .

Вы можете размещать изображения по размеру канала. (Возможно, это даже означает, что вы можете складывать изображения в новом измерении). Таким образом, вы можете применить преобразование один раз ко всем изображениям одновременно.


Приложение для дополнительных изображений

Для двух изображений это работает нормально. Для списков изображений вы можете сделать то же самое. Но если у вас есть вложенные списки изображений, это становится раздражающим.
Для этого варианта использования я написал функцию, которая работает с вложенными списками (глубиной ровно 2). Обратите внимание, что этот подход работает только с тензорами, а не с изображениями PIL, поэтому сначала он преобразует изображения pil в тензоры, если вы не запретите это делать — пример в тестовой функции ниже:

      def apply_same_transform_to_all_images(transform: torch.nn.Module, images: List[List[Image.Image]], 
                                       to_ten: torch.nn.Module = None) -> List[List[torch.Tensor]]:
    """
        applies the transform to all images in the nested list - in a single run.
        Useful e.g. for consistent RandomCrop.
        Note that this will call torchvision.ToTensor() on your images unless you set `to_ten`!
          That rescales values. It also means the passed-in transform should assume tensor inputs, not PIL inputs.
    """
    if to_ten is None: to_ten = torchvision.transforms.ToTensor()

    # traverse the list to collect all images
    all_images = list()
    for outer_idx, inner_list in enumerate(images):
        for _inner_idx, image in enumerate(inner_list):
            all_images.append(to_ten(image))
    # stack all images in a new dimension
    stacked = torch.stack(all_images, dim=0)
    transformed = transform(stacked)
    transformed_iter = iter(transformed)

    # undo the traversing in the same order.
    output_nested_list = list()
    for outer_idx, inner_list in enumerate(images):
        output_nested_list.append(list())
        for _inner_idx, _image in enumerate(inner_list):
            output_nested_list[outer_idx].append(next(transformed_iter))

    return output_nested_list


def test_apply_same_transform_to_all_images():
    import torch, utils
    from torchvision.transforms import RandomCrop, Compose
    identity_transform = Compose([]) # we already have tensors
    img1 = torch.arange(4*4*3).reshape((3,4,4)) # image of shape 4x4 with 3 channels
    img2 = img1 * 2
    crop_transform = RandomCrop((2,2))
    result = apply_same_transform_to_all_images(crop_transform, to_ten = identity_transform, images = [[img1], [img1, img2]])
    assert torch.allclose(result[0][0], result[1][0])
    assert torch.allclose(result[0][0] * 2, result[1][1])


if __name__ == "__main__":
    # just for debugging
    test_apply_same_transform_to_all_images()
Другие вопросы по тегам