-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathcall_methods.py
More file actions
74 lines (62 loc) · 2.31 KB
/
call_methods.py
File metadata and controls
74 lines (62 loc) · 2.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from models.multi_query import MultiQueryRetrieval
from models.fusion_rrf import FusionRetrieval
from models.decomposition_ltm import DecompositionRetrieval
class RAGMethodCaller:
"""
This class is responsible for calling the different RAG methods based on the user's choice.
Args
----
retriever:
The retriever object that will be used to retrieve documents.
Methods
-------
call_method(method_name):
Calls the specified RAG method based on the method name provided.
"""
def __init__(self, retriever):
"""
Initializes the RAGMethodCaller with the retriever object.
parameters
----------
retriever:
The retriever object that will be used to retrieve documents.
Raises
------
ValueError:
If the retriever object is None.
"""
if not retriever:
raise ValueError("Retriever cannot be None")
self.retriever = retriever
def call_method(self, method_name):
"""
Calls the specified RAG method based on the method name provided.
parameters
----------
method_name: str
The name of the method to be called.
raises
------
ValueError:
If the method name is invalid.
"""
if method_name.lower() == "multi_query":
try:
multi_query_retrieval = MultiQueryRetrieval(self.retriever)
multi_query_retrieval.run()
except ValueError as e:
print(f"Error in MultiQueryRetrieval: {e}")
elif method_name.lower() == "fusion":
try:
fusion_retrieval = FusionRetrieval(self.retriever)
fusion_retrieval.run()
except ValueError as e:
print(f"Error in FusionRetrieval: {e}")
elif method_name.lower() == "decomposition":
try:
decomposition_retrieval = DecompositionRetrieval(self.retriever)
decomposition_retrieval.run()
except ValueError as e:
print(f"Error in DecompositionRetrieval: {e}")
else:
print("Invalid method name provided. Please choose from 'multi_query', 'fusion', or 'decomposition'.")