diff --git a/docs/go-agent-design.md b/docs/go-agent-design.md new file mode 100644 index 0000000000..3e249996f4 --- /dev/null +++ b/docs/go-agent-design.md @@ -0,0 +1,851 @@ +# Genkit Go Agent with Snapshots - Design Document + +## Overview + +This document describes the design for the `Agent` primitive in Genkit Go with snapshot-based state management. An Agent is a stateful, multi-turn conversational agent with automatic snapshot persistence and turn semantics. + +Snapshots provide: +- **State encapsulation**: Messages, user-defined state, and artifacts in a single serializable unit +- **Resumability**: Start new invocations from any previous snapshot +- **Flexibility**: Support for both client-managed and server-managed state patterns +- **Debugging**: Point-in-time state capture for inspection and replay + +This design builds on the bidirectional streaming primitives described in [go-bidi-design.md](go-bidi-design.md). + +## Package Location + +Agent is an AI concept and belongs in `go/ai/x/` (experimental): + +``` +go/ai/x/ +├── agent.go # Agent, AgentFunc, AgentParams, Responder +├── agent_options.go # AgentOption, StreamBidiOption +├── agent_test.go # Tests +``` + +Import as `aix "github.com/firebase/genkit/go/ai/x"`. + +--- + +## 1. Core Type Definitions + +### 1.1 State and Snapshot Types + +**AgentState** is the portable state that flows between client and server. It contains only the data needed for conversation continuity. + +**AgentSnapshot** is a persisted point-in-time capture with metadata. It wraps AgentState with additional fields for storage, debugging, and restoration. + +```go +// AgentState is the portable conversation state. +type AgentState[State any] struct { + // Messages is the conversation history. + Messages []*ai.Message `json:"messages,omitempty"` + // Custom is the user-defined state associated with this conversation. + Custom State `json:"custom,omitempty"` + // Artifacts are named collections of parts produced during the conversation. + Artifacts []*AgentArtifact `json:"artifacts,omitempty"` +} + +// AgentSnapshot is a persisted point-in-time capture of agent state. +type AgentSnapshot[State any] struct { + // SnapshotID is the unique identifier for this snapshot (content-addressed). + SnapshotID string `json:"snapshotId"` + // SessionID identifies the session this snapshot belongs to. + SessionID string `json:"sessionId"` + // ParentID is the ID of the previous snapshot in this session's timeline. + ParentID string `json:"parentId,omitempty"` + // CreatedAt is when the snapshot was created. + CreatedAt time.Time `json:"createdAt"` + // TurnIndex is the turn number when this snapshot was created (0-indexed). + TurnIndex int `json:"turnIndex"` + // Event is the snapshot event that triggered this snapshot. + Event SnapshotEvent `json:"event"` + // State is the actual conversation state. + State AgentState[State] `json:"state"` +} + +// AgentArtifact represents a named collection of parts produced during a session. +// Examples: generated files, images, code snippets, diagrams, etc. +type AgentArtifact struct { + // Name identifies the artifact (e.g., "generated_code.go", "diagram.png"). + Name string `json:"name,omitempty"` + // Parts contains the artifact content (text, media, etc.). + Parts []*ai.Part `json:"parts"` + // Metadata contains additional artifact-specific data. + Metadata map[string]any `json:"metadata,omitempty"` +} +``` + +### 1.2 Input/Output Types + +```go +// AgentInput is the input sent to an agent during a conversation turn. +// This wrapper allows future extensibility beyond just messages. +type AgentInput struct { + // Messages contains the user's input for this turn. + Messages []*ai.Message `json:"messages,omitempty"` +} + +// AgentInit is the input for starting an agent invocation. +// Provide either SnapshotID (to load from store) or State (direct state). +type AgentInit[State any] struct { + // SnapshotID loads state from a persisted snapshot. + // Mutually exclusive with State. + SnapshotID string `json:"snapshotId,omitempty"` + // State provides direct state for the invocation. + // Mutually exclusive with SnapshotID. + State *AgentState[State] `json:"state,omitempty"` +} + +// AgentResponse is the output when an agent invocation completes. +type AgentResponse[State any] struct { + // SessionID identifies the session for this conversation. + // Use this to list snapshots via store.ListSnapshots(ctx, sessionID). + SessionID string `json:"sessionId"` + // State contains the final conversation state. + State *AgentState[State] `json:"state"` +} +``` + +### 1.3 Stream Types + +```go +// AgentStreamChunk represents a single item in the agent's output stream. +// Only one field is populated per chunk. +type AgentStreamChunk[Stream any] struct { + // Chunk contains token-level generation data. + Chunk *ai.ModelResponseChunk `json:"chunk,omitempty"` + // Status contains user-defined structured status information. + // The Stream type parameter defines the shape of this data. + Status Stream `json:"status,omitempty"` + // Artifact contains a newly produced artifact. + Artifact *AgentArtifact `json:"artifact,omitempty"` + // SnapshotCreated contains the ID of a snapshot that was just persisted. + SnapshotCreated string `json:"snapshotCreated,omitempty"` +} +``` + +### 1.4 Session + +The Session provides mutable working state during an agent invocation. It is propagated via context so that nested operations (tools, sub-agents) can access consistent state. + +```go +// Session holds the working state during an agent invocation. +// It is propagated through context and provides read/write access to state. +type Session[State any] struct { + mu sync.RWMutex + id string + state AgentState[State] + store SnapshotStore[State] + + // Snapshot tracking + lastSnapshot *AgentSnapshot[State] + turnIndex int + newSnapshotIDs []string // IDs of snapshots created this invocation +} + +// ID returns the session identifier. +func (s *Session[State]) ID() string + +// State returns a copy of the current agent state. +func (s *Session[State]) State() *AgentState[State] + +// Messages returns the current conversation history. +func (s *Session[State]) Messages() []*ai.Message + +// AddMessages appends messages to the conversation history. +func (s *Session[State]) AddMessages(messages ...*ai.Message) + +// SetMessages replaces the entire conversation history. +func (s *Session[State]) SetMessages(messages []*ai.Message) + +// Custom returns the current user-defined custom state. +func (s *Session[State]) Custom() State + +// SetCustom updates the user-defined custom state. +func (s *Session[State]) SetCustom(custom State) + +// Artifacts returns the current artifacts. +func (s *Session[State]) Artifacts() []*AgentArtifact + +// AddArtifact appends an artifact. +func (s *Session[State]) AddArtifact(artifact *AgentArtifact) + +// SetArtifacts replaces the entire artifact list. +func (s *Session[State]) SetArtifacts(artifacts ...*AgentArtifact) + +// NewSnapshotIDs returns the IDs of snapshots created during this invocation. +func (s *Session[State]) NewSnapshotIDs() []string + +// Context integration +func NewSessionContext[State any](ctx context.Context, s *Session[State]) context.Context +func SessionFromContext[State any](ctx context.Context) *Session[State] +``` + +### 1.5 Responder + +The Responder wraps the output stream with typed methods for sending different kinds of data. + +```go +// Responder provides methods for sending data to the agent's output stream. +type Responder[Stream any] struct { + ch chan<- *AgentStreamChunk[Stream] + session *Session[any] // for snapshot notifications +} + +// Send sends a complete stream chunk. Use this for full control over the chunk contents. +func (r *Responder[Stream]) Send(chunk *AgentStreamChunk[Stream]) + +// SendChunk sends a generation chunk (token-level streaming). +func (r *Responder[Stream]) SendChunk(chunk *ai.ModelResponseChunk) + +// SendStatus sends a user-defined status update. +func (r *Responder[Stream]) SendStatus(status Stream) + +// SendArtifact sends an artifact to the stream. +func (r *Responder[Stream]) SendArtifact(artifact *AgentArtifact) + +// EndTurn signals that the agent has finished responding to the current input. +// This triggers a snapshot check (based on the configured callback). +// The consumer's Receive() iterator will exit, allowing them to send the next input. +func (r *Responder[Stream]) EndTurn() +``` + +### 1.6 Agent Function and Parameters + +```go +// AgentParams contains the parameters passed to an agent function. +// This struct may be extended with additional fields in the future. +type AgentParams[Stream, State any] struct { + // Session provides access to the working state. + Session *Session[State] + // Init contains the initialization data provided when starting the invocation. + Init *AgentInit[State] +} + +// AgentFunc is the function signature for agents. +// Type parameters: +// - Stream: Type for status updates sent via the responder +// - State: Type for user-defined state in snapshots +type AgentFunc[Stream, State any] func( + ctx context.Context, + inCh <-chan *AgentInput, + resp *Responder[Stream], + params *AgentParams[Stream, State], +) error +``` + +### 1.7 Agent + +```go +// Agent is a bidirectional streaming action with automatic snapshot management. +type Agent[Stream, State any] struct { + *corex.BidiAction[*AgentInit[State], *AgentInput, *AgentResponse[State], *AgentStreamChunk[Stream]] + store SnapshotStore[State] + snapshotCallback SnapshotCallback[State] +} +``` + +--- + +## 2. Snapshot Store + +### 2.1 Store Interface + +```go +// SnapshotStore persists and retrieves snapshots. +type SnapshotStore[State any] interface { + // GetSnapshot retrieves a snapshot by ID. Returns nil if not found. + GetSnapshot(ctx context.Context, snapshotID string) (*AgentSnapshot[State], error) + // SaveSnapshot persists a snapshot. + // The snapshot ID is computed from the content (content-addressed). + // If a snapshot with the same ID exists, this is a no-op. + SaveSnapshot(ctx context.Context, snapshot *AgentSnapshot[State]) error + // ListSnapshots returns snapshots for a session, ordered by creation time. + ListSnapshots(ctx context.Context, sessionID string) ([]*AgentSnapshot[State], error) +} +``` + +### 2.2 In-Memory Implementation + +```go +// InMemorySnapshotStore provides a thread-safe in-memory snapshot store. +type InMemorySnapshotStore[State any] struct { + snapshots map[string]*AgentSnapshot[State] + mu sync.RWMutex +} + +func NewInMemorySnapshotStore[State any]() *InMemorySnapshotStore[State] +``` + +--- + +## 3. Snapshot Callbacks + +### 3.1 Callback Types + +```go +// SnapshotEvent identifies when a snapshot opportunity occurs. +type SnapshotEvent int + +const ( + // SnapshotEventTurnEnd occurs after resp.EndTurn() is called. + SnapshotEventTurnEnd SnapshotEvent = iota + // SnapshotEventInvocationEnd occurs when the agent function returns. + SnapshotEventInvocationEnd +) + +// SnapshotContext provides context for snapshot decision callbacks. +type SnapshotContext[State any] struct { + // Event is the snapshot event that triggered this callback. + Event SnapshotEvent + // State is the current state that will be snapshotted if the callback returns true. + State *AgentState[State] + // PrevState is the state at the last snapshot, or nil if none exists. + PrevState *AgentState[State] + // TurnIndex is the current turn number. + TurnIndex int +} + +// SnapshotCallback decides whether to create a snapshot at a given event. +type SnapshotCallback[State any] func(ctx context.Context, sc *SnapshotContext[State]) bool +``` + +### 3.2 Convenience Callbacks + +```go +// SnapshotAlways returns a callback that always creates snapshots. +func SnapshotAlways[State any]() SnapshotCallback[State] { + return func(ctx context.Context, sc *SnapshotContext[State]) bool { + return true + } +} + +// SnapshotNever returns a callback that never creates snapshots. +func SnapshotNever[State any]() SnapshotCallback[State] { + return func(ctx context.Context, sc *SnapshotContext[State]) bool { + return false + } +} + +// SnapshotOn returns a callback that creates snapshots only for specified events. +func SnapshotOn[State any](events ...SnapshotEvent) SnapshotCallback[State] { + eventSet := make(map[SnapshotEvent]bool) + for _, e := range events { + eventSet[e] = true + } + return func(ctx context.Context, sc *SnapshotContext[State]) bool { + return eventSet[sc.Event] + } +} +``` + +--- + +## 4. API Surface + +### 4.1 Defining Agents + +```go +// DefineAgent creates an Agent with automatic snapshot management and registers it. +func DefineAgent[Stream, State any]( + r api.Registry, + name string, + fn AgentFunc[Stream, State], + opts ...AgentOption[State], +) *Agent[Stream, State] + +// AgentOption configures an Agent. +type AgentOption[State any] interface { + applyAgent(*agentOptions[State]) error +} + +// WithSnapshotStore sets the store for persisting snapshots. +func WithSnapshotStore[State any](store SnapshotStore[State]) AgentOption[State] + +// WithSnapshotCallback configures when snapshots are created. +// If not provided, snapshots are never created automatically. +func WithSnapshotCallback[State any](cb SnapshotCallback[State]) AgentOption[State] +``` + +### 4.2 Starting Connections + +```go +// StreamBidiOption configures a StreamBidi call. +type StreamBidiOption[State any] interface { + applyStreamBidi(*streamBidiOptions[State]) error +} + +// WithState sets the initial state for the invocation. +// Use this for client-managed state where the client sends state directly. +func WithState[State any](state *AgentState[State]) StreamBidiOption[State] + +// WithSnapshotID loads state from a persisted snapshot by ID. +// Use this for server-managed state where snapshots are stored. +func WithSnapshotID[State any](id string) StreamBidiOption[State] + +// StreamBidi starts a new agent invocation. +func (a *Agent[Stream, State]) StreamBidi( + ctx context.Context, + opts ...StreamBidiOption[State], +) (*AgentConnection[Stream, State], error) +``` + +### 4.3 Agent Connection + +```go +// AgentConnection wraps BidiConnection with agent-specific functionality. +type AgentConnection[Stream, State any] struct { + conn *corex.BidiConnection[*AgentInput, *AgentResponse[State], *AgentStreamChunk[Stream]] +} + +// Send sends an AgentInput to the agent. +// Use this for full control over the input structure. +func (c *AgentConnection[Stream, State]) Send(input *AgentInput) error + +// SendMessages sends messages to the agent. +// This is a convenience method that wraps messages in an AgentInput. +func (c *AgentConnection[Stream, State]) SendMessages(messages ...*ai.Message) error + +// SendText sends a single user text message to the agent. +// This is a convenience method that creates a user message and wraps it in AgentInput. +func (c *AgentConnection[Stream, State]) SendText(text string) error + +// Close signals that no more inputs will be sent. +func (c *AgentConnection[Stream, State]) Close() error + +// Receive returns an iterator for receiving stream chunks. +func (c *AgentConnection[Stream, State]) Receive() iter.Seq2[*AgentStreamChunk[Stream], error] + +// Output returns the final response after the agent completes. +func (c *AgentConnection[Stream, State]) Output() (*AgentResponse[State], error) + +// Done returns a channel closed when the connection completes. +func (c *AgentConnection[Stream, State]) Done() <-chan struct{} +``` + +### 4.4 High-Level Genkit API + +```go +// In go/genkit/agent.go + +func DefineAgent[Stream, State any]( + g *Genkit, + name string, + fn aix.AgentFunc[Stream, State], + opts ...aix.AgentOption[State], +) *aix.Agent[Stream, State] +``` + +--- + +## 5. Snapshot Lifecycle + +### 5.1 Snapshot Points + +Snapshots are created at two points: + +| Event | Trigger | Description | +|-------|---------|-------------| +| `SnapshotEventTurnEnd` | `resp.EndTurn()` | After processing user input and generating a response | +| `SnapshotEventInvocationEnd` | Agent function returns | Final state capture when invocation completes | + +At each point: +1. The snapshot callback is invoked +2. If callback returns true: + - Compute snapshot ID (SHA256 of JSON-serialized state) + - Persist to store (no-op if ID already exists) + - Send `SnapshotCreated` on the stream + - Record ID in `session.newSnapshotIDs` + +### 5.2 Snapshot ID Computation + +Snapshot IDs are content-addressed using SHA256 of the state (not the full snapshot with metadata): + +```go +func computeSnapshotID[State any](state *AgentState[State]) string { + data, _ := json.Marshal(state) + hash := sha256.Sum256(data) + return hex.EncodeToString(hash[:]) +} +``` + +Benefits: +- **Deduplication**: Identical states produce identical IDs +- **Verification**: State integrity can be verified against ID +- **Determinism**: No dependency on timestamps for uniqueness + +### 5.3 Resuming from Snapshots + +When `WithSnapshotID` is provided to `StreamBidi`: + +1. Load the snapshot from the store +2. Extract the `AgentState` from the snapshot +3. Initialize the session with that state (messages, state, artifacts) +4. New snapshots will reference this as the parent +5. Conversation continues from the restored state + +When `WithState` is provided to `StreamBidi`: + +1. Use the provided `AgentState` directly +2. Initialize the session with that state +3. No parent snapshot reference (client-managed mode) + +--- + +## 6. Internal Flow + +### 6.1 Agent Wrapping + +The user's `AgentFunc` returns `error`. The framework wraps this to produce `AgentResponse`: + +```go +// Simplified internal logic +func (a *Agent[Stream, State]) runWrapped( + ctx context.Context, + init *AgentInit[State], + inCh <-chan *AgentInput, + outCh chan<- *AgentStreamChunk[Stream], +) (*AgentResponse[State], error) { + // Initialize session from snapshot + session := newSessionFromInit(init, a.store) + ctx = NewSessionContext(ctx, session) + + // Create responder with snapshot callback + responder := &Responder[Stream]{ + ch: outCh, + session: session, + snapshotCallback: a.snapshotCallback, + store: a.store, + } + + params := &AgentParams[Stream, State]{ + Session: session, + Init: init, + } + + // Run user function + err := a.fn(ctx, params, inCh, responder) + if err != nil { + return nil, err + } + + // Trigger invocation-end snapshot + responder.triggerSnapshot(SnapshotEventInvocationEnd) + + // Build response from session state + return &AgentResponse[State]{ + State: session.toState(), + }, nil +} +``` + +### 6.2 Turn Signaling + +When `resp.EndTurn()` is called: + +1. Trigger snapshot callback with `SnapshotEventTurnEnd` +2. If callback returns true, create and persist snapshot +3. Send `SnapshotCreated` notification on stream +4. Send end-of-turn signal so consumer's `Receive()` exits + +```go +func (r *Responder[Stream]) EndTurn() { + // Trigger snapshot check + r.triggerSnapshot(SnapshotEventTurnEnd) + + // Signal end of turn (internal channel mechanism) + r.ch <- &AgentStreamChunk[Stream]{endTurn: true} // internal field +} +``` + +--- + +## 7. Example Usage + +### 7.1 Chat Agent with Snapshots + +```go +package main + +import ( + "context" + "fmt" + + "github.com/firebase/genkit/go/ai" + aix "github.com/firebase/genkit/go/ai/x" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/googlegenai" +) + +// ChatState holds user-defined state for the chat agent. +type ChatState struct { + UserPreferences map[string]string `json:"userPreferences,omitempty"` + TopicHistory []string `json:"topicHistory,omitempty"` +} + +// ChatStatus represents status updates streamed to the client. +type ChatStatus struct { + Phase string `json:"phase"` + Details string `json:"details,omitempty"` +} + +func main() { + ctx := context.Background() + store := aix.NewInMemorySnapshotStore[ChatState]() + + g := genkit.Init(ctx, + genkit.WithPlugins(&googlegenai.GoogleAI{}), + genkit.WithDefaultModel("googleai/gemini-3-flash-preview"), + ) + + chatAgent := genkit.DefineAgent(g, "chatAgent", + func(ctx context.Context, inCh <-chan *aix.AgentInput, resp *aix.Responder[ChatStatus], params *aix.AgentParams[ChatStatus, ChatState]) error { + sess := params.Session + + for input := range inCh { + // Add user messages to session + sess.AddMessages(input.Messages...) + + // Send status update + resp.SendStatus(ChatStatus{Phase: "generating"}) + + // Generate response + for result, err := range genkit.GenerateStream(ctx, g, + ai.WithMessages(sess.Messages()...), + ) { + if err != nil { + return err + } + if result.Done { + sess.AddMessages(result.Response.Message) + } + resp.SendChunk(result.Chunk) + } + + // Update custom state + custom := sess.Custom() + custom.TopicHistory = append(custom.TopicHistory, extractTopic(input.Messages)) + sess.SetCustom(custom) + + resp.SendStatus(ChatStatus{Phase: "complete"}) + resp.EndTurn() // Triggers snapshot check + } + + return nil + }, + aix.WithSnapshotStore(store), + aix.WithSnapshotCallback(aix.SnapshotOn[ChatState]( + aix.SnapshotEventTurnEnd, + aix.SnapshotEventInvocationEnd, + )), + ) + + // Start new conversation + conn, _ := chatAgent.StreamBidi(ctx) + + conn.SendText("Hello! Tell me about Go programming.") + for chunk, err := range conn.Receive() { + if err != nil { + panic(err) + } + if chunk.Chunk != nil { + fmt.Print(chunk.Chunk.Text()) + } + if chunk.SnapshotCreated != "" { + fmt.Printf("\n[Snapshot created: %s]\n", chunk.SnapshotCreated) + } + } + + conn.SendText("What are channels used for?") + for chunk, err := range conn.Receive() { + if err != nil { + panic(err) + } + if chunk.Chunk != nil { + fmt.Print(chunk.Chunk.Text()) + } + } + + conn.Close() + + response, _ := conn.Output() + fmt.Printf("\nSession ID: %s\n", response.SessionID) + fmt.Printf("Messages in history: %d\n", len(response.State.Messages)) + + // List all snapshots for this session + snapshots, _ := store.ListSnapshots(ctx, response.SessionID) + fmt.Printf("Total snapshots in session: %d\n", len(snapshots)) +} +``` + +### 7.2 Resuming from a Snapshot + +```go +// Later, resume from a saved snapshot +snapshotID := "abc123..." + +conn, _ := chatAgent.StreamBidi(ctx, aix.WithSnapshotID[ChatState](snapshotID)) + +conn.SendText("Continue our discussion about channels") +for chunk, err := range conn.Receive() { + // ... handle response ... +} +``` + +### 7.3 Client-Managed State + +For clients that manage their own state (e.g., web apps with local storage): + +```go +// Client sends state directly on each invocation +clientState := &aix.AgentState[ChatState]{ + Messages: previousMessages, + Custom: ChatState{UserPreferences: prefs}, +} + +conn, _ := chatAgent.StreamBidi(ctx, aix.WithState(clientState)) + +// ... interact ... + +response, _ := conn.Output() +// Client stores response.State locally for next invocation +``` + +### 7.4 Agent with Artifacts + +```go +type CodeState struct { + Language string `json:"language"` +} + +type CodeStatus struct { + Phase string `json:"phase"` +} + +codeAgent := genkit.DefineAgent(g, "codeAgent", + func(ctx context.Context, inCh <-chan *aix.AgentInput, resp *aix.Responder[CodeStatus], params *aix.AgentParams[CodeStatus, CodeState]) error { + sess := params.Session + + for input := range inCh { + sess.AddMessages(input.Messages...) + + // Generate code... + generatedCode := "func main() { fmt.Println(\"Hello\") }" + + resp.SendStatus(CodeStatus{Phase: "code_generated"}) + + // Send artifact + resp.SendArtifact(&aix.AgentArtifact{ + Name: "main.go", + Parts: []*ai.Part{ai.NewTextPart(generatedCode)}, + Metadata: map[string]any{"language": "go"}, + }) + + sess.AddMessages(ai.NewModelTextMessage("Here's the code you requested.")) + resp.EndTurn() + } + + return nil + }, + aix.WithSnapshotStore(store), + aix.WithSnapshotCallback(aix.SnapshotAlways[CodeState]()), +) +``` + +--- + +## 8. Files to Create/Modify + +### New Files + +| File | Description | +|------|-------------| +| `go/ai/x/agent.go` | Agent, AgentFunc, AgentParams, Responder, Session | +| `go/ai/x/agent_state.go` | AgentState, AgentSnapshot, AgentArtifact, AgentInit, AgentResponse, AgentStreamChunk | +| `go/ai/x/agent_options.go` | AgentOption, StreamBidiOption, SnapshotCallback, SnapshotContext | +| `go/ai/x/agent_store.go` | SnapshotStore interface, InMemorySnapshotStore | +| `go/ai/x/agent_test.go` | Tests | + +### Modified Files + +| File | Change | +|------|--------| +| `go/genkit/agent.go` | Add DefineAgent wrapper | +| `go/core/api/action.go` | Add ActionTypeAgent constant if needed | + +--- + +## 9. Design Decisions + +### Why Separate AgentState from AgentSnapshot? + +**AgentState** is the portable state that flows between client and server: +- Just the data: Messages, State, Artifacts +- No IDs, timestamps, or metadata +- Time is implicit: it's either input or output +- Clients manage it however they want + +**AgentSnapshot** is a persisted point-in-time capture: +- Has an ID (content-addressed) +- Has timestamps and metadata (ParentID, TurnIndex, Event) +- Used for storage, debugging, branching/restoration +- Managed by the framework and store + +This separation provides: +- **Clarity**: Users know exactly what fields are relevant for their use case +- **Simplicity**: Client-managed state doesn't deal with server metadata +- **Flexibility**: Server can add snapshot metadata without affecting client API + +### Why Mandate Messages in State? + +Messages are fundamental to conversation continuity. By including them in the state schema: + +- Ensures consistent conversation history across invocations +- Prevents common bugs where messages are lost between turns +- Enables the framework to optimize message handling +- Provides a standard structure that tools and middleware can rely on + +### Why Content-Addressed Snapshot IDs? + +- **Deduplication**: Identical states don't create duplicate snapshots +- **Verification**: Snapshot integrity can be verified +- **Determinism**: Same state always produces same ID, regardless of timing + +### Why Callback-Based Snapshotting? + +Rather than always snapshotting or never snapshotting: + +- **Efficiency**: Only snapshot when needed +- **Flexibility**: Different strategies for different use cases +- **User control**: Application decides snapshot granularity + +--- + +## 10. Open Questions + +### Artifact and Session State Relationship + +**TODO**: Should `resp.SendArtifact()` automatically add the artifact to session state? + +Options: +1. **Automatic**: `SendArtifact()` adds to session AND streams to client +2. **Manual**: User must call both `sess.AddArtifact()` and `resp.SendArtifact()` +3. **Configurable**: Option to control the behavior + +Considerations: +- Automatic reduces boilerplate and prevents forgetting one call +- Manual provides more control (e.g., stream without persisting) +- Similar question applies to `SendChunk()` and message accumulation + +--- + +## 11. Future Considerations + +Out of scope for this design: + +- **Snapshot expiration**: Automatic cleanup based on age or count +- **Snapshot compression**: Delta/patch-based storage +- **Snapshot branching**: Tree-structured conversation histories +- **Snapshot annotations**: User-provided labels or descriptions +- **Tool iteration snapshots**: Snapshots after tool execution (could be added as new SnapshotEvent) diff --git a/docs/go-bidi-design.md b/docs/go-bidi-design.md new file mode 100644 index 0000000000..c1551c4c13 --- /dev/null +++ b/docs/go-bidi-design.md @@ -0,0 +1,588 @@ +# Genkit Go Bidirectional Streaming Features - Design Document + +## Overview + +This document describes the design for bidirectional streaming features in Genkit Go. The implementation introduces three new primitives: + +1. **BidiAction** - Core primitive for bidirectional operations (`go/core/x`) +2. **BidiFlow** - BidiAction with observability, intended for user definition (`go/core/x`) +3. **BidiModel** - Specialized bidi action for real-time LLM APIs (`go/ai/x`) + +For stateful multi-turn agents with session persistence, see [go-agent-design.md](go-agent-design.md). + +## Package Location + +``` +go/core/x/ +├── bidi.go # BidiAction, BidiFunc, BidiConnection +├── bidi_flow.go # BidiFlow +├── bidi_options.go # Options +├── bidi_test.go # Tests + +go/ai/x/ +├── bidi_model.go # BidiModel, BidiModelFunc +├── bidi_model_test.go +``` + +Import as: +- `corex "github.com/firebase/genkit/go/core/x"` +- `aix "github.com/firebase/genkit/go/ai/x"` + +--- + +## 1. Core Type Definitions + +### 1.1 BidiAction + +```go +// BidiAction represents a bidirectional streaming action. +// Type parameters: +// - Init: Type of initialization data (use struct{} if not needed) +// - In: Type of each message sent to the action +// - Out: Type of the final output +// - Stream: Type of each streamed output chunk +type BidiAction[Init, In, Out, Stream any] struct { + name string + fn BidiFunc[Init, In, Out, Stream] + registry api.Registry + desc *api.ActionDesc +} + +// BidiFunc is the function signature for bidi actions. +type BidiFunc[Init, In, Out, Stream any] func( + ctx context.Context, + init Init, + inCh <-chan In, + outCh chan<- Stream, +) (Out, error) +``` + +### 1.2 BidiConnection + +```go +// BidiConnection represents an active bidirectional streaming session. +type BidiConnection[In, Out, Stream any] struct { + inputCh chan In // Internal, accessed via Send() + streamCh chan Stream // Internal output stream channel + doneCh chan struct{} // Closed when action completes + output Out // Final output (valid after done) + err error // Error if any (valid after done) + ctx context.Context + cancel context.CancelFunc + span tracing.Span // Trace span, ended on completion + mu sync.Mutex + closed bool +} + +// Send sends an input message to the bidi action. +func (c *BidiConnection[In, Out, Stream]) Send(input In) error + +// Close signals that no more inputs will be sent. +func (c *BidiConnection[In, Out, Stream]) Close() error + +// Receive returns an iterator for receiving streamed response chunks. +// The iterator completes when the action finishes or signals end of turn. +func (c *BidiConnection[In, Out, Stream]) Receive() iter.Seq2[Stream, error] + +// Output returns the final output after the action completes. +// Blocks until done or context cancelled. +func (c *BidiConnection[In, Out, Stream]) Output() (Out, error) + +// Done returns a channel closed when the connection completes. +func (c *BidiConnection[In, Out, Stream]) Done() <-chan struct{} +``` + +### 1.3 BidiFlow + +```go +type BidiFlow[Init, In, Out, Stream any] struct { + *BidiAction[Init, In, Out, Stream] +} +``` + +--- + +## 2. BidiModel + +### 2.1 Overview + +`BidiModel` is a specialized bidi action for real-time LLM APIs like Gemini Live and OpenAI Realtime. These APIs establish a persistent connection where configuration (temperature, system prompt, tools) must be provided upfront, and then the conversation streams bidirectionally. + +### 2.2 The Role of `init` + +For real-time sessions, the connection to the model API often requires configuration to be established *before* the first user message is received. The `init` payload fulfills this requirement: + +- **`init`**: `ModelRequest` (contains config, tools, system prompt) +- **`inputStream`**: Stream of `ModelRequest` (contains user messages/turns) +- **`stream`**: Stream of `ModelResponseChunk` + +### 2.3 Type Definitions + +```go +// In go/ai/x/bidi_model.go + +// BidiModel represents a bidirectional streaming model for real-time LLM APIs. +type BidiModel struct { + *corex.BidiAction[*ai.ModelRequest, *ai.ModelRequest, *ai.ModelResponse, *ai.ModelResponseChunk] +} + +// BidiModelFunc is the function signature for bidi model implementations. +type BidiModelFunc func( + ctx context.Context, + init *ai.ModelRequest, + inCh <-chan *ai.ModelRequest, + outCh chan<- *ai.ModelResponseChunk, +) (*ai.ModelResponse, error) +``` + +### 2.4 Defining a BidiModel + +```go +// DefineBidiModel creates and registers a BidiModel for real-time LLM interactions. +// The opts parameter follows the same pattern as DefineModel for consistency. +func DefineBidiModel(r api.Registry, name string, opts *ai.ModelOptions, fn BidiModelFunc) *BidiModel +``` + +**Example Plugin Implementation:** + +```go +func (g *GoogleAI) defineBidiModel(r api.Registry) *aix.BidiModel { + return aix.DefineBidiModel(r, "googleai/gemini-2.0-flash-live", + &ai.ModelOptions{ + Label: "Gemini 2.0 Flash Live", + Supports: &ai.ModelSupports{ + Multiturn: true, + Tools: true, + SystemRole: true, + Media: true, + }, + }, + func(ctx context.Context, init *ai.ModelRequest, inCh <-chan *ai.ModelRequest, outCh chan<- *ai.ModelResponseChunk) (*ai.ModelResponse, error) { + session, err := g.client.Live.Connect(ctx, "gemini-2.0-flash-live", &genai.LiveConnectConfig{ + SystemInstruction: toContent(init.Messages), + Tools: toTools(init.Tools), + Temperature: toFloat32Ptr(init.Config), + ResponseModalities: []genai.Modality{genai.ModalityText}, + }) + if err != nil { + return nil, err + } + defer session.Close() + + var totalUsage ai.GenerationUsage + + for request := range inCh { + err := session.SendClientContent(genai.LiveClientContentInput{ + Turns: toContents(request.Messages), + TurnComplete: true, + }) + if err != nil { + return nil, err + } + + for { + msg, err := session.Receive() + if err != nil { + return nil, err + } + + if msg.ToolCall != nil { + outCh <- &ai.ModelResponseChunk{ + Content: toToolCallParts(msg.ToolCall), + } + continue + } + + if msg.ServerContent != nil { + if msg.ServerContent.ModelTurn != nil { + outCh <- &ai.ModelResponseChunk{ + Content: fromParts(msg.ServerContent.ModelTurn.Parts), + } + } + if msg.ServerContent.TurnComplete { + break + } + } + + if msg.UsageMetadata != nil { + totalUsage.InputTokens += int(msg.UsageMetadata.PromptTokenCount) + totalUsage.OutputTokens += int(msg.UsageMetadata.CandidatesTokenCount) + } + } + } + + return &ai.ModelResponse{ + Usage: &totalUsage, + }, nil + }, + ) +} +``` + +### 2.5 Using BidiModel (`GenerateBidi`) + +`GenerateBidi` is the high-level API for interacting with bidi models. It provides a session-like interface for real-time conversations. + +```go +// In go/genkit/generate.go or go/ai/x/generate_bidi.go + +// ModelBidiConnection wraps BidiConnection with model-specific convenience methods. +type ModelBidiConnection struct { + conn *corex.BidiConnection[*ai.ModelRequest, *ai.ModelResponse, *ai.ModelResponseChunk] +} + +// Send sends a user message to the model. +func (s *ModelBidiConnection) Send(messages ...*ai.Message) error { + return s.conn.Send(&ai.ModelRequest{Messages: messages}) +} + +// SendText is a convenience method for sending a text message. +func (s *ModelBidiConnection) SendText(text string) error { + return s.Send(ai.NewUserTextMessage(text)) +} + +// Stream returns an iterator for receiving response chunks. +func (s *ModelBidiConnection) Receive() iter.Seq2[*ai.ModelResponseChunk, error] { + return s.conn.Receive() +} + +// Close signals that the conversation is complete. +func (s *ModelBidiConnection) Close() error { + return s.conn.Close() +} + +// Output returns the final response after the session completes. +func (s *ModelBidiConnection) Output() (*ai.ModelResponse, error) { + return s.conn.Output() +} +``` + +**Usage:** + +`GenerateBidi` uses the same shared option types as regular `Generate` calls. Options like `WithModel`, `WithConfig`, `WithSystem`, and `WithTools` work the same way - they configure the initial session setup. + +```go +// GenerateBidi starts a bidirectional streaming session with a model. +// Uses the existing shared option types from ai/option.go. +func GenerateBidi(ctx context.Context, g *Genkit, opts ...ai.GenerateBidiOption) (*ModelBidiConnection, error) +``` + +**Example:** + +```go +conn, err := genkit.GenerateBidi(ctx, g, + ai.WithModel(geminiLive), + ai.WithConfig(&genai.LiveConnectConfig{Temperature: genai.Ptr[float32](0.7)}), + ai.WithSystem("You are a helpful voice assistant"), + ai.WithTools(weatherTool), +) +if err != nil { + return err +} +defer conn.Close() + +conn.SendText("Hello!") + +for chunk, err := range conn.Receive() { + if err != nil { + return err + } + fmt.Print(chunk.Text()) +} + +conn.SendText("Tell me more about that.") +for chunk, err := range conn.Receive() { + // ... +} + +response, _ := conn.Output() +fmt.Printf("Total tokens: %d\n", response.Usage.TotalTokens) +``` + +### 2.6 Tool Calling in BidiModel + +Real-time models may support tool calling. The pattern follows the standard generate flow but within the streaming context: + +```go +conn, _ := genkit.GenerateBidi(ctx, g, + ai.WithModel(geminiLive), + ai.WithTools(weatherTool, calculatorTool), +) + +conn.SendText("What's the weather in NYC?") + +for chunk, err := range conn.Receive() { + if err != nil { + return err + } + + if toolCall := chunk.ToolCall(); toolCall != nil { + result, _ := toolCall.Tool.Execute(ctx, toolCall.Input) + conn.Send(ai.NewToolResponseMessage(toolCall.ID, result)) + } else { + fmt.Print(chunk.Text()) + } +} +``` + +--- + +## 3. API Surface + +### 3.1 Defining Bidi Actions + +```go +// In go/core/x/bidi.go + +// NewBidiAction creates a BidiAction without registering it. +func NewBidiAction[Init, In, Out, Stream any]( + name string, + fn BidiFunc[Init, In, Out, Stream], +) *BidiAction[Init, In, Out, Stream] + +// DefineBidiAction creates and registers a BidiAction. +func DefineBidiAction[Init, In, Out, Stream any]( + r api.Registry, + name string, + fn BidiFunc[Init, In, Out, Stream], +) *BidiAction[Init, In, Out, Stream] +``` + +Schemas for `In`, `Out`, `Init`, and `Stream` types are automatically inferred from the type parameters using the existing JSON schema inference in `go/internal/base/json.go`. + +### 3.2 Defining Bidi Flows + +```go +// In go/core/x/bidi_flow.go + +// DefineBidiFlow creates a BidiFlow with tracing and registers it. +// Use this for user-defined bidirectional streaming operations. +func DefineBidiFlow[Init, In, Out, Stream any]( + r api.Registry, + name string, + fn BidiFunc[Init, In, Out, Stream], +) *BidiFlow[Init, In, Out, Stream] +``` + +### 3.3 Starting Connections + +All bidi types (BidiAction, BidiFlow, BidiModel) use the same `StreamBidi` method to start connections: + +```go +func (ba *BidiAction[Init, In, Out, Stream]) StreamBidi(ctx context.Context, init Init) (*BidiConnection[In, Out, Stream], error) +``` + +### 3.4 High-Level Genkit API + +```go +// In go/genkit/bidi.go + +func DefineBidiFlow[Init, In, Out, Stream any]( + g *Genkit, + name string, + fn corex.BidiFunc[Init, In, Out, Stream], +) *corex.BidiFlow[Init, In, Out, Stream] + +// GenerateBidi uses shared options from ai/option.go +// Options like WithModel, WithConfig, WithSystem, WithTools configure the session init. +func GenerateBidi( + ctx context.Context, + g *Genkit, + opts ...ai.GenerateBidiOption, +) (*ModelBidiConnection, error) +``` + +--- + +## 4. Integration with Existing Infrastructure + +### 4.1 Tracing Integration + +BidiFlows create spans that remain open for the lifetime of the connection, enabling streaming trace visualization in the Dev UI. + +**Key behaviors:** +- Span starts when `StreamBidi()` is called +- Span ends when the bidi function returns (via `defer` in the connection goroutine) +- Flow context is injected so `core.Run()` works inside the bidi function +- Nested spans for sub-operations (e.g., each LLM call) work normally + +**Important**: The span stays open while the connection is active, allowing: +- Streaming traces to the Dev UI in real-time +- Nested spans for sub-operations (e.g., each LLM call) +- Events recorded as they happen + +### 4.2 Action Registration + +Add new action types and schema fields: + +```go +// In go/core/api/action.go +const ( + ActionTypeBidiFlow ActionType = "bidi-flow" + ActionTypeBidiModel ActionType = "bidi-model" +) + +// ActionDesc gets two new optional fields +type ActionDesc struct { + // ... existing fields ... + StreamSchema map[string]any `json:"streamSchema,omitempty"` // NEW: schema for streamed chunks + InitSchema map[string]any `json:"initSchema,omitempty"` // NEW: schema for initialization data +} +``` + +--- + +## 5. Example Usage + +### 5.1 Basic Echo Bidi Flow + +```go +package main + +import ( + "context" + "fmt" + + "github.com/firebase/genkit/go/genkit" +) + +func main() { + ctx := context.Background() + g := genkit.Init(ctx) + + echoFlow := genkit.DefineBidiFlow(g, "echo", + func(ctx context.Context, init struct{}, inCh <-chan string, outCh chan<- string) (string, error) { + var count int + for input := range inCh { + count++ + outCh <- fmt.Sprintf("echo: %s", input) + } + return fmt.Sprintf("processed %d messages", count), nil + }, + ) + + conn, err := echoFlow.StreamBidi(ctx) + if err != nil { + panic(err) + } + + conn.Send("hello") + conn.Send("world") + conn.Close() + + for chunk, err := range conn.Receive() { + if err != nil { + panic(err) + } + fmt.Println(chunk) + } + + output, _ := conn.Output() + fmt.Println(output) +} +``` + +### 5.2 Bidi Flow with Initialization Data + +```go +type ChatInit struct { + SystemPrompt string `json:"systemPrompt"` + Temperature float64 `json:"temperature"` +} + +configuredChat := genkit.DefineBidiFlow(g, "configuredChat", + func(ctx context.Context, init ChatInit, inCh <-chan string, outCh chan<- string) (string, error) { + for input := range inCh { + resp, _ := genkit.GenerateText(ctx, g, + ai.WithSystem(init.SystemPrompt), + ai.WithConfig(&genai.GenerateContentConfig{Temperature: &init.Temperature}), + ai.WithPrompt(input), + ) + outCh <- resp + } + return "done", nil + }, +) + +conn, _ := configuredChat.StreamBidi(ctx, ChatInit{ + SystemPrompt: "You are a helpful assistant.", + Temperature: 0.7, +}) +``` + +--- + +## 6. Files to Create/Modify + +### New Files + +| File | Description | +|------|-------------| +| `go/core/x/bidi.go` | BidiAction, BidiFunc, BidiConnection | +| `go/core/x/bidi_flow.go` | BidiFlow with tracing | +| `go/core/x/bidi_options.go` | BidiOption types | +| `go/core/x/bidi_test.go` | Tests | +| `go/ai/x/bidi_model.go` | BidiModel, BidiModelFunc, ModelBidiConnection | +| `go/ai/x/bidi_model_test.go` | Tests | +| `go/genkit/bidi.go` | High-level API wrappers | + +### Modified Files + +| File | Change | +|------|--------| +| `go/core/api/action.go` | Add `ActionTypeBidiFlow`, `ActionTypeBidiModel` constants | + +--- + +## 7. Implementation Notes + +### Error Handling +- Errors from the bidi function propagate to both `Responses()` iterator and `Output()` +- Context cancellation closes all channels and terminates the action +- Send after Close returns an error +- Errors are yielded as the second value in the `iter.Seq2[Stream, error]` iterator + +### Goroutine Management +- BidiConnection spawns a goroutine to run the action +- Proper cleanup on context cancellation using `defer` and `sync.Once` +- Channel closure follows Go idioms (sender closes) +- Trace span is ended in the goroutine's defer + +### Thread Safety +- BidiConnection uses mutex for state (closed flag) +- Send is safe to call from multiple goroutines + +### Channels and Backpressure +- Both input and output channels are **unbuffered** by default (size 0) +- This provides natural backpressure: `Send()` blocks until the action reads, output blocks until consumer reads +- If needed, `WithInputBufferSize` / `WithOutputBufferSize` options could be added later for specific use cases + +### Tracing +- Span is started when connection is created, ended when action completes +- Nested spans work normally within the bidi function +- Events can be recorded throughout the connection lifecycle +- Dev UI can show traces in real-time as they stream +- Implementation uses the existing tracer infrastructure (details left to implementation) + +### Shutdown Sequence +When `Close()` is called on a BidiConnection: +1. The input channel is closed, signaling no more inputs +2. The bidi function's `for range inputStream` loop exits +3. The function returns its final output +4. The stream channel is closed +5. The `Done()` channel is closed +6. `Output()` unblocks and returns the result + +On context cancellation: +1. Context error propagates to the bidi function +2. All channels are closed +3. `Output()` returns the context error + +--- + +## 8. Integration with Reflection API + +These features align with **Reflection API V2**, which uses WebSockets to support bidirectional streaming between the Runtime and the CLI/Manager. + +- `runAction` now supports an `input` stream +- `streamChunk` notifications are bidirectional (Manager <-> Runtime) diff --git a/docs/go-session-design.md b/docs/go-session-design.md new file mode 100644 index 0000000000..b64404b03b --- /dev/null +++ b/docs/go-session-design.md @@ -0,0 +1,530 @@ +# Genkit Go Session Snapshots - Design Document + +## Overview + +This document describes the design for session snapshots in Genkit Go. This feature builds on the bidirectional streaming primitives described in [go-bidi-design.md](./go-bidi-design.md), extending the session management system with point-in-time state capture and restoration capabilities. + +Session snapshots enable: +- **Debugging**: Inspect session state at any point in a conversation +- **Restoration**: Resume conversations from previous states +- **Dev UI Integration**: Display state alongside traces for better observability + +--- + +# Part 1: API Definitions + +## 1. Core Types + +### 1.1 Snapshot + +```go +// Snapshot represents a point-in-time capture of session state. +// Snapshots are immutable once created. +type Snapshot[S any] struct { + // ID is the content-addressed identifier (SHA256 of JSON-serialized state). + ID string `json:"id"` + + // ParentID is the ID of the previous snapshot in this session's timeline. + // Empty for the first snapshot in a session. + ParentID string `json:"parentId,omitempty"` + + // SessionID is the session this snapshot belongs to. + SessionID string `json:"sessionId"` + + // CreatedAt is when the snapshot was created. + CreatedAt time.Time `json:"createdAt"` + + // State is the complete session state at the time of the snapshot. + State S `json:"state"` + + // Index is a monotonically increasing sequence number for ordering snapshots + // within a session. This is independent of turn boundaries. + Index int `json:"index"` + + // TurnIndex is the turn number when this snapshot was created (0-indexed). + // Turn 0 is after the first user input and agent response. + TurnIndex int `json:"turnIndex"` + + // Event is the snapshot event that triggered this snapshot. + Event SnapshotEvent `json:"event"` + + // Orphaned indicates this snapshot is no longer on the main timeline. + // This occurs when a user restores from an earlier snapshot, causing + // all subsequent snapshots to be marked as orphaned. + Orphaned bool `json:"orphaned,omitempty"` +} +``` + +### 1.2 SnapshotEvent + +```go +// SnapshotEvent identifies when a snapshot opportunity occurs. +type SnapshotEvent int + +const ( + // SnapshotEventTurnEnd occurs after resp.EndTurn() is called, + // when control returns to the user. + SnapshotEventTurnEnd SnapshotEvent = iota + + // SnapshotEventToolIterationEnd occurs after all tool calls in a single + // model iteration complete, before the results are sent back to the model. + // This captures state after tools have mutated it but before the next + // model response. + SnapshotEventToolIterationEnd + + // SnapshotEventInvocationEnd occurs when the agent function returns, + // capturing the final state of the invocation. + SnapshotEventInvocationEnd +) +``` + +### 1.3 SnapshotContext + +```go +// SnapshotContext provides context for snapshot decision callbacks. +type SnapshotContext[S any] struct { + // Event is the snapshot event that triggered this callback. + Event SnapshotEvent + + // State is the current session state that will be snapshotted if the callback returns true. + State S + + // PrevState is the state at the last snapshot, or nil if no previous snapshot exists. + // Useful for comparing states to decide whether a snapshot is needed. + PrevState *S + + // Index is the sequence number this snapshot would have if created. + Index int + + // TurnIndex is the current turn number. + TurnIndex int +} +``` + +### 1.4 SnapshotCallback + +```go +// SnapshotCallback decides whether to create a snapshot at a given event. +// It receives the context and snapshot context, returning true if a snapshot +// should be created. +// +// The callback is invoked at each snapshot opportunity. Users can filter +// by event type, inspect state, compare with previous state, or apply any +// custom logic to decide. +type SnapshotCallback[S any] = func(ctx context.Context, snap *SnapshotContext[S]) bool +``` + +--- + +## 2. Store Interface + +The existing `Store[S]` interface in `go/core/x/session` is extended with snapshot methods: + +```go +type Store[S any] interface { + // Existing session methods + Get(ctx context.Context, sessionID string) (*Data[S], error) + Save(ctx context.Context, sessionID string, data *Data[S]) error + + // GetSnapshot retrieves a snapshot by ID. Returns nil if not found. + GetSnapshot(ctx context.Context, snapshotID string) (*Snapshot[S], error) + + // SaveSnapshot persists a snapshot. If a snapshot with the same ID already + // exists (content-addressed deduplication), this is a no-op and returns nil. + SaveSnapshot(ctx context.Context, snapshot *Snapshot[S]) error + + // ListSnapshots returns snapshots for a session, ordered by Index ascending. + // If includeOrphaned is false, only active (non-orphaned) snapshots are returned. + ListSnapshots(ctx context.Context, sessionID string, includeOrphaned bool) ([]*Snapshot[S], error) + + // InvalidateSnapshotsAfter marks all snapshots with Index > afterIndex as orphaned. + // Called when restoring from a snapshot to mark "future" snapshots as no longer active. + InvalidateSnapshotsAfter(ctx context.Context, sessionID string, afterIndex int) error +} +``` + +--- + +## 3. Agent Options + +### 3.1 WithSnapshotCallback + +```go +// WithSnapshotCallback configures when snapshots are created. +// The callback is invoked at each snapshot opportunity (turn end, tool iteration +// end, invocation end) and decides whether to create a snapshot. +// +// If no callback is provided, snapshots are never created automatically. +// Requires WithSessionStore to be configured; otherwise snapshots cannot be persisted. +func WithSnapshotCallback[S any](cb SnapshotCallback[S]) AgentOption[S] +``` + +### 3.2 Convenience Callbacks + +```go +// SnapshotAlways returns a callback that always creates snapshots at all events. +func SnapshotAlways[S any]() SnapshotCallback[S] { + return func(ctx context.Context, snap *SnapshotContext[S]) bool { + return true + } +} + +// SnapshotNever returns a callback that never creates snapshots. +// This is the default behavior when no callback is configured. +func SnapshotNever[S any]() SnapshotCallback[S] { + return func(ctx context.Context, snap *SnapshotContext[S]) bool { + return false + } +} + +// SnapshotOn returns a callback that creates snapshots only for the specified events. +func SnapshotOn[S any](events ...SnapshotEvent) SnapshotCallback[S] { + eventSet := make(map[SnapshotEvent]bool) + for _, e := range events { + eventSet[e] = true + } + return func(ctx context.Context, snap *SnapshotContext[S]) bool { + return eventSet[snap.Event] + } +} + +// SnapshotOnChange returns a callback that creates snapshots only when state has changed +// since the last snapshot. +func SnapshotOnChange[S any](events ...SnapshotEvent) SnapshotCallback[S] { + eventSet := make(map[SnapshotEvent]bool) + for _, e := range events { + eventSet[e] = true + } + return func(ctx context.Context, snap *SnapshotContext[S]) bool { + if !eventSet[snap.Event] { + return false + } + // Always snapshot if this is the first one + if snap.PrevState == nil { + return true + } + // Compare by computing content-addressed IDs + return computeStateHash(snap.State) != computeStateHash(*snap.PrevState) + } +} +``` + +--- + +## 4. Invocation Options + +### 4.1 WithSnapshotID + +```go +// WithSnapshotID specifies a snapshot to restore from when starting the agent. +// This loads the session state from the snapshot and marks all subsequent +// snapshots in that session as orphaned. +// +// The session continues with the same session ID as the snapshot. +// +// Requires the agent to be configured with WithSessionStore; returns an error +// if no store is available to load the snapshot from. +func WithSnapshotID[Init any](id string) BidiOption[Init] +``` + +--- + +## 5. AgentOutput + +```go +// AgentOutput wraps the output with session info for persistence. +type AgentOutput[State, Out any] struct { + SessionID string `json:"sessionId"` + Output Out `json:"output"` + State State `json:"state"` + Artifacts []Artifact `json:"artifacts,omitempty"` + + // SnapshotIDs contains the IDs of all snapshots created during this agent invocation. + // Empty if no snapshots were created (callback returned false or not configured). + SnapshotIDs []string `json:"snapshotIds,omitempty"` +} +``` + +--- + +# Part 2: Behaviors + +## 6. Snapshot Creation + +Snapshots are created at three points, each corresponding to a `SnapshotEvent`: + +| Event | Trigger | +|-------|---------| +| `SnapshotEventTurnEnd` | When `resp.EndTurn()` is called, signaling control returns to the user | +| `SnapshotEventToolIterationEnd` | After all tool calls in a single model iteration complete | +| `SnapshotEventInvocationEnd` | When the agent function returns | + +At each point, the snapshot callback is invoked. If it returns true: +1. Compute the snapshot ID by hashing the state (SHA256) +2. Create the snapshot with the next sequence index +3. Set the parent snapshot ID to the previous snapshot (if any) +4. Persist to the store (no-op if ID already exists due to identical state) +5. Record the snapshot ID in the current trace span + +### 6.1 Snapshot ID Computation + +Snapshot IDs are content-addressed using SHA256 of the JSON-serialized state: +- **Deduplication**: Identical states produce identical IDs +- **Verification**: State integrity can be verified against the ID +- **Determinism**: No dependency on timestamps for uniqueness + +--- + +## 7. Snapshot Restoration + +When `WithSnapshotID` is provided to `StreamBidi`: + +1. Load the snapshot from the store +2. Call `InvalidateSnapshotsAfter(sessionID, snapshot.Index)` to orphan subsequent snapshots +3. Update the session with the snapshot's state +4. Continue from the snapshot's turn and index +5. Track the snapshot ID as the parent for new snapshots + +### 7.1 Option Validation + +| Combination | Result | +|-------------|--------| +| `WithSnapshotID` + `WithInit` | **Error**: Cannot specify initial state when restoring | +| `WithSnapshotID` + `WithSessionID` (mismatched) | **Error**: Session ID must match snapshot | +| `WithSnapshotID` + `WithSessionID` (matching) | Allowed but redundant | + +--- + +## 8. Tracing Integration + +When a snapshot is created, span metadata is recorded: +- `genkit:metadata:snapshotId` - The snapshot ID +- `genkit:metadata:agent` - The agent name (e.g., `chatAgent`) + +This enables the Dev UI to fetch snapshot data via the reflection API. + +--- + +# Part 3: Examples + +## 9. Usage Examples + +### 9.1 Defining an Agent with Snapshots + +```go +type ChatState struct { + Messages []*ai.Message `json:"messages"` +} + +chatAgent := genkit.DefineAgent(g, "chatAgent", + func(ctx context.Context, sess *session.Session[ChatState], inCh <-chan string, resp *corex.Responder[string]) (corex.AgentResult[string], error) { + state := sess.State() + + for input := range inCh { + state.Messages = append(state.Messages, ai.NewUserTextMessage(input)) + resp := generateResponse(ctx, g, state.Messages) + state.Messages = append(state.Messages, resp.Message) + sess.UpdateState(ctx, state) + resp.EndTurn() // SnapshotEventTurnEnd fires here + } + + return corex.AgentResult[string]{Output: "done"}, nil + // SnapshotEventInvocationEnd fires after return + }, + corex.WithSessionStore(store), + corex.WithSnapshotCallback(session.SnapshotOn[ChatState]( + session.SnapshotEventTurnEnd, + session.SnapshotEventInvocationEnd, + )), +) +``` + +### 9.2 Restoring from a Snapshot + +```go +snapshotID := previousOutput.SnapshotIDs[0] + +conn, _ := chatAgent.StreamBidi(ctx, + corex.WithSnapshotID[ChatState](snapshotID), +) + +conn.Send("Actually, tell me about channels instead") +// ... conversation continues from restored state ... +``` + +### 9.3 Custom Snapshot Callback + +```go +// Snapshot every 5 messages at turn end, always at invocation end +corex.WithSnapshotCallback(func(ctx context.Context, snap *session.SnapshotContext[ChatState]) bool { + switch snap.Event { + case session.SnapshotEventTurnEnd: + return len(snap.State.Messages) % 5 == 0 + case session.SnapshotEventInvocationEnd: + return true + default: + return false + } +}) +``` + +### 9.4 Snapshot Only When State Changed + +```go +// Only snapshot if messages have been added since last snapshot +corex.WithSnapshotCallback(func(ctx context.Context, snap *session.SnapshotContext[ChatState]) bool { + if snap.Event != session.SnapshotEventTurnEnd { + return false + } + // Always snapshot if this is the first one + if snap.PrevState == nil { + return true + } + // Only snapshot if message count increased + return len(snap.State.Messages) > len(snap.PrevState.Messages) +}) +``` + +### 9.5 Snapshot Based on Index + +```go +// Snapshot every 3rd snapshot opportunity +corex.WithSnapshotCallback(func(ctx context.Context, snap *session.SnapshotContext[ChatState]) bool { + return snap.Index % 3 == 0 +}) +``` + +### 9.6 Listing Snapshots + +```go +activeSnapshots, _ := store.ListSnapshots(ctx, sessionID, false) +allSnapshots, _ := store.ListSnapshots(ctx, sessionID, true) // includes orphaned +``` + +--- + +# Part 4: Implementation Details + +## 10. Reflection API Integration + +Session stores are exposed via the reflection API for Dev UI access. + +### 10.1 Action Registration + +When `DefineAgent` is called with `WithSessionStore`, actions are registered: + +| Action | Key | Returns | +|--------|-----|---------| +| getSnapshot | `/session-store/{agent}/getSnapshot` | `Snapshot[S]` | +| listSnapshots | `/session-store/{agent}/listSnapshots` | `[]*Snapshot[S]` | +| getSession | `/session-store/{agent}/getSession` | `*Data[S]` | + +### 10.2 Action Type + +```go +const ActionTypeSessionStore api.ActionType = "session-store" +``` + +### 10.3 Dev UI Flow + +1. Dev UI extracts `snapshotId` and `agent` from span metadata +2. Calls `POST /api/runAction` with key `/session-store/{agent}/getSnapshot` +3. Displays the returned state alongside the trace + +--- + +## 11. Session Snapshot Fields + +The `Session` struct is extended with fields to track snapshot state. These are persisted with the session so that loading a session restores the snapshot tracking state. + +```go +type Session[S any] struct { + // ... existing fields (id, state, store, mu) ... + + // LastSnapshot is the most recent snapshot for this session. + // Used to derive ParentID (LastSnapshot.ID), PrevState (LastSnapshot.State), + // and next index (LastSnapshot.Index + 1). + // Nil if no snapshots have been created. + LastSnapshot *Snapshot[S] `json:"lastSnapshot,omitempty"` + + // TurnIndex tracks the current turn number. + TurnIndex int `json:"turnIndex"` +} +``` + +The `snapshotIDs` list (for `AgentOutput.SnapshotIDs`) is tracked transiently during an invocation and does not need to be persisted. + +When building `SnapshotContext` for the callback: +- `PrevState` = `session.LastSnapshot.State` (or nil if `LastSnapshot` is nil) +- `Index` = `session.LastSnapshot.Index + 1` (or 0 if `LastSnapshot` is nil) +- `ParentID` for new snapshot = `session.LastSnapshot.ID` (or empty if nil) + +--- + +## 12. Tool Iteration Snapshot Mechanism + +The `SnapshotEventToolIterationEnd` event requires coordination between the agent layer (typed state) and Generate layer (untyped). + +This is accomplished via a context-based trigger: + +1. Agent layer creates a closure capturing the typed callback +2. Stores an untyped trigger function in context +3. Generate calls `TriggerSnapshot(ctx, SnapshotEventToolIterationEnd)` after tool iterations +4. Trigger retrieves session from context, gets state, invokes callback +5. If callback returns true, snapshot is created + +This keeps Generate decoupled from session types. + +--- + +# Part 5: Design Decisions + +## 13. Rationale + +### Why a Single Callback with Event Types? + +Rather than separate options for each trigger: +- **Extensibility**: New events can be added without new options +- **Flexibility**: Filter by event AND inspect state in one place +- **Composability**: Logic like "every N messages at turn end, always at invocation end" is natural + +### Why Content-Addressed IDs? + +- **Automatic deduplication**: Identical states share the same snapshot +- **Verification**: State integrity can be verified against the ID +- **Determinism**: No dependency on sequence numbers or timestamps + +### Why Orphaned Instead of Deleted? + +When restoring from an earlier snapshot, subsequent snapshots are marked orphaned: +- **Audit trail**: Complete history preserved for debugging +- **Recovery**: Accidentally orphaned snapshots can be recovered +- **Visualization**: Dev UI can show the full conversation tree + +### Why IDs Only in Traces? + +- **Lightweight traces**: Avoid bloating with large state objects +- **Single source of truth**: State lives in the session store +- **On-demand retrieval**: Dev UI fetches when needed + +### Why Both Index and TurnIndex? + +- **Index**: Monotonically increasing for ordering and invalidation +- **TurnIndex**: Human-comprehensible ("after turn 3") + +### Why Separate Session and Snapshot? + +- **Session state**: Working copy that changes frequently +- **Snapshots**: Explicit, immutable captures (like git commits) + +This provides efficiency (not every change needs snapshot overhead) and user control via callbacks. + +--- + +## 14. Future Considerations + +Out of scope for this design: + +- **Snapshot expiration**: Automatic cleanup based on age or count +- **Snapshot compression**: Delta/patch-based storage +- **Snapshot annotations**: User-provided labels or descriptions