在线聚类示例

在线聚类示例演示了如何设置一个实时聚类管道,该管道可以从 Pub/Sub 读取文本,使用语言模型将文本转换为嵌入,并使用 BIRCH 对文本进行聚类。

用于聚类的数据集

此示例使用了一个名为 emotion 的数据集,其中包含 20,000 条英文 Twitter 消息,包含 6 种基本情绪:愤怒、恐惧、喜悦、爱、悲伤和惊讶。该数据集有三个拆分:训练、验证和测试。因为它包含文本和数据集的类别(类别),所以它是一个有监督的数据集。要访问此数据集,请使用 Hugging Face 数据集页面

以下文本显示了来自数据集训练拆分的示例

文本情绪类型
我正在抽空发帖,我觉得自己贪得无厌,错了愤怒
我总是怀念壁炉,我知道它还在房产内
我一直服用 10 毫克,比推荐剂量多很多倍,我很快就睡着了,但我感觉也怪怪的恐惧
去丹麦的乘船旅行喜悦
我觉得自己,你知道,基本上就像科幻小说中的一个赝品悲伤
我开始每周出现几次,感觉幻觉折磨着我,移动的人和物体,声音和振动恐惧

聚类算法

对于推文的聚类,我们使用一种称为 BIRCH 的增量聚类算法。它代表使用层次结构的平衡迭代缩减和聚类,它是一种无监督的数据挖掘算法,用于对特别大的数据集执行层次聚类。BIRCH 的一个优点是它能够增量且动态地对传入的多维度度量数据点进行聚类,以尝试在给定的一组资源(内存和时间限制)内生成最佳质量的聚类。

摄取到 Pub/Sub

该示例首先将数据摄取到 Pub/Sub 中,以便我们可以在聚类时从 Pub/Sub 读取推文。Pub/Sub 是一种消息服务,用于在应用程序和服务之间交换事件数据。流式分析和数据集成管道使用 Pub/Sub 来摄取和分发数据。

您可以在 GitHub 中找到将数据摄取到 Pub/Sub 的完整示例代码。

摄取管道的文件结构在以下图表中显示

write_data_to_pubsub_pipeline/
├── pipeline/
│   ├── __init__.py
│   ├── options.py
│   └── utils.py
├── __init__.py
├── config.py
├── main.py
└── setup.py

pipeline/utils.py 包含用于加载情绪数据集和两个 beam.DoFn 的代码,这些代码用于数据转换。

pipeline/options.py 包含用于配置 Dataflow 管道的管道选项。

config.py 定义了多次使用的某些变量,例如 GCP PROJECT_ID 和 NUM_WORKERS。

setup.py 定义了管道运行所需的软件包和要求。

main.py 包含管道代码和一些用于运行管道的附加函数。

运行管道

首先,安装所需的软件包。

  1. 在您机器上的本地:python main.py
  2. 在 GCP 上用于 Dataflow:python main.py --mode cloud

write_data_to_pubsub_pipeline 包含四个不同的转换

  1. 使用 Hugging Face 数据集加载情绪数据集(为简单起见,我们从三个类别而不是六个类别中获取样本)。
  2. 将每段文本与一个唯一的标识符 (UID) 关联。
  3. 将文本转换为 Pub/Sub 预期格式。
  4. 将格式化的消息写入 Pub/Sub。

对流数据进行聚类

将数据摄取到 Pub/Sub 后,检查第二个管道,我们在其中从 Pub/Sub 读取流式消息,使用语言模型将文本转换为嵌入,并使用 BIRCH 对嵌入进行聚类。

您可以在 GitHub 中找到所有上述步骤的完整示例代码。

clustering_pipeline 的文件结构是

clustering_pipeline/
├── pipeline/
│   ├── __init__.py
│   ├── options.py
│   └── transformations.py
├── __init__.py
├── config.py
├── main.py
└── setup.py

pipeline/transformations.py 包含用于管道中使用的不同 beam.DoFn 的代码。

pipeline/options.py 包含用于配置 Dataflow 管道的管道选项。

config.py 定义了多次使用的变量,例如 Google Cloud PROJECT_ID 和 NUM_WORKERS。

setup.py 定义了管道运行所需的软件包和要求。

main.py 包含管道代码和一些用于运行管道的附加函数。

运行管道

安装所需的软件包并将数据推送到 Pub/Sub。

  1. 在您机器上的本地:python main.py
  2. 在 GCP 上用于 Dataflow:python main.py --mode cloud

管道可以细分为以下步骤

  1. 从 Pub/Sub 读取消息。
  2. 将 Pub/Sub 消息转换为 PCollection,其中键是 UID,值为 Twitter 文本。
  3. 使用标记器将文本编码为可由 Transformer 读取的标记 ID 整数。
  4. 使用 RunInference 从基于 Transformer 的语言模型获取向量嵌入。
  5. 规范化嵌入以进行聚类。
  6. 使用有状态处理执行 BIRCH 聚类。
  7. 打印分配给集群的文本。

以下代码显示了管道的头两个步骤,其中从 Pub/Sub 读取消息并将其转换为字典。

    docs = (
        pipeline
        | "Read from PubSub"
        >> ReadFromPubSub(subscription=cfg.SUBSCRIPTION_ID, with_attributes=True)
        | "Decode PubSubMessage" >> beam.ParDo(Decode())
    )

接下来的部分将检查三个重要的管道步骤

  1. 标记化文本。
  2. 将分词后的文本输入基于 transformer 的语言模型以获取嵌入。
  3. 使用有状态处理执行聚类。

从语言模型获取嵌入

为了对文本数据进行聚类,你需要将文本映射到适合统计分析的数值向量。本示例使用了一个名为sentence-transformers/stsb-distilbert-base/stsb-distilbert-base的基于 transformer 的语言模型。它将句子和段落映射到一个 768 维的稠密向量空间,你可以将其用于聚类或语义搜索等任务。

由于语言模型期望的是分词后的输入而不是原始文本,因此首先要对文本进行分词。分词是一个预处理任务,它将文本转换为可供模型输入以获得预测的结果。

    normalized_embedding = (
        docs
        | "Tokenize Text" >> beam.Map(tokenize_sentence)

这里,tokenize_sentence 是一个函数,它接收一个包含文本和 ID 的字典,对文本进行分词,并返回一个元组 (text, id) 和分词后的输出。

然后将分词后的输出传递给语言模型以获取嵌入。为了从语言模型中获取嵌入,我们使用 Apache Beam 的 RunInference() 函数。

    | "Get Embedding" >> RunInference(KeyedModelHandler(model_handler))

为了获得更好的聚类结果,在获取每条 Twitter 文本的嵌入后,对嵌入进行归一化。

    | "Normalize Embedding" >> beam.ParDo(NormalizeEmbedding())

StatefulOnlineClustering

由于数据是流式传输的,因此你需要使用迭代聚类算法,例如 BIRCH。由于该算法是迭代的,因此你需要一个机制来存储先前状态,以便在 Twitter 文本到达时可以对其进行更新。有状态处理使 DoFn 能够拥有持久状态,可以在处理每个元素时进行读写。有关有状态处理的更多信息,请参阅 使用 Apache Beam 进行有状态处理

在本示例中,每次从 Pub/Sub 读取新消息时,都会检索聚类模型的现有状态,对其进行更新,然后将其写回状态。

    clustering = (
        normalized_embedding
        | "Map doc to key" >> beam.Map(lambda x: (1, x))
        | "StatefulClustering using Birch" >> beam.ParDo(StatefulOnlineClustering())
    )

由于 BIRCH 不支持并行化,因此你需要确保只有一个 worker 执行所有有状态处理。为此,请使用 Beam.Map 将每条文本与同一个键 1 关联。

StatefulOnlineClustering 是一个 DoFn,它接收文本的嵌入并更新聚类模型。为了存储状态,它使用 ReadModifyWriteStateSpec 状态对象,该对象充当存储容器。

class StatefulOnlineClustering(beam.DoFn):

    BIRCH_MODEL_SPEC = ReadModifyWriteStateSpec("clustering_model", PickleCoder())
    DATA_ITEMS_SPEC = ReadModifyWriteStateSpec("data_items", PickleCoder())
    EMBEDDINGS_SPEC = ReadModifyWriteStateSpec("embeddings", PickleCoder())
    UPDATE_COUNTER_SPEC = ReadModifyWriteStateSpec("update_counter", PickleCoder())

本示例声明了四个不同的 ReadModifyWriteStateSpec 对象

这些 ReadModifyWriteStateSpec 对象作为附加参数传递给 process 函数。当有新闻项目到达时,我们将检索不同对象的现有状态,对其进行更新,然后将其写回作为持久共享状态。

def process(
    self,
    element,
    model_state=beam.DoFn.StateParam(BIRCH_MODEL_SPEC),
    collected_docs_state=beam.DoFn.StateParam(DATA_ITEMS_SPEC),
    collected_embeddings_state=beam.DoFn.StateParam(EMBEDDINGS_SPEC),
    update_counter_state=beam.DoFn.StateParam(UPDATE_COUNTER_SPEC),
    *args,
    **kwargs,
):
  """
      Takes the embedding of a document and updates the clustering model

      Args:
        element: The input element to be processed.
        model_state: This is the state of the clustering model. It is a stateful parameter,
        which means that it will be updated after each call to the process function.
        collected_docs_state: This is a stateful dictionary that stores the documents that
        have been processed so far.
        collected_embeddings_state: This is a dictionary of document IDs and their embeddings.
        update_counter_state: This is a counter that keeps track of how many documents have been
      processed.
      """
  # 1. Initialise or load states
  clustering = model_state.read() or Birch(n_clusters=None, threshold=0.7)
  collected_documents = collected_docs_state.read() or {}
  collected_embeddings = collected_embeddings_state.read() or {}
  update_counter = update_counter_state.read() or Counter()

  # 2. Extract document, add to state, and add to clustering model
  _, doc = element
  doc_id = doc["id"]
  embedding_vector = doc["embedding"]
  collected_embeddings[doc_id] = embedding_vector
  collected_documents[doc_id] = {"id": doc_id, "text": doc["text"]}
  update_counter = len(collected_documents)

  clustering.partial_fit(np.atleast_2d(embedding_vector))

  # 3. Predict cluster labels of collected documents
  cluster_labels = clustering.predict(
      np.array(list(collected_embeddings.values())))

  # 4. Write states
  model_state.write(clustering)
  collected_docs_state.write(collected_documents)
  collected_embeddings_state.write(collected_embeddings)
  update_counter_state.write(update_counter)
  yield {
      "labels": cluster_labels,
      "docs": collected_documents,
      "id": list(collected_embeddings.keys()),
      "counter": update_counter,
  }

GetUpdates 是一个 DoFn,它在每次有新消息到达时打印分配给每条 Twitter 消息的聚类。

updated_clusters = clustering | "Format Update" >> beam.ParDo(GetUpdates())