diff --git a/js/core/src/action.ts b/js/core/src/action.ts index 27d746f761..8ac411bcd9 100644 --- a/js/core/src/action.ts +++ b/js/core/src/action.ts @@ -17,7 +17,7 @@ import type { JSONSchema7 } from 'json-schema'; import type * as z from 'zod'; import { getAsyncContext } from './async-context.js'; -import { lazy } from './async.js'; +import { Channel, lazy } from './async.js'; import { getContext, runWithContext, type ActionContext } from './context.js'; import type { ActionType, Registry } from './registry.js'; import { parseSchema } from './schema.js'; @@ -51,14 +51,23 @@ export interface ActionMetadata< O extends z.ZodTypeAny = z.ZodTypeAny, S extends z.ZodTypeAny = z.ZodTypeAny, > { + /** The type of action (e.g. 'prompt', 'flow'). */ actionType?: ActionType; + /** The name of the action. */ name: string; + /** Description of the action. */ description?: string; + /** Input Zod schema. */ inputSchema?: I; + /** Input JSON schema. */ inputJsonSchema?: JSONSchema7; + /** Output Zod schema. */ outputSchema?: O; + /** Output JSON schema. */ outputJsonSchema?: JSONSchema7; + /** Stream Zod schema. */ streamSchema?: S; + /** Metadata for the action. */ metadata?: Record; } @@ -76,7 +85,7 @@ export interface ActionResult { /** * Options (side channel) data to pass to the model. */ -export interface ActionRunOptions { +export interface ActionRunOptions { /** * Streaming callback (optional). */ @@ -104,12 +113,22 @@ export interface ActionRunOptions { * Note: This only fires once for the root action span, not for nested spans. */ onTraceStart?: (traceInfo: { traceId: string; spanId: string }) => void; + + /** + * Streaming input (optional). + */ + inputStream?: AsyncIterable; + + /** + * Initialization data provided to the action. + */ + init?: Init; } /** * Options (side channel) data to pass to the model. */ -export interface ActionFnArg { +export interface ActionFnArg { /** * Whether the caller of the action requested streaming. */ @@ -139,6 +158,16 @@ export interface ActionFnArg { abortSignal: AbortSignal; registry?: Registry; + + /** + * Streaming input. + */ + inputStream: AsyncIterable; + + /** + * Initialization data provided to the action. + */ + init?: Init; } /** @@ -154,6 +183,24 @@ export interface StreamingResponse< output: Promise>; } +/** + * Streaming response from a bi-directional action. + */ +export interface BidiStreamingResponse< + O extends z.ZodTypeAny = z.ZodTypeAny, + S extends z.ZodTypeAny = z.ZodTypeAny, + I extends z.ZodTypeAny = z.ZodTypeAny, +> extends StreamingResponse { + /** + * Sends a chunk of data to the action (for bi-directional streaming). + */ + send(chunk: z.infer): void; + /** + * Closes the input stream to the action. + */ + close(): void; +} + /** * Self-describing, validating, observable, locally and remotely callable function. */ @@ -161,19 +208,30 @@ export type Action< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, S extends z.ZodTypeAny = z.ZodTypeAny, - RunOptions extends ActionRunOptions = ActionRunOptions, + RunOptions extends ActionRunOptions< + z.infer, + z.infer + > = ActionRunOptions, z.infer>, + Init extends z.ZodTypeAny = z.ZodTypeAny, > = ((input?: z.infer, options?: RunOptions) => Promise>) & { + /** @hidden */ __action: ActionMetadata; + /** @hidden */ __registry?: Registry; run( input?: z.infer, - options?: ActionRunOptions> + options?: ActionRunOptions, z.infer> ): Promise>>; stream( input?: z.infer, - opts?: ActionRunOptions> + opts?: ActionRunOptions, z.infer> ): StreamingResponse; + + streamBidi( + input?: AsyncIterable>, + opts?: ActionRunOptions, z.infer, z.infer> + ): BidiStreamingResponse; }; /** @@ -184,23 +242,75 @@ export type ActionParams< O extends z.ZodTypeAny, S extends z.ZodTypeAny = z.ZodTypeAny, > = { + /** + * Name of the action, or an object with pluginId and actionId. + */ name: | string | { pluginId: string; actionId: string; }; + /** + * Description of the action. + */ description?: string; + /** + * Input Zod schema. + */ inputSchema?: I; + /** + * Input JSON schema. + */ inputJsonSchema?: JSONSchema7; + /** + * Output Zod schema. + */ outputSchema?: O; + /** + * Output JSON schema. + */ outputJsonSchema?: JSONSchema7; + /** + * Metadata for the action. + */ metadata?: Record; + /** + * Middleware to apply to the action. + */ use?: Middleware, z.infer, z.infer>[]; + /** + * Stream Zod schema. + */ streamSchema?: S; + /** + * The type of action. + */ actionType: ActionType; }; +/** + * Configuration for a bi-directional action. + */ +export interface BidiActionParams< + I extends z.ZodTypeAny, + O extends z.ZodTypeAny, + S extends z.ZodTypeAny = z.ZodTypeAny, + Init extends z.ZodTypeAny = z.ZodTypeAny, +> extends ActionParams { + /** + * Zod schema for the initialization data. + */ + initSchema?: Init; + /** + * JSON schema for the initialization data. + */ + initJsonSchema?: JSONSchema7; +} + +/** + * Configuration for an async action (lazy loaded). + */ export type ActionAsyncParams< I extends z.ZodTypeAny, O extends z.ZodTypeAny, @@ -208,19 +318,25 @@ export type ActionAsyncParams< > = ActionParams & { fn: ( input: z.infer, - options: ActionFnArg> + options: ActionFnArg, z.infer> ) => Promise>; }; +/** + * Simple middleware that only modifies request/response. + */ export type SimpleMiddleware = ( req: I, next: (req?: I) => Promise ) => Promise; +/** + * Middleware that has access to options (including streaming callback). + */ export type MiddlewareWithOptions = ( req: I, - options: ActionRunOptions | undefined, - next: (req?: I, options?: ActionRunOptions) => Promise + options: ActionRunOptions | undefined, + next: (req?: I, options?: ActionRunOptions) => Promise ) => Promise; /** @@ -243,20 +359,20 @@ export function actionWithMiddleware< ): Action { const wrapped = (async ( req: z.infer, - options?: ActionRunOptions> + options?: ActionRunOptions, z.infer> ) => { return (await wrapped.run(req, options)).result; }) as Action; wrapped.__action = action.__action; wrapped.run = async ( req: z.infer, - options?: ActionRunOptions> + options?: ActionRunOptions, z.infer> ): Promise>> => { let telemetry; const dispatch = async ( index: number, req: z.infer, - opts?: ActionRunOptions> + opts?: ActionRunOptions, z.infer> ) => { if (index === middleware.length) { // end of the chain, call the original model action @@ -283,6 +399,7 @@ export function actionWithMiddleware< } }; wrapped.stream = action.stream; + wrapped.streamBidi = action.streamBidi; return { result: await dispatch(0, req, options), telemetry }; }; @@ -297,10 +414,10 @@ export function action< O extends z.ZodTypeAny, S extends z.ZodTypeAny = z.ZodTypeAny, >( - config: ActionParams, + config: BidiActionParams, fn: ( input: z.infer, - options: ActionFnArg> + options: ActionFnArg, z.infer> ) => Promise> ): Action> { const actionName = @@ -321,7 +438,7 @@ export function action< const actionFn = (async ( input?: I, - options?: ActionRunOptions> + options?: ActionRunOptions, z.infer> ) => { return (await actionFn.run(input, options)).result; }) as Action>; @@ -329,12 +446,42 @@ export function action< actionFn.run = async ( input: z.infer, - options?: ActionRunOptions> + options?: ActionRunOptions, z.infer> ): Promise>> => { - input = parseSchema(input, { - schema: config.inputSchema, - jsonSchema: config.inputJsonSchema, - }); + if (config.inputSchema || config.inputJsonSchema) { + if (!options?.inputStream) { + input = parseSchema(input, { + schema: config.inputSchema, + jsonSchema: config.inputJsonSchema, + }); + } else { + const inputStream = options.inputStream; + options = { + ...options, + inputStream: (async function* () { + for await (const item of inputStream) { + yield parseSchema(item, { + schema: config.inputSchema, + jsonSchema: config.inputJsonSchema, + }); + } + })(), + }; + } + } + + if (config.initSchema || config.initJsonSchema) { + const validatedInit = parseSchema(options?.init, { + schema: config.initSchema, + jsonSchema: config.initJsonSchema, + }); + if (options) { + options.init = validatedInit; + } else { + options = { init: validatedInit }; + } + } + let traceId; let spanId; let output = await runInNewSpan( @@ -379,13 +526,15 @@ export function action< !!options?.onChunk && options.onChunk !== sentinelNoopStreamingCallback, sendChunk: options?.onChunk ?? sentinelNoopStreamingCallback, + inputStream: + options?.inputStream ?? asyncIterableFromArray([input]), trace: { traceId, spanId, }, registry: actionFn.__registry, abortSignal: options?.abortSignal ?? makeNoopAbortSignal(), - }); + } as ActionFnArg, z.infer>); // if context is explicitly passed in, we run action with the provided context, // otherwise we let upstream context carry through. const output = await runWithContext(options?.context, actFn); @@ -415,7 +564,7 @@ export function action< actionFn.stream = ( input?: z.infer, - opts?: ActionRunOptions> + opts?: ActionRunOptions, z.infer> ): StreamingResponse => { let chunkStreamController: ReadableStreamController>; const chunkStream = new ReadableStream>({ @@ -427,17 +576,24 @@ export function action< }); const invocationPromise = actionFn - .run(config.inputSchema ? config.inputSchema.parse(input) : input, { - onChunk: ((chunk: z.infer) => { - chunkStreamController.enqueue(chunk); - }) as S extends z.ZodVoid ? undefined : StreamingCallback>, - context: { - ...actionFn.__registry?.context, - ...(opts?.context ?? getContext()), - }, - abortSignal: opts?.abortSignal, - telemetryLabels: opts?.telemetryLabels, - }) + .run( + !opts?.inputStream && config.inputSchema + ? config.inputSchema.parse(input) + : input, + { + onChunk: ((chunk: z.infer) => { + chunkStreamController.enqueue(chunk); + }) as S extends z.ZodVoid ? undefined : StreamingCallback>, + context: { + ...actionFn.__registry?.context, + ...(opts?.context ?? getContext()), + }, + inputStream: opts?.inputStream, + abortSignal: opts?.abortSignal, + telemetryLabels: opts?.telemetryLabels, + init: (opts as ActionFnArg>)?.init, + } as ActionRunOptions, z.infer> + ) .then((s) => s.result) .finally(() => { chunkStreamController.close(); @@ -461,6 +617,38 @@ export function action< }; }; + actionFn.streamBidi = ( + inputStream?: AsyncIterable>, + opts?: ActionRunOptions, z.infer> + ): BidiStreamingResponse => { + let channel: Channel> | undefined; + if (!inputStream) { + channel = new Channel>(); + inputStream = channel; + } + + const result = actionFn.stream(undefined, { + ...opts, + inputStream, + } as ActionRunOptions, z.infer>); + + return { + ...result, + send: (chunk) => { + if (!channel) { + throw new Error('Cannot send to a provided stream.'); + } + channel.send(chunk); + }, + close: () => { + if (!channel) { + throw new Error('Cannot close a provided stream.'); + } + channel.close(); + }, + }; + }; + if (config.use) { return actionWithMiddleware(actionFn, config.use); } @@ -483,7 +671,7 @@ export function defineAction< config: ActionParams, fn: ( input: z.infer, - options: ActionFnArg> + options: ActionFnArg, z.infer> ) => Promise> ): Action { if (isInRuntimeContext()) { @@ -501,6 +689,73 @@ export function defineAction< return act; } +/** + * Defines a bi-directional action with the given config and registers it in the registry. + */ +export function defineBidiAction< + I extends z.ZodTypeAny, + O extends z.ZodTypeAny, + S extends z.ZodTypeAny = z.ZodTypeAny, + Init extends z.ZodTypeAny = z.ZodTypeAny, +>( + registry: Registry, + config: BidiActionParams, + fn: ( + input: ActionFnArg, z.infer, z.infer> + ) => AsyncGenerator, z.infer, void> +): Action, z.infer>, Init> { + const act = bidiAction(config, fn); + registry.registerAction(config.actionType, act); + return act; +} + +/** + * Creates a bi-directional action with the given config. + */ +export function bidiAction< + I extends z.ZodTypeAny, + O extends z.ZodTypeAny, + S extends z.ZodTypeAny = z.ZodTypeAny, + Init extends z.ZodTypeAny = z.ZodTypeAny, +>( + config: BidiActionParams, + fn: ( + input: ActionFnArg, z.infer, z.infer> + ) => AsyncGenerator, z.infer, void> +): Action, z.infer>, Init> { + const meta = { ...config.metadata, bidi: true }; + return action({ ...config, metadata: meta }, async (input, options) => { + let stream = options.inputStream; + if (!stream) { + if (input !== undefined) { + stream = (async function* () { + yield input; + })(); + } else { + stream = (async function* () {})(); + } + } + + const outputGen = fn({ + ...options, + inputStream: stream, + }); + + // Manually iterate to get chunks and the return value + const iter = outputGen[Symbol.asyncIterator](); + let result: z.infer; + while (true) { + const { value, done } = await iter.next(); + if (done) { + result = value; + break; + } + options.sendChunk(value); + } + return result; + }); +} + /** * Defines an action with the given config promise and registers it in the registry. */ @@ -598,3 +853,9 @@ export function runInActionRuntimeContext(fn: () => R) { export function runOutsideActionRuntimeContext(fn: () => R) { return getAsyncContext().run(runtimeContextAslKey, 'outside', fn); } + +async function* asyncIterableFromArray(array: T[]): AsyncIterable { + for (const item of array) { + yield item; + } +} diff --git a/js/core/src/flow.ts b/js/core/src/flow.ts index 56c3a3d6b9..bf916cf7ef 100644 --- a/js/core/src/flow.ts +++ b/js/core/src/flow.ts @@ -15,7 +15,14 @@ */ import type { z } from 'zod'; -import { ActionFnArg, action, type Action } from './action.js'; +import { + ActionFnArg, + ActionRunOptions, + JSONSchema7, + action, + bidiAction, + type Action, +} from './action.js'; import { Registry, type HasRegistry } from './registry.js'; import { SPAN_TYPE_ATTR, runInNewSpan } from './tracing.js'; @@ -26,7 +33,8 @@ export interface Flow< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, S extends z.ZodTypeAny = z.ZodTypeAny, -> extends Action {} + Init extends z.ZodTypeAny = z.ZodTypeAny, +> extends Action, z.infer>, Init> {} /** * Configuration for a streaming flow. @@ -35,6 +43,7 @@ export interface FlowConfig< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, S extends z.ZodTypeAny = z.ZodTypeAny, + Init extends z.ZodTypeAny = z.ZodTypeAny, > { /** Name of the flow. */ name: string; @@ -46,6 +55,10 @@ export interface FlowConfig< streamSchema?: S; /** Metadata of the flow used by tooling. */ metadata?: Record; + /** Schema of the initialization data. */ + initSchema?: Init; + /** JSON schema of the initialization data. */ + initJsonSchema?: JSONSchema7; } /** @@ -104,6 +117,50 @@ export function defineFlow< return f; } +/** + * Defines a bi-directional flow and registers the flow in the provided registry. + */ +export function defineBidiFlow< + I extends z.ZodTypeAny = z.ZodTypeAny, + O extends z.ZodTypeAny = z.ZodTypeAny, + S extends z.ZodTypeAny = z.ZodTypeAny, + Init extends z.ZodTypeAny = z.ZodTypeAny, +>( + registry: Registry, + config: FlowConfig, + fn: ( + input: ActionFnArg, z.infer, z.infer> + ) => AsyncGenerator, z.infer, void> +): Flow { + const flow = bidiFlow(config, fn); + registry.registerAction('flow', flow); + return flow; +} + +/** + * Defines a bi-directional flow. + */ +export function bidiFlow< + I extends z.ZodTypeAny = z.ZodTypeAny, + O extends z.ZodTypeAny = z.ZodTypeAny, + S extends z.ZodTypeAny = z.ZodTypeAny, + Init extends z.ZodTypeAny = z.ZodTypeAny, +>( + config: FlowConfig, + fn: ( + input: ActionFnArg, z.infer, z.infer> + ) => AsyncGenerator, z.infer, void> +): Flow { + const f = bidiAction( + { + ...config, + actionType: 'flow', + }, + fn + ); + return f; +} + /** * Registers a flow as an action in the registry. */ @@ -137,12 +194,18 @@ function flowAction< ); } +/** + * A flow step that executes the provided function. + */ export function run( name: string, func: () => Promise, _?: Registry ): Promise; +/** + * A flow step that executes the provided function with input. + */ export function run( name: string, input: any, diff --git a/js/core/src/index.ts b/js/core/src/index.ts index 340903a614..d5dbe97dc5 100644 --- a/js/core/src/index.ts +++ b/js/core/src/index.ts @@ -73,6 +73,8 @@ export { type StatusName, } from './error.js'; export { + bidiFlow, + defineBidiFlow, defineFlow, flow, run, diff --git a/js/core/tests/bidi-action_test.ts b/js/core/tests/bidi-action_test.ts new file mode 100644 index 0000000000..3bb3beb612 --- /dev/null +++ b/js/core/tests/bidi-action_test.ts @@ -0,0 +1,241 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import * as assert from 'assert'; +import { beforeEach, describe, it } from 'node:test'; +import { z } from 'zod'; +import { defineBidiAction } from '../src/action.js'; +import { initNodeFeatures } from '../src/node.js'; +import { Registry } from '../src/registry.js'; + +initNodeFeatures(); + +describe('bidi action', () => { + var registry: Registry; + beforeEach(() => { + registry = new Registry(); + }); + + it('streamBidi ergonomic (push)', async () => { + const act = defineBidiAction( + registry, + { + name: 'chat', + actionType: 'custom', + inputSchema: z.string(), + outputSchema: z.string(), + }, + async function* ({ inputStream }) { + for await (const chunk of inputStream) { + yield `echo ${chunk}`; + } + return 'done'; + } + ); + + const session = act.streamBidi(); + session.send('1'); + session.send('2'); + session.close(); + + const chunks: string[] = []; + for await (const chunk of session.stream) { + chunks.push(chunk); + } + + assert.deepStrictEqual(chunks, ['echo 1', 'echo 2']); + assert.strictEqual(await session.output, 'done'); + }); + + it('streamBidi pull (generator)', async () => { + const act = defineBidiAction( + registry, + { + name: 'chat', + actionType: 'custom', + inputSchema: z.string(), + outputSchema: z.string(), + }, + async function* ({ inputStream }) { + for await (const chunk of inputStream) { + yield `echo ${chunk}`; + } + return 'done'; + } + ); + + async function* inputGen() { + yield '1'; + yield '2'; + } + + const session = act.streamBidi(inputGen()); + + const chunks: string[] = []; + for await (const chunk of session.stream) { + chunks.push(chunk); + } + + assert.deepStrictEqual(chunks, ['echo 1', 'echo 2']); + assert.strictEqual(await session.output, 'done'); + }); + + it('classic run works on bidi action', async () => { + const act = defineBidiAction( + registry, + { + name: 'chat', + actionType: 'custom', + inputSchema: z.string(), + outputSchema: z.string(), + }, + async function* ({ inputStream }) { + const inputs: string[] = []; + for await (const chunk of inputStream) { + inputs.push(chunk); + } + return `done: ${inputs.join(', ')}`; + } + ); + + const result = await act.run('1'); + assert.strictEqual(result.result, 'done: 1'); + }); + + it('classic run works on bidi action with streaming', async () => { + const act = defineBidiAction( + registry, + { + name: 'chat', + actionType: 'custom', + inputSchema: z.string(), + outputSchema: z.string(), + }, + async function* ({ inputStream }) { + for await (const chunk of inputStream) { + yield `echo ${chunk}`; + } + return 'done'; + } + ); + + const chunks: string[] = []; + const result = await act.run('1', { + onChunk: (c) => chunks.push(c), + }); + + assert.deepStrictEqual(chunks, ['echo 1']); + assert.strictEqual(result.result, 'done'); + }); + + it('validates input stream items', async () => { + const act = defineBidiAction( + registry, + { + name: 'chat', + actionType: 'custom', + inputSchema: z.string(), // Input is string + }, + async function* ({ inputStream }) { + for await (const chunk of inputStream) { + yield `echo ${chunk}`; + } + } + ); + + const session = act.streamBidi(); + // Bypass TS check to send invalid data + (session as any).send(123); + session.close(); + + try { + for await (const _ of session.stream) { + // Consume + } + assert.fail('Should have thrown validation error'); + } catch (e: any) { + // Zod validation error or Genkit validation error + assert.ok( + e.message.includes('Expected string, received number') || + e.name === 'ZodError' || + e.code === 'invalid_type' || + e.message.includes('Validation failed') || + e.message.includes('must be string') + ); + } + }); + + it('bidi action receives init data', async () => { + const act = defineBidiAction( + registry, + { + name: 'chatWithInit', + actionType: 'custom', + inputSchema: z.string(), + outputSchema: z.string(), + initSchema: z.object({ prefix: z.string() }), + }, + async function* ({ inputStream, init }) { + const prefix = init?.prefix || ''; + for await (const chunk of inputStream) { + yield `${prefix}${chunk}`; + } + return 'done'; + } + ); + + const session = act.streamBidi(undefined, { init: { prefix: '>> ' } }); + session.send('1'); + session.send('2'); + session.close(); + + const chunks: string[] = []; + for await (const chunk of session.stream) { + chunks.push(chunk); + } + + assert.deepStrictEqual(chunks, ['>> 1', '>> 2']); + assert.strictEqual(await session.output, 'done'); + }); + + it('validates init data', async () => { + const act = defineBidiAction( + registry, + { + name: 'chatWithInitValidation', + actionType: 'custom', + inputSchema: z.string(), + initSchema: z.object({ count: z.number() }), + }, + async function* ({ inputStream, init }) { + yield `count: ${init?.count}`; + } + ); + + // Invalid init (string instead of number) + try { + const session = act.streamBidi(undefined, { + init: { count: '123' } as any, + }); + for await (const _ of session.stream) { + // Consume + } + assert.fail('Should have thrown validation error'); + } catch (e: any) { + assert.ok(e.message.includes('count: must be number')); + } + }); +}); diff --git a/js/core/tests/bidi-flow_test.ts b/js/core/tests/bidi-flow_test.ts new file mode 100644 index 0000000000..7bd150963d --- /dev/null +++ b/js/core/tests/bidi-flow_test.ts @@ -0,0 +1,112 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import * as assert from 'assert'; +import { describe, it } from 'node:test'; +import { z } from 'zod'; +import { bidiFlow } from '../src/flow.js'; +import { initNodeFeatures } from '../src/node.js'; + +initNodeFeatures(); + +describe('bidi flow', () => { + it('streamBidi ergonomic (push)', async () => { + const flow = bidiFlow( + { + name: 'chatFlow', + inputSchema: z.string(), + outputSchema: z.string(), + }, + async function* ({ inputStream }) { + for await (const chunk of inputStream) { + yield `echo ${chunk}`; + } + return 'done'; + } + ); + + const session = flow.streamBidi(); + session.send('1'); + session.send('2'); + session.close(); + + const chunks: string[] = []; + for await (const chunk of session.stream) { + chunks.push(chunk); + } + + assert.deepStrictEqual(chunks, ['echo 1', 'echo 2']); + assert.strictEqual(await session.output, 'done'); + assert.strictEqual(flow.__action.actionType, 'flow'); + assert.ok(flow.__action.metadata?.bidi); + }); + + it('bidi flow receives init data', async () => { + const flow = bidiFlow( + { + name: 'chatFlowWithInit', + inputSchema: z.string(), + outputSchema: z.string(), + initSchema: z.object({ prefix: z.string() }), + }, + async function* ({ inputStream, init }) { + const prefix = init?.prefix || ''; + for await (const chunk of inputStream) { + yield `${prefix}${chunk}`; + } + return 'done'; + } + ); + + const session = flow.streamBidi(undefined, { init: { prefix: '>> ' } }); + session.send('1'); + session.send('2'); + session.close(); + + const chunks: string[] = []; + for await (const chunk of session.stream) { + chunks.push(chunk); + } + + assert.deepStrictEqual(chunks, ['>> 1', '>> 2']); + assert.strictEqual(await session.output, 'done'); + }); + + it('validates init data in bidi flow', async () => { + const flow = bidiFlow( + { + name: 'chatFlowWithInitValidation', + inputSchema: z.string(), + initSchema: z.object({ count: z.number() }), + }, + async function* ({ init }) { + yield `count: ${init?.count}`; + } + ); + + try { + const session = flow.streamBidi(undefined, { + init: { count: '123' } as any, + }); + for await (const _ of session.stream) { + // consume + } + assert.fail('Should have thrown validation error'); + } catch (e: any) { + assert.ok(e.message.includes('count: must be number')); + } + }); +}); diff --git a/js/genkit/src/genkit.ts b/js/genkit/src/genkit.ts index 5053f6dbfd..ef7caba60a 100644 --- a/js/genkit/src/genkit.ts +++ b/js/genkit/src/genkit.ts @@ -107,9 +107,11 @@ import { } from '@genkit-ai/ai/tool'; import { ActionFnArg, + ActionRunOptions, GenkitError, Operation, ReflectionServer, + defineBidiFlow, defineDynamicActionProvider, defineFlow, defineJsonSchema, @@ -229,6 +231,25 @@ export class Genkit implements HasRegistry { return flow; } + /** + * Defines and registers a bi-directional flow. + */ + defineBidiFlow< + I extends z.ZodTypeAny = z.ZodTypeAny, + O extends z.ZodTypeAny = z.ZodTypeAny, + S extends z.ZodTypeAny = z.ZodTypeAny, + Init extends z.ZodTypeAny = z.ZodTypeAny, + >( + config: FlowConfig, + fn: ( + input: ActionFnArg, z.infer, z.infer> + ) => AsyncGenerator, z.infer, void> + ): Action, z.infer>, Init> { + const flow = defineBidiFlow(this.registry, config, fn); + this.flows.push(flow); + return flow; + } + /** * Defines and registers a tool that can return multiple parts of content. * diff --git a/js/testapps/flow-sample1/src/index.ts b/js/testapps/flow-sample1/src/index.ts index 8c3f55abf1..1c921b4a67 100644 --- a/js/testapps/flow-sample1/src/index.ts +++ b/js/testapps/flow-sample1/src/index.ts @@ -312,3 +312,33 @@ function generateString(length: number) { } return str.substring(0, length); } + +export const chatFlow = ai.defineBidiFlow( + { + name: 'chatFlow', + inputSchema: z.string(), + outputSchema: z.string(), + }, + async function* ({ inputStream }) { + for await (const chunk of inputStream) { + yield `echo ${chunk}`; + } + return 'done'; + } +); + +export const chatFlowWithInit = ai.defineBidiFlow( + { + name: 'chatFlowWithInit', + inputSchema: z.string(), + outputSchema: z.string(), + initSchema: z.object({ prefix: z.string() }), + }, + async function* ({ inputStream, init }) { + const prefix = init?.prefix || ''; + for await (const chunk of inputStream) { + yield `${prefix}${chunk}`; + } + return 'done'; + } +);