首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >如何在GCP AI平台上使用TFRecord文件进行批量预测?

如何在GCP AI平台上使用TFRecord文件进行批量预测?
EN

Stack Overflow用户
提问于 2020-09-18 17:28:52
回答 1查看 390关注 0票数 0

TL;DR谷歌云AI平台在进行批量预测时如何解压TFRecord文件?

我已经在Google Cloud AI平台上部署了一个经过训练的Keras模型,但我在批量预测的文件格式方面遇到了问题。为了进行训练,我使用tf.data.TFRecordDataset来读取TFRecord的列表,如下所示,一切都很好。

代码语言:javascript
运行
复制
def unpack_tfrecord(record):
    parsed = tf.io.parse_example(record, {
        'chunk': tf.io.FixedLenFeature([128, 2, 3], tf.float32),  # Input
        'class': tf.io.FixedLenFeature([2], tf.int64),            # One-hot classification (binary)
    })

    return (parsed['chunk'], parsed['class'])

files = [str(p) for p in training_chunks_path.glob('*.tfrecord')]
dataset = tf.data.TFRecordDataset(files).batch(32).map(unpack_tfrecord)
model.fit(x=dataset, epochs=train_epochs)
tf.saved_model.save(model, model_save_path)

我将保存的模型上传到云存储中,并在AI平台中创建一个新模型。https://cloud.google.com/ai-platform/prediction/docs/overview#prediction_input_data平台文档中写道:“gcloud工具批量支持JSON实例字符串的文本文件或者TFRecord文件(可以压缩)”( AI )。但是,当我提供一个TFRecord文件时,我得到了错误:

代码语言:javascript
运行
复制
("'utf-8' codec can't decode byte 0xa4 in position 1: invalid start byte", 8)

我的TFRecord文件包含一堆Protobuf编码的tf.train.Example。我没有为AI平台提供unpack_tfrecord函数,所以我猜它不能正确地解压它是有道理的,但我有一个节点的想法,从这里开始。我对使用JSON格式不感兴趣,因为数据太大了。

EN

回答 1

Stack Overflow用户

发布于 2020-10-21 07:04:17

我不知道这是不是最好的方法,但是对于TF 2.x,你可以这样做:

代码语言:javascript
运行
复制
import tensorflow as tf

def make_serving_input_fn():
    # your feature spec
    feature_spec = {
        'chunk': tf.io.FixedLenFeature([128, 2, 3], tf.float32),  
        'class': tf.io.FixedLenFeature([2], tf.int64),
    }

    serialized_tf_examples = tf.keras.Input(
        shape=[], name='input_example_tensor', dtype=tf.string)

    examples = tf.io.parse_example(serialized_tf_examples, feature_spec)

    # any processing 
    processed_chunks = tf.map_fn(
        <PROCESSING_FN>, 
        examples['chunk'], # ?
        dtype=tf.float32)

    return tf.estimator.export.ServingInputReceiver(
        features={<MODEL_FIRST_LAYER_NAME>: processed_chunks},
        receiver_tensors={"input_example_tensor": serialized_tf_examples}
    )


estimator = tf.keras.estimator.model_to_estimator(
    keras_model=model,
    model_dir=<ESTIMATOR_SAVE_DIR>)

estimator.export_saved_model(
    export_dir_base=<WORKING_DIR>,
    serving_input_receiver_fn=make_serving_input_fn)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/63953040

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档