Обслуживающий клиент тензор потока не работает (grpc.framework.interfaces.face.face.AbortionError: AbortionError)

Я развертываю программу соответствия текста с текстом в докере со ссылкой на официальный веб-сайт, все этапы установки и этапы тестирования в порядке, включая состояние работающего сервера, а также OK, но у клиента есть проблема. Позвольте мне объяснить это подробно.

Это мой модельный график с четырьмя входами: введите описание изображения здесь

Это моя модель сервера:

import os import shutil

import tensorflow as tf from preprocess import Word2Vec, TestQA, WebQA from ABCNN import ABCNN from utils import build_path import numpy as np Max_len = 40 d0 = 300 ''' Loads the saved bcnn model, injects additional layers for the input transformation and export the model into protobuf format '''

# Command line arguments tf.app.flags.DEFINE_string('checkpoint_dir', './models/',
                           "Directory where to read training checkpoints.") tf.app.flags.DEFINE_string('output_dir', './models-export',
                           "Directory where to export the model.") tf.app.flags.DEFINE_integer('model_version', 1,
                            "Version number of the model.") FLAGS = tf.app.flags.FLAGS


def test(w, l2_reg, epoch, max_len, model_type, num_layers, data_type, classifier, num_classes):

    model_path = build_path("./models/", data_type, model_type, num_layers)

    model = ABCNN(s=max_len, w=w, l2_reg=l2_reg, model_type=model_type, num_classes=num_classes, num_layers=num_layers)
    with tf.Session() as sess:

        saver = tf.train.Saver()
        print(model_path + "-" + str(12))
        saver.restore(sess, model_path + "-" + str(12))
        x1 = tf.placeholder(tf.float32, shape=[None, d0, max_len])
        x2 = tf.placeholder(tf.float32, shape=[None, d0, max_len])
        y = tf.placeholder(tf.int32, shape=[None])
        features = tf.placeholder(tf.float32, shape=[None, 4]) #num_features = 4

        export_path = os.path.join(
            tf.compat.as_bytes(FLAGS.output_dir),
            tf.compat.as_bytes(str(FLAGS.model_version)))
        if os.path.exists(export_path):
            shutil.rmtree(export_path)

        # create model builder
        builder = tf.saved_model.builder.SavedModelBuilder(export_path)

        # create tensors info

        inputs = {
               "x1": tf.saved_model.utils.build_tensor_info(x1),
               "x2": tf.saved_model.utils.build_tensor_info(x2),
               "label": tf.saved_model.utils.build_tensor_info(y),
               "features": tf.saved_model.utils.build_tensor_info(features)

           }
        output = {"predict_score": tf.saved_model.utils.build_tensor_info(model.prediction)}
        # build prediction signature
        prediction_signature = (
            tf.saved_model.signature_def_utils.build_signature_def(
                inputs=inputs,
                outputs=output,
                method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
            )
        )

        # save the model
        legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
        builder.add_meta_graph_and_variables(
            sess, [tf.saved_model.tag_constants.SERVING],
            signature_def_map={
                'predict_score': prediction_signature
            },
            legacy_init_op=legacy_init_op)

        builder.save()

    print("Successfully exported BCNN model version '{}' into '{}'".format(
        FLAGS.model_version, FLAGS.output_dir))

def main(_):

    # default parameters
    params = {
        "ws": 4,
        "l2_reg": 0.0004,
        "epoch": 20,
        "max_len": 40,
        "model_type": "BCNN",
        "num_layers": 2,
        "num_classes": 2,
        "data_type": "WebQA",
        "classifier": "LR",
        # "word2vec": Word2Vec()
    }


    test(w=int(params["ws"]), l2_reg=float(params["l2_reg"]), epoch=int(params["epoch"]),
         max_len=int(params["max_len"]), model_type=params["model_type"],
         num_layers=int(params["num_layers"]), data_type=params["data_type"],
         classifier=params["classifier"], num_classes=params["num_classes"])

if __name__ == '__main__':
    tf.app.run()

И это мой клиент:

import sys sys.path.insert(0, \"./\")
# from tensorflow_serving_client.protos import predict_pb2, prediction_service_pb2_grpc from tensorflow_serving.apis import predict_pb2 from tensorflow_serving.apis import prediction_service_pb2 from grpc.beta import implementations
# import grpc import tensorflow as tf import numpy as np from tensorflow.python.framework import dtypes import time from preprocess import Word2Vec, MSRP, WikiQA,WebQA

im_name = "dir/files" if __name__ == '__main__':
    test_data = WebQA(word2vec=Word2Vec(), max_len=40)
    test_data.open_file(mode="test")

    s1s, s2s, labels, features = test_data.only_batch()

    s1s = s1s[0]
    s2s = s2s[0]
    labels = labels[0]
    features = features[0]
    print(s1s.shape)
    print(s2s.shape)
    print(labels.shape)
    print(features.shape)

    start_time = time.time()

    channel = implementations.insecure_channel("localhost", 9000)
    stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
    request = predict_pb2.PredictRequest()

    request.model_spec.name = "bcnn"
    request.model_spec.signature_name = "predict_score"

    request.inputs["x1"].CopyFrom(tf.contrib.util.make_tensor_proto(s1s, dtype=dtypes.float32))
    request.inputs["x2"].CopyFrom(tf.contrib.util.make_tensor_proto(s2s, dtype=dtypes.float32))
    request.inputs["label"].CopyFrom(tf.contrib.util.make_tensor_proto(labels, dtype=dtypes.int32))
    request.inputs["features"].CopyFrom(tf.contrib.util.make_tensor_proto(features, dtype=dtypes.float32))

    response = stub.Predict(request, 10.0)
    results = {}
    for key in response.outputs:
        tensor_proto = response.outputs[key]
        nd_array = tf.contrib.util.make_ndarray(tensor_proto)
        results[key] = nd_array
    print("cost %ss to predict: " % (time.time() - start_time))
    print(results["predict_score"])

И это ошибка:

root@15bb1c2766e3:/ABCNN-master# python3 client.py
(1, 300, 40)
(1, 300, 40)
(1,)
(1, 4)
WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Use the retry module or similar alternatives.
Traceback (most recent call last):
  File "/usr/local/lib/python3.5/dist-packages/grpc/beta/_client_adaptations.py", line 193, in _blocking_unary_unary
    credentials=_credentials(protocol_options))
  File "/usr/local/lib/python3.5/dist-packages/grpc/_channel.py", line 487, in __call__
    return _end_unary_response_blocking(state, call, False, deadline)
  File "/usr/local/lib/python3.5/dist-packages/grpc/_channel.py", line 437, in _end_unary_response_blocking
    raise _Rendezvous(state, None, None, deadline)
grpc._channel._Rendezvous: <_Rendezvous of RPC that terminated with (StatusCode.INVALID_ARGUMENT, You must feed a value for placeholder tensor 'x2' with dtype float and shape [?,300,40]
     [[Node: x2 = Placeholder[dtype=DT_FLOAT, shape=[?,300,40], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]])>

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "client.py", line 50, in <module>
    response = stub.Predict(request, 10.0)
  File "/usr/local/lib/python3.5/dist-packages/grpc/beta/_client_adaptations.py", line 309, in __call__
    self._request_serializer, self._response_deserializer)
  File "/usr/local/lib/python3.5/dist-packages/grpc/beta/_client_adaptations.py", line 195, in _blocking_unary_unary
    raise _abortion_error(rpc_error_call)
grpc.framework.interfaces.face.face.AbortionError: AbortionError(code=StatusCode.INVALID_ARGUMENT, details="You must feed a value for placeholder tensor 'x2' with dtype float and shape [?,300,40]
     [[Node: x2 = Placeholder[dtype=DT_FLOAT, shape=[?,300,40], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]")
ns.py", line 309, in __call__er# on3.5/dist-packages/grpc/beta/_client_adaptatio

Я напечатал форму функций: (1,4)почему до сих пор говорится, что я ввел неправильные тензорные "черты"? Я не могу получить это.

Спасибо за любые предложения заранее.

0 ответов

Другие вопросы по тегам