Skip to main content

DiarizationLM

Project description

DiarizationLM

Python application PyPI Version Python Versions Downloads

Overview

Here we open source some functions and tools used in the DiarizationLM paper.

demo

Disclaimer

This is NOT an official Google product.

Instructions

img

Install the package

You can install the package with:

pip install diarizationlm

Once installed, you can directly use many of the existing functions from the package. For example:

import diarizationlm

src_text = "hello good morning hi how are you pretty good"
src_spk = "1 1 1 2 2 2 2 1 1"
tgt_text = "hello morning hi hey are you be good"
tgt_spk = "1 2 2 2 1 1 2 1"
transferred_spk = diarizationlm.transcript_preserving_speaker_transfer(
    src_text, src_spk, tgt_text, tgt_spk)
print(transferred_spk)

Data format

We assume all internal data are stored in JSON files. An example is testdata/example_data.json. The field "utterances" stores a list of utterances, and in each utterance we have these string fields:

Field Description
"utterance_id" This stores the utterance ID.
"hyp_text" This stores the sequence of hypothesis words, but joined by spaces.
"hyp_spk" This stores the sequence of hypothesis speakers, but joined by spaces.
"hyp_diarized_text" This is the text representation of the hypothesis words and speakers. It can be used for debugging and to build the prompts to LLM.
"ref_*" Similar to the "hyp_*" fields, but these are ground truth reference, rather than hypothesis.

Conversion between representations

In the paper, we mentioned two representations:

  1. The word sequence and speaker sequence representation.
  2. The pure text representation.

Example:

Word sequence:         ["good", "morning", "how", "are", "you"]
Speaker sequence:      [1, 1, 2, 2, 2]
Text representation:   "<spk:1> good morning <spk:2> how are you"

We provide the functions in diarizationlm/utils.py to convert between these two representations:

  • create_diarized_text() converts the word and speaker sequences to the pure text representation.
  • extract_text_and_spk() converts the pure text representation to the word and speaker sequences.

Transcript-preserving speaker transfer (TPST)

TPST is a critical data processing algorithm used in multiple places in our paper.

A Python implementation is available in diarizationlm/utils.py, defined as:

def transcript_preserving_speaker_transfer(
    src_text: str, src_spk: str, tgt_text: str, tgt_spk: str
) -> str

img

Training data preparation

We provide a Python script train_data_prep.py that can be used for preparing the dataset for finetuning LLMs (i.e. the prompt builder module described in the paper). This tool will do these for you:

  1. Segment the prompts and completions based on the input and output length limit.
  2. Optionally apply prefix and suffix to prompts and completions.
  3. Store prompt-completion pairs in different file formats.

The segmentation length, prefix, and suffix are passed in as flags to train_data_prep.py. In Python code, they are configured as PromptOptions defined in utils.py.

We support 3 different output file formats:

Format Description
tfrecord The TFRecord format can be used by various machine learning libraries.
json This format is more human readable and can be used for debugging. It's also useful for finetuning PaLM models via the Google Cloud API.
csv This format can be used by many existing tools. OpenAI also provides a tool to convert csv files to jsonl files.
jsonl This format can be directly used by the OpenAI API for finetuning GPT models.

Example command:

python3 train_data_prep.py \
--input="testdata/example_data.json" \
--output="/tmp/example_data.jsonl" \
--output_type=jsonl \
--emit_input_length=1000 \
--emit_target_length=1000 \
--prompt_suffix=" --> " \
--completion_suffix=" [eod]" \
--input_feature_key="prompt" \
--output_feature_key="completion"

LLM finetuning and inference

Warning: This step is very costly! Proceed with caution at your own risk. Also GPT models are very different from PaLM models. Reproducibility is not guaranteed!

In our paper, we used Google's internal tools to finetune PaLM 2 models and to run the model inference. Google's policy does not allow us to disclose any details about the tools and the PaLM 2 models.

However, if you are interested in reproducing some of our experiments, one option is to use other alternative LLMs, such as OpenAI's GPT models.

Using the train_data_prep.py tool mentioned above, you can create csv files, and use OpenAI libraries to convert to the jsonl format. Example command:

openai tools fine_tunes.prepare_data -f train_data.csv

Once you have the training data in jsonl format, you can finetune GPT models with the data, either via the API or using OpenAI's web UI. For example:

openai api fine_tunes.create -t "train_data.jsonl"

After you have finetuned a model, we provide a Python script run_finetuned_gpt.py to run the GPT model inference on testing data. You need to provide your --api_key and --engine to the script.

Completion parser

During inference, the prompts are send to the LLM, and the LLM will generate the completions. We provide a postprocess_completions.py script that serves as the completion parser module as described in the paper. It will:

  1. Truncate the completion suffix, and any text generated after this suffix.
  2. Concatenate the completions of all segments from the same utterance.
  3. Transfer the speakers to the original hypothesis ASR transcript.

Citation

Our paper is cited as:

@article{wang2024diarizationlm,
  title={{DiarizationLM: Speaker Diarization Post-Processing with Large Language Models}},
  author={Quan Wang and Yiling Huang and Guanlong Zhao and Evan Clark and Wei Xia and Hank Liao},
  journal={arXiv preprint arXiv:2401.03506},
  year={2024}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

diarizationlm-0.0.5.tar.gz (20.2 kB view hashes)

Uploaded Source

Built Distribution

diarizationlm-0.0.5-py3-none-any.whl (18.7 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page