使用 WatchFilePattern 在 RunInference 中自动更新 ML 模型

此示例中的管道使用 RunInference PTransform 使用 TensorFlow 模型对图像进行推理。它使用一个 侧输入 PCollection,它发出 ModelMetadata 来更新模型。

使用侧输入,您可以实时更新模型(它在 ModelHandler 配置对象中传递),即使 Beam 管道仍在运行。这可以通过利用 Beam 提供的模式之一(例如 WatchFilePattern)来完成,也可以通过配置自定义侧输入 PCollection 来定义模型更新的逻辑。

有关侧输入的更多信息,请参阅 Apache Beam 编程指南中的 侧输入 部分。

此示例使用 WatchFilePattern 作为侧输入。WatchFilePattern 用于监视与 file_pattern 匹配的文件更新(基于时间戳)。它发出最新的 ModelMetadata,该元数据在 RunInference PTransform 中使用,以自动更新 ML 模型,而无需停止 Beam 管道。

设置源

要读取图像名称,请使用 Pub/Sub 主题作为源。Pub/Sub 主题发出 UTF-8 编码的模型路径,该路径用于读取和预处理图像以运行推理。

用于图像分割的模型

为了本示例的目的,请使用保存在 HDF5 格式的 TensorFlow 模型。

预处理用于推理的图像

Pub/Sub 主题发出图像路径。我们需要读取和预处理图像才能将其用于 RunInference。read_image 函数用于读取用于推理的图像。

import io
from PIL import Image
from apache_beam.io.filesystems import FileSystems
import numpy
import tensorflow as tf

def read_image(image_file_name):
  with FileSystems().open(image_file_name, 'r') as file:
    data = Image.open(io.BytesIO(file.read())).convert('RGB')
  img = data.resize((224, 224))
  img = numpy.array(img) / 255.0
  img_tensor = tf.cast(tf.convert_to_tensor(img[...]), dtype=tf.float32)
  return img_tensor

现在,让我们深入了解管道代码。

管道步骤:

  1. 从 Pub/Sub 主题获取图像名称。
  2. 使用 read_image 函数读取和预处理图像。
  3. 将图像传递给 RunInference PTransform。RunInference 以 model_handlermodel_metadata_pcoll 作为输入参数。

对于 model_handler,我们使用 TFModelHandlerTensor

from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerTensor
# initialize TFModelHandlerTensor with a .h5 model saved in a directory accessible by the pipeline.
tf_model_handler = TFModelHandlerTensor(model_uri='gs://<your-bucket>/<model_path.h5>')

model_metadata_pcoll 是 RunInference PTransform侧输入 PCollection。此侧输入用于更新 model_handler 中的模型,而无需停止 beam 管道。我们将使用 WatchFilePattern 作为侧输入来监视与 .h5 文件匹配的 glob 模式。

model_metadata_pcoll 期望一个与 AsSingleton 兼容的 ModelMetadataPCollection。由于管道使用 WatchFilePattern 作为侧输入,因此它将负责窗口化并将输出包装到 ModelMetadata 中。

在管道开始处理数据后,当您看到 RunInference PTransform 发出的某些输出时,将与 file_pattern 匹配的 .h5 TensorFlow 模型上传到 Google Cloud Storage 存储桶。RunInference 将使用 WatchFilePattern 作为侧输入来更新 TFModelHandlerTensormodel_uri

注意:侧输入更新频率是非确定性的,更新之间的间隔可能更长。

import apache_beam as beam
from apache_beam.ml.inference.utils import WatchFilePattern
from apache_beam.ml.inference.base import RunInference
with beam.Pipeline() as pipeline:

  file_pattern = 'gs://<your-bucket>/*.h5'
  pubsub_topic = '<topic_emitting_image_names>'

  side_input_pcoll = (
    pipeline
    | "FilePatternUpdates" >> WatchFilePattern(file_pattern=file_pattern))

  images_pcoll = (
    pipeline
    | "ReadFromPubSub" >> beam.io.ReadFromPubSub(topic=pubsub_topic)
    | "DecodeBytes" >> beam.Map(lambda x: x.decode('utf-8'))
    | "PreProcessImage" >> beam.Map(read_image)
  )

  inference_pcoll = (
    images_pcoll
    | "RunInference" >> RunInference(
    model_handler=tf_model_handler,
    model_metadata_pcoll=side_input_pcoll))

后处理 PredictionResult 对象

推理完成后,RunInference 输出一个 PredictionResult 对象,该对象包含 exampleinferencemodel_id 字段。model_id 用于标识用于运行推理的模型。

from apache_beam.ml.inference.base import PredictionResult

class PostProcessor(beam.DoFn):
  """
  Process the PredictionResult to get the predicted label and model id used for inference.
  """
  def process(self, element: PredictionResult) -> typing.Iterable[str]:
    predicted_class = numpy.argmax(element.inference[0], axis=-1)
    labels_path = tf.keras.utils.get_file(
        'ImageNetLabels.txt',
        'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt'
    )
    imagenet_labels = numpy.array(open(labels_path).read().splitlines())
    predicted_class_name = imagenet_labels[predicted_class]
    return predicted_class_name.title(), element.model_id

post_processor_pcoll = (inference_pcoll | "PostProcessor" >> PostProcessor())

运行管道

result = pipeline.run().wait_until_finish()

注意ModelMetaData 对象的 model_name 将作为前缀附加到 RunInference PTransform 计算的 指标

最后说明

当您将侧输入与 RunInference PTransform 一起使用以在不停止管道的情况下自动更新模型时,您可以将此示例用作模式。您可以在 GitHub 上看到 PyTorch 的类似示例。