Chain-of-Retrieval Augmented Generation
This paper introduces an approach for training o1-like RAG models that retrieve and reason over relevant information step by step before generating the final answer. Conventional RAG methods usually perform a single retrieval step before the generation process, which limits their effectiveness in addressing complex queries due to imperfect retrieval results. In contrast, our proposed method, CoRAG (Chain-of-Retrieval Augmented Generation), allows the model to dynamically reformulate the query based on the evolving state. To train CoRAG effectively, we utilize rejection sampling to automatically generate intermediate retrieval chains, thereby augmenting existing RAG datasets that only provide the correct final answer. At test time, we propose various decoding strategies to scale the model's test-time compute by controlling the length and number of sampled retrieval chains. Experimental results across multiple benchmarks validate the efficacy of CoRAG, particularly in multi-hop question answering tasks, where we observe more than 10 points improvement in EM score compared to strong baselines. On the KILT benchmark, CoRAG establishes a new state-of-the-art performance across a diverse range of knowledge-intensive tasks. Furthermore, we offer comprehensive analyses to understand the scaling behavior of CoRAG, laying the groundwork for future research aimed at developing factual and grounded foundation models.
Discussion
Host: Hey everyone, and welcome back to the podcast! Today, we're diving into a fascinating area of AI research, something that's really pushing the boundaries of how we build these incredibly powerful language models. I'm your host, Leo, and I'm super excited to unpack this topic with you all.
Host: We're going to be exploring a new approach called 'Chain-of-Retrieval Augmented Generation,' or CoRAG for short. Now, that might sound like a mouthful, but trust me, it's actually a really clever way of getting AI models to be more accurate and reliable, especially when dealing with complex questions. We'll break it down, talk about how it works, and what makes it different from other methods we've seen.
Host: Think of it like this: instead of just throwing a question at a model and hoping it has all the answers memorized, CoRAG allows the model to actively search for information, piece by piece, almost like a human would when researching a topic. It's not just about finding information; it's about how the model finds it, and that’s what we'll really get into today.
Host: Before we jump into the nitty-gritty, it’s worth acknowledging that we are basing our discussion on a research paper from a team at Microsoft and Renmin University of China. The paper introduces this CoRAG method, and we are really dissecting it to understand it better. This isn't just some abstract concept; it's something actively being worked on and developed in the AI research community.
Host: Let’s start with a little background. We're all familiar with Retrieval-Augmented Generation, right? It’s that technique that combines these massive language models with retrieval systems. The idea is pretty straightforward: instead of relying solely on what a model learned during its training, we give it the ability to go out and find relevant information to help answer a question. It's all about grounding the answers in real-world knowledge, reducing those annoying hallucinations, and making the responses more factual. We've talked about RAG on previous episodes, but CoRAG takes it a step further.
Host: The traditional RAG approach typically involves a single retrieval step. You have a question, the system retrieves relevant documents, and then the language model generates the answer, using both the original question and the retrieved info. It’s like a quick search-and-respond method. However, this can become problematic when dealing with those intricate, multi-layered questions that require a bit more complex reasoning. What happens if the first retrieval is not perfect or doesn’t bring back all the required information? This is where the limitations of traditional RAG start to show.
Host: This is exactly where CoRAG enters the scene. It doesn’t just perform one single search. Instead, CoRAG introduces a multi-step retrieval system. The key thing is that the system can dynamically reformulate the question based on what it already found. It’s like a detective following leads, adjusting their investigation as new evidence emerges. It's not just blindly retrieving documents; it is about evolving the search based on the previous steps. That really resonates with human problem-solving where we break down complex problems into simpler parts.
Host: The paper specifically notes how conventional RAG methods perform only a single retrieval step and this limits their ability to address complex questions. Imperfect retrieval results in a bottleneck for the whole system. The idea of dynamically reformulating the query is super interesting, it suggests a thinking process that adapts to the evolving state. Now, how do you even train a model to do this iterative kind of retrieval?
Host: Well, here's where things get really clever. The team used something called 'rejection sampling' to automatically generate these intermediate retrieval chains. This is crucial because, typically, RAG datasets only have the initial question and the final answer; they don't include all these little steps. So, to train CoRAG, they had to create these intermediate retrieval steps artificially. Basically, the model tries out different ways to search for information. If a retrieval chain leads to a wrong path, it is essentially rejected, and the model learns from these mistakes to build the chain in a step-by-step process. It is like learning by trial and error but in a supervised learning way.
Host: And it’s not just about generating these retrieval chains; it’s also about figuring out the best way to utilize them at test time. The paper discusses various decoding strategies to manage the computation costs by controlling how long the retrieval chains are and how many chains to sample. This is key because, obviously, the more you explore with retrieval steps, the more computational resources it takes. The research here really highlights the trade-off between computational resources and model performance, and how this can be managed at test time. They mention greedy decoding, best-of-N sampling, and tree search, all to achieve the best performance while optimizing resource usage.
Host: Now, let's zoom in a little on the technical side. The idea of a retrieval chain is pretty interesting. Each chain starts with the original question, then has a series of sub-queries and corresponding sub-answers. These sub-queries are generated by an LLM based on the original question, along with previous queries and answers. This method is really about learning to navigate the information space with an intelligent plan, rather than just jumping to the answer.
Host: The iterative process is very interesting. To generate the sub-answers, they use a text retriever to find the top-k most relevant documents based on each sub-query. And then an LLM is used to generate the answer from those retrieved documents. This iterative process continues until either the chain hits the maximum length or the model generates the final correct answer. It's a pretty elaborate process to generate one chain, right?
Host: Absolutely! And the key here is that the quality of each retrieval chain is assessed. They calculate the log-likelihood of the correct answer given the information in the chain. The chain with the highest likelihood is then used to augment the existing datasets. So, it’s not just about randomly generating chains; it’s about selecting those chains that are most likely to lead to the correct answers. This is a super important step to make sure that model is trained on quality data.
Host: So, we've covered how these chains are generated. The next step is training the model using them. The training data includes the original question, the correct answer, and the best retrieval chain they found. The model is fine-tuned on this data, using standard next-token prediction. The model is trained to predict the next sub-query, sub-answer, and the final answer. It’s essentially a multi-task learning setup, all aiming at improving the overall reasoning and retrieval process. The research team is simultaneously training the model on these three tasks. It is very important that the same prompt templates used to generate data are used during the model training, so the model becomes very good at this step-by-step retrieval and generation process.
Host: And the fun doesn't stop at the training stage, right? The test-time strategies are crucial as well. The paper mentions different decoding strategies like greedy decoding, best-of-N sampling, and tree search. Greedy decoding basically goes through the chain step by step to find the answer. Best-of-N sampling involves sampling multiple chains, then choosing the best one based on a penalty score, and using it to generate the final answer. Tree search is a more exhaustive method of exploring multiple retrieval paths using a breadth-first method and choosing the best one, which is naturally more computationally expensive. All these are to control the trade-off between the model performance and the computational cost. They aim to find the sweet spot where the model can deliver best results without requiring infinite resources.
Host: Let's delve deeper into these decoding strategies, starting with greedy decoding. This is the simplest, right? It just generates the sub-queries and sub-answers sequentially, step-by-step, until reaching the final answer. It's like the most direct route. But what are the downsides?
Host: Exactly. Greedy decoding might not be the most effective, as it relies on making the best decision at each step without considering future implications. That's where Best-of-N Sampling comes in, it samples N retrieval chains, exploring more of the search space, and then selects the best one for the answer generation based on a penalty score. The research is really clever by calculating this conditional log-likelihood of “No relevant information found” as a penalty for each chain. It’s like figuring out which chain is most certain and accurate.
Host: And then there's Tree Search, which is like an even more comprehensive exploration method. It's a breadth-first search where each state is expanded by sampling sub-queries and performing rollouts. Each state is evaluated based on average penalty score, and the state with lowest score is chosen for future expansion. So, it's like trying out multiple different paths to see which one is most promising, but this also means more computational cost compared to the other two strategies. The paper mentions a simplified version of breadth-first search, because implementing a full tree search is very computationally expensive.
Host: The researchers mention that the length of the chain, 'L', is very important to control the test time compute across all decoding strategies. For Best-of-N Sampling, the N value also offers an extra way to scale computation. These are all knobs that researchers can turn to see how the model performance changes when different resources are used. It's a fine balance, right? You want to use enough resources to find the best answer, but without overspending and wasting resources. It's about finding efficiency and effectiveness at the same time.
Host: So, they really tested out their model, right? They used two main sets of benchmarks: a collection of multi-hop QA datasets and the KILT benchmark. Multi-hop QA datasets, like 2WikiMultihopQA, HotpotQA, Bamboogle, and MuSiQue, are specifically designed to test a model's ability to handle complex reasoning questions, whereas the KILT benchmark is a broader spectrum of knowledge-intensive tasks. What was their setup like?
Host: For each training dataset, they used the open-source Llama-3.1-8B-Instruct model to generate those retrieval chains using rejection sampling. They also used E5-large as the text retriever for those intermediate retrieval steps. So, that's their baseline for retrieval. The retrieval corpus they used was the English Wikipedia, provided by KILT, which is pretty massive. The idea is to augment the original dataset with the generated retrieval chains and to train the LLM with these augmented datasets.
Host: And when it comes to evaluation, for the multi-hop QA datasets, they used metrics like the exact match score and F1 scores, standard in the QA field, I guess. Then for KILT, they submitted their predictions to the official evaluation server and reported the results of their downstream metrics. That gives a clear comparison against previous methods using an official leaderboard. The idea is to show how CoRAG performs in various scenarios.
Host: They did full parameter fine-tuning on these augmented datasets, starting with the Llama-3.1-8B-Instruct checkpoint. Two separate models were trained, one for the multi-hop QA datasets and the other for KILT. It’s worth noting that they also fine-tuned an E5-Mistral retriever and a RankLLaMA re-ranker on KILT’s data, which means their performance on the KILT benchmark might be partly attributed to this better-tuned retrieval system. All of their training was done using 8 A100 GPUs. That gives you a sense of the computational power that went into these experiments.
Host: Let's talk about the results! For multi-hop QA, CoRAG-8B significantly outperformed all the other models except for the Bamboogle dataset. They tested against models like a few-shot Llama-3.1-8B-Instruct, GPT-4o, Self-RAG, and many more. It's interesting that, in most cases, the CoRAG model did much better, even though it used a weaker LLM compared to some baselines, such as Search-o1-32B, which is based on QwQ. The paper does point out that fine-tuning on multi-hop QA datasets provides an advantage to CoRAG-8B compared to the few-shot setting for DRAG and IterDRAG models.
Host: And they acknowledged that the Bamboogle dataset had some issues with the limited instances and the need for more recent knowledge, which they didn’t have in their knowledge base. It's interesting that the models relying on commercial search engines had an advantage here. But the overall trend was pretty clear: CoRAG with multi-step retrieval shows superior performance on multi-hop QA tasks.
Host: Moving on to the KILT benchmark, CoRAG-8B achieved state-of-the-art performance across almost all tasks except for FEVER where it was slightly behind a model with 11B parameters, so it is a huge performance boost. They mentioned that, for submission to the KILT leaderboard, they chose the best decoding configuration for each task based on their validation dataset. It's pretty impressive to see CoRAG outperforming many other models, considering it is only an 8B model. It shows how much the multi-step retrieval approach improves the performance.
Host: The paper also explores scaling test-time compute to see how performance changes based on resource allocation. Like OpenAI’s o1 model, they explored the trade-off between computation and performance without changing the weights of the model. So, they varied retrieval chain length and the number of sampled chains for best-of-N sampling. What were some of their key observations?
Host: They found that as they increased retrieval chain length, the model performance improved, particularly when the chain length was small. However, as the length increased, there was a decrease in marginal gains. Increasing N for best-of-N sampling had mixed results, depending on the dataset. The best sampling N also seems to depend on the complexity of the dataset. For MuSiQue, a more challenging dataset, a larger N helped, but for 2WikiMultihopQA, which is less challenging, a smaller N was enough. This shows that, to achieve optimal resource use, the test-time scaling needs to be adjusted based on the complexity of the task.
Host: And it is very interesting how the Pareto frontier, or the optimal trade-off between performance and resources, follows a log-linear trend. The paper mentions that this frontier gives practitioners an idea of how much computational resources to allocate for the model. And they note that their study simplified a few things, like treating prompt tokens and generated tokens equally, and ignoring retrieval costs, which is a fair point as more rigorous analysis would definitely take into account these factors. But it still gives us a great insight into test time compute scaling.
Host: Let's jump into some of the analysis they did. The researchers explored whether iterative training would improve the performance. Since they can use trained models to generate more data and then train it again, just like how LLM post-training is done. What did they find?
Host: The results were mixed; they saw improvements on the 2WikiMultihopQA dataset but slight declines on the other datasets. This may suggest that the instruction-tuned LLMs already have a pretty good ability to generate high-quality retrieval chains. But it is good to see that there is more avenues for improving the model through iterative training. The authors have also noted the robustness and generalization ability of their model, specifically for different text retrievers and weak-to-strong generalization settings. Let's break these down.
Host: The paper examined different text retrievers and how the performance changes by replacing the E5-large retriever with two weaker alternatives. E5-base and BM25. While CoRAG consistently improved results with more test-time compute across all datasets, as expected, they found that using stronger retrievers led to higher absolute performance. This shows that retrieval quality and the overall model's performance are still very closely related, and improving retriever quality is crucial. So, even though the CoRAG model can compensate for a weaker retriever using chain-of-retrieval, the optimal scenario is to use a strong retriever.
Host: The weak-to-strong generalization analysis is very fascinating. The idea is to reduce the computational cost of data generation. They used weaker LLMs for the retrieval chain generation, such as the Llama 1B and 3B models, and then they fine-tuned the stronger Llama-3.1-8B-Inst model using the generated data from weaker models. What did they find?
Host: They found that using the Llama-3B model achieved very close performance compared to the 8B model but the Llama-1B model had a noticeable performance drop, probably because the 1B model had difficulty following given instructions and generated sub-optimal retrieval chains. This weak-to-strong generalization setting shows how we can reduce the cost of data generation using weaker and more efficient models, but still retain the same performance when trained with a stronger model.
Host: Now, this is an interesting question that the paper discusses. Does chain-of-retrieval always help? They found that multi-hop QA datasets, which are specifically made for evaluating reasoning, really benefit from CoRAG, as they expected. However, for tasks where a single retrieval step is sufficient, the benefit of using chain-of-retrieval is not as large. Datasets like NQ and TriviaQA do not really require multi-step reasoning and information retrieval to generate the final answer. The authors note that the decoding strategies should be more adaptive based on the complexity of the task, and the chain-of-retrieval approach should be used when it is beneficial.
Host: Also, they explored the idea of learning to stop at test time, where the model would dynamically decide when to stop retrieving information. Rather than always performing the retrieval steps, they prompted the model to predict whether they have enough information to give the final answer after each step. What was the idea behind this, and did it work?
Host: They added a loss term for this stop-prediction task, with the target output being “Yes” if enough information is available, and “No” otherwise. And, by adjusting the logit bias of the “Yes” token, they could control how early the model would stop. While early stopping could save on token consumption, it did come with some performance degradation. The overall conclusion is that the optimal configuration still depends on the specific characteristics of the dataset and the performance requirements. In short, dynamically controlling retrieval is very challenging, and more research is needed.
Host: In conclusion, the paper introduces CoRAG, which is a very promising framework that teaches LLMs to perform iterative retrieval and reasoning for complex queries, and it does this by generating intermediate retrieval chains using rejection sampling, and it shows that you can improve the performance by managing the trade-off between performance and compute at test time through different decoding strategies. The results of this research show a state-of-the-art performance on multi-hop QA datasets and the KILT benchmark. What’s next for this research direction?