RAG using OCI GenAI

Oracle Cloud (OCI) has released Generative AI service that offers cohere and LLama family of models. In this article, I will use Meta Llama 3.1 model and LangChain framework to perform neural search and use the search results to answer user queries (commonly referred to as Retrieval Augmented Generation).
There are multiple ways to use LangChain for RAG usecase, I will use the method that provides more control at each step and will also provide various hyper-parameters that can be tuned. Based on the structure and nature of the content, one solution may not work for all use cases. Hence, we need to be aware of the knobs which can be tuned to get the desired results.
Data Ingestion
Let us start with data ingestion. I will be using Chroma DB as a vector store and OCI Cohere as embedding model. Regardless of the choice of the vector store, look out for the search parameters offered by the vector store.
I have scraped a few documents from the CNCF project website to ingest. Let us load the documents:
def load_documents():
loader = DirectoryLoader('../../files', glob="**/*.txt")
documents = loader.load()
print("number of documents " + str(len(documents)))
return documents
Before ingesting the data, we need to perform chunking of data. There are quite a few techniques for chunking, I have used ‘Recursive Text Character Splitter’ from LangChain. This will chunk the data recursively based on a list of characters like new line, space etc.
def split_text(documents: list[Document]):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000, # Size of each chunk in characters
chunk_overlap=200, # Overlap between consecutive chunks
length_function=len, # Function to compute the length of the text
add_start_index=True, # Flag to add start index to each chunk
)
chunks = text_splitter.split_documents(documents)
return chunks
Look out for the parameters to tune: chunk size and chunk overlap. Chunk size is generally a function of the max tokens supported by the model and also depends on the data. Chunk overlap decides how much contextual data, we need from the neighboring chunks. A general rule of thumb is to set the chunk overlap to be 1/5th of the chunk size.
Now, let us initialize the vector db and the models.
def initialize_embedding_model():
embedding_model = OCIGenAIEmbeddings(
model_id="cohere.embed-english-light-v3.0",
service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com",
compartment_id="<COMPARTMENT_ID>",
auth_profile="<YOUR_OCI_PROFILE>"
)
return embedding_model
def initialize_query_model():
query_model = ChatOCIGenAI(
model_id="meta.llama-3-70b-instruct",
service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com",
compartment_id="<COMPARTMENT_ID>",
model_kwargs={"temperature": 1, "max_tokens": 500},
auth_profile="<YOUR_OCI_PROFILE>"
)
return query_model
Look out for the model arguments: temperature and max tokens. These should be tuned based on the usecase.
def initialize_vector_store(embedding_model, collection_name):
persistent_client = chromadb.PersistentClient()
collection = persistent_client.get_or_create_collection(collection_name)
vector_store = Chroma(
client = persistent_client,
collection_name = collection_name,
embedding_function = embedding_model,
collection_metadata = {"hnsw:space": "cosine", "hnsw:construction_ef": 50, "hnsw:M": 32, "search_ef": 15},
)
return vector_store
Look out for the collection metadata. These parameters will impact the search accuracy. There will be some trade-offs with search time/memory, so check out the documentation for more details.
- Search algorithm (HNSW in this case).
- Function to match similarity (cosine similarity etc).
- Number of neighbors to explore while adding new vector.
- Max number of neighbor connections.
- Number of neighbors to explore while searching.
Now, let us ingest the data.
def save_to_db(chunks: list[Document], vector_store):
vector_store.add_documents(documents=chunks)
print(f"Saved {len(chunks)} chunks ")
Search and Generate
Now, let us implement RAG using chains in LangChain. We can use retrievers in LangChain, however, I have used the LangChain vector store handle so that we can gain control over searching and passing the results to the model. I have also added prompt to generate citations along with the results.
def predict_with_citation(query: str, vector_store, query_model, embedding_model):
template = """
You are an assistant who is expert in answering questions based on given context.
Use the provided context only to answer the following question. Say 'I dont know, if you are unable to find the relevant document'.
You must return both an answer and citations. A citation consists of a VERBATIM quote that \
justifies the answer and the ID of the quote article. Return a citation for every quote across all articles \
that justify the answer. Use the following format for your final output:
<cited_answer>
<answer></answer>
<citations>
<citation><source_id></source_id><quote></quote></citation>
<citation><source_id></source_id><quote></quote></citation>
...
</citations>
</cited_answer>
Here is the contextual information to be used: {context}
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)
embedding_vector = embedding_model.embed_query(query)
docs = vector_store.similarity_search_by_vector(embedding = [embedding_vector], k = 5)
print("Number of documents returned " + str(len(docs)))
chain = (
RunnablePassthrough.assign(context=lambda input: format_docs_xml(input["context"]))
| prompt
| query_model
| XMLOutputParser()
)
response = chain.invoke({"context": docs, "question": query})
print(response['cited_answer'])
return response
While retrieving the similar documents, I am passing k=5. This is the number of documents to return. This parameter along with chunk size has to be adopted accordingly based on the maximum tokens supported by the model. And if we are using a conversational chat, the size of the history also will add up to the overall token size. Following is the the general formulae which we can use to calculate the total token length:
chunk_length * (Number of documents fetched) + length of history + prompt length + output length < Max tokens
Note: length refers to token length
Look out for the ‘filter’ argument in the search function. We can pass in additional metadata filters while searching. This will help us filter out documents based on document metadata. Note that, we can perform re-ranking of documents after retrieving the documents.
Time to test out our solution.
query = "What are components of Open Telemetry?"
predict_with_citation(query, vector_store, query_model, embedding_model)
Output
<answer>
The components of OpenTelemetry include:
1. A specification for all components
2. A standard protocol that defines the shape of telemetry data
3. Semantic conventions that define a standard naming scheme for common telemetry data types
4. APIs that define how to generate telemetry data
5. Language SDKs that implement the specification, APIs, and export of telemetry data
6. A library ecosystem that implements instrumentation for common libraries and frameworks
7. Automatic instrumentation components that generate telemetry data without requiring code changes
8. The OpenTelemetry Collector, a proxy that receives, processes, and exports telemetry data
9. Various other tools, such as the OpenTelemetry Operator for Kubernetes, OpenTelemetry Helm Charts, and community assets for FaaS
</answer>
<citations>
<citation><source_id>0</source_id><quote>Main OpenTelemetry components OpenTelemetry consists of the following major components: A specification for all components A standard protocol that defines the shape of telemetry data Semantic conventions that define a standard naming scheme for common telemetry data types APIs that define how to generate telemetry data Language SDKs that implement the specification, APIs, and export of telemetry data A library ecosystem that implements instrumentation for common libraries and frameworks Automatic instrumentation components that generate telemetry data without requiring code changes</quote></citation>
<citation><source_id>1</source_id><quote>A library ecosystem that implements instrumentation for common libraries and frameworks Automatic instrumentation components that generate telemetry data without requiring code changes The OpenTelemetry Collector, a proxy that receives, processes, and exports telemetry data Various other tools, such as the OpenTelemetry Operator for Kubernetes, OpenTelemetry Helm Charts, and community assets for FaaS</quote></citation>
</citations>
We can see it gives the appropriate response and also provides citation for the result. You can find the complete source code in github.