Невозможно сохранить отсеваемый слой в модели тензорного потока с помощью tf2onnx.convert.from_keras (и загрузить его с помощью onnx)
Опишите ошибку
Я хочу сохранить модель, которую я обучил с помощью tensorflow и которая содержит слои Dropout. Мне нужны эти слои при выводе, чтобы использовать их в training_mode для измерения эпистемической неопределенности моей модели. Я думаю, что функция convert.from_keras tf2onnx не сохраняет слой дропаута.
Системная информация
Windows 11
Tensorflow Version: 2.7.0
Python version: 3.8.10
Чтобы воспроизвести Вот пример кода, который я использую:
import tf2onnx
import onnx
from tensorflow.keras.layers import ( Dense, Input, Dropout,)
from tensorflow.keras.models import Model
input_layer = Input(shape=(96))
h = Dense(128, activation="relu")(input_layer)
h = Dropout(0.1)(h)
h = Dense(128, activation="relu")(h)
h = Dropout(0.1)(h)
h = Dense(12, activation=None)(h)
model = Model(input_layer, h, name=name)
model_onnx, _ =tf2onnx.convert.from_keras(model, output_path='test_dropout.onnx')
test = onnx.load('test_dropout.onnx')
test.graph.node
В отображаемом узле я не упоминаю об отсеве. Дело в том, что я хочу использовать слой отсева в training_mode во время вывода, чтобы использовать отсев MC для измерения эпистемической неопределенности.
Есть ли что-то, что я делаю неправильно здесь? Или это просто невозможно сделать?
Результат
test.graph.node
является:
[input: "input_1"
input: "model/dense/MatMul/ReadVariableOp:0"
output: "model/dense/MatMul:0"
name: "model/dense/MatMul"
op_type: "MatMul"
, input: "model/dense/MatMul:0"
output: "model/dense/Relu:0"
name: "model/dense/Relu"
op_type: "Relu"
, input: "model/dense/Relu:0"
input: "model/dense_1/MatMul/ReadVariableOp:0"
output: "model/dense_1/MatMul:0"
name: "model/dense_1/MatMul"
op_type: "MatMul"
, input: "model/dense_1/MatMul:0"
output: "model/dense_1/Relu:0"
name: "model/dense_1/Relu"
op_type: "Relu"
, input: "model/dense_1/Relu:0"
input: "model/dense_2/MatMul/ReadVariableOp:0"
output: "dense_2"
name: "model/dense_2/MatMul"
op_type: "MatMul"
]