使用共享对象缓存数据

缓存是一个软件组件,它存储数据,以便将来对该数据的请求可以更快地得到服务。要访问缓存,您可以使用侧输入、有状态的 DoFn 以及对外部服务的调用。Python SDK 在共享模块中提供了另一种选择。这种选择可能比侧输入更节省内存,比有状态的 DoFn 更简单,比调用外部服务性能更高,因为它不需要为每个元素或元素捆绑访问外部服务。有关使用 Beam SDK 缓存数据的策略的更多详细信息,请参阅 2022 Beam Summit 上的会议 使用 Beam SDK 在 Dataflow 中缓存数据的策略

本页上的示例演示了如何使用 shared 模块Shared 类来丰富有界和无界 PCollection 对象中的元素。示例中使用了两个数据集:ordercustomer。订单记录包含客户 ID,通过映射客户记录向其添加了客户属性。

在批处理管道上创建缓存

在此示例中,客户缓存作为字典加载到 EnrichOrderFnsetup 方法中。缓存用于向订单记录添加客户属性。由于 Python 字典不支持弱引用,而 Shared 对象封装了对共享资源单例实例的弱引用的封装,因此创建一个包装类。

# The wrapper class is needed for a dictionary, because it does not support weak references.
class WeakRefDict(dict):
    pass

class EnrichOrderFn(beam.DoFn):
    def __init__(self):
        self._customers = {}
        self._shared_handle = shared.Shared()

    def setup(self):
        # setup is a good place to initialize transient in-memory resources.
        self._customer_lookup = self._shared_handle.acquire(self.load_customers)

    def load_customers(self):
        self._customers = expensive_remote_call_to_load_customers()
        return WeakRefDict(self._customers)

    def process(self, element):
        attr = self._customer_lookup.get(element["customer_id"], {})
        yield {**element, **attr}

在流式管道上创建缓存并定期更新

由于客户缓存假设会随着时间的推移而发生变化,因此您需要定期刷新它。要重新加载共享对象,请更改 acquire 方法的 tag 参数。在此示例中,刷新是在 start_bundle 方法中实现的,该方法将当前标记值与与现有共享对象关联的值进行比较。set_tag 方法返回在最大陈旧时间段内相同的标记值。因此,如果标记值大于现有标记值,它将触发客户缓存的刷新。

# The wrapper class is needed for a dictionary, because it does not support weak references.
class WeakRefDict(dict):
    pass

class EnrichOrderFn(beam.DoFn):
    def __init__(self):
        self._max_stale_sec = 60
        self._customers = {}
        self._shared_handle = shared.Shared()

    def setup(self):
        # setup is a good place to initialize transient in-memory resources.
        self._customer_lookup = self._shared_handle.acquire(
            self.load_customers, self.set_tag()
        )

    def set_tag(self):
        # A single tag value is returned within a period, which is upper-limited by the max stale second.
        current_ts = datetime.now().timestamp()
        return current_ts - (current_ts % self._max_stale_sec)

    def load_customers(self):
        # Assign the tag value of the current period for comparison.
        self._customers = expensive_remote_call_to_load_customers(tag=self.set_tag())
        return WeakRefDict(self._customers)

    def start_bundle(self):
        # Update the shared object when the current tag value exceeds the existing value.
        if self.set_tag() > self._customers["tag"]:
            self._customer_lookup = self._shared_handle.acquire(
                self.load_customers, self.set_tag()
            )

    def process(self, element):
        attr = self._customer_lookup.get(element["customer_id"], {})
        yield {**element, **attr}