Вывод tflite предсказывает только одну метку, несмотря на обучение меток нескольких классов
Я обучил мультиклассовый классификатор для распознавания речи с использованием тензорного потока. Затем преобразовал модель с помощью конвертера tflite. Модель может предсказывать, но всегда выводит один класс. Я предполагаю, что проблема связана с кодом вывода, потому что модель .h5 может предсказать мультикласс без каких-либо проблем. Я искал в Интернете несколько дней для некоторого понимания, но я не могу понять это. Вот мой код. Любые предложения будут действительно оценены.
import sounddevice as sd
import numpy as np
import scipy.signal
import timeit
import python_speech_features
import tflite_runtime.interpreter as tflite
import importlib
# Parameters
debug_time = 0
debug_acc = 0
word_threshold = 0.95
rec_duration = 0.5 # 0.5
sample_length = 0.5
window_stride = 0.5 # 0.5
sample_rate = 8000 # The mic requires at least 44100 Hz to work
resample_rate = 8000
num_channels = 1
num_mfcc = 16
model_path = 'model.tflite'
mfccs_old = np.zeros((32, 25))
# Load model (interpreter)
interpreter = tflite.Interpreter(model_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(input_details)
# Filter and downsample
def decimate(signal, old_fs, new_fs):
# Check to make sure we're downsampling
if new_fs > old_fs:
print("Error: target sample rate higher than original")
return signal, old_fs
# Downsampling is possible only by an integer factor
dec_factor = old_fs / new_fs
if not dec_factor.is_integer():
print("Error: can only downsample by integer factor")
# Do decimation
resampled_signal = scipy.signal.decimate(signal, int(dec_factor))
return resampled_signal, new_fs
# Callback that gets called every 0.5 seconds
def sd_callback(rec, frames, time, status):
# Start timing for debug purposes
start = timeit.default_timer()
# Notify errors
if status:
print('Error:', status)
global mfccs_old
# Compute MFCCs
mfccs = python_speech_features.base.mfcc(rec,
samplerate=resample_rate,
winlen=0.02,
winstep=0.02,
numcep=num_mfcc,
nfilt=26,
nfft=512, # 2048
preemph=0.0,
ceplifter=0,
appendEnergy=True,
winfunc=np.hanning)
delta = python_speech_features.base.delta(mfccs, 2)
mfccs_delta = np.append(mfccs, delta, axis=1)
mfccs_new = mfccs_delta.transpose()
mfccs = np.append(mfccs_old, mfccs_new, axis=1)
# mfccs = np.insert(mfccs, [0], 0, axis=1)
mfccs_old = mfccs_new
# Run inference and make predictions
in_tensor = np.float32(mfccs.reshape(1, mfccs.shape[0], mfccs.shape[1], 1))
interpreter.set_tensor(input_details[0]['index'], in_tensor)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
val = np.amax(output_data) # DEFINED FOR BINARY CLASSIFICATION, CHANGE TO MULTICLASS
ind = np.where(output_data == val)
prediction = ind[1].astype(int)
if val > word_threshold:
print('index:', ind[1])
print('accuracy', val, '/n')
print(int(prediction))
if debug_acc:
# print('accuracy:', val)
# print('index:', ind[1])
print('out tensor:', output_data)
if debug_time:
print(timeit.default_timer() - start)
# Start recording from microphone
with sd.InputStream(channels=num_channels,
samplerate=sample_rate,
blocksize=int(sample_rate * rec_duration),
callback=sd_callback):
while True:
pass
1 ответ
Поскольку я разобрался с проблемой, я сам отвечаю на нее, если другие сочтут это полезным.
Проблема в том, что в вашем наборе данных нет класса «фоновый шум». Также убедитесь, что у вас достаточно данных для фоновых шумов. Если вы посмотрите на аудиопроект обучаемой машины Google ( https://teachablemachine.withgoogle.com/train/audio ), класс «фоновый шум» уже существует, вы не можете удалить или отключить класс.
Я протестировал оба кода, представленные в примере github tensorflow ( https://github.com/tensorflow/examples/blob/master/lite/examples/sound_classification/raspberry_pi/classify.py ) и на веб-сайте tensorflow ( https://www. tensorflow.org/tutorials/audio/simple_audio ). Они оба хорошо работают для вашего прогноза, если в вашем наборе данных достаточно образцов фонового шума с учетом конкретной среды, в которой вы его тестируете.
Я внес небольшие изменения в код github tensorflow, чтобы вывести имя категории и показатель достоверности категории.
# Loop until the user close the classification results plot.
while True:
# Wait until at least interval_between_inference seconds has passed since
# the last inference.
now = time.time()
diff = now - last_inference_time
if diff < interval_between_inference:
time.sleep(pause_time)
continue
last_inference_time = now
# Load the input audio and run classify.
tensor_audio.load_from_audio_record(audio_record)
result = classifier.classify(tensor_audio)
for category in result.classifications[0].categories:
print(category.category_name, category.score)
Надеюсь, это будет полезно для людей, играющих с похожими проектами.