从 Java SDK 使用 RunInference
此示例中的管道是用 Java 编写的,并从 Google Cloud Storage 读取输入数据。借助 PythonExternalTransform,调用了一个复合 Python 转换来执行预处理、后处理和推理。最后,数据在 Java 管道中被写回 Google Cloud Storage。
您可以在 Beam 存储库 中找到此示例中使用的代码。
NLP 模型和数据集
bert-base-uncased
自然语言处理 (NLP) 模型用于进行推理。该模型是开源的,可在 HuggingFace 上获得。此 BERT 模型用于根据句子的上下文预测句子的最后一个词。
我们还使用 IMDB 电影评论 数据集,这是一个可从 Kaggle 获得的开源数据集。
以下是预处理后数据的示例
文本 | 最后一个词 |
---|---|
一位评论者提到,在只看了 1 集 Oz 之后,你就会 [MASK] | 上瘾 |
一部很棒的小 [MASK] | 制作 |
所以我不是 Boll 作品的忠实粉丝,但话说回来,并不是很多人 [MASK] | 是 |
这是一部关于三个成为 [MASK] 的囚犯的精彩电影 | 著名 |
有些电影根本不应该 [MASK] | 重拍 |
凯伦·卡彭特的生平故事更多地展示了歌手凯伦·卡彭特复杂的 [MASK] | 人生 |
多语言推理管道
使用多语言管道时,您可以访问更多转换池。有关更多信息,请参阅 Apache Beam 编程指南中的 多语言管道。
自定义 Python 转换
除了运行推理之外,我们还需要对数据进行预处理和后处理。后处理数据可以解释输出。为了完成这三个任务,编写了一个单一的复合自定义 PTransform,每个任务都使用一个单元 DoFn 或 PTransform,如下面的代码片段所示
def expand(self, pcoll):
return (
pcoll
| 'Preprocess' >> beam.ParDo(self.Preprocess(self._tokenizer))
| 'Inference' >> RunInference(KeyedModelHandler(self._model_handler))
| 'Postprocess' >> beam.ParDo(self.Postprocess(
self._tokenizer))
)
首先,对数据进行预处理。在本例中,原始文本数据将被清理并标记化以用于 BERT 模型。所有这些步骤都在 Preprocess
DoFn 中运行。Preprocess
DoFn 将单个元素作为输入,并返回包含原始文本和标记化文本的列表。
然后使用预处理后的数据进行推理。这在 RunInference
PTransform 中完成,该 PTransform 已在 Apache Beam SDK 中提供。RunInference
PTransform 需要一个参数,一个模型处理程序。在本例中,使用了 KeyedModelHandler
,因为 Preprocess
DoFn 还会输出原始句子。您可以根据您的要求更改预处理方式。此模型处理程序在复合 PTransform 的以下初始化函数中定义
def __init__(self, model, model_path):
self._model = model
logging.info(f"Downloading {self._model} model from GCS.")
self._model_config = BertConfig.from_pretrained(self._model)
self._tokenizer = BertTokenizer.from_pretrained(self._model)
self._model_handler = self.PytorchModelHandlerKeyedTensorWrapper(
state_dict_path=(model_path),
model_class=BertForMaskedLM,
model_params={'config': self._model_config},
device='cuda:0')
使用了 PytorchModelHandlerKeyedTensorWrapper
,它是 PytorchModelHandlerKeyedTensor
模型处理程序的包装器。PytorchModelHandlerKeyedTensor
模型处理程序对 PyTorch 模型进行推理。由于从 BertTokenizer
生成的标记化字符串可能具有不同的长度,并且 stack() 要求张量大小相同,因此 PytorchModelHandlerKeyedTensorWrapper
将批次大小限制为 1。将 max_batch_size
限制为 1 表示 run_inference() 调用每个批次包含一个示例。以下代码显示了包装器的定义
class PytorchModelHandlerKeyedTensorWrapper(PytorchModelHandlerKeyedTensor):
def batch_elements_kwargs(self):
return {'max_batch_size': 1}
另一种方法是使所有张量具有相同的长度。此 示例 显示了如何执行此操作。
ModelConfig
和 ModelTokenizer
在初始化函数中加载。ModelConfig
用于定义模型架构,ModelTokenizer
用于标记化输入数据。以下两个参数用于这些任务
model
:用于推理的模型名称。在本例中,它是bert-base-uncased
。model_path
:用于推理的模型state_dict
的路径。在本例中,它是 Google Cloud Storage 存储桶的路径,其中存储着state_dict
。
这两个参数都在 Java PipelineOptions
中指定。
最后,我们在 Postprocess
DoFn 中对模型预测进行后处理。Postprocess
DoFn 返回原始文本、句子的最后一个词和预测的词。
将 Python 代码编译成包
自定义 Python 代码需要写入本地包并编译成 tarball。此包随后可由 Java 管道使用。以下示例显示了如何将 Python 包编译成 tarball
pip install --upgrade build && python -m build --sdist
为了运行这个程序,需要一个 `setup.py` 文件。压缩包的路径将作为参数在 Java pipeline 的管道选项中使用。
运行 Java 管道
Java pipeline 在 MultiLangRunInference
类中定义。在这个管道中,数据从 Google Cloud Storage 读取,应用跨语言 Python 变换,最后将输出写入回 Google Cloud Storage。
PythonExternalTransform
用于将跨语言 Python 变换注入到 Java pipeline 中。PythonExternalTransform
接收一个字符串参数,该参数是 Python 变换的完全限定名。
withKwarg
方法用于指定 Python 变换所需的参数。在本例中,指定了 `model` 和 `model_path` 参数。这些参数用于复合 Python PTransform 的初始化函数中,如第一部分所示。最后,withExtraPackages
方法用于指定 Python 变换所需的额外 Python 依赖项。在本例中,使用 `local_packages` 列表,该列表包含 Python 需求和已编译压缩包的路径。
要运行 pipeline,请使用以下命令:
mvn compile exec:java -Dexec.mainClass=org.apache.beam.examples.MultiLangRunInference \
-Dexec.args="--runner=DataflowRunner \
--project=$GCP_PROJECT\
--region=$GCP_REGION \
--gcpTempLocation=gs://$GCP_BUCKET/temp/ \
--inputFile=gs://$GCP_BUCKET/input/imdb_reviews.csv \
--outputFile=gs://$GCP_BUCKET/output/ouput.txt \
--modelPath=gs://$GCP_BUCKET/input/bert-model/bert-base-uncased.pth \
--modelName=$MODEL_NAME \
--localPackage=$LOCAL_PACKAGE" \
-Pdataflow-runner
指定了标准的 Google Cloud 和 Runner 参数。inputFile
和 outputFile
参数用于指定输入和输出文件。modelPath
和 modelName
自定义参数传递给 PythonExternalTransform
。最后,localPackage
参数用于指定已编译 Python 包的路径,该包包含自定义 Python 变换。
最后说明
使用本例作为创建其他自定义跨语言推理管道的基础。您也可以使用其他 SDK。例如,Go 也具有一个可以进行跨语言变换的包装器。有关更多信息,请参阅 Apache Beam 编程指南中的 在 Go pipeline 中使用跨语言变换。
本例中使用的完整代码可以在 GitHub 上找到。
最后更新时间:2024/10/31
您找到所需的所有内容了吗?
这些内容对您有用且清晰吗?您想更改什么?请告诉我们!