Beam 中的单元测试:一个有主见的指南

测试仍然是软件工程中最基本的部分之一。在这篇博文中,我们重点介绍了 Apache Beam 提供的一些用于测试的结构。我们涵盖了一套编写数据管道单元测试的最佳实践。本文不包括集成测试,您需要单独编写这些测试。本文中的所有代码段都包含在此笔记本中。此外,要查看体现最佳实践的测试,请查看Beam 初学者项目,其中包含体现最佳实践的测试。

最佳实践

在测试 Beam 管道时,我们建议遵循以下最佳实践

  1. 不要为 Beam 库中已经支持的连接器编写单元测试,例如 ReadFromBigQueryWriteToText。这些连接器已经在 Beam 的测试套件中进行了测试,以确保其功能正确。它们会给单元测试带来不必要的成本和依赖关系。

  2. 在将函数与 MapFlatMapFilter 一起使用时,请确保该函数经过充分测试。您可以假设在使用 Map(your_function) 时,您的函数将按预期工作。

  3. 对于更复杂的转换,如 ParDo、侧输入、时间戳检查等,请将整个转换视为一个单元并对其进行测试。

  4. 如果需要,使用模拟来模拟 DoFn 中可能存在的任何 API 调用。模拟的目的是对您的功能进行广泛测试,即使此测试需要从 API 调用获取特定响应。

    1. 确保将 API 调用模块化到单独的函数中,而不是在 DoFn 中直接进行 API 调用。此步骤在模拟外部 API 调用时提供更简洁的体验。

示例 1

以下管道用作示例。假设函数 median_house_value_per_bedroom 在代码中的其他地方进行了单元测试,那么您不必编写单独的单元测试来测试此管道中该函数的上下文。您可以相信 Map 原语按预期工作(这说明了先前提到的要点 #2)。

# The following code computes the median house value per bedroom.

with beam.Pipeline() as p1:
   result = (
       p1
       | ReadFromText("/content/sample_data/california_housing_test.csv",skip_header_lines=1)
       | beam.Map(median_house_value_per_bedroom)
       | WriteToText("/content/example2")
   )

示例 2

以下函数用作示例。函数 median_house_value_per_bedroommultiply_by_factor 在其他地方进行了测试,但整个管道(由复合转换组成)没有进行测试。

with beam.Pipeline() as p2:
    result = (
        p2
        | ReadFromText("/content/sample_data/california_housing_test.csv",skip_header_lines=1)
        | beam.Map(median_house_value_per_bedroom)
        | beam.Map(multiply_by_factor)
        | beam.CombinePerKey(sum)
        | WriteToText("/content/example3")
    )

对先前代码的最佳实践是在 ReadFromTextWriteToText 之间的所有函数中创建一个转换。此步骤将转换逻辑与 I/O 分开,使您能够对转换逻辑进行单元测试。以下示例是对先前代码的重构

def transform_data_set(pcoll):
  return (pcoll
          | beam.Map(median_house_value_per_bedroom)
          | beam.Map(multiply_by_factor)
          | beam.CombinePerKey(sum))

# Define a new class that inherits from beam.PTransform.
class MapAndCombineTransform(beam.PTransform):
  def expand(self, pcoll):
    return transform_data_set(pcoll)

with beam.Pipeline() as p2:
   result = (
       p2
       | ReadFromText("/content/sample_data/california_housing_test.csv",skip_header_lines=1)
       | MapAndCombineTransform()
       | WriteToText("/content/example3")
   )

此代码显示了先前示例的相应单元测试

import unittest
import apache_beam as beam
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that, equal_to


class TestBeam(unittest.TestCase):

# This test corresponds to example 3, and is written to confirm the pipeline works as intended.
  def test_transform_data_set(self):
    expected=[(1, 10570.185786231425), (2, 13.375337533753376), (3, 13.315649867374006)]
    input_elements = [
      '-122.050000,37.370000,27.000000,3885.000000,661.000000,1537.000000,606.000000,6.608500,344700.000000',
      '121.05,99.99,23.30,39.5,55.55,41.01,10,34,74.30,91.91',
      '122.05,100.99,24.30,40.5,56.55,42.01,11,35,75.30,92.91',
      '-120.05,39.37,29.00,4085.00,681.00,1557.00,626.00,6.8085,364700.00'
    ]
    with beam.Pipeline() as p2:
      result = (
                p2
                | beam.Create(input_elements)
                | beam.Map(MapAndCombineTransform())
        )
      assert_that(result,equal_to(expected))

示例 3

假设我们编写了一个管道,该管道从 JSON 文件中读取数据,将其通过一个自定义函数(该函数进行外部 API 调用以进行解析),然后将其写入自定义目标(例如,如果我们需要进行一些自定义数据格式化,以便为下游应用程序准备数据)。

管道具有以下结构

# The following packages are used to run the example pipelines.

import apache_beam as beam
from apache_beam.io import ReadFromText, WriteToText
from apache_beam.options.pipeline_options import PipelineOptions

class MyDoFn(beam.DoFn):
  def process(self,element):
          returned_record = MyApiCall.get_data("http://my-api-call.com")
          if len(returned_record)!=10:
            raise ValueError("Length of record does not match expected length")
          yield returned_record

with beam.Pipeline() as p3:
  result = (
          p3
          | ReadFromText("/content/sample_data/anscombe.json")
          | beam.ParDo(MyDoFn())
          | WriteToText("/content/example1")
  )

此测试检查 API 响应是否为错误长度的记录,如果测试失败,则会抛出预期的错误。

!pip install mock  # Install the 'mock' module.
# Import the mock package for mocking functionality.
from unittest.mock import Mock,patch
# from MyApiCall import get_data
import mock


# MyApiCall is a function that calls get_data to fetch some data by using an API call.
@patch('MyApiCall.get_data')
def test_error_message_wrong_length(self, mock_get_data):
 response = ['field1','field2']
 mock_get_data.return_value = Mock()
 mock_get_data.return_value.json.return_value=response

 input_elements = ['-122.050000,37.370000,27.000000,3885.000000,661.000000,1537.000000,606.000000,6.608500,344700.000000'] #input length 9
 with self.assertRaisesRegex(ValueError,
                             "Length of record does not match expected length'"):
     p3 = beam.Pipeline()
     result = p3 | beam.create(input_elements) | beam.ParDo(MyDoFn())
     result

其他测试最佳实践

  1. 测试您引发的所有错误消息。
  2. 涵盖数据中可能存在的任何边缘情况。
  3. 示例 1 本来可以使用 lambda 函数来编写 beam.Map 步骤,而不是 beam.Map(median_house_value_per_bedroom)
beam.Map(lambda x: x.strip().split(',')) | beam.Map(lambda x: float(x[8])/float(x[4])

使用 beam.Map(median_house_value_per_bedroom) 将 lambda 函数分离到辅助函数中是推荐的更易于测试代码的方法,因为对该函数的更改将被模块化。

  1. 使用 assert_that 语句来确保 PCollection 值正确匹配,如前面的示例所示。

有关 Beam 和 Dataflow 上测试的更多指导,请参阅Google Cloud 文档。有关 Beam 中单元测试的更多示例,请参阅 base_test.py 代码

特别感谢 Robert Bradshaw、Danny McCormick、XQ Hu、Surjit Singh 和 Rebecca Spzer,他们在完善本文中的想法方面提供了帮助。