diff --git a/dimos/perception/experimental/temporal_memory/README.md b/dimos/perception/experimental/temporal_memory/README.md index 9ef5f6cb22..291de546e3 100644 --- a/dimos/perception/experimental/temporal_memory/README.md +++ b/dimos/perception/experimental/temporal_memory/README.md @@ -30,3 +30,6 @@ Notes - Evidence is extracted in sliding windows, so queries can refer to recent or past entities. - Distance estimation can run in the background to enrich graph relations. - If you want a different output directory, set `TemporalMemoryConfig(output_dir=...)`. + +To visualize, run +` ` and open in `localhost:8080` in your browser. diff --git a/dimos/perception/experimental/temporal_memory/entity_graph_db.py b/dimos/perception/experimental/temporal_memory/entity_graph_db.py index 7109459f40..c099b8bf4b 100644 --- a/dimos/perception/experimental/temporal_memory/entity_graph_db.py +++ b/dimos/perception/experimental/temporal_memory/entity_graph_db.py @@ -227,6 +227,61 @@ def get_entity(self, entity_id: str) -> dict[str, Any] | None: "metadata": json.loads(row["metadata"]) if row["metadata"] else None, } + def update_entity( + self, + entity_id: str, + descriptor: str | None = None, + metadata: dict[str, Any] | None = None, + ) -> bool: + """ + Update an entity's descriptor and/or metadata. + + Args: + entity_id: Entity ID to update + descriptor: New descriptor (optional) + metadata: New metadata dict (optional, will merge with existing) + + Returns: + True if entity was updated, False if not found + """ + conn = self._get_connection() + cursor = conn.cursor() + + # Get existing entity + cursor.execute("SELECT metadata FROM entities WHERE entity_id = ?", (entity_id,)) + row = cursor.fetchone() + if row is None: + return False + + # Merge metadata if provided + existing_metadata = json.loads(row["metadata"]) if row["metadata"] else {} + if metadata: + existing_metadata.update(metadata) + + # Update descriptor and/or metadata + updates = [] + params: list[Any] = [] + + if descriptor is not None: + updates.append("descriptor = ?") + params.append(descriptor) + + if metadata is not None: + updates.append("metadata = ?") + params.append(json.dumps(existing_metadata)) + + if not updates: + return True # Nothing to update + + params.append(entity_id) + cursor.execute( + f"UPDATE entities SET {', '.join(updates)} WHERE entity_id = ?", + params, + ) + conn.commit() + logger.debug(f"Updated entity {entity_id}") + return True + def get_all_entities(self, entity_type: str | None = None) -> list[dict[str, Any]]: """Get all entities, optionally filtered by type.""" conn = self._get_connection() diff --git a/dimos/perception/experimental/temporal_memory/graph_viz_server.py b/dimos/perception/experimental/temporal_memory/graph_viz_server.py new file mode 100644 index 0000000000..88bc5a9220 --- /dev/null +++ b/dimos/perception/experimental/temporal_memory/graph_viz_server.py @@ -0,0 +1,688 @@ +#!/usr/bin/env python3 +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Real-time graph database visualization server. + +Usage: + python -m dimos.perception.experimental.temporal_memory.graph_viz_server + +Then open http://localhost:8080 in your browser. +""" + +from pathlib import Path +import sys +from threading import Lock +import time +from typing import Any + +from flask import Flask, jsonify, render_template_string + +from dimos.perception.experimental.temporal_memory.entity_graph_db import EntityGraphDB + +app = Flask(__name__) + +_db: EntityGraphDB | None = None +_db_path: Path | None = None +_output_dir: Path | None = None +_db_lock = Lock() +_last_update = 0.0 + +HTML_TEMPLATE = """ + + + + Temporal Memory Graph Visualization + + + + + +
+
+ +
+ + + + +""" + + +@app.route("/") +def index() -> str: + return render_template_string(HTML_TEMPLATE) + + +def _try_init_db() -> bool: + """Try to initialize the database if the file exists.""" + global _db, _db_path + + with _db_lock: + if _db is not None: + return True + + if _db_path is None or not _db_path.exists(): + return False + + try: + _db = EntityGraphDB(db_path=_db_path) + return True + except Exception as e: + print(f"Warning: Failed to initialize database: {e}") + return False + + +@app.route("/api/graph") +def get_graph() -> Any: + """Get current graph state.""" + global _last_update + + # Try to initialize DB if not already initialized + if not _try_init_db(): + # Return empty data when waiting for DB + return jsonify( + { + "stats": {"entities": 0, "relations": 0, "distances": 0}, + "entities": [], + "relations": [], + "distances": [], + "waiting": True, + } + ) + + with _db_lock: + if not _db: + return jsonify( + { + "stats": {"entities": 0, "relations": 0, "distances": 0}, + "entities": [], + "relations": [], + "distances": [], + "waiting": True, + } + ) + + stats = _db.get_stats() + entities = _db.get_all_entities() + recent_relations = _db.get_recent_relations(limit=100) + + # Get all distances (latest per pair) + distances = [] + entity_ids = [e["entity_id"] for e in entities] + for i, e1 in enumerate(entity_ids): + for e2 in entity_ids[i + 1 :]: + dist = _db.get_distance(e1, e2) + if dist: + distances.append(dist) + + _last_update = time.time() + + return jsonify( + { + "stats": stats, + "entities": entities, + "relations": recent_relations, + "distances": distances, + "waiting": False, + } + ) + + +def main() -> None: + """Run the visualization server.""" + global _db_path, _output_dir + + if len(sys.argv) < 2: + print( + "Usage: python -m dimos.perception.experimental.temporal_memory.graph_viz_server " + ) + print( + "Example: python -m dimos.perception.experimental.temporal_memory.graph_viz_server assets/temporal_memory/entity_graph.db" + ) + sys.exit(1) + + db_path = Path(sys.argv[1]) + _db_path = db_path + # Infer output_dir from db_path (db is in output_dir/entity_graph.db) + _output_dir = db_path.parent + + # Try to initialize DB if file exists, but don't fail if it doesn't + if db_path.exists(): + if _try_init_db(): + print(f"✅ Database loaded: {db_path}") + else: + print(f"⚠️ Database file exists but couldn't be opened: {db_path}") + else: + print(f"⏳ Waiting for database file: {db_path}") + print(" (The server will start and wait for the file to appear)") + + print("🚀 Graph visualization server starting...") + print(f"📊 Database path: {db_path}") + print("🌐 Open http://localhost:8080 in your browser") + print("Press Ctrl+C to stop") + + app.run(host="127.0.0.1", port=8080, debug=False, threaded=True) + + +if __name__ == "__main__": + main() diff --git a/dimos/perception/experimental/temporal_memory/temporal_memory.py b/dimos/perception/experimental/temporal_memory/temporal_memory.py index 29d4ecf3d9..28b7bdb6fd 100644 --- a/dimos/perception/experimental/temporal_memory/temporal_memory.py +++ b/dimos/perception/experimental/temporal_memory/temporal_memory.py @@ -25,17 +25,24 @@ import json import os from pathlib import Path +import re import threading import time from typing import Any from reactivex import Subject, interval from reactivex.disposable import Disposable +import rerun as rr +import rerun.blueprint as rrb from dimos.agents import skill from dimos.core import In, rpc + +# Add these imports near the top with other imports +from dimos.core.global_config import GlobalConfig from dimos.core.module import ModuleConfig from dimos.core.skill_module import SkillModule +from dimos.dashboard.rerun_init import connect_rerun from dimos.models.vl.base import VlModel from dimos.msgs.sensor_msgs import Image from dimos.msgs.sensor_msgs.Image import sharpness_barrier @@ -112,12 +119,16 @@ class TemporalMemory(SkillModule): color_image: In[Image] def __init__( - self, vlm: VlModel | None = None, config: TemporalMemoryConfig | None = None + self, + vlm: VlModel | None = None, + config: TemporalMemoryConfig | None = None, + global_config: GlobalConfig | None = None, ) -> None: super().__init__() self._vlm = vlm # Can be None for blueprint usage self.config: TemporalMemoryConfig = config or TemporalMemoryConfig() + self._global_config = global_config # Store it # single lock protects all state self._state_lock = threading.Lock() @@ -202,12 +213,23 @@ def vlm(self) -> VlModel: def start(self) -> None: super().start() + # Connect to Rerun if backend is Rerun + if self._global_config and self._global_config.viewer_backend.startswith("rerun"): + connect_rerun(global_config=self._global_config) + with self._state_lock: self._stopped = False if self._video_start_wall_time is None: self._video_start_wall_time = time.time() def on_frame(image: Image) -> None: + # Log image to Rerun if enabled + if self._global_config and self._global_config.viewer_backend.startswith("rerun"): + try: + rr.log("world/temporal_memory/camera/rgb", image.to_rerun()) + except Exception as e: + logger.debug(f"Failed to log image to Rerun: {e}") + with self._state_lock: video_start = self._video_start_wall_time if video_start is None: @@ -539,7 +561,25 @@ def query(self, question: str) -> str: # query vlm (slow, outside lock) try: answer_text = self.vlm.query(latest_frame, prompt) - return answer_text.strip() + answer_text = answer_text.strip() + + # Check for rename commands in the response + rename_pattern = r'RENAME_ENTITY:\s*entity_id="([^"]+)"\s+new_name="([^"]+)"' + matches = re.findall(rename_pattern, answer_text) + + if matches: + # Execute renames + for entity_id, new_name in matches: + success = self.rename_entity(entity_id=entity_id, new_name=new_name) + if success: + logger.info(f"Renamed entity {entity_id} to '{new_name}' via query") + else: + logger.warning(f"Failed to rename entity {entity_id} to '{new_name}'") + + # Remove rename commands from response + answer_text = re.sub(rename_pattern, "", answer_text).strip() + + return answer_text except Exception as e: logger.error(f"query failed: {e}", exc_info=True) return f"error: {e}" @@ -591,6 +631,53 @@ def get_graph_db_stats(self) -> dict[str, Any]: return {"stats": {}, "entities": [], "recent_relations": []} return self._graph_db.get_summary() + @rpc + def rename_entity( + self, entity_id: str, new_name: str | None = None, new_descriptor: str | None = None + ) -> bool: + """Rename or update an entity's descriptor based on human input. + + Args: + entity_id: Entity ID to rename (e.g., "E8") + new_name: Optional name to store in metadata (e.g., "stash") + new_descriptor: Optional new descriptor (e.g., "stash (person wearing brown jacket)") + + Returns: + True if entity was updated, False if not found + """ + if not self._graph_db: + return False + + metadata = {} + if new_name: + metadata["name"] = new_name + # If no descriptor provided, update it to include the name + if new_descriptor is None: + entity = self._graph_db.get_entity(entity_id) + if entity: + old_desc = entity.get("descriptor", "") + new_descriptor = f"{new_name} ({old_desc})" + + success = self._graph_db.update_entity( + entity_id=entity_id, + descriptor=new_descriptor, + metadata=metadata if metadata else None, + ) + + # Also update the entity roster in state + if success: + with self._state_lock: + roster = self._state.get("entity_roster", []) + for entity in roster: + if entity.get("id") == entity_id: + if new_descriptor: + entity["descriptor"] = new_descriptor + if new_name: + entity["name"] = new_name + break + + return success + @rpc def save_state(self) -> bool: if not self.config.output_dir: @@ -659,6 +746,16 @@ def save_frames_index(self) -> bool: logger.error(f"save frames failed: {e}", exc_info=True) return False + @classmethod + def rerun_views(cls) -> list[Any]: # type: ignore[no-untyped-def] + """Return Rerun view blueprints for temporal memory camera visualization.""" + return [ + rrb.Spatial2DView( + name="Temporal Memory Camera", + origin="world/temporal_memory/camera/rgb", + ), + ] + temporal_memory = TemporalMemory.blueprint diff --git a/dimos/perception/experimental/temporal_memory/temporal_utils/prompts.py b/dimos/perception/experimental/temporal_memory/temporal_utils/prompts.py index 5269a3d67d..ce19ba91da 100644 --- a/dimos/perception/experimental/temporal_memory/temporal_utils/prompts.py +++ b/dimos/perception/experimental/temporal_memory/temporal_utils/prompts.py @@ -250,6 +250,24 @@ def build_query_prompt( - If the context says entities were present but you don't see them in the current frame, mention both: what was recently detected AND what you currently see - For duration questions, use the 'duration_s' field from 'entity_timestamps' if available +**Entity Renaming:** +If the user asks to rename an entity (e.g., "rename the person in the brown jacket to john" or "call that person john"), you should: +1. Identify which entity they're referring to based on the descriptor +2. If there are multiple possible matches, list them and ask for clarification (DO NOT rename if ambiguous) +3. If there's exactly one clear match, output a rename command in this format at the END of your response: + +RENAME_ENTITY: entity_id="E1" new_name="stash" + +Examples: +- User: "rename the person in the brown jacket to john" + - If E8 is "person wearing brown jacket" → answer normally, then add: RENAME_ENTITY: entity_id="E8" new_name="john", so it'll be "john, person in brown jacket" instead of "unknown entity..." + - If both E8 and E9 are wearing brown jackets → respond: "I found multiple entities wearing brown jackets: E8 (person in light brown jacket) and E9 (person in dark brown jacket). Which one did you mean?" + +- User: "call that laptop 'work computer'" + - If E5 is "silver laptop" → answer normally, then add: RENAME_ENTITY: entity_id="E5" new_name="work computer" + +**Important:** Only output RENAME_ENTITY if you're certain which entity the user means. When in doubt, ask for clarification and list the possible matches. + Provide a concise answer. """ return prompt diff --git a/dimos/perception/experimental/temporal_memory/video_temporal_example.py b/dimos/perception/experimental/temporal_memory/video_temporal_example.py new file mode 100644 index 0000000000..d15cba7663 --- /dev/null +++ b/dimos/perception/experimental/temporal_memory/video_temporal_example.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python3 +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example usage of TemporalMemory module with a VLM. + +This example demonstrates how to: +1. Deploy a camera module +2. Deploy TemporalMemory with the camera +3. Query the temporal memory about entities and events +""" + +from pathlib import Path +import sys +import threading +import time +from typing import Any + +import cv2 +from dotenv import load_dotenv +from flask import Flask, jsonify, request +import numpy as np +from numpy.typing import NDArray + +from dimos import core +from dimos.core import Module, Out, rpc +from dimos.msgs.sensor_msgs import Image, ImageFormat +from dimos.perception.experimental.temporal_memory import TemporalMemoryConfig +from dimos.perception.experimental.temporal_memory.temporal_memory_deploy import deploy +from dimos.stream.video_provider import VideoProvider + +# Load environment variables +load_dotenv() + +# Flask app for query endpoint +app = Flask(__name__) +_temporal_memory_ref = None + + +@app.route("/api/query", methods=["POST"]) +def query_endpoint() -> Any: + """Query endpoint for the running TemporalMemory.""" + global _temporal_memory_ref + if _temporal_memory_ref is None: + return jsonify({"error": "TemporalMemory not initialized"}), 503 + + data = request.get_json() + if not data or "question" not in data: + return jsonify({"error": "Missing 'question' field"}), 400 + + try: + answer = _temporal_memory_ref.query(data["question"]) + return jsonify({"answer": answer, "question": data["question"]}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +def start_query_server() -> None: + """Start Flask server in background thread.""" + app.run(host="127.0.0.1", port=8081, debug=False, threaded=True) + + +# Simple video file module +class VideoFileModule(Module): + color_image: Out[Image] = None # type: ignore[assignment] + + def __init__(self, video_path: str): + super().__init__() + self.video_provider = VideoProvider(dev_name="mp4", video_source=video_path) + + @rpc + def start(self) -> None: + def on_frame(frame: NDArray[Any]) -> None: + img = Image.from_numpy(frame, format=ImageFormat.BGR) + self.color_image.publish(img) + + self._disposables.add( + self.video_provider.capture_video_as_observable(realtime=True).subscribe(on_frame) + ) + + @rpc + def stop(self) -> None: + """Stop the video provider.""" + super().stop() + + +def example_usage() -> None: + """Example of how to use TemporalMemory with a video file.""" + global _temporal_memory_ref + # Initialize variables to None for cleanup + temporal_memory = None + camera = None + dimos = None + + try: + # Create Dimos cluster + dimos = core.start(1) + + # Get video path from command line or use default + if len(sys.argv) > 1: + video_path = sys.argv[1] + else: + video_path = "assets/simple_demo.mp4" + + if not Path(video_path).exists(): + print(f"Error: Video file not found: {video_path}") + sys.exit(1) + + # Deploy video file module + camera = dimos.deploy(VideoFileModule, video_path=video_path) # type: ignore[attr-defined] + camera.start() + + # Deploy temporal memory using the deploy function + output_dir = Path("./temporal_memory_output") + temporal_memory = deploy( + dimos, + camera, + vlm=None, # Will auto-create OpenAIVlModel if None + config=TemporalMemoryConfig( + fps=1.0, # Process 1 frame per second + window_s=2.0, # Analyze 2-second windows + stride_s=2.0, # New window every 2 seconds + summary_interval_s=10.0, # Update rolling summary every 10 seconds + max_frames_per_window=3, # Max 3 frames per window + output_dir=output_dir, + ), + ) + + # Store reference for query endpoint + _temporal_memory_ref = temporal_memory + + # Start query server in background + server_thread = threading.Thread(target=start_query_server, daemon=True) + server_thread.start() + print("✅ Query server started on http://127.0.0.1:8081/api/query") + + print("TemporalMemory deployed and started!") + print(f"Artifacts will be saved to: {output_dir}") + + # Calculate video duration and wait for full video to process + + cap = cv2.VideoCapture(video_path) + fps = cap.get(cv2.CAP_PROP_FPS) + frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT) + video_duration = frame_count / fps if fps > 0 else 0 + cap.release() + + if video_duration > 0: + print( + f"Video duration: {video_duration:.1f} seconds ({frame_count:.0f} frames @ {fps:.1f} fps)" + ) + print(f"Processing video... (this will take ~{video_duration:.1f} seconds)") + # Wait for video duration + a small buffer for processing + time.sleep(video_duration + 5) + else: + print("Could not determine video duration, waiting 30 seconds...") + time.sleep(30) + + # Query the temporal memory + questions = [ + "Are there any people in the scene?", + "Describe the main activity happening now", + "What has happened in the last few seconds?", + "What entities are currently visible?", + ] + + for question in questions: + print(f"\nQuestion: {question}") + answer = temporal_memory.query(question) + print(f"Answer: {answer}") + + # Get current state + state = temporal_memory.get_state() + print("\n=== Current State ===") + print(f"Entity count: {state['entity_count']}") + print(f"Frame count: {state['frame_count']}") + print(f"Rolling summary: {state['rolling_summary']}") + print(f"Entities: {state['entities']}") + + # Get entity roster + entities = temporal_memory.get_entity_roster() + print("\n=== Entity Roster ===") + for entity in entities: + print(f" {entity['id']}: {entity['descriptor']}") + + # Check graph database stats + graph_stats = temporal_memory.get_graph_db_stats() + print("\n=== Graph Database Stats ===") + if "error" in graph_stats: + print(f"Error: {graph_stats['error']}") + else: + print(f"Stats: {graph_stats['stats']}") + print(f"\nEntities in DB ({len(graph_stats['entities'])}):") + for entity in graph_stats["entities"]: + print(f" {entity['entity_id']} ({entity['entity_type']}): {entity['descriptor']}") + print(f"\nRecent relations ({len(graph_stats['recent_relations'])}):") + for rel in graph_stats["recent_relations"]: + print( + f" {rel['subject_id']} --{rel['relation_type']}--> {rel['object_id']} (confidence: {rel['confidence']:.2f})" + ) + + # Stop when done + print("\nStopping TemporalMemory...") + temporal_memory.stop() + camera.stop() + print("TemporalMemory stopped") + + finally: + if temporal_memory is not None: + temporal_memory.stop() + if camera is not None: + camera.stop() + if dimos is not None: + dimos.close_all() # type: ignore[attr-defined] + + +if __name__ == "__main__": + example_usage() diff --git a/dimos/protocol/skill/utils.py b/dimos/protocol/skill/utils.py index 278134c525..9984d8f033 100644 --- a/dimos/protocol/skill/utils.py +++ b/dimos/protocol/skill/utils.py @@ -32,6 +32,14 @@ def interpret_tool_call_args( return args["args"], args["kwargs"] if args.keys() == {"kwargs"}: return [], args["kwargs"] + + # Check if all keys are numeric strings (e.g., {'0': 'value', '1': 'value2'}) + # This happens when the agent returns positional args as a dict with index keys + if args and all(key.isdigit() for key in args.keys()): + # Convert to positional args list, sorted by index + sorted_items = sorted(args.items(), key=lambda x: int(x[0])) + return [v for _, v in sorted_items], {} + if args.keys() != {"args"}: return [], args diff --git a/dimos/robot/unitree_webrtc/unitree_g1_blueprints.py b/dimos/robot/unitree_webrtc/unitree_g1_blueprints.py index c79cee2a18..d5b7331318 100644 --- a/dimos/robot/unitree_webrtc/unitree_g1_blueprints.py +++ b/dimos/robot/unitree_webrtc/unitree_g1_blueprints.py @@ -54,6 +54,7 @@ from dimos.perception.detection.module3D import Detection3DModule, detection3d_module from dimos.perception.detection.moduleDB import ObjectDBModule, detectionDB_module from dimos.perception.detection.person_tracker import PersonTracker, person_tracker_module +from dimos.perception.experimental.temporal_memory import TemporalMemoryConfig, temporal_memory from dimos.perception.object_tracker import object_tracking from dimos.perception.spatial_perception import spatial_memory from dimos.robot.foxglove_bridge import foxglove_bridge @@ -265,4 +266,9 @@ standard_with_shm, _agentic_skills, keyboard_teleop(), + temporal_memory( + config=TemporalMemoryConfig( + clear_memory_on_start=True, + ), + ), )