Классификация нескольких снимков с помощью Reptile Keras
Я пытаюсь использовать пример обучения с несколькими выстрелами, предоставленный Keras, с использованием алгоритма Reptile ( https://keras.io/examples/vision/reptile/ ) с моей собственной базой данных. Но после загрузки базы данных я получаю сообщение об ошибке, когда хочу визуализировать некоторые примеры из базы данных.
строка, в которой у меня проблема:
sample_keys = list(train_dataset())
он выдает следующее:TypeError: объект «Набор данных» не вызывается
может кто-нибудь помочь мне исправить это, так как я новичок в пространстве ML и все еще учусь.
моя база данных 3 класса все разделены по папкам с именами. каждый класс имеет 20 изображений RGB, всего 60.
Код с моими изменениями:
learning_rate = 0.003
meta_step_size = 0.25
inner_batch_size = 15
eval_batch_size = 15
meta_iters = 2000
eval_iters = 5
inner_iters = 4
eval_interval = 1
train_shots = 20
shots = 5
classes = 3
class Dataset:
def __init__(self, training):
# Download the tfrecord files containing the omniglot data and convert to a
# dataset.
split = "train" if training else "test"
images_ds = tf.data.Dataset.list_files('packages/*/*', shuffle=False)
image_count = len (images_ds)
#label=["Box","Cardboard","Plastic_bag"]
ds = images_ds.take(image_count)
# Iterate over the dataset to get each individual image and its class,
# and put that data into a dictionary.
self.data = {}
def get_label(file_path):
return tf.strings.split(file_path, os.path.sep)[-2]
def extraction(file_path):
# This function will shrink the Omniglot images to the desired size,
# scale pixel values and convert the RGB image to grayscale
label = get_label(file_path)
img = tf.io.read_file(file_path)
image = tf.image.decode_jpeg(img)
image = tf.image.convert_image_dtype(image, tf.float32)
#image = tf.image.rgb_to_grayscale(image)
image = tf.image.resize(image, [800, 800])
return image, label
for image, label in images_ds.map(extraction):
image = image.numpy()
label = str(label.numpy())
if label not in self.data:
self.data[label] = []
self.data[label].append(image)
self.labels = list(self.data.keys())
def get_mini_dataset(
self, batch_size, repetitions, shots, num_classes, split=False
):
temp_labels = np.zeros(shape=(num_classes * shots))
temp_images = np.zeros(shape=(num_classes * shots, 800, 800, 1))
if split:
test_labels = np.zeros(shape=(num_classes))
test_images = np.zeros(shape=(num_classes, 800, 800, 1))
# Get a random subset of labels from the entire label set.
label_subset = random.choices(self.labels, k=num_classes)
for class_idx, class_obj in enumerate(label_subset):
# Use enumerated index value as a temporary label for mini-batch in
# few shot learning.
temp_labels[class_idx * shots : (class_idx + 1) * shots] = class_idx
# If creating a split dataset for testing, select an extra sample from each
# label to create the test dataset.
if split:
test_labels[class_idx] = class_idx
images_to_split = random.choices(
self.data[label_subset[class_idx]], k=shots + 1
)
test_images[class_idx] = images_to_split[-1]
temp_images[
class_idx * shots : (class_idx + 1) * shots
] = images_to_split[:-1]
else:
# For each index in the randomly selected label_subset, sample the
# necessary number of images.
temp_images[
class_idx * shots : (class_idx + 1) * shots
] = random.choices(self.data[label_subset[class_idx]], k=shots)
dataset = tf.data.Dataset.from_tensor_slices(
(temp_images.astype(np.float32), temp_labels.astype(np.int32))
)
dataset = dataset.shuffle(60).batch(batch_size).repeat(repetitions)
if split:
return dataset, test_images, test_labels
return dataset
import urllib3
urllib3.disable_warnings() # Disable SSL warnings that may happen during download.
train_dataset = Dataset(training=True)
test_dataset = Dataset(training=False)
_, axarr = plt.subplots(nrows=5, ncols=5, figsize=(20, 20))
sample_keys = list(train_dataset()) #Here i get the error
for a in range(5):
for b in range(5):
temp_image = train_dataset[sample_keys[a]][b]
temp_image = np.stack((temp_image[:, :, 0],) * 3, axis=2)
temp_image *= 255
temp_image = np.clip(temp_image, 0, 255).astype("uint8")
if b == 2:
axarr[a, b].set_title("Class : " + sample_keys[a])
axarr[a, b].imshow(temp_image)#, cmap="rgb")
axarr[a, b].xaxis.set_visible(False)
axarr[a, b].yaxis.set_visible(False)
plt.show()
Ошибка:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[23], line 5
1 ###################### Visualize some examples from the dataset ####################################
3 _, axarr = plt.subplots(nrows=5, ncols=5, figsize=(20, 20))
----> 5 sample_keys = list(train_dataset())
7 for a in range(5):
8 for b in range(5):
TypeError: 'Dataset' object is not callable
я тоже пробовал заменить
метка = ул(метка.numpy())
с
label = label, поскольку он уже находится в строковом формате, но затем я также получаю сообщение об ошибке, которое я все еще пытаюсь понять, почему