From 76234798e565a709e5df4846b088f7a3c8511b41 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Mon, 16 Jun 2025 14:28:58 +0330 Subject: [PATCH 01/10] feat: Added tracing the persisted payload! --- .env.example | 7 +- requirements.txt | 2 +- tasks/agent.py | 228 ++++++++++++++++++++++----- tasks/hivemind/agent.py | 4 + tasks/hivemind/query_data_sources.py | 16 +- tasks/mongo_persistence.py | 207 ++++++++++++++++++++++++ tests/unit/test_mongo_persistence.py | 163 +++++++++++++++++++ 7 files changed, 581 insertions(+), 46 deletions(-) create mode 100644 tasks/mongo_persistence.py create mode 100644 tests/unit/test_mongo_persistence.py diff --git a/.env.example b/.env.example index 9655a86..477557c 100644 --- a/.env.example +++ b/.env.example @@ -8,4 +8,9 @@ OPENAI_API_KEY= REDIS_HOST= REDIS_PORT= -REDIS_PASSWORD= \ No newline at end of file +REDIS_PASSWORD= + +MONGODB_HOST= +MONGODB_PORT= +MONGODB_USER= +MONGODB_PASS= diff --git a/requirements.txt b/requirements.txt index f1a23d3..48d76ed 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ python-dotenv>=1.0.0, <2.0.0 redis==5.2.0 pydantic==2.9.2 crewai==0.105.0 -tc-temporal-backend==1.1.2 +tc-temporal-backend==1.1.3 transformers[torch]==4.49.0 nest-asyncio==1.6.0 openai==1.66.3 diff --git a/tasks/agent.py b/tasks/agent.py index 5cb4eaa..8d5703c 100644 --- a/tasks/agent.py +++ b/tasks/agent.py @@ -8,6 +8,7 @@ from crewai.crews.crew_output import CrewOutput from tasks.hivemind.agent import AgenticHivemindFlow from tasks.redis_memory import RedisMemory + from tasks.mongo_persistence import MongoPersistence from tc_temporal_backend.schema.hivemind import HivemindQueryPayload @@ -20,46 +21,193 @@ async def run_hivemind_agent_activity( It places the resulting answer into payload.content.response. """ - memory: RedisMemory | None - chat_history: str | None - - if payload.chat_id: - memory = RedisMemory(key=f"conversation:{payload.chat_id}") - chat_history = memory.get_text() - else: - chat_history = None - memory = None - - # Instantiate the flow with the user query - flow = AgenticHivemindFlow( - community_id=payload.community_id, - user_query=payload.query, - enable_answer_skipping=payload.enable_answer_skipping, - chat_history=chat_history, - ) - - # Run the flow - crew_output = await flow.kickoff_async(inputs={"query": payload.query}) - - if isinstance(crew_output, CrewOutput): - final_answer = crew_output.raw - elif not payload.enable_answer_skipping: - final_answer = "No answer was generated." - else: - final_answer = None - - if isinstance(final_answer, str) and "encountered an error" in final_answer.lower(): - logging.error(f"final_answer: {final_answer}") - final_answer = "Looks like things didn't go through. Please give it another go." - - if memory and final_answer != "NONE": - chat = f"User: {payload.query}\nAgent: {final_answer}" - memory.append_text(chat) - - if final_answer == "NONE": - return None - else: - return final_answer + # Initialize MongoDB persistence + mongo_persistence = MongoPersistence() + workflow_id = None + + try: + # Create initial workflow state in MongoDB + workflow_id = mongo_persistence.create_workflow_state( + community_id=payload.community_id, + query=payload.query, + chat_id=getattr(payload, 'chat_id', None), + enable_answer_skipping=payload.enable_answer_skipping, + ) + + logging.info(f"Created workflow state with ID: {workflow_id}") + + # Update step: Initialization + mongo_persistence.update_workflow_step( + workflow_id=workflow_id, + step_name="initialization", + step_data={ + "communityId": payload.community_id, + "query": payload.query, + "chatId": getattr(payload, 'chat_id', None), + "enableAnswerSkipping": payload.enable_answer_skipping, + } + ) + + memory: RedisMemory | None + chat_history: str | None + + if hasattr(payload, 'chat_id') and payload.chat_id: + memory = RedisMemory(key=f"conversation:{payload.chat_id}") + chat_history = memory.get_text() + + # Update step: Chat history retrieval + mongo_persistence.update_workflow_step( + workflow_id=workflow_id, + step_name="chat_history_retrieval", + step_data={ + "chatId": payload.chat_id, + "chatHistoryLength": len(chat_history) if chat_history else 0, + } + ) + else: + chat_history = None + memory = None + + # Update step: No chat history + mongo_persistence.update_workflow_step( + workflow_id=workflow_id, + step_name="no_chat_history", + step_data={"reason": "No chat_id provided"} + ) + + # Update step: Flow initialization + mongo_persistence.update_workflow_step( + workflow_id=workflow_id, + step_name="flow_initialization", + step_data={ + "flowType": "AgenticHivemindFlow", + "enableAnswerSkipping": payload.enable_answer_skipping, + } + ) + + # Instantiate the flow with the user query + flow = AgenticHivemindFlow( + community_id=payload.community_id, + user_query=payload.query, + enable_answer_skipping=payload.enable_answer_skipping, + chat_history=chat_history, + workflow_id=workflow_id, + ) + + # Update step: Flow execution start + mongo_persistence.update_workflow_step( + workflow_id=workflow_id, + step_name="flow_execution_start", + step_data={"userQuery": payload.query} + ) + + # Run the flow + crew_output = await flow.kickoff_async(inputs={"query": payload.query}) + + # Update step: Flow execution complete + mongo_persistence.update_workflow_step( + workflow_id=workflow_id, + step_name="flow_execution_complete", + step_data={ + "crewOutputType": type(crew_output).__name__, + "hasOutput": crew_output is not None, + } + ) + + if isinstance(crew_output, CrewOutput): + final_answer = crew_output.raw + elif not payload.enable_answer_skipping: + final_answer = "No answer was generated." + else: + final_answer = None + + # Update step: Answer processing + mongo_persistence.update_workflow_step( + workflow_id=workflow_id, + step_name="answer_processing", + step_data={ + "answerType": type(final_answer).__name__, + "answerLength": len(final_answer) if isinstance(final_answer, str) else 0, + "enableAnswerSkipping": payload.enable_answer_skipping, + } + ) + + if isinstance(final_answer, str) and "encountered an error" in final_answer.lower(): + logging.error(f"final_answer: {final_answer}") + fallback_answer = "Looks like things didn't go through. Please give it another go." + + # Update step: Error handling + mongo_persistence.update_workflow_step( + workflow_id=workflow_id, + step_name="error_handling", + step_data={ + "errorType": "crewai_error", + "originalAnswer": final_answer, + "fallbackAnswer": fallback_answer, + } + ) + final_answer = fallback_answer + + if memory and final_answer != "NONE": + chat = f"User: {payload.query}\nAgent: {final_answer}" + memory.append_text(chat) + + # Update step: Memory update + mongo_persistence.update_workflow_step( + workflow_id=workflow_id, + step_name="memory_update", + step_data={ + "memoryKey": f"conversation:{payload.chat_id}", + "chatEntryLength": len(chat), + } + ) + + # Update final answer in MongoDB + if final_answer and final_answer != "NONE": + mongo_persistence.update_response( + workflow_id=workflow_id, + response_message=final_answer, + status="completed" + ) + else: + mongo_persistence.update_response( + workflow_id=workflow_id, + response_message="No answer generated", + status="completed_no_answer" + ) + + if final_answer == "NONE": + return None + else: + return final_answer + + except Exception as e: + logging.error(f"Error in run_hivemind_agent_activity: {e}") + + # Update step: Error occurred + if workflow_id: + mongo_persistence.update_workflow_step( + workflow_id=workflow_id, + step_name="error_occurred", + step_data={ + "errorType": type(e).__name__, + "errorMessage": str(e), + }, + status="failed" + ) + + # Update final status + mongo_persistence.update_response( + workflow_id=workflow_id, + response_message=None, + status="failed" + ) + + raise + finally: + # Close MongoDB connection + if mongo_persistence: + mongo_persistence.close() @workflow.defn diff --git a/tasks/hivemind/agent.py b/tasks/hivemind/agent.py index 3d4322a..5c26fda 100644 --- a/tasks/hivemind/agent.py +++ b/tasks/hivemind/agent.py @@ -9,6 +9,7 @@ from pydantic import BaseModel from crewai.tools import tool from openai import OpenAI +from typing import Optional class AgenticFlowState(BaseModel): @@ -28,12 +29,14 @@ def __init__( community_id: str, enable_answer_skipping: bool = False, chat_history: str | None = None, + workflow_id: Optional[str] = None, persistence=None, max_retry_count: int = 3, **kwargs, ) -> None: self.enable_answer_skipping = enable_answer_skipping self.community_id = community_id + self.workflow_id = workflow_id self.max_retry_count = max_retry_count super().__init__(persistence, **kwargs) @@ -100,6 +103,7 @@ def do_rag_query(self) -> str: query_data_source_tool = RAGPipelineTool.setup_tools( community_id=self.community_id, enable_answer_skipping=self.enable_answer_skipping, + workflow_id=self.workflow_id, ) q_a_bot_agent = Agent( diff --git a/tasks/hivemind/query_data_sources.py b/tasks/hivemind/query_data_sources.py index 8f0e0f9..41494f1 100644 --- a/tasks/hivemind/query_data_sources.py +++ b/tasks/hivemind/query_data_sources.py @@ -4,7 +4,7 @@ import nest_asyncio from dotenv import load_dotenv -from typing import Type +from typing import Type, Optional from tc_temporal_backend.client import TemporalClient from tc_temporal_backend.schema.hivemind import HivemindQueryPayload from pydantic import BaseModel, Field @@ -15,9 +15,10 @@ class QueryDataSources: - def __init__(self, community_id: str, enable_answer_skipping: bool): + def __init__(self, community_id: str, enable_answer_skipping: bool, workflow_id: Optional[str] = None): self.community_id = community_id self.enable_answer_skipping = enable_answer_skipping + self.workflow_id = workflow_id async def query(self, query: str) -> str | None: """ @@ -36,6 +37,10 @@ async def query(self, query: str) -> str | None: enable_answer_skipping=self.enable_answer_skipping, ) + # Add workflow_id to payload if available + if self.workflow_id: + payload.workflow_id = self.workflow_id + hivemind_queue = self.load_hivemind_queue() result = await client.execute_workflow( "HivemindWorkflow", @@ -79,12 +84,14 @@ class RAGPipelineTool(BaseTool): args_schema: Type[BaseModel] = RAGPipelineToolSchema @classmethod - def setup_tools(cls, community_id: str, enable_answer_skipping: bool): + def setup_tools(cls, community_id: str, enable_answer_skipping: bool, workflow_id: Optional[str] = None): """ - Setup the tool with the necessary community identifier and the flag to enable answer skipping. + Setup the tool with the necessary community identifier, the flag to enable answer skipping, + and the workflow ID for tracking. """ cls.community_id = community_id cls.enable_answer_skipping = enable_answer_skipping + cls.workflow_id = workflow_id return cls def _run(self, query: str) -> str: @@ -104,6 +111,7 @@ def _run(self, query: str) -> str: query_data_sources = QueryDataSources( community_id=self.community_id, enable_answer_skipping=self.enable_answer_skipping, + workflow_id=self.workflow_id, ) response = asyncio.run(query_data_sources.query(query)) diff --git a/tasks/mongo_persistence.py b/tasks/mongo_persistence.py new file mode 100644 index 0000000..49fa760 --- /dev/null +++ b/tasks/mongo_persistence.py @@ -0,0 +1,207 @@ +import logging + +from datetime import datetime, timezone +from typing import Optional, Dict, Any +from pymongo.database import Database +from pymongo.collection import Collection +from bson import ObjectId +from tc_hivemind_backend.db.mongo import MongoSingleton + +class MongoPersistence: + """A class for persisting workflow state data to MongoDB.""" + + def __init__(self, database_name: str = "hivemind", collection_name: str = "internal_messages"): + """Initialize MongoDB connection using environment variables. + + Parameters + ---------- + collection_name : str + The MongoDB collection name to use for storing workflow states + """ + self.collection_name = collection_name + self.client = MongoSingleton.get_instance().get_client() + self.db: Database = self.client[database_name] + self.collection: Collection = self.db[self.collection_name] + + def create_workflow_state( + self, + community_id: str, + query: str, + source: str = "temporal", + destination: dict[str, str] | None = None, + filters: Optional[Dict[str, Any]] = None, + metadata: Optional[Dict[str, Any]] = None, + chat_id: Optional[str] = None, + enable_answer_skipping: bool = False, + ) -> str: + """Create a new workflow state document and return its ID. + + Parameters + ---------- + community_id : str + The community identifier + query : str + The user query + source : str + The source of the request (e.g., "discord") + destination : dict[str, str] | None + The destination of the request (e.g., {"queue": "DISCORD_HIVEMIND_ADAPTER", "event": "QUESTION_COMMAND_RECEIVED"}) + filters : Optional[Dict[str, Any]] + Optional filters for the question + metadata : Optional[Dict[str, Any]] + Optional metadata from the client side + chat_id : Optional[str] + The chat identifier + enable_answer_skipping : bool + Whether answer skipping is enabled + + Returns + ------- + str + The MongoDB document ID as a string + """ + try: + workflow_state = { + "communityId": community_id, + "route": { + "source": source, + "destination": destination, + }, + "question": { + "message": query, + "filters": filters + }, + "response": None, + "metadata": metadata or {}, + "createdAt": datetime.now(tz=timezone.utc), + "updatedAt": datetime.now(tz=timezone.utc), + "steps": [], + "currentStep": "initialized", + "status": "running", + "chatId": chat_id, + "enableAnswerSkipping": enable_answer_skipping, + } + + result = self.collection.insert_one(workflow_state) + return str(result.inserted_id) + except Exception as e: + logging.error(f"Error creating workflow state: {e}") + raise + + def update_workflow_step( + self, + workflow_id: str, + step_name: str, + step_data: Dict[str, Any], + status: str = "running", + ) -> bool: + """Update the workflow state with a new step. + + Parameters + ---------- + workflow_id : str + The MongoDB document ID + step_name : str + The name of the current step + step_data : Dict[str, Any] + The data for this step + status : str + The current status of the workflow + + Returns + ------- + bool + True if update was successful, False otherwise + """ + try: + step_entry = { + "stepName": step_name, + "timestamp": datetime.utcnow(), + "data": step_data, + } + + update_data = { + "$push": {"steps": step_entry}, + "$set": { + "currentStep": step_name, + "status": status, + "updatedAt": datetime.utcnow(), + } + } + + result = self.collection.update_one( + {"_id": ObjectId(workflow_id)}, update_data + ) + return result.modified_count > 0 + except Exception as e: + logging.error(f"Error updating workflow step: {e}") + return False + + def update_response( + self, + workflow_id: str, + response_message: str, + status: str = "completed", + ) -> bool: + """Update the workflow state with the response message. + + Parameters + ---------- + workflow_id : str + The MongoDB document ID + response_message : str + The response message from the workflow + status : str + The final status of the workflow + + Returns + ------- + bool + True if update was successful, False otherwise + """ + try: + update_data = { + "$set": { + "response.message": response_message, + "status": status, + "updatedAt": datetime.now(tz=timezone.utc), + } + } + + result = self.collection.update_one( + {"_id": ObjectId(workflow_id)}, update_data + ) + return result.modified_count > 0 + except Exception as e: + logging.error(f"Error updating response: {e}") + return False + + def get_workflow_state(self, workflow_id: str) -> Optional[Dict[str, Any]]: + """Get the workflow state by ID. + + Parameters + ---------- + workflow_id : str + The MongoDB document ID + + Returns + ------- + Optional[Dict[str, Any]] + The workflow state document or None if not found + """ + try: + document = self.collection.find_one({"_id": ObjectId(workflow_id)}) + if document: + # Convert ObjectId to string for JSON serialization + document["_id"] = str(document["_id"]) + return document + except Exception as e: + logging.error(f"Error getting workflow state: {e}") + return None + + def close(self): + """Close the MongoDB connection.""" + try: + self.client.close() + except Exception as e: + logging.error(f"Error closing MongoDB connection: {e}") \ No newline at end of file diff --git a/tests/unit/test_mongo_persistence.py b/tests/unit/test_mongo_persistence.py new file mode 100644 index 0000000..1074f3a --- /dev/null +++ b/tests/unit/test_mongo_persistence.py @@ -0,0 +1,163 @@ +import unittest +from unittest.mock import patch, MagicMock +import os +from datetime import datetime +from bson import ObjectId +from tasks.mongo_persistence import MongoPersistence + + +class TestMongoPersistence(unittest.TestCase): + """Test cases for the MongoPersistence class""" + + def setUp(self): + """Set up test environment""" + # Mock environment variables + self.env_patcher = patch.dict( + os.environ, + { + "MONGODB_HOST": "test-host", + "MONGODB_PORT": "27017", + "MONGODB_USER": "test-user", + "MONGODB_PASS": "test-password", + }, + ) + self.env_patcher.start() + + # Mock the MongoDB client and collection + self.collection_mock = MagicMock() + self.db_mock = MagicMock() + self.db_mock.get_collection.return_value = self.collection_mock + + self.client_mock = MagicMock() + self.client_mock.get_database.return_value = self.db_mock + + self.mongo_patcher = patch("pymongo.MongoClient", return_value=self.client_mock) + self.mongo_mock = self.mongo_patcher.start() + + # Create instance of MongoPersistence with mocked dependencies + self.persistence = MongoPersistence() + + def tearDown(self): + """Clean up after tests""" + self.env_patcher.stop() + self.mongo_patcher.stop() + + def test_init_with_env_vars(self): + """Test initialization with environment variables""" + self.mongo_mock.assert_called_once_with( + host="test-host", port=27017, username="test-user", password="test-password" + ) + self.assertEqual(self.persistence.collection_name, "hivemind_workflow_states") + + def test_create_workflow_state(self): + """Test creating a new workflow state""" + # Mock the insert_one result + mock_result = MagicMock() + mock_result.inserted_id = ObjectId("507f1f77bcf86cd799439011") + self.collection_mock.insert_one.return_value = mock_result + + workflow_id = self.persistence.create_workflow_state( + community_id="test-community", + query="test query", + chat_id="test-chat", + enable_answer_skipping=True, + ) + + self.assertEqual(workflow_id, "507f1f77bcf86cd799439011") + self.collection_mock.insert_one.assert_called_once() + + # Check the inserted document structure + inserted_doc = self.collection_mock.insert_one.call_args[0][0] + self.assertEqual(inserted_doc["communityId"], "test-community") + self.assertEqual(inserted_doc["question"]["message"], "test query") + self.assertEqual(inserted_doc["chatId"], "test-chat") + self.assertTrue(inserted_doc["enableAnswerSkipping"]) + self.assertEqual(inserted_doc["currentStep"], "initialized") + self.assertEqual(inserted_doc["status"], "running") + self.assertIn("route", inserted_doc) + self.assertIn("response", inserted_doc) + self.assertIn("metadata", inserted_doc) + self.assertIn("steps", inserted_doc) + + def test_update_workflow_step(self): + """Test updating workflow step""" + # Mock the update_one result + mock_result = MagicMock() + mock_result.modified_count = 1 + self.collection_mock.update_one.return_value = mock_result + + success = self.persistence.update_workflow_step( + workflow_id="507f1f77bcf86cd799439011", + step_name="test_step", + step_data={"key": "value"}, + status="running", + ) + + self.assertTrue(success) + self.collection_mock.update_one.assert_called_once() + + # Check the update operation + call_args = self.collection_mock.update_one.call_args + self.assertEqual(call_args[0][0], {"_id": ObjectId("507f1f77bcf86cd799439011")}) + + update_data = call_args[0][1] + self.assertIn("$push", update_data) + self.assertIn("$set", update_data) + self.assertEqual(update_data["$set"]["currentStep"], "test_step") + + def test_update_response(self): + """Test updating response""" + # Mock the update_one result + mock_result = MagicMock() + mock_result.modified_count = 1 + self.collection_mock.update_one.return_value = mock_result + + success = self.persistence.update_response( + workflow_id="507f1f77bcf86cd799439011", + response_message="Test answer", + status="completed", + ) + + self.assertTrue(success) + self.collection_mock.update_one.assert_called_once() + + # Check the update operation + call_args = self.collection_mock.update_one.call_args + self.assertEqual(call_args[0][0], {"_id": ObjectId("507f1f77bcf86cd799439011")}) + + update_data = call_args[0][1] + self.assertIn("$set", update_data) + self.assertEqual(update_data["$set"]["response.message"], "Test answer") + self.assertEqual(update_data["$set"]["status"], "completed") + + def test_get_workflow_state(self): + """Test getting workflow state""" + # Mock the find_one result + mock_doc = { + "_id": ObjectId("507f1f77bcf86cd799439011"), + "communityId": "test-community", + "question": {"message": "test query"}, + "status": "completed", + } + self.collection_mock.find_one.return_value = mock_doc + + result = self.persistence.get_workflow_state("507f1f77bcf86cd799439011") + + self.assertIsNotNone(result) + self.assertEqual(result["_id"], "507f1f77bcf86cd799439011") + self.assertEqual(result["communityId"], "test-community") + self.assertEqual(result["question"]["message"], "test query") + self.assertEqual(result["status"], "completed") + + def test_get_workflow_state_not_found(self): + """Test getting workflow state that doesn't exist""" + self.collection_mock.find_one.return_value = None + + result = self.persistence.get_workflow_state("507f1f77bcf86cd799439011") + + self.assertIsNone(result) + + def test_close(self): + """Test closing the MongoDB connection""" + self.persistence.close() + self.client_mock.close.assert_called_once() From 83b5edbe346969cb80a44c63bbc04182c32bb557 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Mon, 16 Jun 2025 14:36:27 +0330 Subject: [PATCH 02/10] feat: Added persist AQC reasoning! --- tasks/agent.py | 1 + tasks/hivemind/agent.py | 54 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/tasks/agent.py b/tasks/agent.py index 8d5703c..27cef66 100644 --- a/tasks/agent.py +++ b/tasks/agent.py @@ -92,6 +92,7 @@ async def run_hivemind_agent_activity( enable_answer_skipping=payload.enable_answer_skipping, chat_history=chat_history, workflow_id=workflow_id, + mongo_persistence=mongo_persistence, ) # Update step: Flow execution start diff --git a/tasks/hivemind/agent.py b/tasks/hivemind/agent.py index 5c26fda..8cbbe79 100644 --- a/tasks/hivemind/agent.py +++ b/tasks/hivemind/agent.py @@ -10,6 +10,7 @@ from crewai.tools import tool from openai import OpenAI from typing import Optional +from tasks.mongo_persistence import MongoPersistence class AgenticFlowState(BaseModel): @@ -30,6 +31,7 @@ def __init__( enable_answer_skipping: bool = False, chat_history: str | None = None, workflow_id: Optional[str] = None, + mongo_persistence: MongoPersistence | None = None, persistence=None, max_retry_count: int = 3, **kwargs, @@ -37,6 +39,7 @@ def __init__( self.enable_answer_skipping = enable_answer_skipping self.community_id = community_id self.workflow_id = workflow_id + self.mongo_persistence = mongo_persistence self.max_retry_count = max_retry_count super().__init__(persistence, **kwargs) @@ -60,18 +63,57 @@ def detect_question(self): # classify using a local model question = checker.classify_message(message=self.state.user_query) + # Persist the local model classification result + if self.mongo_persistence and self.workflow_id: + self.mongo_persistence.update_workflow_step( + workflow_id=self.workflow_id, + step_name="local_model_classification", + step_data={ + "result": question, + "model": "local_transformer", + "query": self.state.user_query, + } + ) + if not question: self.state.state = "stop" return # classify using a language model is_question = checker.classify_question_lm(message=self.state.user_query) + # Persist the is_question result and reasoning + if self.mongo_persistence and self.workflow_id: + self.mongo_persistence.update_workflow_step( + workflow_id=self.workflow_id, + step_name="question_classification", + step_data={ + "result": is_question.result, + "reasoning": is_question.reasoning, + "model": "language_model", + "query": self.state.user_query, + } + ) + if not is_question.result: self.state.state = "stop" return # classify if its a RAG question rag_question = checker.classify_message_lm(message=self.state.user_query) + # Persist the rag_question result and reasoning and score + if self.mongo_persistence and self.workflow_id: + self.mongo_persistence.update_workflow_step( + workflow_id=self.workflow_id, + step_name="rag_classification", + step_data={ + "result": rag_question.result, + "score": rag_question.score, + "reasoning": rag_question.reasoning, + "model": "language_model", + "query": self.state.user_query, + } + ) + self.state.state = "continue" if rag_question.result else "stop" @router(detect_question) @@ -90,6 +132,18 @@ def detect_question_type(self) -> str: is_history_query = False if self.state.chat_history: is_history_query = self.classify_query(self.state.user_query) + # Persist the history query classification result + if self.mongo_persistence and self.workflow_id: + self.mongo_persistence.update_workflow_step( + workflow_id=self.workflow_id, + step_name="history_query_classification", + step_data={ + "result": is_history_query, + "model": "openai_gpt4", + "query": self.state.user_query, + "hasChatHistory": True, + } + ) if is_history_query: logging.info("History query detected") From f9e63f883e9b884d0015f52d64f9703df4dc1f01 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Mon, 16 Jun 2025 14:36:43 +0330 Subject: [PATCH 03/10] feat: updated README.md to support the changes! --- README.md | 165 +++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 164 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index df54266..c311aea 100644 --- a/README.md +++ b/README.md @@ -1 +1,164 @@ -# agents-workflow \ No newline at end of file +# Agents Workflow with MongoDB Persistence + +This project implements a CrewAI-based workflow system with comprehensive MongoDB persistence for tracking every step of the workflow execution. + +## Features + +- **MongoDB Persistence**: Every step of the workflow is persisted to MongoDB for audit trails and debugging +- **Workflow Tracking**: Complete visibility into the execution flow with timestamps and step data +- **Error Handling**: Comprehensive error tracking and recovery mechanisms +- **Chat History**: Redis-based chat history management +- **RAG Integration**: Retrieval-Augmented Generation pipeline for data source queries + +## Architecture + +### Components + +1. **MongoPersistence**: Handles all MongoDB operations for workflow state tracking +2. **AgenticHivemindFlow**: CrewAI flow that orchestrates the agent interactions +3. **run_hivemind_agent_activity**: Temporal activity that manages the workflow execution +4. **QueryDataSources**: Handles RAG queries with workflow ID tracking + +### Workflow Steps Tracked + +The system tracks the following steps in MongoDB: + +1. **initialization**: Initial workflow setup with parameters +2. **chat_history_retrieval**: Redis chat history retrieval (if applicable) +3. **no_chat_history**: When no chat history is available +4. **flow_initialization**: AgenticHivemindFlow setup +5. **flow_execution_start**: Beginning of CrewAI flow execution +6. **local_model_classification**: Local transformer model classification result +7. **question_classification**: Language model question classification with reasoning +8. **rag_classification**: RAG question classification with score and reasoning +9. **history_query_classification**: History vs RAG query classification (if applicable) +10. **flow_execution_complete**: Completion of CrewAI flow +11. **answer_processing**: Processing of the final answer +12. **error_handling**: Any error handling steps +13. **memory_update**: Redis memory updates (if applicable) +14. **error_occurred**: Any errors during execution + +## Environment Variables + +Use the `.env.example` to prepare your `.env` file. + +## Classification Data Persistence + +The system now persists detailed classification reasoning and results for better audit trails and debugging: + +### Local Model Classification +- **Step Name**: `local_model_classification` +- **Data**: Result from local transformer model +- **Model**: `local_transformer` + +### Question Classification +- **Step Name**: `question_classification` +- **Data**: + - `result`: Boolean indicating if the message is a question + - `reasoning`: Detailed explanation for the classification + - `model`: `language_model` + - `query`: Original user query + +### RAG Classification +- **Step Name**: `rag_classification` +- **Data**: + - `result`: Boolean indicating if RAG is needed + - `score`: Sensitivity score (0-1) + - `reasoning`: Detailed explanation for the score + - `model`: `language_model` + - `query`: Original user query + +### History Query Classification +- **Step Name**: `history_query_classification` +- **Data**: + - `result`: Boolean indicating if it's a history query + - `model`: `openai_gpt4` + - `query`: Original user query + - `hasChatHistory`: Boolean indicating if chat history was available + +## MongoDB Schema + +The workflow states are stored in the `internal_messages` collection with the following structure: + +```json +{ + "_id": "ObjectId", + "communityId": "string", + "route": { + "source": "string", + "destination": { + "queue": "string", + "event": "string" + } + }, + "question": { + "message": "string", + "filters": "object (optional)" + }, + "response": { + "message": "string" + }, + "metadata": "object", + "createdAt": "datetime", + "updatedAt": "datetime", + "steps": [ + { + "stepName": "string", + "timestamp": "datetime", + "data": "object" + } + ], + "currentStep": "string", + "status": "string", + "chatId": "string (optional)", + "enableAnswerSkipping": "boolean" +} +``` + +## Usage + +### Running the Worker + +```bash +python worker.py +``` + +### Querying Workflow States + +You can query the MongoDB collection to inspect workflow execution: + +```python +from tasks.mongo_persistence import MongoPersistence + +persistence = MongoPersistence() +workflow_state = persistence.get_workflow_state("workflow_id_here") +print(workflow_state) +``` + +## Testing + +Run the unit tests: + +```bash +python -m pytest tests/unit/test_mongo_persistence.py +``` + +## Dependencies + +- `pymongo==4.8.0`: MongoDB driver +- `redis==5.2.0`: Redis client +- `crewai==0.105.0`: AI agent framework +- `temporalio`: Temporal workflow engine +- `openai==1.66.3`: OpenAI API client + +## Workflow ID Tracking + +The workflow ID is passed through the entire execution chain: + +1. Created in `run_hivemind_agent_activity` +2. Passed to `AgenticHivemindFlow` +3. Passed to `RAGPipelineTool` +4. Passed to `QueryDataSources` +5. Included in `HivemindQueryPayload` for the `HivemindWorkflow` + +This ensures complete traceability from the initial query to the final response. \ No newline at end of file From bd135c38a44c177b02dd5daa84b5161787b8a19e Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Tue, 17 Jun 2025 16:24:32 +0330 Subject: [PATCH 04/10] fix: missing pymongo package! --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 48d76ed..b8eb2d3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ tc-temporal-backend==1.1.3 transformers[torch]==4.49.0 nest-asyncio==1.6.0 openai==1.66.3 +pymongo==4.13.2 From e1a544cd49120c5db89881841598cb45582c3e9e Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Tue, 17 Jun 2025 16:29:14 +0330 Subject: [PATCH 05/10] fix: added the right missing dependency! --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index b8eb2d3..b953f80 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,4 @@ tc-temporal-backend==1.1.3 transformers[torch]==4.49.0 nest-asyncio==1.6.0 openai==1.66.3 -pymongo==4.13.2 +tc-hivemind-backend==1.4.3 From d6c9654cf7b77d3e9acb6c11211977a90ad10b3e Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Tue, 17 Jun 2025 16:35:44 +0330 Subject: [PATCH 06/10] fix: depricated now timestamp! --- tasks/mongo_persistence.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tasks/mongo_persistence.py b/tasks/mongo_persistence.py index 49fa760..9614fc4 100644 --- a/tasks/mongo_persistence.py +++ b/tasks/mongo_persistence.py @@ -116,7 +116,7 @@ def update_workflow_step( try: step_entry = { "stepName": step_name, - "timestamp": datetime.utcnow(), + "timestamp": datetime.now(tz=timezone.utc), "data": step_data, } @@ -125,7 +125,7 @@ def update_workflow_step( "$set": { "currentStep": step_name, "status": status, - "updatedAt": datetime.utcnow(), + "updatedAt": datetime.now(tz=timezone.utc), } } From c9e4612f760175b5598213d48581fb4f4ec40ae1 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Tue, 17 Jun 2025 16:36:10 +0330 Subject: [PATCH 07/10] fix: closing mongo client while it is being in used in other places! --- tasks/agent.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tasks/agent.py b/tasks/agent.py index 27cef66..4258a44 100644 --- a/tasks/agent.py +++ b/tasks/agent.py @@ -205,10 +205,6 @@ async def run_hivemind_agent_activity( ) raise - finally: - # Close MongoDB connection - if mongo_persistence: - mongo_persistence.close() @workflow.defn From 371274e5e40a634bcaafa1ba0f5702257253f9b7 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Tue, 17 Jun 2025 16:49:03 +0330 Subject: [PATCH 08/10] fix: removed closing the mongo client! --- tasks/mongo_persistence.py | 7 ------- tests/unit/test_mongo_persistence.py | 5 ----- 2 files changed, 12 deletions(-) diff --git a/tasks/mongo_persistence.py b/tasks/mongo_persistence.py index 9614fc4..a41959f 100644 --- a/tasks/mongo_persistence.py +++ b/tasks/mongo_persistence.py @@ -198,10 +198,3 @@ def get_workflow_state(self, workflow_id: str) -> Optional[Dict[str, Any]]: except Exception as e: logging.error(f"Error getting workflow state: {e}") return None - - def close(self): - """Close the MongoDB connection.""" - try: - self.client.close() - except Exception as e: - logging.error(f"Error closing MongoDB connection: {e}") \ No newline at end of file diff --git a/tests/unit/test_mongo_persistence.py b/tests/unit/test_mongo_persistence.py index 1074f3a..f916bf0 100644 --- a/tests/unit/test_mongo_persistence.py +++ b/tests/unit/test_mongo_persistence.py @@ -156,8 +156,3 @@ def test_get_workflow_state_not_found(self): result = self.persistence.get_workflow_state("507f1f77bcf86cd799439011") self.assertIsNone(result) - - def test_close(self): - """Test closing the MongoDB connection""" - self.persistence.close() - self.client_mock.close.assert_called_once() From 6f33ce4a9755686865cb8601079c54f044087526 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Wed, 18 Jun 2025 10:02:37 +0330 Subject: [PATCH 09/10] fix: trying more to fix the test case --- tests/unit/test_mongo_persistence.py | 23 ++--------------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/tests/unit/test_mongo_persistence.py b/tests/unit/test_mongo_persistence.py index f916bf0..eaba857 100644 --- a/tests/unit/test_mongo_persistence.py +++ b/tests/unit/test_mongo_persistence.py @@ -1,27 +1,15 @@ import unittest from unittest.mock import patch, MagicMock -import os -from datetime import datetime +from dotenv import load_dotenv from bson import ObjectId from tasks.mongo_persistence import MongoPersistence - class TestMongoPersistence(unittest.TestCase): """Test cases for the MongoPersistence class""" def setUp(self): """Set up test environment""" - # Mock environment variables - self.env_patcher = patch.dict( - os.environ, - { - "MONGODB_HOST": "test-host", - "MONGODB_PORT": "27017", - "MONGODB_USER": "test-user", - "MONGODB_PASS": "test-password", - }, - ) - self.env_patcher.start() + load_dotenv() # Mock the MongoDB client and collection self.collection_mock = MagicMock() @@ -42,13 +30,6 @@ def tearDown(self): self.env_patcher.stop() self.mongo_patcher.stop() - def test_init_with_env_vars(self): - """Test initialization with environment variables""" - self.mongo_mock.assert_called_once_with( - host="test-host", port=27017, username="test-user", password="test-password" - ) - self.assertEqual(self.persistence.collection_name, "hivemind_workflow_states") - def test_create_workflow_state(self): """Test creating a new workflow state""" # Mock the insert_one result From 2798d7f59fc3de22e8d175e821d788e969739970 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Wed, 18 Jun 2025 12:11:11 +0330 Subject: [PATCH 10/10] fix: update test case to be integration test! --- tasks/mongo_persistence.py | 2 +- tests/integration/test_mongo_persistence.py | 330 ++++++++++++++++++++ tests/unit/test_mongo_persistence.py | 139 --------- 3 files changed, 331 insertions(+), 140 deletions(-) create mode 100644 tests/integration/test_mongo_persistence.py delete mode 100644 tests/unit/test_mongo_persistence.py diff --git a/tasks/mongo_persistence.py b/tasks/mongo_persistence.py index a41959f..cd4345f 100644 --- a/tasks/mongo_persistence.py +++ b/tasks/mongo_persistence.py @@ -162,7 +162,7 @@ def update_response( try: update_data = { "$set": { - "response.message": response_message, + "response": {"message": response_message}, "status": status, "updatedAt": datetime.now(tz=timezone.utc), } diff --git a/tests/integration/test_mongo_persistence.py b/tests/integration/test_mongo_persistence.py new file mode 100644 index 0000000..47a6abf --- /dev/null +++ b/tests/integration/test_mongo_persistence.py @@ -0,0 +1,330 @@ +import unittest +import uuid +from dotenv import load_dotenv +from bson import ObjectId +from tasks.mongo_persistence import MongoPersistence + +class TestMongoPersistenceIntegration(unittest.TestCase): + """Integration test cases for the MongoPersistence class that work with real MongoDB""" + + @classmethod + def setUpClass(cls): + """Set up test environment once for all tests""" + load_dotenv() + + # Use a test-specific collection name to avoid interfering with production data + cls.test_collection_name = f"test_internal_messages_{uuid.uuid4().hex[:8]}" + cls.persistence = MongoPersistence(collection_name=cls.test_collection_name) + + # Verify MongoDB connection + try: + # Test the connection by trying to access the collection + cls.persistence.collection.find_one() + print(f"✅ MongoDB connection successful. Using test collection: {cls.test_collection_name}") + except Exception as e: + print(f"❌ MongoDB connection failed: {e}") + print("Make sure MongoDB is running and environment variables are set correctly") + raise + + def setUp(self): + """Set up each test case""" + # Clear the test collection before each test + self.persistence.collection.delete_many({}) + + def tearDown(self): + """Clean up after each test""" + # Clear the test collection after each test + self.persistence.collection.delete_many({}) + + @classmethod + def tearDownClass(cls): + """Clean up test environment after all tests""" + # Drop the test collection + try: + cls.persistence.collection.drop() + print(f"✅ Test collection {cls.test_collection_name} dropped successfully") + except Exception as e: + print(f"⚠️ Warning: Could not drop test collection: {e}") + + def test_create_workflow_state(self): + """Test creating a new workflow state with real MongoDB""" + # Create a workflow state + workflow_id = self.persistence.create_workflow_state( + community_id="test-community-123", + query="What is the weather like today?", + chat_id="test-chat-456", + enable_answer_skipping=True, + ) + + # Verify the workflow ID is returned + self.assertIsNotNone(workflow_id) + self.assertIsInstance(workflow_id, str) + + # Verify the document exists in MongoDB + doc = self.persistence.collection.find_one({"_id": ObjectId(workflow_id)}) + self.assertIsNotNone(doc) + + # Verify the document structure + self.assertEqual(doc["communityId"], "test-community-123") + self.assertEqual(doc["question"]["message"], "What is the weather like today?") + self.assertEqual(doc["chatId"], "test-chat-456") + self.assertTrue(doc["enableAnswerSkipping"]) + self.assertEqual(doc["currentStep"], "initialized") + self.assertEqual(doc["status"], "running") + self.assertIn("route", doc) + self.assertIn("response", doc) + self.assertIn("metadata", doc) + self.assertIn("steps", doc) + self.assertIn("createdAt", doc) + self.assertIn("updatedAt", doc) + + def test_create_workflow_state_with_optional_params(self): + """Test creating workflow state with all optional parameters""" + workflow_id = self.persistence.create_workflow_state( + community_id="test-community-full", + query="How do I configure the system?", + source="slack", + destination={"queue": "SLACK_HIVEMIND_ADAPTER", "event": "MESSAGE_RECEIVED"}, + filters={"category": "general", "priority": "high"}, + metadata={"user_id": "123", "channel": "general", "timestamp": "2024-01-01"}, + chat_id="test-chat-full", + enable_answer_skipping=False, + ) + + # Verify the document exists and has correct structure + doc = self.persistence.collection.find_one({"_id": ObjectId(workflow_id)}) + self.assertIsNotNone(doc) + + # Check optional parameters + self.assertEqual(doc["route"]["source"], "slack") + self.assertEqual(doc["route"]["destination"]["queue"], "SLACK_HIVEMIND_ADAPTER") + self.assertEqual(doc["route"]["destination"]["event"], "MESSAGE_RECEIVED") + self.assertEqual(doc["question"]["filters"]["category"], "general") + self.assertEqual(doc["question"]["filters"]["priority"], "high") + self.assertEqual(doc["metadata"]["user_id"], "123") + self.assertEqual(doc["metadata"]["channel"], "general") + self.assertEqual(doc["metadata"]["timestamp"], "2024-01-01") + self.assertFalse(doc["enableAnswerSkipping"]) + + def test_update_workflow_step(self): + """Test updating workflow step with real MongoDB""" + # First create a workflow state + workflow_id = self.persistence.create_workflow_state( + community_id="test-community", + query="Test query", + ) + + # Update with a step + step_data = { + "model": "gpt-4", + "confidence": 0.95, + "reasoning": "This is a test step" + } + + success = self.persistence.update_workflow_step( + workflow_id=workflow_id, + step_name="test_classification", + step_data=step_data, + status="processing" + ) + + self.assertTrue(success) + + # Verify the update in MongoDB + doc = self.persistence.collection.find_one({"_id": ObjectId(workflow_id)}) + self.assertIsNotNone(doc) + self.assertEqual(doc["currentStep"], "test_classification") + self.assertEqual(doc["status"], "processing") + self.assertEqual(len(doc["steps"]), 1) + + # Check the step data + step = doc["steps"][0] + self.assertEqual(step["stepName"], "test_classification") + self.assertEqual(step["data"]["model"], "gpt-4") + self.assertEqual(step["data"]["confidence"], 0.95) + self.assertEqual(step["data"]["reasoning"], "This is a test step") + self.assertIn("timestamp", step) + + def test_update_workflow_step_multiple_steps(self): + """Test updating workflow with multiple steps""" + workflow_id = self.persistence.create_workflow_state( + community_id="test-community", + query="Test query", + ) + + # Add multiple steps + steps = [ + ("initialization", {"status": "started"}), + ("classification", {"model": "local", "result": True}), + ("rag_query", {"sources": ["doc1", "doc2"]}), + ("response_generation", {"model": "gpt-4", "tokens": 150}), + ] + + for step_name, step_data in steps: + success = self.persistence.update_workflow_step( + workflow_id=workflow_id, + step_name=step_name, + step_data=step_data, + ) + self.assertTrue(success) + + # Verify all steps are stored + doc = self.persistence.collection.find_one({"_id": ObjectId(workflow_id)}) + self.assertEqual(len(doc["steps"]), 4) + self.assertEqual(doc["currentStep"], "response_generation") + + # Check each step + step_names = [step["stepName"] for step in doc["steps"]] + self.assertEqual(step_names, ["initialization", "classification", "rag_query", "response_generation"]) + + def test_update_response(self): + """Test updating response with real MongoDB""" + workflow_id = self.persistence.create_workflow_state( + community_id="test-community", + query="What is the answer?", + ) + + # Update the response + response_message = "The answer is 42. This is a comprehensive response that addresses the user's question." + success = self.persistence.update_response( + workflow_id=workflow_id, + response_message=response_message, + status="completed" + ) + + self.assertTrue(success) + + # Verify the response in MongoDB + doc = self.persistence.collection.find_one({"_id": ObjectId(workflow_id)}) + self.assertIsNotNone(doc) + self.assertEqual(doc["response"]["message"], response_message) + self.assertEqual(doc["status"], "completed") + + def test_get_workflow_state(self): + """Test getting workflow state with real MongoDB""" + # Create a workflow state + original_workflow_id = self.persistence.create_workflow_state( + community_id="test-community", + query="Test query for retrieval", + chat_id="test-chat", + enable_answer_skipping=True, + ) + + # Add some steps + self.persistence.update_workflow_step( + workflow_id=original_workflow_id, + step_name="test_step", + step_data={"key": "value"} + ) + + # Retrieve the workflow state + retrieved_doc = self.persistence.get_workflow_state(original_workflow_id) + + # Verify the retrieved document + self.assertIsNotNone(retrieved_doc) + self.assertEqual(retrieved_doc["_id"], original_workflow_id) + self.assertEqual(retrieved_doc["communityId"], "test-community") + self.assertEqual(retrieved_doc["question"]["message"], "Test query for retrieval") + self.assertEqual(retrieved_doc["chatId"], "test-chat") + self.assertTrue(retrieved_doc["enableAnswerSkipping"]) + self.assertEqual(len(retrieved_doc["steps"]), 1) + self.assertEqual(retrieved_doc["steps"][0]["stepName"], "test_step") + + def test_get_workflow_state_not_found(self): + """Test getting workflow state that doesn't exist""" + # Try to get a non-existent workflow + fake_id = "507f1f77bcf86cd799439011" # Valid ObjectId format but doesn't exist + result = self.persistence.get_workflow_state(fake_id) + self.assertIsNone(result) + + def test_complete_workflow_lifecycle(self): + """Test a complete workflow lifecycle from creation to completion""" + # 1. Create workflow state + workflow_id = self.persistence.create_workflow_state( + community_id="test-community-lifecycle", + query="How do I deploy the application?", + source="discord", + destination={"queue": "DISCORD_ADAPTER", "event": "QUESTION_RECEIVED"}, + metadata={"user": "testuser", "channel": "deployment"}, + chat_id="lifecycle-chat", + enable_answer_skipping=False, + ) + + # 2. Add classification step + self.persistence.update_workflow_step( + workflow_id=workflow_id, + step_name="question_classification", + step_data={ + "model": "local_transformer", + "result": True, + "confidence": 0.92, + "reasoning": "This is a deployment-related question" + } + ) + + # 3. Add RAG query step + self.persistence.update_workflow_step( + workflow_id=workflow_id, + step_name="rag_query", + step_data={ + "sources": ["deployment_guide.md", "troubleshooting.md"], + "query": "How do I deploy the application?", + "results_count": 5 + } + ) + + # 4. Add response generation step + self.persistence.update_workflow_step( + workflow_id=workflow_id, + step_name="response_generation", + step_data={ + "model": "gpt-4", + "tokens_used": 245, + "generation_time": 2.3 + } + ) + + # 5. Update with final response + final_response = "To deploy the application, follow these steps: 1. Build the project, 2. Run tests, 3. Deploy to staging, 4. Deploy to production." + self.persistence.update_response( + workflow_id=workflow_id, + response_message=final_response, + status="completed" + ) + + # 6. Verify the complete workflow state + final_doc = self.persistence.get_workflow_state(workflow_id) + self.assertIsNotNone(final_doc) + self.assertEqual(final_doc["status"], "completed") + self.assertEqual(final_doc["response"]["message"], final_response) + self.assertEqual(len(final_doc["steps"]), 3) + self.assertEqual(final_doc["currentStep"], "response_generation") + + # Verify all the data is preserved + self.assertEqual(final_doc["communityId"], "test-community-lifecycle") + self.assertEqual(final_doc["route"]["source"], "discord") + self.assertEqual(final_doc["metadata"]["user"], "testuser") + self.assertFalse(final_doc["enableAnswerSkipping"]) + + def test_error_handling_invalid_object_id(self): + """Test error handling with invalid ObjectId""" + # Test with invalid ObjectId format + invalid_id = "invalid-id-format" + + # These should handle the error gracefully + result = self.persistence.get_workflow_state(invalid_id) + self.assertIsNone(result) + + # Update operations should return False for invalid IDs + success = self.persistence.update_workflow_step( + workflow_id=invalid_id, + step_name="test", + step_data={} + ) + self.assertFalse(success) + + success = self.persistence.update_response( + workflow_id=invalid_id, + response_message="test" + ) + self.assertFalse(success) diff --git a/tests/unit/test_mongo_persistence.py b/tests/unit/test_mongo_persistence.py deleted file mode 100644 index eaba857..0000000 --- a/tests/unit/test_mongo_persistence.py +++ /dev/null @@ -1,139 +0,0 @@ -import unittest -from unittest.mock import patch, MagicMock -from dotenv import load_dotenv -from bson import ObjectId -from tasks.mongo_persistence import MongoPersistence - -class TestMongoPersistence(unittest.TestCase): - """Test cases for the MongoPersistence class""" - - def setUp(self): - """Set up test environment""" - load_dotenv() - - # Mock the MongoDB client and collection - self.collection_mock = MagicMock() - self.db_mock = MagicMock() - self.db_mock.get_collection.return_value = self.collection_mock - - self.client_mock = MagicMock() - self.client_mock.get_database.return_value = self.db_mock - - self.mongo_patcher = patch("pymongo.MongoClient", return_value=self.client_mock) - self.mongo_mock = self.mongo_patcher.start() - - # Create instance of MongoPersistence with mocked dependencies - self.persistence = MongoPersistence() - - def tearDown(self): - """Clean up after tests""" - self.env_patcher.stop() - self.mongo_patcher.stop() - - def test_create_workflow_state(self): - """Test creating a new workflow state""" - # Mock the insert_one result - mock_result = MagicMock() - mock_result.inserted_id = ObjectId("507f1f77bcf86cd799439011") - self.collection_mock.insert_one.return_value = mock_result - - workflow_id = self.persistence.create_workflow_state( - community_id="test-community", - query="test query", - chat_id="test-chat", - enable_answer_skipping=True, - ) - - self.assertEqual(workflow_id, "507f1f77bcf86cd799439011") - self.collection_mock.insert_one.assert_called_once() - - # Check the inserted document structure - inserted_doc = self.collection_mock.insert_one.call_args[0][0] - self.assertEqual(inserted_doc["communityId"], "test-community") - self.assertEqual(inserted_doc["question"]["message"], "test query") - self.assertEqual(inserted_doc["chatId"], "test-chat") - self.assertTrue(inserted_doc["enableAnswerSkipping"]) - self.assertEqual(inserted_doc["currentStep"], "initialized") - self.assertEqual(inserted_doc["status"], "running") - self.assertIn("route", inserted_doc) - self.assertIn("response", inserted_doc) - self.assertIn("metadata", inserted_doc) - self.assertIn("steps", inserted_doc) - - def test_update_workflow_step(self): - """Test updating workflow step""" - # Mock the update_one result - mock_result = MagicMock() - mock_result.modified_count = 1 - self.collection_mock.update_one.return_value = mock_result - - success = self.persistence.update_workflow_step( - workflow_id="507f1f77bcf86cd799439011", - step_name="test_step", - step_data={"key": "value"}, - status="running", - ) - - self.assertTrue(success) - self.collection_mock.update_one.assert_called_once() - - # Check the update operation - call_args = self.collection_mock.update_one.call_args - self.assertEqual(call_args[0][0], {"_id": ObjectId("507f1f77bcf86cd799439011")}) - - update_data = call_args[0][1] - self.assertIn("$push", update_data) - self.assertIn("$set", update_data) - self.assertEqual(update_data["$set"]["currentStep"], "test_step") - - def test_update_response(self): - """Test updating response""" - # Mock the update_one result - mock_result = MagicMock() - mock_result.modified_count = 1 - self.collection_mock.update_one.return_value = mock_result - - success = self.persistence.update_response( - workflow_id="507f1f77bcf86cd799439011", - response_message="Test answer", - status="completed", - ) - - self.assertTrue(success) - self.collection_mock.update_one.assert_called_once() - - # Check the update operation - call_args = self.collection_mock.update_one.call_args - self.assertEqual(call_args[0][0], {"_id": ObjectId("507f1f77bcf86cd799439011")}) - - update_data = call_args[0][1] - self.assertIn("$set", update_data) - self.assertEqual(update_data["$set"]["response.message"], "Test answer") - self.assertEqual(update_data["$set"]["status"], "completed") - - def test_get_workflow_state(self): - """Test getting workflow state""" - # Mock the find_one result - mock_doc = { - "_id": ObjectId("507f1f77bcf86cd799439011"), - "communityId": "test-community", - "question": {"message": "test query"}, - "status": "completed", - } - self.collection_mock.find_one.return_value = mock_doc - - result = self.persistence.get_workflow_state("507f1f77bcf86cd799439011") - - self.assertIsNotNone(result) - self.assertEqual(result["_id"], "507f1f77bcf86cd799439011") - self.assertEqual(result["communityId"], "test-community") - self.assertEqual(result["question"]["message"], "test query") - self.assertEqual(result["status"], "completed") - - def test_get_workflow_state_not_found(self): - """Test getting workflow state that doesn't exist""" - self.collection_mock.find_one.return_value = None - - result = self.persistence.get_workflow_state("507f1f77bcf86cd799439011") - - self.assertIsNone(result)