使用 WatchFilePattern 在 RunInference 中自动更新 ML 模型
此示例中的管道使用 RunInference PTransform
使用 TensorFlow 模型对图像进行推理。它使用一个 侧输入 PCollection
,它发出 ModelMetadata
来更新模型。
使用侧输入,您可以实时更新模型(它在 ModelHandler
配置对象中传递),即使 Beam 管道仍在运行。这可以通过利用 Beam 提供的模式之一(例如 WatchFilePattern
)来完成,也可以通过配置自定义侧输入 PCollection
来定义模型更新的逻辑。
有关侧输入的更多信息,请参阅 Apache Beam 编程指南中的 侧输入 部分。
此示例使用 WatchFilePattern
作为侧输入。WatchFilePattern
用于监视与 file_pattern
匹配的文件更新(基于时间戳)。它发出最新的 ModelMetadata
,该元数据在 RunInference PTransform
中使用,以自动更新 ML 模型,而无需停止 Beam 管道。
设置源
要读取图像名称,请使用 Pub/Sub 主题作为源。Pub/Sub 主题发出 UTF-8
编码的模型路径,该路径用于读取和预处理图像以运行推理。
用于图像分割的模型
为了本示例的目的,请使用保存在 HDF5 格式的 TensorFlow 模型。
预处理用于推理的图像
Pub/Sub 主题发出图像路径。我们需要读取和预处理图像才能将其用于 RunInference。read_image
函数用于读取用于推理的图像。
import io
from PIL import Image
from apache_beam.io.filesystems import FileSystems
import numpy
import tensorflow as tf
def read_image(image_file_name):
with FileSystems().open(image_file_name, 'r') as file:
data = Image.open(io.BytesIO(file.read())).convert('RGB')
img = data.resize((224, 224))
img = numpy.array(img) / 255.0
img_tensor = tf.cast(tf.convert_to_tensor(img[...]), dtype=tf.float32)
return img_tensor
现在,让我们深入了解管道代码。
管道步骤:
- 从 Pub/Sub 主题获取图像名称。
- 使用
read_image
函数读取和预处理图像。 - 将图像传递给 RunInference
PTransform
。RunInference 以model_handler
和model_metadata_pcoll
作为输入参数。
对于 model_handler
,我们使用 TFModelHandlerTensor。
from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerTensor
# initialize TFModelHandlerTensor with a .h5 model saved in a directory accessible by the pipeline.
tf_model_handler = TFModelHandlerTensor(model_uri='gs://<your-bucket>/<model_path.h5>')
model_metadata_pcoll
是 RunInference PTransform
的 侧输入 PCollection
。此侧输入用于更新 model_handler
中的模型,而无需停止 beam 管道。我们将使用 WatchFilePattern
作为侧输入来监视与 .h5
文件匹配的 glob 模式。
model_metadata_pcoll
期望一个与 AsSingleton 兼容的 ModelMetadata
的 PCollection
。由于管道使用 WatchFilePattern
作为侧输入,因此它将负责窗口化并将输出包装到 ModelMetadata
中。
在管道开始处理数据后,当您看到 RunInference PTransform
发出的某些输出时,将与 file_pattern
匹配的 .h5
TensorFlow
模型上传到 Google Cloud Storage 存储桶。RunInference 将使用 WatchFilePattern
作为侧输入来更新 TFModelHandlerTensor
的 model_uri
。
注意:侧输入更新频率是非确定性的,更新之间的间隔可能更长。
import apache_beam as beam
from apache_beam.ml.inference.utils import WatchFilePattern
from apache_beam.ml.inference.base import RunInference
with beam.Pipeline() as pipeline:
file_pattern = 'gs://<your-bucket>/*.h5'
pubsub_topic = '<topic_emitting_image_names>'
side_input_pcoll = (
pipeline
| "FilePatternUpdates" >> WatchFilePattern(file_pattern=file_pattern))
images_pcoll = (
pipeline
| "ReadFromPubSub" >> beam.io.ReadFromPubSub(topic=pubsub_topic)
| "DecodeBytes" >> beam.Map(lambda x: x.decode('utf-8'))
| "PreProcessImage" >> beam.Map(read_image)
)
inference_pcoll = (
images_pcoll
| "RunInference" >> RunInference(
model_handler=tf_model_handler,
model_metadata_pcoll=side_input_pcoll))
后处理 PredictionResult
对象
推理完成后,RunInference 输出一个 PredictionResult
对象,该对象包含 example
、inference
和 model_id
字段。model_id
用于标识用于运行推理的模型。
from apache_beam.ml.inference.base import PredictionResult
class PostProcessor(beam.DoFn):
"""
Process the PredictionResult to get the predicted label and model id used for inference.
"""
def process(self, element: PredictionResult) -> typing.Iterable[str]:
predicted_class = numpy.argmax(element.inference[0], axis=-1)
labels_path = tf.keras.utils.get_file(
'ImageNetLabels.txt',
'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt'
)
imagenet_labels = numpy.array(open(labels_path).read().splitlines())
predicted_class_name = imagenet_labels[predicted_class]
return predicted_class_name.title(), element.model_id
post_processor_pcoll = (inference_pcoll | "PostProcessor" >> PostProcessor())
运行管道
result = pipeline.run().wait_until_finish()
注意:ModelMetaData
对象的 model_name
将作为前缀附加到 RunInference PTransform
计算的 指标。
最后说明
当您将侧输入与 RunInference PTransform
一起使用以在不停止管道的情况下自动更新模型时,您可以将此示例用作模式。您可以在 GitHub 上看到 PyTorch 的类似示例。
上次更新时间:2024/10/31
您是否找到了您要查找的所有内容?
所有内容都有用且清晰吗?您想更改任何内容吗?请告诉我们!