Building a retrieval-backed chatbot

Introduction

DPR (Dense Passage Retrieval) is a system for looking up relevant documents given a query. Basically this is the same as the retrieval task now recognized in the Massive Text Embedding Benchmark. The basic idea is to use a model trained on a contrastive loss function to embed context and queries. This creates an embedding space where questions and the snippets that provide answers to them are near to each other. In some systems, embeddings for queries and context are provided by two separate models trained together.

After chunking and embedding our documents into this space, we can search the space using an appropriate algorithm, and provide them to a chat system as context, using the chat system to synthesize them into plain-english results.

This system is not without precedent. Youtube is rife with tutorials using eg langchain to perform this task, which have now been canonized as the very sus ‘pdf-chat: chat with your pdf-files!’ but I’m not sure if a langchain system is scalable at the time of writing. I also found haystack and vespa, which seem to be in this or related domains.

Initial build

I found Facebook’s model specifically for DPR. I think this model is trained on a Wikipedia dataset, so it will likely work best for properly formatted wiki-style data.

The ever-helpful SBERT website tells us we need to encode text formatted like this:

"London [SEP] London is the capital and largest city of England and the United Kingdom."

I found this to be a key to getting good results out of the model in the end.

I used a a project called WikiExtractor to clean the dumped XML from my target wiki, and wrote a script to run a window through it and extract correctly formatted text snippets and load them into an index file for use with Annoy, the Spotify library I’ve used previously. What I like about Annoy lib is it writes the index to disk, so if I spend a little extra on disk size I can load these into Docker images. It ain’t pretty but it gets the job done for now. In the future I can scale further by using shared volumes, since it’s read only, but for now this will suffice.

At this point we can perform queries to retrieve context and place them together with the user’s query to obtain output.

Choice of chatbot

I chose mosaicml/mpt-7b for a few reasons. First, it fits on the GPU in my home server. Second, it has a massive context size. Third, it’s fully open-source under the Apache-2.0 license.

The downside of this model is in order to have such a large context size, it has a weird architecture (specifically replacing positional encoding with some other scheme).

Load testing the chatbot

How much throughput can I get with a single GPU? I decided to do a little bit of very basic load testing

def load_test(num_prompts: int):
    print("begin load test", num_prompts)
    now = datetime.datetime.now()
    for _ in range(num_prompts):
        handle_message({
            "context": "This is a load test",
            "content": "We are performing a load test, please do your best to create a response that looks like you are answering 
a question"
        })
    print("load test complete")
    print("time taken: ", datetime.datetime.now() - now)

There’s probably some caching gotcha somewhere in here (and the system did not create good output). I just wanted a ballpark number of how much throughput I could hope to get out of this system.

Well, it wasn’t that great. I ran 100 through and it took about 100 seconds. That’s not nearly enough throughput to handle production loads (if I ever got production loads).

I had recently been exposed to vLLM, an LLM runtime that claims to be compatible with some popular huggingface models, including mosiac-mpt. The bar charts they present claim very significant speedups, literally 10x.

Out of the box, vLLM seems to triple the output, taking about 35s to process 100 requests, and greatly simplifies the code. This is without any batching though. If we use a batch size of 5… We process 100 in 9 seconds. About 11x speedup. Amazing! A 90ms response time from my home GPU server! For a chat system, that’s pretty excellent.

Next up

Facebook came out with llama-2, with a semi-unrestricted license (free for commercial use up to some huge number of users). It also has a fairly large context window, and scores better on benchmarks. So I thought I’d go ahead and swap to that one. It should be fairly simple, since vLLM immediately added support for llama-2.

I’d also like to do a bit more scraping an cleaning to diversify my data sources.