我用 modelzoo/mmoe 跑训练,但是我还需要生成 saved_model,请问该怎么修改?[阿里云机器学习PAI]

问一个机器学习PAI问题,我用 modelzoo/mmoe 跑训练,但是我还需要生成 saved_model,请问该怎么修改?

现在我是这么写的,就在 MonitoredTrainingSession 调用之后:
inputInfo = {key : tf.saved_model.utils.build_tensor_info(val) for key, val in model._feature.items()}

inputInfo = {“input”: tf.saved_model.utils.build_tensor_info(model.input_emb)}

outputInfo = {“output”: tf.saved_model.utils.build_tensor_info(model.output)}
print(“inputInfo: “, inputInfo)
print(“outputInfo: “, outputInfo)

prediction_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs = inputInfo,
outputs = outputInfo,
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))

sess.graph._unsafe_unfinalize()
builder = tf.saved_model.builder.SavedModelBuilder(export_path)
builder.add_meta_graph_and_variables(
sess._tf_sess(), [tf.saved_model.tag_constants.SERVING],
signature_def_map = {‘predict’: prediction_signature,})
builder.save()

但不管我怎么设置 input 和 output,都会报错:
tensorflow.python.framework.errors_impl.InternalError: Output 3 of type int32 does not match declared output type int64 for node {{node prefetch_2/TensorBufferTake}}

有人知道怎么解决吗?多谢了~

「点点赞赏,手留余香」

    还没有人赞赏,快来当第一个赞赏的人吧!
=====这是一个广告位,招租中,联系qq 78315851====