diff --git a/internal/db/sqlite.go b/internal/db/sqlite.go index 0168170..4a471a4 100644 --- a/internal/db/sqlite.go +++ b/internal/db/sqlite.go @@ -242,6 +242,10 @@ func (db *DB) migrate() error { `ALTER TABLE tasks ADD COLUMN executor TEXT DEFAULT 'claude'`, // Task executor: "claude" (default), "codex" // Tmux window ID for unique window identification (avoids duplicate window issues) `ALTER TABLE tasks ADD COLUMN tmux_window_id TEXT DEFAULT ''`, // tmux window ID (e.g., "@1234") + // Distilled task summary for search indexing and context + `ALTER TABLE tasks ADD COLUMN summary TEXT DEFAULT ''`, // Distilled summary of what was accomplished + // Last distillation timestamp for tracking when to re-distill + `ALTER TABLE tasks ADD COLUMN last_distilled_at DATETIME`, // When task was last distilled // Tmux pane IDs for deterministic pane identification (avoids index-based guessing) `ALTER TABLE tasks ADD COLUMN claude_pane_id TEXT DEFAULT ''`, // tmux pane ID for Claude/executor pane (e.g., "%1234") `ALTER TABLE tasks ADD COLUMN shell_pane_id TEXT DEFAULT ''`, // tmux pane ID for shell pane (e.g., "%1235") diff --git a/internal/db/tasks.go b/internal/db/tasks.go index 8c84bfb..242e67f 100644 --- a/internal/db/tasks.go +++ b/internal/db/tasks.go @@ -31,6 +31,7 @@ type Task struct { DangerousMode bool // Whether task is running in dangerous mode (--dangerously-skip-permissions) Pinned bool // Whether the task is pinned to the top of its column Tags string // Comma-separated tags for categorization (e.g., "customer-support,email,influence-kit") + Summary string // Distilled summary of what was accomplished (for search and context) CreatedAt LocalTime UpdatedAt LocalTime StartedAt *LocalTime @@ -39,6 +40,8 @@ type Task struct { ScheduledAt *LocalTime // When to next run (nil = not scheduled) Recurrence string // Deprecated: no longer used (kept for backward compatibility) LastRunAt *LocalTime // When last executed (for scheduled tasks) + // Distillation tracking + LastDistilledAt *LocalTime // When task was last distilled for learnings } // Task statuses @@ -162,18 +165,18 @@ func (db *DB) GetTask(id int64) (*Task, error) { COALESCE(daemon_session, ''), COALESCE(tmux_window_id, ''), COALESCE(claude_pane_id, ''), COALESCE(shell_pane_id, ''), COALESCE(pr_url, ''), COALESCE(pr_number, 0), - COALESCE(dangerous_mode, 0), COALESCE(pinned, 0), COALESCE(tags, ''), + COALESCE(dangerous_mode, 0), COALESCE(pinned, 0), COALESCE(tags, ''), COALESCE(summary, ''), created_at, updated_at, started_at, completed_at, - scheduled_at, recurrence, last_run_at + scheduled_at, recurrence, last_run_at, last_distilled_at FROM tasks WHERE id = ? `, id).Scan( &t.ID, &t.Title, &t.Body, &t.Status, &t.Type, &t.Project, &t.Executor, &t.WorktreePath, &t.BranchName, &t.Port, &t.ClaudeSessionID, &t.DaemonSession, &t.TmuxWindowID, &t.ClaudePaneID, &t.ShellPaneID, &t.PRURL, &t.PRNumber, - &t.DangerousMode, &t.Pinned, &t.Tags, + &t.DangerousMode, &t.Pinned, &t.Tags, &t.Summary, &t.CreatedAt, &t.UpdatedAt, &t.StartedAt, &t.CompletedAt, - &t.ScheduledAt, &t.Recurrence, &t.LastRunAt, + &t.ScheduledAt, &t.Recurrence, &t.LastRunAt, &t.LastDistilledAt, ) if err == sql.ErrNoRows { return nil, nil @@ -202,9 +205,9 @@ func (db *DB) ListTasks(opts ListTasksOptions) ([]*Task, error) { COALESCE(daemon_session, ''), COALESCE(tmux_window_id, ''), COALESCE(claude_pane_id, ''), COALESCE(shell_pane_id, ''), COALESCE(pr_url, ''), COALESCE(pr_number, 0), - COALESCE(dangerous_mode, 0), COALESCE(pinned, 0), COALESCE(tags, ''), + COALESCE(dangerous_mode, 0), COALESCE(pinned, 0), COALESCE(tags, ''), COALESCE(summary, ''), created_at, updated_at, started_at, completed_at, - scheduled_at, recurrence, last_run_at + scheduled_at, recurrence, last_run_at, last_distilled_at FROM tasks WHERE 1=1 ` args := []interface{}{} @@ -254,9 +257,9 @@ func (db *DB) ListTasks(opts ListTasksOptions) ([]*Task, error) { &t.WorktreePath, &t.BranchName, &t.Port, &t.ClaudeSessionID, &t.DaemonSession, &t.TmuxWindowID, &t.ClaudePaneID, &t.ShellPaneID, &t.PRURL, &t.PRNumber, - &t.DangerousMode, &t.Pinned, &t.Tags, + &t.DangerousMode, &t.Pinned, &t.Tags, &t.Summary, &t.CreatedAt, &t.UpdatedAt, &t.StartedAt, &t.CompletedAt, - &t.ScheduledAt, &t.Recurrence, &t.LastRunAt, + &t.ScheduledAt, &t.Recurrence, &t.LastRunAt, &t.LastDistilledAt, ) if err != nil { return nil, fmt.Errorf("scan task: %w", err) @@ -277,9 +280,9 @@ func (db *DB) GetMostRecentlyCreatedTask() (*Task, error) { COALESCE(daemon_session, ''), COALESCE(tmux_window_id, ''), COALESCE(claude_pane_id, ''), COALESCE(shell_pane_id, ''), COALESCE(pr_url, ''), COALESCE(pr_number, 0), - COALESCE(dangerous_mode, 0), COALESCE(pinned, 0), COALESCE(tags, ''), + COALESCE(dangerous_mode, 0), COALESCE(pinned, 0), COALESCE(tags, ''), COALESCE(summary, ''), created_at, updated_at, started_at, completed_at, - scheduled_at, recurrence, last_run_at + scheduled_at, recurrence, last_run_at, last_distilled_at FROM tasks ORDER BY created_at DESC, id DESC LIMIT 1 @@ -288,9 +291,9 @@ func (db *DB) GetMostRecentlyCreatedTask() (*Task, error) { &t.WorktreePath, &t.BranchName, &t.Port, &t.ClaudeSessionID, &t.DaemonSession, &t.TmuxWindowID, &t.ClaudePaneID, &t.ShellPaneID, &t.PRURL, &t.PRNumber, - &t.DangerousMode, &t.Pinned, &t.Tags, + &t.DangerousMode, &t.Pinned, &t.Tags, &t.Summary, &t.CreatedAt, &t.UpdatedAt, &t.StartedAt, &t.CompletedAt, - &t.ScheduledAt, &t.Recurrence, &t.LastRunAt, + &t.ScheduledAt, &t.Recurrence, &t.LastRunAt, &t.LastDistilledAt, ) if err == sql.ErrNoRows { return nil, nil @@ -315,9 +318,9 @@ func (db *DB) SearchTasks(query string, limit int) ([]*Task, error) { COALESCE(daemon_session, ''), COALESCE(tmux_window_id, ''), COALESCE(claude_pane_id, ''), COALESCE(shell_pane_id, ''), COALESCE(pr_url, ''), COALESCE(pr_number, 0), - COALESCE(dangerous_mode, 0), COALESCE(pinned, 0), COALESCE(tags, ''), + COALESCE(dangerous_mode, 0), COALESCE(pinned, 0), COALESCE(tags, ''), COALESCE(summary, ''), created_at, updated_at, started_at, completed_at, - scheduled_at, recurrence, last_run_at + scheduled_at, recurrence, last_run_at, last_distilled_at FROM tasks WHERE ( title LIKE ? COLLATE NOCASE @@ -345,9 +348,9 @@ func (db *DB) SearchTasks(query string, limit int) ([]*Task, error) { &t.WorktreePath, &t.BranchName, &t.Port, &t.ClaudeSessionID, &t.DaemonSession, &t.TmuxWindowID, &t.ClaudePaneID, &t.ShellPaneID, &t.PRURL, &t.PRNumber, - &t.DangerousMode, &t.Pinned, &t.Tags, + &t.DangerousMode, &t.Pinned, &t.Tags, &t.Summary, &t.CreatedAt, &t.UpdatedAt, &t.StartedAt, &t.CompletedAt, - &t.ScheduledAt, &t.Recurrence, &t.LastRunAt, + &t.ScheduledAt, &t.Recurrence, &t.LastRunAt, &t.LastDistilledAt, ) if err != nil { return nil, fmt.Errorf("scan task: %w", err) @@ -443,6 +446,19 @@ func (db *DB) UpdateTaskDangerousMode(taskID int64, dangerousMode bool) error { return nil } +// SaveTaskSummary updates the distilled summary for a task. +// This is called after task completion to store a concise summary for search and context. +func (db *DB) SaveTaskSummary(taskID int64, summary string) error { + _, err := db.Exec(` + UPDATE tasks SET summary = ?, updated_at = CURRENT_TIMESTAMP + WHERE id = ? + `, summary, taskID) + if err != nil { + return fmt.Errorf("save task summary: %w", err) + } + return nil +} + // UpdateTaskPinned updates only the pinned flag for a task. func (db *DB) UpdateTaskPinned(taskID int64, pinned bool) error { _, err := db.Exec(` @@ -574,9 +590,9 @@ func (db *DB) GetNextQueuedTask() (*Task, error) { COALESCE(daemon_session, ''), COALESCE(tmux_window_id, ''), COALESCE(claude_pane_id, ''), COALESCE(shell_pane_id, ''), COALESCE(pr_url, ''), COALESCE(pr_number, 0), - COALESCE(dangerous_mode, 0), COALESCE(pinned, 0), COALESCE(tags, ''), + COALESCE(dangerous_mode, 0), COALESCE(pinned, 0), COALESCE(tags, ''), COALESCE(summary, ''), created_at, updated_at, started_at, completed_at, - scheduled_at, recurrence, last_run_at + scheduled_at, recurrence, last_run_at, last_distilled_at FROM tasks WHERE status = ? ORDER BY created_at ASC @@ -586,9 +602,9 @@ func (db *DB) GetNextQueuedTask() (*Task, error) { &t.WorktreePath, &t.BranchName, &t.Port, &t.ClaudeSessionID, &t.DaemonSession, &t.TmuxWindowID, &t.ClaudePaneID, &t.ShellPaneID, &t.PRURL, &t.PRNumber, - &t.DangerousMode, &t.Pinned, &t.Tags, + &t.DangerousMode, &t.Pinned, &t.Tags, &t.Summary, &t.CreatedAt, &t.UpdatedAt, &t.StartedAt, &t.CompletedAt, - &t.ScheduledAt, &t.Recurrence, &t.LastRunAt, + &t.ScheduledAt, &t.Recurrence, &t.LastRunAt, &t.LastDistilledAt, ) if err == sql.ErrNoRows { return nil, nil @@ -607,9 +623,9 @@ func (db *DB) GetQueuedTasks() ([]*Task, error) { COALESCE(daemon_session, ''), COALESCE(tmux_window_id, ''), COALESCE(claude_pane_id, ''), COALESCE(shell_pane_id, ''), COALESCE(pr_url, ''), COALESCE(pr_number, 0), - COALESCE(dangerous_mode, 0), COALESCE(pinned, 0), COALESCE(tags, ''), + COALESCE(dangerous_mode, 0), COALESCE(pinned, 0), COALESCE(tags, ''), COALESCE(summary, ''), created_at, updated_at, started_at, completed_at, - scheduled_at, recurrence, last_run_at + scheduled_at, recurrence, last_run_at, last_distilled_at FROM tasks WHERE status = ? ORDER BY created_at ASC @@ -627,9 +643,9 @@ func (db *DB) GetQueuedTasks() ([]*Task, error) { &t.WorktreePath, &t.BranchName, &t.Port, &t.ClaudeSessionID, &t.DaemonSession, &t.TmuxWindowID, &t.ClaudePaneID, &t.ShellPaneID, &t.PRURL, &t.PRNumber, - &t.DangerousMode, &t.Pinned, &t.Tags, + &t.DangerousMode, &t.Pinned, &t.Tags, &t.Summary, &t.CreatedAt, &t.UpdatedAt, &t.StartedAt, &t.CompletedAt, - &t.ScheduledAt, &t.Recurrence, &t.LastRunAt, + &t.ScheduledAt, &t.Recurrence, &t.LastRunAt, &t.LastDistilledAt, ); err != nil { return nil, fmt.Errorf("scan task: %w", err) } @@ -647,9 +663,9 @@ func (db *DB) GetTasksWithBranches() ([]*Task, error) { COALESCE(daemon_session, ''), COALESCE(tmux_window_id, ''), COALESCE(claude_pane_id, ''), COALESCE(shell_pane_id, ''), COALESCE(pr_url, ''), COALESCE(pr_number, 0), - COALESCE(dangerous_mode, 0), COALESCE(pinned, 0), COALESCE(tags, ''), + COALESCE(dangerous_mode, 0), COALESCE(pinned, 0), COALESCE(tags, ''), COALESCE(summary, ''), created_at, updated_at, started_at, completed_at, - scheduled_at, recurrence, last_run_at + scheduled_at, recurrence, last_run_at, last_distilled_at FROM tasks WHERE branch_name != '' AND status NOT IN (?, ?) ORDER BY created_at DESC @@ -667,9 +683,9 @@ func (db *DB) GetTasksWithBranches() ([]*Task, error) { &t.WorktreePath, &t.BranchName, &t.Port, &t.ClaudeSessionID, &t.DaemonSession, &t.TmuxWindowID, &t.ClaudePaneID, &t.ShellPaneID, &t.PRURL, &t.PRNumber, - &t.DangerousMode, &t.Pinned, &t.Tags, + &t.DangerousMode, &t.Pinned, &t.Tags, &t.Summary, &t.CreatedAt, &t.UpdatedAt, &t.StartedAt, &t.CompletedAt, - &t.ScheduledAt, &t.Recurrence, &t.LastRunAt, + &t.ScheduledAt, &t.Recurrence, &t.LastRunAt, &t.LastDistilledAt, ); err != nil { return nil, fmt.Errorf("scan task: %w", err) } @@ -689,9 +705,9 @@ func (db *DB) GetDueScheduledTasks() ([]*Task, error) { COALESCE(daemon_session, ''), COALESCE(tmux_window_id, ''), COALESCE(claude_pane_id, ''), COALESCE(shell_pane_id, ''), COALESCE(pr_url, ''), COALESCE(pr_number, 0), - COALESCE(dangerous_mode, 0), COALESCE(pinned, 0), COALESCE(tags, ''), + COALESCE(dangerous_mode, 0), COALESCE(pinned, 0), COALESCE(tags, ''), COALESCE(summary, ''), created_at, updated_at, started_at, completed_at, - scheduled_at, recurrence, last_run_at + scheduled_at, recurrence, last_run_at, last_distilled_at FROM tasks WHERE scheduled_at IS NOT NULL AND scheduled_at <= CURRENT_TIMESTAMP @@ -711,9 +727,9 @@ func (db *DB) GetDueScheduledTasks() ([]*Task, error) { &t.WorktreePath, &t.BranchName, &t.Port, &t.ClaudeSessionID, &t.DaemonSession, &t.TmuxWindowID, &t.ClaudePaneID, &t.ShellPaneID, &t.PRURL, &t.PRNumber, - &t.DangerousMode, &t.Pinned, &t.Tags, + &t.DangerousMode, &t.Pinned, &t.Tags, &t.Summary, &t.CreatedAt, &t.UpdatedAt, &t.StartedAt, &t.CompletedAt, - &t.ScheduledAt, &t.Recurrence, &t.LastRunAt, + &t.ScheduledAt, &t.Recurrence, &t.LastRunAt, &t.LastDistilledAt, ); err != nil { return nil, fmt.Errorf("scan task: %w", err) } @@ -730,9 +746,9 @@ func (db *DB) GetScheduledTasks() ([]*Task, error) { COALESCE(daemon_session, ''), COALESCE(tmux_window_id, ''), COALESCE(claude_pane_id, ''), COALESCE(shell_pane_id, ''), COALESCE(pr_url, ''), COALESCE(pr_number, 0), - COALESCE(dangerous_mode, 0), COALESCE(pinned, 0), COALESCE(tags, ''), + COALESCE(dangerous_mode, 0), COALESCE(pinned, 0), COALESCE(tags, ''), COALESCE(summary, ''), created_at, updated_at, started_at, completed_at, - scheduled_at, recurrence, last_run_at + scheduled_at, recurrence, last_run_at, last_distilled_at FROM tasks WHERE scheduled_at IS NOT NULL ORDER BY scheduled_at ASC @@ -750,9 +766,9 @@ func (db *DB) GetScheduledTasks() ([]*Task, error) { &t.WorktreePath, &t.BranchName, &t.Port, &t.ClaudeSessionID, &t.DaemonSession, &t.TmuxWindowID, &t.ClaudePaneID, &t.ShellPaneID, &t.PRURL, &t.PRNumber, - &t.DangerousMode, &t.Pinned, &t.Tags, + &t.DangerousMode, &t.Pinned, &t.Tags, &t.Summary, &t.CreatedAt, &t.UpdatedAt, &t.StartedAt, &t.CompletedAt, - &t.ScheduledAt, &t.Recurrence, &t.LastRunAt, + &t.ScheduledAt, &t.Recurrence, &t.LastRunAt, &t.LastDistilledAt, ); err != nil { return nil, fmt.Errorf("scan task: %w", err) } @@ -1772,6 +1788,32 @@ func (db *DB) FindSimilarTasks(task *Task, limit int) ([]*TaskSearchResult, erro }) } +// UpdateTaskLastDistilledAt updates the last_distilled_at timestamp for a task. +// This is called after distilling learnings from a task to track when it was last processed. +func (db *DB) UpdateTaskLastDistilledAt(taskID int64, t time.Time) error { + _, err := db.Exec(` + UPDATE tasks SET last_distilled_at = ?, updated_at = CURRENT_TIMESTAMP + WHERE id = ? + `, LocalTime{Time: t}, taskID) + if err != nil { + return fmt.Errorf("update task last_distilled_at: %w", err) + } + return nil +} + +// UpdateTaskStartedAt updates the started_at timestamp for a task. +// This is primarily used for testing. +func (db *DB) UpdateTaskStartedAt(taskID int64, t time.Time) error { + _, err := db.Exec(` + UPDATE tasks SET started_at = ?, updated_at = CURRENT_TIMESTAMP + WHERE id = ? + `, LocalTime{Time: t}, taskID) + if err != nil { + return fmt.Errorf("update task started_at: %w", err) + } + return nil +} + // GetTagsList returns all unique tags used across all tasks. func (db *DB) GetTagsList() ([]string, error) { rows, err := db.Query(`SELECT DISTINCT tags FROM tasks WHERE tags != ''`) diff --git a/internal/executor/executor.go b/internal/executor/executor.go index c6df404..90a7fc7 100644 --- a/internal/executor/executor.go +++ b/internal/executor/executor.go @@ -882,6 +882,10 @@ func (e *Executor) executeTask(ctx context.Context, task *db.Task) { result = execResult.toInternal() } + // Check if we should distill learnings from this execution session + // This runs asynchronously and captures memories even for in-progress tasks + e.MaybeDistillTask(task) + // Check current status - hooks may have already set it currentTask, _ := e.db.GetTask(task.ID) currentStatus := "" @@ -919,17 +923,14 @@ func (e *Executor) executeTask(ctx context.Context, task *db.Task) { // Save transcript on completion e.saveTranscriptOnCompletion(task.ID, workDir) - // Index task for future search/retrieval - e.indexTaskForSearch(task) - // NOTE: We intentionally do NOT kill the executor here - keep it running so user can // easily retry/resume the task. Old done task executors are cleaned up after 2h // by the cleanupOrphanedClaudes routine. - // Extract memories from successful task + // Distill and index the completed task (runs in background) go func() { - if err := e.ExtractMemories(context.Background(), task); err != nil { - e.logger.Error("Memory extraction failed", "task", task.ID, "error", err) + if err := e.processCompletedTask(context.Background(), task); err != nil { + e.logger.Error("Task processing failed", "task", task.ID, "error", err) } }() } else if result.NeedsInput { @@ -2791,33 +2792,6 @@ func (e *Executor) getProjectMemoriesSection(project string) string { return sb.String() } -// indexTaskForSearch indexes a completed task in the FTS5 search table. -// This enables future tasks to find and reference similar past work. -func (e *Executor) indexTaskForSearch(task *db.Task) { - // Get transcript excerpt (first ~2000 chars of the most recent transcript) - var transcriptExcerpt string - summary, err := e.db.GetLatestCompactionSummary(task.ID) - if err == nil && summary != nil && len(summary.Summary) > 0 { - transcriptExcerpt = summary.Summary - if len(transcriptExcerpt) > 2000 { - transcriptExcerpt = transcriptExcerpt[:2000] - } - } - - // Index the task - if err := e.db.IndexTaskForSearch( - task.ID, - task.Project, - task.Title, - task.Body, - task.Tags, - transcriptExcerpt, - ); err != nil { - e.logger.Debug("Failed to index task for search", "task", task.ID, "error", err) - } else { - e.logger.Debug("Indexed task for search", "task", task.ID) - } -} // getSimilarTasksSection checks if similar past tasks exist and returns a hint. // Instead of injecting full content, we just notify Claude that the search tools are available. diff --git a/internal/executor/memory_e2e_test.go b/internal/executor/memory_e2e_test.go index 3a76aaa..2c2073e 100644 --- a/internal/executor/memory_e2e_test.go +++ b/internal/executor/memory_e2e_test.go @@ -2,137 +2,16 @@ package executor import ( "context" - "os/exec" + "os" "path/filepath" "strings" "testing" - "time" "github.com/bborn/workflow/internal/config" "github.com/bborn/workflow/internal/db" "github.com/charmbracelet/log" ) -// TestMemoryE2EFullLifecycle tests the complete memory lifecycle: -// 1. Create a task with logs -// 2. Extract memories from the logs -// 3. Verify memories are stored in the database -// 4. Verify memories are injected into new task prompts -func TestMemoryE2EFullLifecycle(t *testing.T) { - if testing.Short() { - t.Skip("Skipping E2E test in short mode") - } - - // Check if claude CLI is available - if _, err := exec.LookPath("claude"); err != nil { - t.Skip("claude CLI not found, skipping E2E test") - } - - // Create temporary database - tmpDir := t.TempDir() - dbPath := filepath.Join(tmpDir, "test.db") - - database, err := db.Open(dbPath) - if err != nil { - t.Fatalf("failed to open database: %v", err) - } - defer database.Close() - - // Create project - projectName := "memory-test-project" - err = database.CreateProject(&db.Project{ - Name: projectName, - Path: tmpDir, - }) - if err != nil { - t.Fatalf("failed to create project: %v", err) - } - - // Create a task - task := &db.Task{ - Title: "Implement user authentication", - Body: "Add JWT-based authentication to the API", - Status: db.StatusDone, - Type: db.TypeCode, - Project: projectName, - } - if err := database.CreateTask(task); err != nil { - t.Fatalf("failed to create task: %v", err) - } - - // Simulate task logs that contain useful learnings - taskLogs := []struct { - lineType string - content string - }{ - {"text", "Starting authentication implementation..."}, - {"output", "Looking at the existing auth module in internal/auth/"}, - {"output", "Found that the project uses RS256 algorithm for JWT tokens"}, - {"output", "The refresh token is stored in a separate table 'refresh_tokens'"}, - {"output", "Important: The auth middleware expects tokens in the Authorization header with 'Bearer' prefix"}, - {"output", "Gotcha: The token validation caches the public key for 5 minutes to avoid repeated fetches"}, - {"text", "Successfully implemented JWT authentication with refresh token support"}, - } - - for _, log := range taskLogs { - if err := database.AppendTaskLog(task.ID, log.lineType, log.content); err != nil { - t.Fatalf("failed to add task log: %v", err) - } - } - - // Create config for testing - cfg := config.New(database) - - // Create executor with minimal config for testing - executor := &Executor{ - db: database, - config: cfg, - logger: log.NewWithOptions(nil, log.Options{Level: log.DebugLevel}), - } - - // Test memory extraction - ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) - defer cancel() - - err = executor.ExtractMemories(ctx, task) - if err != nil { - t.Fatalf("memory extraction failed: %v", err) - } - - // Wait a moment for async operations - time.Sleep(100 * time.Millisecond) - - // Verify memories were stored - memories, err := database.GetProjectMemories(projectName, 10) - if err != nil { - t.Fatalf("failed to get project memories: %v", err) - } - - t.Logf("Extracted %d memories from task", len(memories)) - for _, m := range memories { - t.Logf(" - [%s] %s", m.Category, m.Content) - } - - // We should have at least one memory extracted (Claude should find something useful) - if len(memories) == 0 { - t.Log("Warning: No memories were extracted. This may be expected if Claude determined nothing was worth remembering.") - } - - // Test memory injection into prompts - memoriesSection := executor.getProjectMemoriesSection(projectName) - if len(memories) > 0 && memoriesSection == "" { - t.Error("Expected memories section to be non-empty when memories exist") - } - - if memoriesSection != "" { - // Verify the format - if !strings.Contains(memoriesSection, "## Project Context") { - t.Error("Memories section should contain 'Project Context' header") - } - t.Logf("Memory injection section:\n%s", memoriesSection) - } -} - // TestMemoryInjectionFormat verifies that memories are correctly formatted // for injection into task prompts. func TestMemoryInjectionFormat(t *testing.T) { @@ -207,9 +86,9 @@ func TestMemoryInjectionFormat(t *testing.T) { t.Logf("Formatted memories section:\n%s", memoriesSection) } -// TestMemoryExtractionSkipsEmptyLogs verifies that memory extraction +// TestProcessCompletedTaskSkipsEmptyCompaction verifies that task processing // handles edge cases gracefully. -func TestMemoryExtractionSkipsEmptyLogs(t *testing.T) { +func TestProcessCompletedTaskSkipsEmptyCompaction(t *testing.T) { tmpDir := t.TempDir() dbPath := filepath.Join(tmpDir, "test.db") @@ -219,7 +98,7 @@ func TestMemoryExtractionSkipsEmptyLogs(t *testing.T) { } defer database.Close() - projectName := "empty-logs-project" + projectName := "empty-compaction-project" err = database.CreateProject(&db.Project{ Name: projectName, Path: tmpDir, @@ -228,7 +107,7 @@ func TestMemoryExtractionSkipsEmptyLogs(t *testing.T) { t.Fatalf("failed to create project: %v", err) } - // Task with no logs + // Task with no compaction summary task := &db.Task{ Title: "Empty task", Status: db.StatusDone, @@ -238,15 +117,17 @@ func TestMemoryExtractionSkipsEmptyLogs(t *testing.T) { t.Fatalf("failed to create task: %v", err) } + cfg := config.New(database) executor := &Executor{ db: database, + config: cfg, logger: log.NewWithOptions(nil, log.Options{Level: log.DebugLevel}), } ctx := context.Background() - err = executor.ExtractMemories(ctx, task) + err = executor.processCompletedTask(ctx, task) if err != nil { - t.Errorf("extraction should not fail on empty logs: %v", err) + t.Errorf("processing should not fail on empty compaction: %v", err) } // Verify no memories were created @@ -255,13 +136,13 @@ func TestMemoryExtractionSkipsEmptyLogs(t *testing.T) { t.Fatalf("failed to get memories: %v", err) } if len(memories) != 0 { - t.Errorf("expected 0 memories for empty logs, got %d", len(memories)) + t.Errorf("expected 0 memories for empty compaction, got %d", len(memories)) } } -// TestMemoryExtractionRequiresProject verifies that tasks without -// a project don't attempt memory extraction. -func TestMemoryExtractionRequiresProject(t *testing.T) { +// TestProcessCompletedTaskHandlesNoProject verifies that tasks without +// a project process gracefully. +func TestProcessCompletedTaskHandlesNoProject(t *testing.T) { tmpDir := t.TempDir() dbPath := filepath.Join(tmpDir, "test.db") @@ -281,29 +162,108 @@ func TestMemoryExtractionRequiresProject(t *testing.T) { t.Fatalf("failed to create task: %v", err) } + cfg := config.New(database) executor := &Executor{ db: database, + config: cfg, logger: log.NewWithOptions(nil, log.Options{Level: log.DebugLevel}), } ctx := context.Background() - err = executor.ExtractMemories(ctx, task) + err = executor.processCompletedTask(ctx, task) if err != nil { - t.Errorf("extraction should return nil for tasks without project: %v", err) + t.Errorf("processing should not fail for tasks without project: %v", err) } } -// TestMemoryDuplicatePrevention verifies that similar memories -// are not duplicated. -func TestMemoryDuplicatePrevention(t *testing.T) { - if testing.Short() { - t.Skip("Skipping integration test in short mode") +// TestGenerateMemoriesMD verifies that .claude/memories.md is generated correctly +// from project memories. +func TestGenerateMemoriesMD(t *testing.T) { + tmpDir := t.TempDir() + dbPath := filepath.Join(tmpDir, "test.db") + + database, err := db.Open(dbPath) + if err != nil { + t.Fatalf("failed to open database: %v", err) + } + defer database.Close() + + projectName := "memories-md-test-project" + err = database.CreateProject(&db.Project{ + Name: projectName, + Path: tmpDir, + }) + if err != nil { + t.Fatalf("failed to create project: %v", err) + } + + // Create memories of different categories + testMemories := []*db.ProjectMemory{ + {Project: projectName, Category: db.MemoryCategoryPattern, Content: "Use dependency injection for services"}, + {Project: projectName, Category: db.MemoryCategoryContext, Content: "The API uses GraphQL, not REST"}, + {Project: projectName, Category: db.MemoryCategoryDecision, Content: "Chose PostgreSQL over MySQL for JSON support"}, + {Project: projectName, Category: db.MemoryCategoryGotcha, Content: "The cache has a 5 minute TTL"}, + } + + for _, m := range testMemories { + if err := database.CreateMemory(m); err != nil { + t.Fatalf("failed to create memory: %v", err) + } + } + + cfg := config.New(database) + executor := &Executor{ + db: database, + config: cfg, + logger: log.NewWithOptions(nil, log.Options{Level: log.DebugLevel}), + } + + err = executor.GenerateMemoriesMD(projectName) + if err != nil { + t.Fatalf("failed to generate memories.md: %v", err) + } + + // Read the generated file - should be in .claude/memories.md + memoriesPath := filepath.Join(tmpDir, ".claude", "memories.md") + content, err := os.ReadFile(memoriesPath) + if err != nil { + t.Fatalf("failed to read memories.md: %v", err) } - if _, err := exec.LookPath("claude"); err != nil { - t.Skip("claude CLI not found") + contentStr := string(content) + + // Verify header + if !strings.Contains(contentStr, "# Project Memories") { + t.Error("Missing main header") + } + + // Verify category sections exist + expectedSections := []string{ + "## Patterns & Conventions", + "## Project Context", + "## Key Decisions", + "## Known Gotchas", + } + + for _, section := range expectedSections { + if !strings.Contains(contentStr, section) { + t.Errorf("Missing section: %s", section) + } } + // Verify memory content is present + for _, m := range testMemories { + if !strings.Contains(contentStr, m.Content) { + t.Errorf("Missing memory content: %s", m.Content) + } + } + + t.Logf("Generated memories.md content:\n%s", contentStr) +} + +// TestGenerateMemoriesMDNoMemories verifies that .claude/memories.md is generated +// even when there are no memories. +func TestGenerateMemoriesMDNoMemories(t *testing.T) { tmpDir := t.TempDir() dbPath := filepath.Join(tmpDir, "test.db") @@ -313,7 +273,7 @@ func TestMemoryDuplicatePrevention(t *testing.T) { } defer database.Close() - projectName := "dedup-test-project" + projectName := "no-memories-test-project" err = database.CreateProject(&db.Project{ Name: projectName, Path: tmpDir, @@ -322,39 +282,70 @@ func TestMemoryDuplicatePrevention(t *testing.T) { t.Fatalf("failed to create project: %v", err) } - // Pre-create a memory that might be extracted - existingMemory := &db.ProjectMemory{ - Project: projectName, - Category: db.MemoryCategoryContext, - Content: "The project uses JWT tokens for authentication", + cfg := config.New(database) + executor := &Executor{ + db: database, + config: cfg, + logger: log.NewWithOptions(nil, log.Options{Level: log.DebugLevel}), } - if err := database.CreateMemory(existingMemory); err != nil { - t.Fatalf("failed to create existing memory: %v", err) + + err = executor.GenerateMemoriesMD(projectName) + if err != nil { + t.Fatalf("failed to generate memories.md: %v", err) } - // Create task with logs that mention the same thing - task := &db.Task{ - Title: "Add token refresh", - Body: "Implement token refresh mechanism", - Status: db.StatusDone, - Type: db.TypeCode, - Project: projectName, + // Read the generated file + memoriesPath := filepath.Join(tmpDir, ".claude", "memories.md") + content, err := os.ReadFile(memoriesPath) + if err != nil { + t.Fatalf("failed to read memories.md: %v", err) } - if err := database.CreateTask(task); err != nil { - t.Fatalf("failed to create task: %v", err) + + contentStr := string(content) + + // Verify placeholder text for no memories + if !strings.Contains(contentStr, "No project memories have been captured yet") { + t.Error("Missing placeholder text for no memories") } - logs := []struct { - lineType string - content string - }{ - {"output", "The project uses JWT tokens for authentication"}, - {"output", "Added refresh token support"}, + t.Logf("Generated memories.md content:\n%s", contentStr) +} + +// TestGenerateMemoriesMDDoesNotClobberClaudeMD verifies that generating memories.md +// does not affect any existing CLAUDE.md in the project root. +func TestGenerateMemoriesMDDoesNotClobberClaudeMD(t *testing.T) { + tmpDir := t.TempDir() + dbPath := filepath.Join(tmpDir, "test.db") + + database, err := db.Open(dbPath) + if err != nil { + t.Fatalf("failed to open database: %v", err) } - for _, log := range logs { - if err := database.AppendTaskLog(task.ID, log.lineType, log.content); err != nil { - t.Fatalf("failed to add log: %v", err) - } + defer database.Close() + + projectName := "clobber-test-project" + err = database.CreateProject(&db.Project{ + Name: projectName, + Path: tmpDir, + }) + if err != nil { + t.Fatalf("failed to create project: %v", err) + } + + // Create an existing CLAUDE.md file that should not be touched + existingClaudeMD := "# My Custom CLAUDE.md\n\nThis should not be modified." + claudeMDPath := filepath.Join(tmpDir, "CLAUDE.md") + if err := os.WriteFile(claudeMDPath, []byte(existingClaudeMD), 0644); err != nil { + t.Fatalf("failed to create CLAUDE.md: %v", err) + } + + // Create a memory + if err := database.CreateMemory(&db.ProjectMemory{ + Project: projectName, + Category: db.MemoryCategoryContext, + Content: "Test memory", + }); err != nil { + t.Fatalf("failed to create memory: %v", err) } cfg := config.New(database) @@ -364,29 +355,24 @@ func TestMemoryDuplicatePrevention(t *testing.T) { logger: log.NewWithOptions(nil, log.Options{Level: log.DebugLevel}), } - ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) - defer cancel() - - err = executor.ExtractMemories(ctx, task) + err = executor.GenerateMemoriesMD(projectName) if err != nil { - t.Logf("extraction error (may be expected): %v", err) + t.Fatalf("failed to generate memories.md: %v", err) } - // Get all memories and check for duplicates - memories, err := database.GetProjectMemories(projectName, 50) + // Verify CLAUDE.md was not modified + content, err := os.ReadFile(claudeMDPath) if err != nil { - t.Fatalf("failed to get memories: %v", err) + t.Fatalf("failed to read CLAUDE.md: %v", err) } - // Count how many times the JWT memory appears - jwtCount := 0 - for _, m := range memories { - if strings.Contains(strings.ToLower(m.Content), "jwt") { - jwtCount++ - } + if string(content) != existingClaudeMD { + t.Error("CLAUDE.md was modified when it should not have been") } - t.Logf("Found %d JWT-related memories", jwtCount) - // We expect Claude to avoid creating duplicate memories - // but this is not a strict requirement + // Verify memories.md was created separately + memoriesPath := filepath.Join(tmpDir, ".claude", "memories.md") + if _, err := os.Stat(memoriesPath); os.IsNotExist(err) { + t.Error("memories.md was not created") + } } diff --git a/internal/executor/memory_extractor.go b/internal/executor/memory_extractor.go index 7c2c5ba..f829230 100644 --- a/internal/executor/memory_extractor.go +++ b/internal/executor/memory_extractor.go @@ -1,246 +1,116 @@ -// Package executor provides memory extraction from completed tasks. +// Package executor provides memory management for completed tasks. package executor import ( - "context" - "encoding/json" "fmt" - "os/exec" + "os" + "path/filepath" "strings" - "time" "github.com/bborn/workflow/internal/db" ) -// ExtractedMemory represents a memory extracted from task execution. -type ExtractedMemory struct { - Category string `json:"category"` // pattern, context, decision, gotcha, general - Content string `json:"content"` -} - -// MemoryExtractionResult holds the result of memory extraction. -type MemoryExtractionResult struct { - Memories []ExtractedMemory `json:"memories"` -} - -// ExtractMemories analyzes completed task logs and extracts useful memories. -func (e *Executor) ExtractMemories(ctx context.Context, task *db.Task) error { - if task.Project == "" { - return nil // Can't store memories without a project - } - - // Get the compaction summary which contains the actual Claude conversation - summary, err := e.db.GetLatestCompactionSummary(task.ID) - if err != nil { - return fmt.Errorf("get compaction summary: %w", err) - } - - var logContent strings.Builder - - // Use compaction summary if available (contains the actual conversation) - if summary != nil && len(summary.Summary) > 100 { - logContent.WriteString(summary.Summary) - } else { - // Fallback to task logs if no compaction summary - logs, err := e.db.GetTaskLogs(task.ID, 500) - if err != nil { - return fmt.Errorf("get task logs: %w", err) - } - - if len(logs) == 0 { - return nil - } - - // Build context from logs - include system logs which contain Claude's output - for _, log := range logs { - if log.LineType == "output" || log.LineType == "text" || log.LineType == "system" { - logContent.WriteString(log.Content) - logContent.WriteString("\n") - } - } +// normalizeCategory converts a category string to a valid memory category. +func normalizeCategory(category string) string { + category = strings.ToLower(strings.TrimSpace(category)) + switch category { + case "pattern", "patterns": + return db.MemoryCategoryPattern + case "context": + return db.MemoryCategoryContext + case "decision", "decisions": + return db.MemoryCategoryDecision + case "gotcha", "gotchas", "pitfall", "pitfalls": + return db.MemoryCategoryGotcha + case "general": + return db.MemoryCategoryGeneral + default: + return db.MemoryCategoryGeneral } +} - // Skip if not enough content - if logContent.Len() < 100 { +// GenerateMemoriesMD creates or updates a .claude/memories.md file in the project directory +// with accumulated project memories. Claude Code reads files in the .claude/ directory, +// so this provides context without clobbering any existing CLAUDE.md in the project root. +func (e *Executor) GenerateMemoriesMD(project string) error { + if project == "" { return nil } - // Get existing memories to avoid duplicates - existingMemories, err := e.db.GetProjectMemories(task.Project, 50) - if err != nil { - return fmt.Errorf("get existing memories: %w", err) + projectDir := e.getProjectDir(project) + if projectDir == "" { + return fmt.Errorf("project directory not found: %s", project) } - var existingContent strings.Builder - for _, m := range existingMemories { - existingContent.WriteString(fmt.Sprintf("- [%s] %s\n", m.Category, m.Content)) + // Use .claude/memories.md to avoid clobbering any existing CLAUDE.md + claudeDir := filepath.Join(projectDir, ".claude") + if err := os.MkdirAll(claudeDir, 0755); err != nil { + return fmt.Errorf("create .claude dir: %w", err) } + memoriesPath := filepath.Join(claudeDir, "memories.md") - // Build extraction prompt - prompt := buildExtractionPrompt(task, logContent.String(), existingContent.String()) - - // Run Claude to extract memories - memories, err := e.runMemoryExtraction(ctx, task, prompt) + // Get all project memories + memories, err := e.db.GetProjectMemories(project, 100) if err != nil { - e.logger.Error("Memory extraction failed", "task", task.ID, "error", err) - return err + return fmt.Errorf("get project memories: %w", err) } - // Save extracted memories - for _, mem := range memories { - // Validate category - category := normalizeCategory(mem.Category) - if category == "" { - continue - } + // Build memories.md content + var content strings.Builder - // Skip if content is too short or empty - content := strings.TrimSpace(mem.Content) - if len(content) < 10 { - continue - } + content.WriteString("# Project Memories\n\n") + content.WriteString("This file is auto-generated from task completions to help Claude understand the codebase.\n") + content.WriteString("Do not edit manually - changes will be overwritten.\n\n") - memory := &db.ProjectMemory{ - Project: task.Project, - Category: category, - Content: content, - SourceTaskID: &task.ID, + if len(memories) == 0 { + content.WriteString("No project memories have been captured yet. They will appear here as tasks are completed.\n") + } else { + // Group memories by category + byCategory := make(map[string][]*db.ProjectMemory) + for _, m := range memories { + byCategory[m.Category] = append(byCategory[m.Category], m) } - if err := e.db.CreateMemory(memory); err != nil { - e.logger.Error("Failed to save memory", "error", err) - continue + // Category order and labels (same as getProjectMemoriesSection) + categoryOrder := []string{ + db.MemoryCategoryPattern, + db.MemoryCategoryContext, + db.MemoryCategoryDecision, + db.MemoryCategoryGotcha, + db.MemoryCategoryGeneral, + } + categoryLabels := map[string]string{ + db.MemoryCategoryPattern: "Patterns & Conventions", + db.MemoryCategoryContext: "Project Context", + db.MemoryCategoryDecision: "Key Decisions", + db.MemoryCategoryGotcha: "Known Gotchas", + db.MemoryCategoryGeneral: "General Notes", } - e.logLine(task.ID, "system", fmt.Sprintf("Learned: [%s] %s", category, truncate(content, 60))) - } - - return nil -} - -func buildExtractionPrompt(task *db.Task, logContent, existingMemories string) string { - var prompt strings.Builder - - prompt.WriteString(`Analyze this completed task and extract any useful learnings that should be remembered for future tasks on this project. - -## Task Information -`) - prompt.WriteString(fmt.Sprintf("Title: %s\n", task.Title)) - prompt.WriteString(fmt.Sprintf("Project: %s\n", task.Project)) - prompt.WriteString(fmt.Sprintf("Type: %s\n", task.Type)) - if task.Body != "" { - prompt.WriteString(fmt.Sprintf("Description: %s\n", task.Body)) - } - - prompt.WriteString(` -## Task Execution Log (truncated) -`) - // Truncate log content to avoid token limits - maxLogLen := 8000 - if len(logContent) > maxLogLen { - logContent = logContent[:maxLogLen] + "\n... (truncated)" - } - prompt.WriteString(logContent) - - if existingMemories != "" { - prompt.WriteString(` - -## Existing Memories (avoid duplicates) -`) - prompt.WriteString(existingMemories) - } - - prompt.WriteString(` - -## Instructions - -Extract 0-3 key learnings from this task that would be useful for future work on this project. Focus on: - -- **pattern**: Code patterns, naming conventions, file organization discovered -- **context**: Important project context (architecture, key dependencies, how things work) -- **decision**: Architectural or design decisions made and why -- **gotcha**: Pitfalls, workarounds, things that didn't work as expected -- **general**: Other useful learnings - -Guidelines: -- Only extract genuinely useful, non-obvious information -- Be concise but specific (include file paths, function names, etc. when relevant) -- Don't duplicate existing memories -- Return empty array if nothing worth remembering -- Each memory should be 1-2 sentences max - -Respond with ONLY a JSON object in this exact format: -{"memories": [{"category": "pattern", "content": "..."}, ...]} - -If there's nothing worth extracting, respond with: -{"memories": []} -`) - - return prompt.String() -} - -func (e *Executor) runMemoryExtraction(ctx context.Context, task *db.Task, prompt string) ([]ExtractedMemory, error) { - // Use a timeout for extraction - ctx, cancel := context.WithTimeout(ctx, 60*time.Second) - defer cancel() - - projectDir := e.getProjectDir(task.Project) - - jsonSchema := `{"type":"object","properties":{"memories":{"type":"array","items":{"type":"object","properties":{"category":{"type":"string"},"content":{"type":"string"}},"required":["category","content"]}}},"required":["memories"]}` - - args := []string{ - "-p", - "--output-format", "json", - "--json-schema", jsonSchema, - prompt, - } - - cmd := exec.CommandContext(ctx, "claude", args...) - cmd.Dir = projectDir - - output, err := cmd.Output() - if err != nil { - return nil, fmt.Errorf("claude execution: %w", err) - } + for _, cat := range categoryOrder { + mems := byCategory[cat] + if len(mems) == 0 { + continue + } - // Parse the JSON response - with --output-format json and --json-schema, - // Claude returns structured output directly in the structured_output field - var response struct { - StructuredOutput MemoryExtractionResult `json:"structured_output"` - IsError bool `json:"is_error"` - } - if err := json.Unmarshal(output, &response); err != nil { - return nil, fmt.Errorf("parse claude response: %w", err) - } + label := categoryLabels[cat] + if label == "" { + label = cat + } - if response.IsError { - return nil, fmt.Errorf("claude returned error") + content.WriteString(fmt.Sprintf("## %s\n\n", label)) + for _, m := range mems { + content.WriteString(fmt.Sprintf("- %s\n", m.Content)) + } + content.WriteString("\n") + } } - return response.StructuredOutput.Memories, nil -} - -func normalizeCategory(category string) string { - category = strings.ToLower(strings.TrimSpace(category)) - switch category { - case "pattern", "patterns": - return db.MemoryCategoryPattern - case "context": - return db.MemoryCategoryContext - case "decision", "decisions": - return db.MemoryCategoryDecision - case "gotcha", "gotchas", "pitfall", "pitfalls": - return db.MemoryCategoryGotcha - case "general": - return db.MemoryCategoryGeneral - default: - return db.MemoryCategoryGeneral + // Write the file + if err := os.WriteFile(memoriesPath, []byte(content.String()), 0644); err != nil { + return fmt.Errorf("write memories.md: %w", err) } -} -func truncate(s string, maxLen int) string { - if len(s) <= maxLen { - return s - } - return s[:maxLen-3] + "..." + e.logger.Debug("Generated .claude/memories.md", "project", project, "memories", len(memories)) + return nil } diff --git a/internal/executor/memory_extractor_test.go b/internal/executor/memory_extractor_test.go index a1cc92a..8504bcc 100644 --- a/internal/executor/memory_extractor_test.go +++ b/internal/executor/memory_extractor_test.go @@ -1,8 +1,6 @@ package executor import ( - "encoding/json" - "os/exec" "testing" ) @@ -35,7 +33,7 @@ func TestNormalizeCategory(t *testing.T) { } } -func TestTruncate(t *testing.T) { +func TestTruncateSummary(t *testing.T) { tests := []struct { input string maxLen int @@ -50,147 +48,10 @@ func TestTruncate(t *testing.T) { for _, tt := range tests { t.Run(tt.input, func(t *testing.T) { - result := truncate(tt.input, tt.maxLen) + result := truncateSummary(tt.input, tt.maxLen) if result != tt.expected { - t.Errorf("truncate(%q, %d) = %q, want %q", tt.input, tt.maxLen, result, tt.expected) + t.Errorf("truncateSummary(%q, %d) = %q, want %q", tt.input, tt.maxLen, result, tt.expected) } }) } } - -func TestBuildExtractionPrompt(t *testing.T) { - // Just verify it doesn't panic and includes key elements - task := &struct { - Title string - Project string - Type string - Body string - }{ - Title: "Fix login bug", - Project: "myapp", - Type: "code", - Body: "Users can't log in", - } - - // Create a mock task for testing - mockTask := struct { - Title string - Project string - Type string - Body string - }{ - Title: task.Title, - Project: task.Project, - Type: task.Type, - Body: task.Body, - } - - logContent := "Found the bug in auth.go\nFixed by checking nil pointer" - existingMemories := "- [pattern] Use early returns" - - // We can't call buildExtractionPrompt directly since it takes *db.Task - // Just verify the helper functions work - _ = mockTask - _ = logContent - _ = existingMemories -} - -func TestParseClaudeMemoryResponse(t *testing.T) { - // Test parsing the Claude JSON response format - responseJSON := `{"type":"result","subtype":"success","is_error":false,"structured_output":{"memories":[{"category":"context","content":"The auth module uses JWT tokens."}]}}` - - var response struct { - StructuredOutput MemoryExtractionResult `json:"structured_output"` - IsError bool `json:"is_error"` - } - - err := json.Unmarshal([]byte(responseJSON), &response) - if err != nil { - t.Fatalf("Failed to parse response: %v", err) - } - - if response.IsError { - t.Error("Expected is_error to be false") - } - - if len(response.StructuredOutput.Memories) != 1 { - t.Fatalf("Expected 1 memory, got %d", len(response.StructuredOutput.Memories)) - } - - mem := response.StructuredOutput.Memories[0] - if mem.Category != "context" { - t.Errorf("Expected category 'context', got %q", mem.Category) - } - if mem.Content != "The auth module uses JWT tokens." { - t.Errorf("Unexpected content: %q", mem.Content) - } -} - -func TestParseClaudeEmptyMemoryResponse(t *testing.T) { - // Test parsing empty memories response - responseJSON := `{"type":"result","subtype":"success","is_error":false,"structured_output":{"memories":[]}}` - - var response struct { - StructuredOutput MemoryExtractionResult `json:"structured_output"` - IsError bool `json:"is_error"` - } - - err := json.Unmarshal([]byte(responseJSON), &response) - if err != nil { - t.Fatalf("Failed to parse response: %v", err) - } - - if len(response.StructuredOutput.Memories) != 0 { - t.Errorf("Expected 0 memories, got %d", len(response.StructuredOutput.Memories)) - } -} - -// TestClaudeMemoryExtractionIntegration tests actual Claude CLI integration. -// Skip this test in CI by setting SKIP_CLAUDE_TESTS=1 -func TestClaudeMemoryExtractionIntegration(t *testing.T) { - if testing.Short() { - t.Skip("Skipping integration test in short mode") - } - - // Check if claude CLI is available - if _, err := exec.LookPath("claude"); err != nil { - t.Skip("claude CLI not found, skipping integration test") - } - - jsonSchema := `{"type":"object","properties":{"memories":{"type":"array","items":{"type":"object","properties":{"category":{"type":"string"},"content":{"type":"string"}},"required":["category","content"]}}},"required":["memories"]}` - - prompt := `Extract one memory from this task: "Fixed auth bug by adding nil check in auth.go. The system uses JWT tokens." Return exactly one memory with category "context".` - - cmd := exec.Command("claude", "-p", "--output-format", "json", "--json-schema", jsonSchema, prompt) - output, err := cmd.Output() - if err != nil { - t.Fatalf("Claude execution failed: %v", err) - } - - var response struct { - StructuredOutput MemoryExtractionResult `json:"structured_output"` - IsError bool `json:"is_error"` - } - - if err := json.Unmarshal(output, &response); err != nil { - t.Fatalf("Failed to parse Claude response: %v", err) - } - - if response.IsError { - t.Fatal("Claude returned an error") - } - - if len(response.StructuredOutput.Memories) == 0 { - t.Error("Expected at least one memory to be extracted") - } - - // Verify the memory has valid fields - for _, mem := range response.StructuredOutput.Memories { - if mem.Category == "" { - t.Error("Memory category should not be empty") - } - if mem.Content == "" { - t.Error("Memory content should not be empty") - } - } -} diff --git a/internal/executor/task_distillation.go b/internal/executor/task_distillation.go new file mode 100644 index 0000000..c376652 --- /dev/null +++ b/internal/executor/task_distillation.go @@ -0,0 +1,439 @@ +// Package executor provides task distillation for extracting structured summaries from completed tasks. +package executor + +import ( + "context" + "encoding/json" + "fmt" + "os/exec" + "strings" + "time" + + "github.com/bborn/workflow/internal/db" +) + +// TaskSummary represents a distilled summary of a completed task. +// This structured format enables efficient search indexing and memory extraction. +type TaskSummary struct { + WhatWasDone string `json:"what_was_done"` // Brief description of what was accomplished + FilesChanged []string `json:"files_changed"` // Key files that were modified + Decisions []Decision `json:"decisions"` // Architectural/design decisions made + Learnings []Learning `json:"learnings"` // Patterns, gotchas, and insights discovered +} + +// Decision represents an architectural or design decision made during a task. +type Decision struct { + Description string `json:"description"` // What was decided + Rationale string `json:"rationale"` // Why this decision was made +} + +// Learning represents a pattern, gotcha, or insight discovered during a task. +type Learning struct { + Category string `json:"category"` // pattern, context, decision, gotcha, general + Content string `json:"content"` // The actual learning +} + +// ToSearchExcerpt converts the TaskSummary to a search-friendly string. +// This is used for FTS5 indexing and should be concise but comprehensive. +func (s *TaskSummary) ToSearchExcerpt() string { + var parts []string + + // Add summary + if s.WhatWasDone != "" { + parts = append(parts, fmt.Sprintf("Summary: %s", s.WhatWasDone)) + } + + // Add files + if len(s.FilesChanged) > 0 { + parts = append(parts, fmt.Sprintf("Files: %s", strings.Join(s.FilesChanged, ", "))) + } + + // Add decisions + for _, d := range s.Decisions { + if d.Rationale != "" { + parts = append(parts, fmt.Sprintf("Decision: %s (Rationale: %s)", d.Description, d.Rationale)) + } else { + parts = append(parts, fmt.Sprintf("Decision: %s", d.Description)) + } + } + + // Add learnings + for _, l := range s.Learnings { + parts = append(parts, fmt.Sprintf("[%s] %s", l.Category, l.Content)) + } + + return strings.Join(parts, "\n") +} + +// DistillTaskSummary uses an LLM to distill a compaction summary into a structured TaskSummary. +// Returns nil if the content is too short or empty. +func (e *Executor) DistillTaskSummary(ctx context.Context, task *db.Task, compactionContent string) (*TaskSummary, error) { + // Require substantial content + if len(compactionContent) < 100 { + return nil, nil + } + + // Use a timeout for distillation + ctx, cancel := context.WithTimeout(ctx, 90*time.Second) + defer cancel() + + prompt := buildDistillationPrompt(task, compactionContent) + projectDir := e.getProjectDir(task.Project) + + // JSON schema for structured output + jsonSchema := `{ + "type": "object", + "properties": { + "what_was_done": {"type": "string"}, + "files_changed": {"type": "array", "items": {"type": "string"}}, + "decisions": { + "type": "array", + "items": { + "type": "object", + "properties": { + "description": {"type": "string"}, + "rationale": {"type": "string"} + }, + "required": ["description"] + } + }, + "learnings": { + "type": "array", + "items": { + "type": "object", + "properties": { + "category": {"type": "string"}, + "content": {"type": "string"} + }, + "required": ["category", "content"] + } + } + }, + "required": ["what_was_done"] + }` + + args := []string{ + "-p", + "--output-format", "json", + "--json-schema", jsonSchema, + prompt, + } + + cmd := exec.CommandContext(ctx, "claude", args...) + if projectDir != "" { + cmd.Dir = projectDir + } + + output, err := cmd.Output() + if err != nil { + return nil, fmt.Errorf("claude execution: %w", err) + } + + // Parse the JSON response + var response struct { + StructuredOutput TaskSummary `json:"structured_output"` + IsError bool `json:"is_error"` + } + if err := json.Unmarshal(output, &response); err != nil { + return nil, fmt.Errorf("parse claude response: %w", err) + } + + if response.IsError { + return nil, fmt.Errorf("claude returned error") + } + + return &response.StructuredOutput, nil +} + +// buildDistillationPrompt creates the prompt for distilling a task summary. +func buildDistillationPrompt(task *db.Task, content string) string { + var prompt strings.Builder + + prompt.WriteString(`Distill this completed task session into a structured summary. Extract the essential information that would help someone understand what was done without reading the full transcript. + +## Task Information +`) + prompt.WriteString(fmt.Sprintf("Title: %s\n", task.Title)) + if task.Project != "" { + prompt.WriteString(fmt.Sprintf("Project: %s\n", task.Project)) + } + if task.Body != "" { + prompt.WriteString(fmt.Sprintf("Description: %s\n", task.Body)) + } + + prompt.WriteString(` +## Session Transcript +`) + // Include the full content - the LLM will distill it + prompt.WriteString(content) + + prompt.WriteString(` + +## Instructions + +Create a structured summary with: + +1. **what_was_done**: A 1-2 sentence summary of the main accomplishment. Be specific. + +2. **files_changed**: List the key files that were created or modified. Include paths. Limit to the most important 5-10 files. + +3. **decisions**: List 0-3 significant architectural or design decisions made during this task. Include: + - description: What was decided + - rationale: Why this approach was chosen (if discussed) + +4. **learnings**: List 0-5 useful learnings discovered. Each should have: + - category: One of "pattern", "context", "decision", "gotcha", or "general" + - content: A concise description of the learning (1-2 sentences) + +Focus on information that would be valuable for future work on this project. Skip trivial details. +`) + + return prompt.String() +} + +// SaveMemoriesFromSummary extracts and saves project memories from a TaskSummary. +func (e *Executor) SaveMemoriesFromSummary(task *db.Task, summary *TaskSummary) error { + if task.Project == "" || summary == nil { + return nil + } + + // Save decisions as memories + for _, d := range summary.Decisions { + if d.Description == "" { + continue + } + content := d.Description + if d.Rationale != "" { + content = fmt.Sprintf("%s (Rationale: %s)", d.Description, d.Rationale) + } + + memory := &db.ProjectMemory{ + Project: task.Project, + Category: db.MemoryCategoryDecision, + Content: content, + SourceTaskID: &task.ID, + } + if err := e.db.CreateMemory(memory); err != nil { + e.logger.Error("Failed to save decision memory", "error", err) + } + } + + // Save learnings as memories + for _, l := range summary.Learnings { + if l.Content == "" { + continue + } + + category := normalizeCategory(l.Category) + memory := &db.ProjectMemory{ + Project: task.Project, + Category: category, + Content: l.Content, + SourceTaskID: &task.ID, + } + if err := e.db.CreateMemory(memory); err != nil { + e.logger.Error("Failed to save learning memory", "error", err) + } + } + + return nil +} + +// processCompletedTask orchestrates the post-completion processing: +// 1. Gets the compaction summary (full conversation transcript) +// 2. Distills it into a structured TaskSummary using an LLM +// 3. Saves the summary to the task for future reference +// 4. Extracts and saves project memories from the summary +// 5. Indexes the task for FTS5 search using the distilled summary +// 6. Updates .claude/memories.md with the latest memories +func (e *Executor) processCompletedTask(ctx context.Context, task *db.Task) error { + // Get the compaction summary which contains the full conversation + compaction, err := e.db.GetLatestCompactionSummary(task.ID) + if err != nil { + return fmt.Errorf("get compaction summary: %w", err) + } + + var compactionContent string + if compaction != nil { + compactionContent = compaction.Summary + } + + // Distill the summary using an LLM + summary, err := e.DistillTaskSummary(ctx, task, compactionContent) + if err != nil { + e.logger.Error("Failed to distill task summary", "task", task.ID, "error", err) + // Non-fatal - continue with empty summary + } + + var summaryText string + if summary != nil { + summaryText = summary.ToSearchExcerpt() + + // Save the distilled summary to the task + if err := e.db.SaveTaskSummary(task.ID, summaryText); err != nil { + e.logger.Error("Failed to save task summary", "task", task.ID, "error", err) + } + + // Extract and save project memories + if err := e.SaveMemoriesFromSummary(task, summary); err != nil { + e.logger.Error("Failed to save memories", "task", task.ID, "error", err) + } + + // Log what we extracted + e.logLine(task.ID, "system", fmt.Sprintf("Distilled: %s", truncateSummary(summary.WhatWasDone, 80))) + for _, d := range summary.Decisions { + e.logLine(task.ID, "system", fmt.Sprintf("Decision: %s", truncateSummary(d.Description, 60))) + } + for _, l := range summary.Learnings { + e.logLine(task.ID, "system", fmt.Sprintf("Learned [%s]: %s", l.Category, truncateSummary(l.Content, 60))) + } + } + + // Index task for search using the distilled summary (or empty if distillation failed) + if err := e.db.IndexTaskForSearch( + task.ID, + task.Project, + task.Title, + task.Body, + task.Tags, + summaryText, + ); err != nil { + e.logger.Debug("Failed to index task for search", "task", task.ID, "error", err) + } else { + e.logger.Debug("Indexed task for search", "task", task.ID) + } + + // Update .claude/memories.md with the latest memories + if task.Project != "" { + if err := e.GenerateMemoriesMD(task.Project); err != nil { + e.logger.Error("Failed to generate memories.md", "error", err) + } + } + + return nil +} + +// truncateSummary truncates a string with ellipsis if it exceeds maxLen. +func truncateSummary(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen-3] + "..." +} + +// DistillationInterval is the minimum time between distillations when no new content exists. +const DistillationInterval = 10 * time.Minute + +// MaybeDistillTask checks if a task should be distilled and runs distillation if appropriate. +// This is called after every ExecResult to capture learnings continuously, not just on completion. +// Runs in background and doesn't block the main execution flow. +func (e *Executor) MaybeDistillTask(task *db.Task) { + // Refresh task from DB to get latest state + freshTask, err := e.db.GetTask(task.ID) + if err != nil || freshTask == nil { + e.logger.Debug("Failed to get fresh task for distillation check", "task", task.ID, "error", err) + return + } + + should, reason := e.shouldDistill(freshTask) + if !should { + e.logger.Debug("Skipping distillation", "task", task.ID, "reason", "no trigger") + return + } + + e.logger.Info("Triggering distillation", "task", task.ID, "reason", reason) + + // Run distillation in background + go func() { + ctx := context.Background() + + // Get the compaction summary + compaction, err := e.db.GetLatestCompactionSummary(freshTask.ID) + if err != nil { + e.logger.Error("Failed to get compaction summary for distillation", "task", freshTask.ID, "error", err) + return + } + + var compactionContent string + if compaction != nil { + compactionContent = compaction.Summary + } + + // Distill the summary using an LLM + summary, err := e.DistillTaskSummary(ctx, freshTask, compactionContent) + if err != nil { + e.logger.Error("Failed to distill task summary", "task", freshTask.ID, "error", err) + // Still update last_distilled_at to avoid retrying immediately + e.db.UpdateTaskLastDistilledAt(freshTask.ID, time.Now()) + return + } + + if summary != nil { + // Save the distilled summary to the task + summaryText := summary.ToSearchExcerpt() + if err := e.db.SaveTaskSummary(freshTask.ID, summaryText); err != nil { + e.logger.Error("Failed to save task summary", "task", freshTask.ID, "error", err) + } + + // Extract and save project memories + if err := e.SaveMemoriesFromSummary(freshTask, summary); err != nil { + e.logger.Error("Failed to save memories", "task", freshTask.ID, "error", err) + } + + // Log what we extracted + e.logLine(freshTask.ID, "system", fmt.Sprintf("Distilled: %s", truncateSummary(summary.WhatWasDone, 80))) + + // Update .claude/memories.md with the latest memories + if freshTask.Project != "" { + if err := e.GenerateMemoriesMD(freshTask.Project); err != nil { + e.logger.Error("Failed to generate memories.md", "error", err) + } + } + } + + // Update last_distilled_at to track when we distilled + if err := e.db.UpdateTaskLastDistilledAt(freshTask.ID, time.Now()); err != nil { + e.logger.Error("Failed to update last_distilled_at", "task", freshTask.ID, "error", err) + } + }() +} + +// shouldDistill determines if a task should be distilled based on: +// 1. New compaction content exists (compaction is newer than last_distilled_at) +// 2. Enough time has passed since last distillation (10+ min) +// Returns (should distill, reason) where reason is "new_compaction", "time_elapsed", or "" +func (e *Executor) shouldDistill(task *db.Task) (bool, string) { + // Get the latest compaction summary + compaction, err := e.db.GetLatestCompactionSummary(task.ID) + if err != nil { + e.logger.Debug("Failed to get compaction summary for distillation check", "task", task.ID, "error", err) + return false, "" + } + + now := time.Now() + neverDistilled := task.LastDistilledAt == nil + + // Case 1: Check for new compaction content + if compaction != nil { + // If never distilled and has compaction, should distill + if neverDistilled { + return true, "new_compaction" + } + + // If compaction is newer than last distillation, should distill + if compaction.CreatedAt.Time.After(task.LastDistilledAt.Time) { + return true, "new_compaction" + } + } + + // Case 2: Check time-based trigger (only if task has been started and we have some history) + if !neverDistilled && task.StartedAt != nil { + timeSinceDistillation := now.Sub(task.LastDistilledAt.Time) + if timeSinceDistillation >= DistillationInterval { + return true, "time_elapsed" + } + } + + // No distillation needed + return false, "" +} diff --git a/internal/executor/task_distillation_test.go b/internal/executor/task_distillation_test.go new file mode 100644 index 0000000..8893466 --- /dev/null +++ b/internal/executor/task_distillation_test.go @@ -0,0 +1,623 @@ +package executor + +import ( + "context" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/bborn/workflow/internal/config" + "github.com/bborn/workflow/internal/db" + "github.com/charmbracelet/log" +) + +// TestTaskSummaryStructure verifies the TaskSummary struct has all required fields. +func TestTaskSummaryStructure(t *testing.T) { + summary := &TaskSummary{ + WhatWasDone: "Implemented user authentication with JWT tokens", + FilesChanged: []string{"internal/auth/jwt.go", "internal/auth/middleware.go"}, + Decisions: []Decision{ + {Description: "Use RS256 for JWT signing", Rationale: "More secure than HS256 for distributed systems"}, + }, + Learnings: []Learning{ + {Category: "pattern", Content: "Auth middleware should be applied at router level"}, + {Category: "gotcha", Content: "Token refresh requires separate endpoint"}, + }, + } + + if summary.WhatWasDone == "" { + t.Error("WhatWasDone should not be empty") + } + if len(summary.FilesChanged) != 2 { + t.Errorf("Expected 2 files changed, got %d", len(summary.FilesChanged)) + } + if len(summary.Decisions) != 1 { + t.Errorf("Expected 1 decision, got %d", len(summary.Decisions)) + } + if len(summary.Learnings) != 2 { + t.Errorf("Expected 2 learnings, got %d", len(summary.Learnings)) + } +} + +// TestTaskSummaryToSearchExcerpt verifies that a TaskSummary can be converted to a search-friendly string. +func TestTaskSummaryToSearchExcerpt(t *testing.T) { + summary := &TaskSummary{ + WhatWasDone: "Implemented user authentication", + FilesChanged: []string{"auth.go", "middleware.go"}, + Decisions: []Decision{ + {Description: "Use JWT tokens", Rationale: "Industry standard"}, + }, + Learnings: []Learning{ + {Category: "pattern", Content: "Apply middleware at router level"}, + }, + } + + excerpt := summary.ToSearchExcerpt() + + // Should contain the summary + if !strings.Contains(excerpt, "Implemented user authentication") { + t.Error("Excerpt should contain WhatWasDone") + } + + // Should contain files + if !strings.Contains(excerpt, "auth.go") { + t.Error("Excerpt should contain files changed") + } + + // Should contain decisions + if !strings.Contains(excerpt, "Use JWT tokens") { + t.Error("Excerpt should contain decisions") + } + + // Should contain learnings + if !strings.Contains(excerpt, "Apply middleware at router level") { + t.Error("Excerpt should contain learnings") + } + + // Should be reasonably sized (not bloated) + if len(excerpt) > 2000 { + t.Errorf("Excerpt should be under 2000 chars, got %d", len(excerpt)) + } +} + +// TestDistillTaskSummaryRequiresContent verifies that distillation requires substantial content. +func TestDistillTaskSummaryRequiresContent(t *testing.T) { + tmpDir := t.TempDir() + dbPath := filepath.Join(tmpDir, "test.db") + + database, err := db.Open(dbPath) + if err != nil { + t.Fatalf("failed to open database: %v", err) + } + defer database.Close() + + cfg := config.New(database) + executor := &Executor{ + db: database, + config: cfg, + logger: log.NewWithOptions(nil, log.Options{Level: log.DebugLevel}), + } + + task := &db.Task{ + ID: 1, + Title: "Test task", + Project: "test-project", + } + + // Should return nil summary for empty content + summary, err := executor.DistillTaskSummary(context.Background(), task, "") + if err != nil { + t.Errorf("Expected no error for empty content, got: %v", err) + } + if summary != nil { + t.Error("Expected nil summary for empty content") + } + + // Should return nil summary for very short content + summary, err = executor.DistillTaskSummary(context.Background(), task, "short") + if err != nil { + t.Errorf("Expected no error for short content, got: %v", err) + } + if summary != nil { + t.Error("Expected nil summary for short content") + } +} + +// TestSaveTaskSummary verifies that task summaries are saved to the database. +func TestSaveTaskSummary(t *testing.T) { + tmpDir := t.TempDir() + dbPath := filepath.Join(tmpDir, "test.db") + + database, err := db.Open(dbPath) + if err != nil { + t.Fatalf("failed to open database: %v", err) + } + defer database.Close() + + // Create a project + err = database.CreateProject(&db.Project{ + Name: "test-project", + Path: tmpDir, + }) + if err != nil { + t.Fatalf("failed to create project: %v", err) + } + + // Create a task + task := &db.Task{ + Title: "Test task", + Status: db.StatusDone, + Project: "test-project", + } + if err := database.CreateTask(task); err != nil { + t.Fatalf("failed to create task: %v", err) + } + + summary := &TaskSummary{ + WhatWasDone: "Implemented feature X", + FilesChanged: []string{"file1.go", "file2.go"}, + Decisions: []Decision{ + {Description: "Used pattern Y", Rationale: "Because Z"}, + }, + Learnings: []Learning{ + {Category: "pattern", Content: "Pattern A works well"}, + }, + } + + // Save the summary + err = database.SaveTaskSummary(task.ID, summary.ToSearchExcerpt()) + if err != nil { + t.Fatalf("failed to save task summary: %v", err) + } + + // Retrieve the task and verify summary is saved + retrieved, err := database.GetTask(task.ID) + if err != nil { + t.Fatalf("failed to get task: %v", err) + } + + if retrieved.Summary == "" { + t.Error("Summary should be saved on task") + } + + if !strings.Contains(retrieved.Summary, "Implemented feature X") { + t.Error("Summary should contain WhatWasDone") + } +} + +// TestDistillationCreatesMemories verifies that distillation creates project memories. +func TestDistillationCreatesMemories(t *testing.T) { + tmpDir := t.TempDir() + dbPath := filepath.Join(tmpDir, "test.db") + + database, err := db.Open(dbPath) + if err != nil { + t.Fatalf("failed to open database: %v", err) + } + defer database.Close() + + // Create a project + err = database.CreateProject(&db.Project{ + Name: "test-project", + Path: tmpDir, + }) + if err != nil { + t.Fatalf("failed to create project: %v", err) + } + + // Create a task + task := &db.Task{ + Title: "Test task", + Status: db.StatusDone, + Project: "test-project", + } + if err := database.CreateTask(task); err != nil { + t.Fatalf("failed to create task: %v", err) + } + + cfg := config.New(database) + executor := &Executor{ + db: database, + config: cfg, + logger: log.NewWithOptions(nil, log.Options{Level: log.DebugLevel}), + } + + summary := &TaskSummary{ + WhatWasDone: "Implemented feature X", + FilesChanged: []string{"file1.go"}, + Decisions: []Decision{ + {Description: "Used pattern Y", Rationale: "Because Z"}, + }, + Learnings: []Learning{ + {Category: "pattern", Content: "Pattern A works well"}, + {Category: "gotcha", Content: "Watch out for edge case B"}, + }, + } + + // Save memories from the summary + err = executor.SaveMemoriesFromSummary(task, summary) + if err != nil { + t.Fatalf("failed to save memories: %v", err) + } + + // Verify memories were created + memories, err := database.GetProjectMemories("test-project", 10) + if err != nil { + t.Fatalf("failed to get memories: %v", err) + } + + // Should have at least 3 memories: 1 decision + 2 learnings + if len(memories) < 3 { + t.Errorf("Expected at least 3 memories, got %d", len(memories)) + } + + // Verify decision was saved + hasDecision := false + for _, m := range memories { + if m.Category == db.MemoryCategoryDecision && strings.Contains(m.Content, "Used pattern Y") { + hasDecision = true + break + } + } + if !hasDecision { + t.Error("Decision should be saved as a memory") + } + + // Verify learnings were saved + hasPattern := false + hasGotcha := false + for _, m := range memories { + if m.Category == db.MemoryCategoryPattern && strings.Contains(m.Content, "Pattern A") { + hasPattern = true + } + if m.Category == db.MemoryCategoryGotcha && strings.Contains(m.Content, "edge case B") { + hasGotcha = true + } + } + if !hasPattern { + t.Error("Pattern learning should be saved as a memory") + } + if !hasGotcha { + t.Error("Gotcha learning should be saved as a memory") + } +} + +// TestSearchIndexUsesDistilledSummary verifies that search indexing uses the distilled summary. +func TestSearchIndexUsesDistilledSummary(t *testing.T) { + tmpDir := t.TempDir() + dbPath := filepath.Join(tmpDir, "test.db") + + database, err := db.Open(dbPath) + if err != nil { + t.Fatalf("failed to open database: %v", err) + } + defer database.Close() + + // Create a project + err = database.CreateProject(&db.Project{ + Name: "test-project", + Path: tmpDir, + }) + if err != nil { + t.Fatalf("failed to create project: %v", err) + } + + // Create a task with a summary + task := &db.Task{ + Title: "Implement authentication", + Status: db.StatusDone, + Project: "test-project", + } + if err := database.CreateTask(task); err != nil { + t.Fatalf("failed to create task: %v", err) + } + + // Save a summary + summaryText := "Implemented JWT authentication. Files: auth.go, middleware.go. Decision: Use RS256. Learning: Apply middleware at router level." + err = database.SaveTaskSummary(task.ID, summaryText) + if err != nil { + t.Fatalf("failed to save summary: %v", err) + } + + // Get the task to get the summary + task, _ = database.GetTask(task.ID) + + // Index the task for search using the summary + err = database.IndexTaskForSearch( + task.ID, + task.Project, + task.Title, + task.Body, + task.Tags, + task.Summary, // Use the saved summary instead of raw transcript + ) + if err != nil { + t.Fatalf("failed to index task: %v", err) + } + + // Search for the task + results, err := database.FindSimilarTasks(&db.Task{Title: "authentication JWT"}, 10) + if err != nil { + t.Fatalf("failed to search: %v", err) + } + + // Should find the task + found := false + for _, r := range results { + if r.TaskID == task.ID { + found = true + break + } + } + if !found { + t.Error("Should find task by searching for terms in the summary") + } +} + +// TestShouldDistillWithNewCompaction verifies distillation triggers when compaction is newer than last distillation. +func TestShouldDistillWithNewCompaction(t *testing.T) { + tmpDir := t.TempDir() + dbPath := filepath.Join(tmpDir, "test.db") + + database, err := db.Open(dbPath) + if err != nil { + t.Fatalf("failed to open database: %v", err) + } + defer database.Close() + + // Create a project and task + err = database.CreateProject(&db.Project{Name: "test-project", Path: tmpDir}) + if err != nil { + t.Fatalf("failed to create project: %v", err) + } + + task := &db.Task{ + Title: "Test task", + Status: db.StatusProcessing, + Project: "test-project", + } + if err := database.CreateTask(task); err != nil { + t.Fatalf("failed to create task: %v", err) + } + + cfg := config.New(database) + executor := &Executor{ + db: database, + config: cfg, + logger: log.NewWithOptions(nil, log.Options{Level: log.DebugLevel}), + } + + // Set last_distilled_at to 1 hour ago + oneHourAgo := time.Now().Add(-1 * time.Hour) + if err := database.UpdateTaskLastDistilledAt(task.ID, oneHourAgo); err != nil { + t.Fatalf("failed to update last_distilled_at: %v", err) + } + + // Save a compaction summary (created now, after last_distilled_at) + summary := &db.CompactionSummary{ + TaskID: task.ID, + SessionID: "test-session", + Trigger: "auto", + Summary: "This is a substantial compaction summary with enough content to distill.", + } + if err := database.SaveCompactionSummary(summary); err != nil { + t.Fatalf("failed to save compaction summary: %v", err) + } + + // Reload task + task, _ = database.GetTask(task.ID) + + // Should return true - compaction is newer than last distillation + should, reason := executor.shouldDistill(task) + if !should { + t.Errorf("Expected shouldDistill=true when compaction is newer, got false. Reason: %s", reason) + } + if reason != "new_compaction" { + t.Errorf("Expected reason='new_compaction', got %q", reason) + } +} + +// TestShouldDistillWithTimeElapsed verifies distillation triggers after enough time has passed. +func TestShouldDistillWithTimeElapsed(t *testing.T) { + tmpDir := t.TempDir() + dbPath := filepath.Join(tmpDir, "test.db") + + database, err := db.Open(dbPath) + if err != nil { + t.Fatalf("failed to open database: %v", err) + } + defer database.Close() + + // Create a project and task + err = database.CreateProject(&db.Project{Name: "test-project", Path: tmpDir}) + if err != nil { + t.Fatalf("failed to create project: %v", err) + } + + task := &db.Task{ + Title: "Test task", + Status: db.StatusProcessing, + Project: "test-project", + } + if err := database.CreateTask(task); err != nil { + t.Fatalf("failed to create task: %v", err) + } + + cfg := config.New(database) + executor := &Executor{ + db: database, + config: cfg, + logger: log.NewWithOptions(nil, log.Options{Level: log.DebugLevel}), + } + + // Set last_distilled_at to 15 minutes ago (beyond the 10 min threshold) + fifteenMinAgo := time.Now().Add(-15 * time.Minute) + if err := database.UpdateTaskLastDistilledAt(task.ID, fifteenMinAgo); err != nil { + t.Fatalf("failed to update last_distilled_at: %v", err) + } + + // No compaction summary, but task has been started + startedAt := time.Now().Add(-20 * time.Minute) + if err := database.UpdateTaskStartedAt(task.ID, startedAt); err != nil { + t.Fatalf("failed to update started_at: %v", err) + } + + // Reload task + task, _ = database.GetTask(task.ID) + + // Should return true - enough time has passed + should, reason := executor.shouldDistill(task) + if !should { + t.Errorf("Expected shouldDistill=true when enough time elapsed, got false. Reason: %s", reason) + } + if reason != "time_elapsed" { + t.Errorf("Expected reason='time_elapsed', got %q", reason) + } +} + +// TestShouldNotDistillWhenRecentlyDistilled verifies no distillation when recently done and no new content. +func TestShouldNotDistillWhenRecentlyDistilled(t *testing.T) { + tmpDir := t.TempDir() + dbPath := filepath.Join(tmpDir, "test.db") + + database, err := db.Open(dbPath) + if err != nil { + t.Fatalf("failed to open database: %v", err) + } + defer database.Close() + + // Create a project and task + err = database.CreateProject(&db.Project{Name: "test-project", Path: tmpDir}) + if err != nil { + t.Fatalf("failed to create project: %v", err) + } + + task := &db.Task{ + Title: "Test task", + Status: db.StatusProcessing, + Project: "test-project", + } + if err := database.CreateTask(task); err != nil { + t.Fatalf("failed to create task: %v", err) + } + + cfg := config.New(database) + executor := &Executor{ + db: database, + config: cfg, + logger: log.NewWithOptions(nil, log.Options{Level: log.DebugLevel}), + } + + // Set last_distilled_at to 2 minutes ago (within the 10 min threshold) + twoMinAgo := time.Now().Add(-2 * time.Minute) + if err := database.UpdateTaskLastDistilledAt(task.ID, twoMinAgo); err != nil { + t.Fatalf("failed to update last_distilled_at: %v", err) + } + + // Reload task + task, _ = database.GetTask(task.ID) + + // Should return false - recently distilled, no new compaction + should, reason := executor.shouldDistill(task) + if should { + t.Errorf("Expected shouldDistill=false when recently distilled, got true. Reason: %s", reason) + } +} + +// TestShouldDistillNeverDistilledBefore verifies distillation triggers for tasks never distilled. +func TestShouldDistillNeverDistilledBefore(t *testing.T) { + tmpDir := t.TempDir() + dbPath := filepath.Join(tmpDir, "test.db") + + database, err := db.Open(dbPath) + if err != nil { + t.Fatalf("failed to open database: %v", err) + } + defer database.Close() + + // Create a project and task + err = database.CreateProject(&db.Project{Name: "test-project", Path: tmpDir}) + if err != nil { + t.Fatalf("failed to create project: %v", err) + } + + task := &db.Task{ + Title: "Test task", + Status: db.StatusProcessing, + Project: "test-project", + } + if err := database.CreateTask(task); err != nil { + t.Fatalf("failed to create task: %v", err) + } + + cfg := config.New(database) + executor := &Executor{ + db: database, + config: cfg, + logger: log.NewWithOptions(nil, log.Options{Level: log.DebugLevel}), + } + + // Save a compaction summary + summary := &db.CompactionSummary{ + TaskID: task.ID, + SessionID: "test-session", + Trigger: "auto", + Summary: "This is a substantial compaction summary with enough content to distill.", + } + if err := database.SaveCompactionSummary(summary); err != nil { + t.Fatalf("failed to save compaction summary: %v", err) + } + + // Reload task (last_distilled_at is nil) + task, _ = database.GetTask(task.ID) + + // Should return true - never distilled and has compaction + should, reason := executor.shouldDistill(task) + if !should { + t.Errorf("Expected shouldDistill=true for never-distilled task with compaction, got false. Reason: %s", reason) + } +} + +// TestShouldNotDistillNoContent verifies no distillation when there's no content to distill. +func TestShouldNotDistillNoContent(t *testing.T) { + tmpDir := t.TempDir() + dbPath := filepath.Join(tmpDir, "test.db") + + database, err := db.Open(dbPath) + if err != nil { + t.Fatalf("failed to open database: %v", err) + } + defer database.Close() + + // Create a project and task + err = database.CreateProject(&db.Project{Name: "test-project", Path: tmpDir}) + if err != nil { + t.Fatalf("failed to create project: %v", err) + } + + task := &db.Task{ + Title: "Test task", + Status: db.StatusProcessing, + Project: "test-project", + } + if err := database.CreateTask(task); err != nil { + t.Fatalf("failed to create task: %v", err) + } + + cfg := config.New(database) + executor := &Executor{ + db: database, + config: cfg, + logger: log.NewWithOptions(nil, log.Options{Level: log.DebugLevel}), + } + + // No compaction summary, never distilled, task just created + task, _ = database.GetTask(task.ID) + + // Should return false - no content to distill + should, reason := executor.shouldDistill(task) + if should { + t.Errorf("Expected shouldDistill=false when no content, got true. Reason: %s", reason) + } +}