This project focuses on the improvement of context selection to generate an answer over Python repositories.
Using advanced chunking techniques, hybrid retrieval and query classification to choose the right tools, the llm will have all of the context it may need to answer the query.
A lot of time and resources goes to understand programming projects. It does not matter if it's just for self-learning, to understand some implementation of new techniques, to just replicate something or, like in most cases, when joining an existing project. Understanding to the core the repositories is really important in all of these cases, the problem is that when trying to apply naive RAG techniques to repositories of code, the results are not good. The reason for that is because in a repository, there are a lot of interdependencies that can not be extracted by similarity search alone, and current chunking techniques are not appropiate since the strucuture of the code will not be respected, and therefore, information will be lost during the retrieval.
This project tries to make the task of creating RAG pipelines over repositories much more precise for Python repositories, although the approach taken in this project can be used in any other programming language.
The way we'll do it is by levering the syntax defined for python as a programming language
. In the indexing pipeline, we optimize the chunks
, both size and content, that will lead to better context in the retrieval phase. Also, really important relationships that need to be taken into account when doing RAG over repositories will be used, such as the use a function in another function in another part of the code. Then, by making query classification, we'll be able to use the right tools to retrieve the necessary context. This retrieval part will be a mix of similarity search
and keyword search
.
Using an external library (developed by the author of this project): rag-pychunk
, we'll leverage the python programming language syntax to accomplish two things:
By optimizing the chunk size and content, we'll make sure that all of the definition of: Classes, functions, methods and independent block of codes, will be together in the same chunk. For example:
def function_a(x, y): z = x + y return z def function_b(x, y): z = function_a(x, y) return z - 10 block_of_code = [...] # block of code
With this piece of code 3 chunks will be created:
And several metadata will be saved for each chunk, such as: function name, arguments of functions, and lines of code.
The relationships that we'll be able to extract in this piece of code:
Why is this indexing pipeline important?
For two reasons:
More information can be seen in the Appendix.
The retrieval phase is where agents
come in. There are different types of retriever depending in the type of the query.
What decides which category the query falls into is the number of subjects. What is a subject?
In order to classify the query, there are two different types of agents, which will return different types of tools depending on the task. Both of the agents have been trained using prompt engineer few shot example
.
Simple: Zero or one subject present in the question. Example: How does the function x work?
Complex: More than one subject. Example: What are the differences between Class A and Class B?
Tool returned: SimilarityRetriever
When classifying between simple and complex multiple things can happen:
In both of this cases, a retry will be applied. We want to get a coherent answer, which means:
This is controlled with pydantic validators
:
class Output(BaseModel): query: str reasoning: str question_type: str subject: Set[str | None] valid: Optional[bool] = True @model_validator(mode='before') def coherence_between_question_type_and_length_of_subjects(cls, values: Dict[str, Any]) -> Dict[str, Any]: subject, question_type = values.get('subject'), values.get('question_type') if question_type not in ('simple', 'complex'): return values if (len(subject) <= 1 and question_type != 'simple') or (len(subject) > 1 and question_type != 'complex'): values['valid'] = False return values
Simple
and classify the question into one of two categories:Particular: The question involves only the subject. Example: How does the function x work?
General: The question is formulated in such a way that the question does not have to do only with the subject. Will my changes break anything?
Tool returned: SimilarityRetriever
(for particular) or GeneralRetriever
(for general)
When the agent pipeline returns a tool, that tool will be used to obtain the nodes which will be used to feed the LLM and generate an answer to the query.
What's a tool?
This retriever will be used when the query is identified as particular or complex. Depending on each case, the retrieval will be done differently.
For each subject, we'll try to get the Node in which the subject appears directly from the database, since in the column node_metadata we are storing either the method, function or class name of the node:
for subject in subjects: # complex case --> multiple subjects # we try to find the node in case it is a method, function or class by looking it up in the database directly node_of_subject = self._db.query(Node).join(NodeMetadata, NodeMetadata.node_id == Node.id)\ .filter( or_( NodeMetadata.node_metadata['additional_metadata']['function_name'].astext == subject, NodeMetadata.node_metadata['additional_metadata']['method_name'].astext == subject, NodeMetadata.node_metadata['additional_metadata']['class_name'].astext == subject )).all()
If this fails, we'll get the nodes via similarity search. Either way, we'll get the relationships of the retrieved nodes.
Complex case
In case we have multiple subjects and the node can't be obtained directly from the database we want to maximize the probability of getting the correct nodes via similarity search. In order to do so, we modify the query like this:
query_to_embed, = (query.replace(subject, "") if subject in query else query.lower().replace(subject, "")) if len(subjects) > 1 else (query,)
That way, only one subject name will appear in the query. For each subject, only itself will appear int the query by removing the other subjects, decreasing the similarity with other nodes, and increasing the similarity for the one that matters.
With this tool we want to answer queries like: "Will changing X break something?", "What will happen If I change this parameter of the function X to...?", "Are there any errors in the repo?"
One subject
We need to get the proper node and everything that is depending on it. Again, we'll try to get the node directly from the database based on the subject name and if it does not succeed then via similarity search. Besides that, we also want to get all of the nodes whose values in the column: node_relationships have the id of this node. Why? Because that means that in those nodes, the retrieved node appears in some way, which means that changing the retrieved node would affect those nodes as well.
all_nodes_related_to_this_node = self._db.query(Node).filter(Node.node_relationships.has_key(str(valid_node.id))).all()
Where valid_node.id is the retrieved node.
Zero subjects
If no subjects are detected, then all of the nodes are retrieved and will need to be parsed one by one (maybe with its relationships to give the LLM more context) to answer the query for each and every one of them.
This retriever is not available for the agent to return. We'll need to use it as a postprocessor.
With this "retriever" the relationshps will be filtered to only keep those similar to the retrieved nodes given a threshold:
def filter_relationships(self, threshold: float) -> Dict[str, List[Node]] | Dict: filtered_relationships = {} for node in self._nodes: for rel, rel_nodes in self._relationships.items(): filtered_nodes = [] for rel_node in rel_nodes: similarity = 1 - np.cos( np.dot( np.array(node.embedding), np.array(self._embeddings(rel_node.text)) ) ) if similarity > threshold: filtered_nodes.append(rel_node) filtered_relationships[rel] = filtered_nodes
Query: How do the methods: _check_common_parent_nodes, _check_relationships_of_retrieved_nodes, return_nodes_after_apply_threshold_filter and return_nodes_with_score_after_apply_threshold_filter work together to improve the result of the function query_vector_database?
To understand how the methods _check_common_parent_nodes, _check_relationships_of_retrieved_nodes, return_nodes_after_apply_threshold_filter, and return_nodes_with_score_after_apply_threshold_filter improve the results of the query_vector_database function, let’s break down each method and explain their roles.
Code Explanation:
def _check_common_parent_nodes(self) -> List[Tuple[str, int]]: parent_node_ids = {} file_of_node_ids = {} for node in self._retrieved_nodes_score: node: Node = node.node file_id = node.file_id if node.node_type.value == NodeType.METHOD.value or node.node_type.value == NodeType.CODE.value: parent_node_id = node.parent_node_id previous_parent_node_id = parent_node_id while parent_node_id is not None: previous_parent_node_id = parent_node_id parent_node_id = self._db.get(Node, parent_node_id).parent_node_id parent_node_id = previous_parent_node_id parent_node_ids[parent_node_id] = 1 if parent_node_id not in parent_node_ids else parent_node_ids[parent_node_id] + 1 file_of_node_ids[file_id] = 1 if file_id not in file_of_node_ids else file_of_node_ids[file_id] + 1 return [ [(parent_node_id, frequency) for (parent_node_id, frequency) in parent_node_ids.items() if frequency >= self._min_parent_nodes], [(file_id, frequency) for (file_id, frequency) in file_of_node_ids.items() if frequency >= self._min_file_nodes] ]
Code Explanation:
def _check_relationships_of_retrieved_nodes(self, nodes: List[Node], depth: int) -> List[str]: relations = [] for node in nodes: id = node.id node_relationships = self._db.get(Node, id).node_relationships if not node_relationships or not len(node_relationships): continue for node_relationship_id, _ in node_relationships.items(): node_ = self._db.get(Node, node_relationship_id) relations.extend(self.__retrieve_relationship_nodes(base_id=id, node=node_, depth=depth-1)) try: relations = [x for x in relations if x not in [n.id for n in self._retrieved_nodes]] except: pass return relations
Code Explanation:
def return_nodes_after_apply_threshold_filter(self): return self._retrieved_nodes
Code Explanation:
def return_nodes_with_score_after_apply_threshold_filter(self): return self._retrieved_nodes_score
Collaboration for Improved Results in query_vector_database
Together, these methods improve the results of the query_vector_database function by:
Identifying Context: _check_common_parent_nodes provides context by focusing on significant parent nodes.
Establishing Relationships: _check_relationships_of_retrieved_nodes ensures that the nodes’ interconnections inform the final output.
Filtering for Relevance: The methods return_nodes_after_apply_threshold_filter and return_nodes_with_score_after_apply_threshold_filter ensure that only the most relevant nodes based on certain thresholds and scoring are returned, which helps narrow down the results to those most relevant to the initial query.
This structured approach to managing relationships and filtering results leads to better, contextually informed output in response to the user’s query.
Agent reasoning response: The subjects are the methods _check_common_parent_nodes, _check_relationships_of_retrieved_nodes, return_nodes_after_apply_threshold_filter, return_nodes_with_score_after_apply_threshold_filter, and the function query_vector_database. Therefore, the answer is complex because there is more than one subject.
Agent answer: complex
Tool decided by the agent: SimilarityRetriever
Original Query: How do the methods: _check_common_parent_nodes, _check_relationships_of_retrieved_nodes, return_nodes_after_apply_threshold_filter and return_nodes_with_score_after_apply_threshold_filter work together to improve the result of the function query_vector_database?
Exact match of subject: query_vector_database in the database. --> 1
Relationships of node retrieve: 18
Exact match of subject: return_nodes_with_score_after_apply_threshold_filter in the database. --> 1
Relationships of node retrieve: 0
Exact match of subject: _check_common_parent_nodes in the database. --> 1
Relationships of node retrieve: 2
Exact match of subject: _check_relationships_of_retrieved_nodes in the database. --> 1
Relationships of node retrieve: 6
Exact match of subject: return_nodes_after_apply_threshold_filter in the database. --> 1
These are logs from the query pipeline process. Since we have the definition of each and every of the subjects, we able to feed all of the important context to the llm.
Query: How will changing the value of the argument threshold of the function query_pipeline affect the rest of the code?
The threshold argument in the query_pipeline function is used to filter relationships based on similarity scores. Let's break down how it affects the rest of the code by examining its context in the query_pipeline function and related components.
Here’s the relevant part of the query_pipeline function:
async def query_pipeline(agent: MultiAgent, query: str, llm: LlamaClient, db: Session, threshold: float=0.25) -> Any: # [...] if relationships: relationship_retriever = RelationShipRetriever(query=query, nodes=nodes, relationships=relationships) filtered_relationships = relationship_retriever.filter_relationships(threshold=threshold) # [...]
Explanation of how threshold is used:
Purpose of threshold:
The threshold parameter is used in the filter_relationships method of the RelationShipRetriever class. It determines the minimum similarity score a relationship must have to be included in the filtered_relationships.
Inside the filter_relationships Method:
In this method, every node's embedding is compared to the embeddings of its related nodes. A similarity score is calculated, and nodes with a score greater than the defined threshold are included in the filtered_relationships dictionary.
Here’s how it looks in the filter_relationships method:
def filter_relationships(self, threshold: float) -> Dict[str, List[Node]] | Dict: filtered_relationships = {} for node in self._nodes: for rel, rel_nodes in self._relationships.items(): filtered_nodes = [] for rel_node in rel_nodes: similarity = 1 - np.dot( np.array(node.embedding), np.array(self._embeddings(rel_node.text)) ) / (np.linalg.norm(np.array(node.embedding)) * np.linalg.norm(np.array(self._embeddings(rel_node.text)))) if similarity > threshold: filtered_nodes.append(rel_node) filtered_relationships[rel] = filtered_nodes
Effect of Changing threshold:
Summary
Changing the threshold value directly influences the number of relationships considered in subsequent operations in the query_pipeline, which can affect the final answer provided by the LLM and the overall outcome of the query. The balance between precision and recall in your results is directly tied to how you specify this threshold value.
In summary, the threshold parameter is a crucial lever in controlling the quality and quantity of relationships processed in the pipeline.
Without hardcoding the RelationShipRetriever
, in Naive RAG this chunk would have never been retrieved, because nothing in the query would have make the similarity score high enough. This has not been the case for our case, because we have a really defined relationship between the function query_pipeline
and the class RelationShipRetriever
.
In this project, we've leveraged the defined python programming language syntax to improve the retrieval phase of the RAG pipeline to improve the context used to feed an LLM, and therefore, improve the quality of the answers.
Our results demonstrate great results, even when making complicated queries that require relationships-awareness to have all of the necessary context to answer the query properly.
By optimizing chunk content and size, and by extracting the existing relationships between the different components of a python project, we can really improve the quality of the generated answers since the LLM will have much more context and context awareness of the code itself.
The chunks and relationships plus metadata of the chunks are stored in the following postgresql
models:
Node Model
class Node(Base): __tablename__ = "node" id = Column(UUID(as_uuid=True), primary_key=True, index=True, default=uuid.uuid4) node_type = Column(Enum(NodeType), nullable=False) file_id = Column(UUID, ForeignKey("file.id", ondelete='CASCADE'), nullable=False) parent_node_id = Column(UUID(as_uuid=True), ForeignKey("node.id"), nullable=True) previous_node_id = Column(UUID(as_uuid=True), ForeignKey("node.id"), nullable=True) next_node_id = Column(UUID(as_uuid=True), ForeignKey("node.id"), nullable=True) text = Column(Text, nullable=False) embedding = Column(Vector(384)) hash = Column(String, nullable=False, index=True) node_relationships = Column(JSONB) parent = relationship("Node",foreign_keys=[parent_node_id], remote_side=[id], backref=backref("children", cascade="all, delete-orphan")) previous = relationship("Node", foreign_keys=[previous_node_id], remote_side=[id], backref=backref("next_node"), uselist=False) next = relationship("Node", foreign_keys=[next_node_id], remote_side=[id], backref=backref("previous_node"), uselist=False) file = relationship("File", back_populates="nodes", foreign_keys=[file_id]) node_metadata = relationship("NodeMetadata", back_populates="node", cascade="all, delete-orphan")
NodeMetadata Model
class NodeMetadata(Base): __tablename__ = "node_metadata" node_id = Column(UUID(as_uuid=True), ForeignKey("node.id", ondelete="CASCADE"), primary_key=True) node_metadata = Column(JSONB) node = relationship("Node", foreign_keys=[node_id], back_populates="node_metadata")
In the node_metadata column of the NodeMetadata model, some key-value pairs will be stored for each node. For example:
{'method_name': <name of the method if the node is a method>}
{'class_name': <name of the class if the node is a class>}
{'function_name': <name of the function if the node is a function>}
In the node_relationships column of the Node Model will be stored in the form of:
{node_id: [lines in which the node appears]}
where node_id is the id of the node that appears in an specific node. For example:
text_of_node_1 = "def hello(): ..."
text_of_node_2 = "def bye(): hello(); print("bye!")"
relationships_of_node_2 = {'id of node 1': [line 1]}
## Links
There are no datasets linked
There are no models linked
There are no datasets linked
There are no models linked