Как можно равномерно расширить список, включив экстраполированные средние значения?

У меня есть модуль Python, который предоставляет цветовые палитры и утилиты для работы с ними. Объект цветовой палитры просто наследуется от list и это просто список цветов, указанных в строках HEX. Объект цветовой палитры имеет возможность расширяться, чтобы обеспечить столько цветов, сколько необходимо. Представьте себе граф с множеством представляемых наборов данных: палитру можно попросить увеличить количество цветов до степени, необходимой для обеспечения уникальных цветов для каждого набора данных графика. Он делает это, просто беря среднее значение смежных цветов и вставляя этот новый средний цвет.

extend_palette Функция работает, но не расширяет палитру равномерно. Например, палитра может выглядеть следующим образом:

Расширение до 15 цветов все еще можно использовать:

Расширение до 30 цветов делает проблему с алгоритмом расширения очевидной; новые цвета добавляются только в одном конце списка цветов:

Как должна функционировать extend_palette модуля, чтобы сделать расширенные новые цвета более равномерно распределенными в палитре?

Код следует (с функцией extend_palette особое внимание и другие фрагменты кода для удобства экспериментов):

def clamp(x): 
    return max(0, min(x, 255))

def RGB_to_HEX(RGB_tuple):
    # This function returns a HEX string given an RGB tuple.
    r = RGB_tuple[0]
    g = RGB_tuple[1]
    b = RGB_tuple[2]
    return "#{0:02x}{1:02x}{2:02x}".format(clamp(r), clamp(g), clamp(b))

def HEX_to_RGB(HEX_string):
    # This function returns an RGB tuple given a HEX string.
    HEX = HEX_string.lstrip("#")
    HEX_length = len(HEX)
    return tuple(
        int(HEX[i:i + HEX_length // 3], 16) for i in range(
            0,
            HEX_length,
            HEX_length // 3
        )
    )

def mean_color(colors_in_HEX):
    # This function returns a HEX string that represents the mean color of a
    # list of colors represented by HEX strings.
    colors_in_RGB = []
    for color_in_HEX in colors_in_HEX:
        colors_in_RGB.append(HEX_to_RGB(color_in_HEX))
    sum_r = 0
    sum_g = 0
    sum_b = 0
    for color_in_RGB in colors_in_RGB:
        sum_r += color_in_RGB[0]
        sum_g += color_in_RGB[1]
        sum_b += color_in_RGB[2]
    mean_r = sum_r / len(colors_in_RGB)
    mean_g = sum_g / len(colors_in_RGB)
    mean_b = sum_b / len(colors_in_RGB)
    return RGB_to_HEX((mean_r, mean_g, mean_b))

class Palette(list):

    def __init__(
        self,
        name        = None, # string name
        description = None, # string description
        colors      = None, # list of colors
        *args
        ):
        super(Palette, self).__init__(*args)
        self._name          = name
        self._description   = description
        self.extend(colors)

    def name(
        self
        ):
        return self._name

    def set_name(
        self,
        name = None
        ):
        self._name = name

    def description(
        self
        ):
        return self._description

    def set_description(
        self,
        description = None
        ):
        self._description = description

    def extend_palette(
        self,
        minimum_number_of_colors_needed = 15
        ):
        colors = extend_palette(
            colors = self,
            minimum_number_of_colors_needed = minimum_number_of_colors_needed
        )
        self = colors

    def save_image_of_palette(
        self,
        filename = "palette.png"
        ):
        save_image_of_palette(
            colors   = self,
            filename = filename
        )

def extend_palette(
    colors = None, # list of HEX string colors
    minimum_number_of_colors_needed = 15
    ):
    while len(colors) < minimum_number_of_colors_needed:
        for index in range(1, len(colors), 2):
            colors.insert(index, mean_color([colors[index - 1], colors[index]]))
    return colors

def save_image_of_palette(
    colors   = None, # list of HEX string colors
    filename = "palette.png"
    ):
    import numpy
    import Image
    scale_x = 200
    scale_y = 124
    data = numpy.zeros((1, len(colors), 3), dtype = numpy.uint8)
    index = -1
    for color in colors:
        index += 1
        color_RGB = HEX_to_RGB(color)
        data[0, index] = [color_RGB[0], color_RGB[1], color_RGB[2]]
    data = numpy.repeat(data, scale_x, axis=0)
    data = numpy.repeat(data, scale_y, axis=1)
    image = Image.fromarray(data)
    image.save(filename)

# Define color palettes.
palettes = []
palettes.append(Palette(
    name        = "palette1",
    description = "primary colors for white background",
    colors      = [
                  "#fc0000",
                  "#ffae3a",
                  "#00ac00",
                  "#6665ec",
                  "#a9a9a9",
                  ]
))
palettes.append(Palette(
    name        = "palette2",
    description = "ATLAS clarity",
    colors      = [
                  "#FEFEFE",
                  "#AACCFF",
                  "#649800",
                  "#9A33CC",
                  "#EE2200",
                  ]
))

def save_images_of_palettes():
    for index, palette in enumerate(palettes):
        save_image_of_palette(
            colors   = palette,
            filename = "palette_{index}.png".format(index = index + 1)
        )

def access_palette(
    name = "palette1"
    ):
    for palette in palettes:
        if palette.name() == name:
            return palette
    return None

2 ответа

Решение

Я думаю, что проблему, с которой вы столкнулись, легче понять, если вы начнете с упрощенного примера:

nums = [1, 100]

def extend_nums(nums, min_needed):
    while len(nums) < min_needed:
        for index in range(1, len(nums), 2):
            nums.insert(index, mean(nums[index - 1], nums[index]))
    return nums


def mean(x, y):
    return (x + y) / 2

Здесь я скопировал ваш код, но использовал цифры вместо цветов, чтобы упростить задачу. Вот что происходит, когда я запускаю его:

>>> nums = [0, 100]
>>> extend_nums(nums, 5)
[0, 12.5, 25.0, 37.5, 50.0, 100]

Что мы имеем здесь?

  • 50 - это среднее от 0 до 100.
  • 25 - среднее значение от 0 до 50.
  • 12,5 - среднее значение от 0 до 25.
  • 37,5 - это среднее от 25 до 50.

Странно, не правда ли? Ну нет: я модифицирую nums на месте. Значение index в for-переключения, когда я вставляю новые элементы: nums[3] изменения до и после nums.insert(1, something),

Давайте попробуем создать новый список на каждой итерации:

def extend_nums(nums, min_needed):
    while len(nums) < min_needed:
        new_nums = []  # This new list will hold the extended nums.
        for index in range(1, len(nums)):
            new_nums.append(nums[index - 1])
            new_nums.append(mean(nums[index - 1], nums[index]))
        new_nums.append(nums[-1])
        nums = new_nums
    return nums

Давай попробуем:

>>> nums = [0, 100]
>>> extend_nums(nums, 5)
[0, 25.0, 50.0, 75.0, 100]

Это решение работает (есть возможности для улучшения). Зачем? Потому что в нашем новом for-loop, index имеет правильное значение. Ранее мы вставляли элементы без смещения index,

Этот код

while len(colors) < minimum_number_of_colors_needed:
    for index in range(1, len(colors), 2):
        colors.insert(index, mean_color([colors[index - 1], colors[index]]))

не распределяет средние цвета равномерно. Вы можете увидеть эффект, запустив:

colors = range(5)
while len(colors) < 15:
    for index in range(1, len(colors), 2):
        colors.insert(index, 99)
print(colors)

который дает

[0, 99, 99, 99, 99, 99, 99, 99, 1, 99, 99, 99, 2, 3, 4]

Слишком много средств, представленных 99-ми, расположены ближе к началу, и ни одно к концу.


К счастью, так как у вас есть NumPy, вы можете использовать np.interp равномерно интерполировать цвета. Например, если у вас есть функция с точками данных (0, 10), (0.5, 20), (1, 30), то вы можете интерполировать в x = [0, 0.33, 0.67, 1], чтобы найти соответствующий y ценности:

In [80]: np.interp([0, 0.33, 0.67, 1], [0, 0.5, 1], [10, 20, 30])
Out[80]: array([ 10. ,  16.6,  23.4,  30. ])

поскольку np.interp работает только с одномерными массивами, мы можем применить его к каждому каналу RGB отдельно:

[np.interp(np.linspace(0,1,min_colors), np.linspace(0,1,ncolors), self.rgb[:,i]) 
 for i in range(nchannels)])

Например,

import numpy as np
import Image

def RGB_to_HEX(RGB_tuple):
    """
    Return a HEX string given an RGB tuple.
    """
    return "#{0:02x}{1:02x}{2:02x}".format(*np.clip(RGB_tuple, 0, 255))


def HEX_to_RGB(HEX_string):
    """
    Return an RGB tuple given a HEX string.
    """
    HEX = HEX_string.lstrip("#")
    HEX_length = len(HEX)
    return tuple(
        int(HEX[i:i + HEX_length // 3], 16) for i in range(
            0,
            HEX_length,
            HEX_length // 3 ))

class Palette(object):

    def __init__(self, name=None, description=None, colors=None, *args):
        super(Palette, self).__init__(*args)
        self.name = name
        self.description = description
        self.rgb = np.array(colors)

    @classmethod
    def from_hex(cls, name=None, description=None, colors=None, *args):
        colors = np.array([HEX_to_RGB(c) for c in colors])
        return cls(name, description, colors, *args)

    def to_hex(self):
        return [RGB_to_HEX(color) for color in self.rgb]

    def extend_palette(self, min_colors=15):
        ncolors, nchannels = self.rgb.shape
        if ncolors >= min_colors:
            return self.rgb

        return np.column_stack(
            [np.interp(
                np.linspace(0,1,min_colors), np.linspace(0,1,ncolors), self.rgb[:,i]) 
             for i in range(nchannels)])

def save_image_of_palette(rgb, filename="palette.png"):
    scale_x = 200
    scale_y = 124
    data = (np.kron(rgb[np.newaxis,...], np.ones((scale_x, scale_y, 1)))
            .astype(np.uint8))
    image = Image.fromarray(data)
    image.save(filename)


# Define color palettes.
palettes = []
palettes.append(Palette.from_hex(
    name="palette1",
    description="primary colors for white background",
    colors=[
        "#fc0000",
        "#ffae3a",
        "#00ac00",
        "#6665ec",
        "#a9a9a9", ]))
palettes.append(Palette.from_hex(
    name="palette2",
    description="ATLAS clarity",
    colors=[
        "#FEFEFE",
        "#AACCFF",
        "#649800",
        "#9A33CC",
        "#EE2200",]))
palettes = {p.name:p for p in palettes}


p = palettes['palette1']
save_image_of_palette(p.extend_palette(), '/tmp/out.png')

доходность


Обратите внимание, что интерполяция в цветовом пространстве HSV (а не в цветовом пространстве RGB) дает лучшие результаты.

Другие вопросы по тегам