From 2ec8fdb701ac420bbeae098fa9a52230d6c78b79 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Thu, 5 Feb 2026 08:22:41 +0530 Subject: [PATCH] llm-chain first draft --- .../versions/045_add_llm_chain_job_type.py | 22 ++ backend/app/api/main.py | 3 +- backend/app/api/routes/llm_chain.py | 41 +++ backend/app/models/__init__.py | 3 + backend/app/models/job.py | 1 + backend/app/models/llm/__init__.py | 10 +- backend/app/models/llm/request.py | 216 +++++++++++ backend/app/models/llm/response.py | 7 + backend/app/services/llm/chain_executor.py | 344 ++++++++++++++++++ 9 files changed, 645 insertions(+), 2 deletions(-) create mode 100644 backend/app/alembic/versions/045_add_llm_chain_job_type.py create mode 100644 backend/app/api/routes/llm_chain.py create mode 100644 backend/app/services/llm/chain_executor.py diff --git a/backend/app/alembic/versions/045_add_llm_chain_job_type.py b/backend/app/alembic/versions/045_add_llm_chain_job_type.py new file mode 100644 index 000000000..5c5f25057 --- /dev/null +++ b/backend/app/alembic/versions/045_add_llm_chain_job_type.py @@ -0,0 +1,22 @@ +"""add LLM_CHAIN job type + +Revision ID: 045 +Revises: 044 +Create Date: 2026-02-04 00:35:43.891644 + +""" +from alembic import op + +# revision identifiers, used by Alembic. +revision = "045" +down_revision = "044" +branch_labels = None +depends_on = None + + +def upgrade(): + op.execute("ALTER TYPE jobtype ADD VALUE IF NOT EXISTS 'LLM_CHAIN'") + + +def downgrade(): + pass diff --git a/backend/app/api/main.py b/backend/app/api/main.py index bcd64eb58..9827433a5 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -23,6 +23,7 @@ fine_tuning, model_evaluation, collection_job, + llm_chain, ) from app.api.routes.evaluations import dataset as evaluation_dataset, evaluation from app.core.config import settings @@ -51,7 +52,7 @@ api_router.include_router(utils.router) api_router.include_router(fine_tuning.router) api_router.include_router(model_evaluation.router) - +api_router.include_router(llm_chain.router) if settings.ENVIRONMENT in ["development", "testing"]: api_router.include_router(private.router) diff --git a/backend/app/api/routes/llm_chain.py b/backend/app/api/routes/llm_chain.py new file mode 100644 index 000000000..34507ec11 --- /dev/null +++ b/backend/app/api/routes/llm_chain.py @@ -0,0 +1,41 @@ +import logging + +from fastapi import APIRouter, Depends + +from app.api.deps import AuthContextDep, SessionDep +from app.api.permissions import Permission, require_permission +from app.models import LLMChainRequest, Message +from app.services.llm.chain_executor import start_chain_job +from app.utils import APIResponse, validate_callback_url + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["llm"]) + + +@router.post( + "/llm/chain", + response_model=APIResponse[Message], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], +) +def llm_chain( + _current_user: AuthContextDep, _session: SessionDep, request: LLMChainRequest +): + project_id = _current_user.project_.id + organization_id = _current_user.organization_.id + + if request.callback_url: + validate_callback_url(str(request.callback_url)) + + start_chain_job( + db=_session, + request=request, + project_id=project_id, + organization_id=organization_id, + ) + + return APIResponse.success_response( + data=Message( + message="Chain execution started. Results will be delivered via callback." + ) + ) diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index a4d76ee2c..419d9df43 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -93,7 +93,10 @@ ConfigBlob, CompletionConfig, LLMCallRequest, + LLMChainRequest, LLMCallResponse, + LLMChainResponse, + LlmCall, ) from .message import Message diff --git a/backend/app/models/job.py b/backend/app/models/job.py index b6a1a5ae7..3b20249f5 100644 --- a/backend/app/models/job.py +++ b/backend/app/models/job.py @@ -17,6 +17,7 @@ class JobStatus(str, Enum): class JobType(str, Enum): RESPONSE = "RESPONSE" LLM_API = "LLM_API" + LLM_CHAIN = "LLM_CHAIN" class Job(SQLModel, table=True): diff --git a/backend/app/models/llm/__init__.py b/backend/app/models/llm/__init__.py index 8738e2126..b76855754 100644 --- a/backend/app/models/llm/__init__.py +++ b/backend/app/models/llm/__init__.py @@ -6,5 +6,13 @@ KaapiLLMParams, KaapiCompletionConfig, NativeCompletionConfig, + LlmCall, + LLMChainRequest, +) +from app.models.llm.response import ( + LLMCallResponse, + LLMResponse, + LLMOutput, + Usage, + LLMChainResponse, ) -from app.models.llm.response import LLMCallResponse, LLMResponse, LLMOutput, Usage diff --git a/backend/app/models/llm/request.py b/backend/app/models/llm/request.py index fc44235f9..c52efd401 100644 --- a/backend/app/models/llm/request.py +++ b/backend/app/models/llm/request.py @@ -4,6 +4,8 @@ from sqlmodel import Field, SQLModel from pydantic import Discriminator, model_validator, HttpUrl +from typing import Dict + class KaapiLLMParams(SQLModel): """ @@ -120,10 +122,17 @@ class KaapiCompletionConfig(SQLModel): ] +class PromptTemplateConfig(SQLModel): + template: str = Field(..., description="prompt template") + + class ConfigBlob(SQLModel): """Raw JSON blob of config.""" completion: CompletionConfig = Field(..., description="Completion configuration") + prompt_template: PromptTemplateConfig | None = Field( + default=None, description="optional prompt template" + ) # Future additions: # classifier: ClassifierConfig | None = None # pre_filter: PreFilterConfig | None = None @@ -223,3 +232,210 @@ class LLMCallRequest(SQLModel): "The exact dictionary provided here will be returned in the response metadata field." ), ) + + +class ChainBlock(SQLModel): + config: LLMCallConfig = Field( + ..., description="LLM call configuration for this block" + ) + intermediate_callback: bool = Field( + default=False, + description="Optional callback URL for intermediary results after this block completes", + ) + include_provider_raw_response: bool = Field( + default=False, + description="Whether to include the raw LLM provider response in the output", + ) + + request_metadata: dict[str, Any] | None = Field( + default=None, + description=( + "Client-provided metadata passed through unchanged in the response. " + "Use this to correlate responses with requests or track request state. " + "The exact dictionary provided here will be returned in the response metadata field." + ), + ) + + +class LLMChainRequest(SQLModel): + # query + query: QueryParams = Field(..., description="Query-specific parameters") + + # blocks + blocks: list[ChainBlock] = Field( + ..., min_length=1, description="Ordered list of blocks to execute" + ) + + # callback_url + callback_url: HttpUrl | None = Field( + default=None, description="Webhook URL for async response delivery" + ) + + +class LlmCall(SQLModel, table=True): + """ + Database model for tracking LLM API call requests and responses. + + Stores both request inputs and response outputs for traceability, + supporting multimodal inputs (text, audio, image) and various completion types. + """ + + __tablename__ = "llm_call" + __table_args__ = ( + Index( + "idx_llm_call_job_id", + "job_id", + postgresql_where=text("deleted_at IS NULL"), + ), + Index( + "idx_llm_call_conversation_id", + "conversation_id", + postgresql_where=text("conversation_id IS NOT NULL AND deleted_at IS NULL"), + ), + ) + + id: UUID = Field( + default_factory=uuid4, + primary_key=True, + sa_column_kwargs={"comment": "Unique identifier for the LLM call record"}, + ) + + job_id: UUID = Field( + foreign_key="job.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={ + "comment": "Reference to the parent job (status tracked in job table)" + }, + ) + + project_id: int = Field( + foreign_key="project.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={ + "comment": "Reference to the project this LLM call belongs to" + }, + ) + + organization_id: int = Field( + foreign_key="organization.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={ + "comment": "Reference to the organization this LLM call belongs to" + }, + ) + + # Request fields + input: str = Field( + ..., + sa_column_kwargs={ + "comment": "User input - text string, binary data, or file path for multimodal" + }, + ) + + input_type: Literal["text", "audio", "image"] = Field( + ..., + sa_column=sa.Column( + sa.String, + nullable=False, + comment="Input type: text, audio, image", + ), + ) + + output_type: Literal["text", "audio", "image"] | None = Field( + default=None, + sa_column=sa.Column( + sa.String, + nullable=True, + comment="Expected output type: text, audio, image", + ), + ) + + # Provider and model info + provider: str = Field( + ..., + sa_column=sa.Column( + sa.String, + nullable=False, + comment="AI provider as sent by user (e.g openai, -native, google)", + ), + ) + + model: str = Field( + ..., + sa_column_kwargs={ + "comment": "Specific model used e.g. 'gpt-4o', 'gemini-2.5-pro'" + }, + ) + + # Response fields + provider_response_id: str | None = Field( + default=None, + sa_column_kwargs={ + "comment": "Original response ID from the provider (e.g., OpenAI's response ID)" + }, + ) + + content: dict[str, Any] | None = Field( + default=None, + sa_column=sa.Column( + JSONB, + nullable=True, + comment="Response content: {text: '...'}, {audio_bytes: '...'}, or {image: '...'}", + ), + ) + + usage: dict[str, Any] | None = Field( + default=None, + sa_column=sa.Column( + JSONB, + nullable=True, + comment="Token usage: {input_tokens, output_tokens, reasoning_tokens}", + ), + ) + + # Conversation tracking + conversation_id: str | None = Field( + default=None, + sa_column_kwargs={ + "comment": "Identifier linking this response to its conversation thread" + }, + ) + + auto_create: bool | None = Field( + default=None, + sa_column_kwargs={ + "comment": "Whether to auto-create conversation if conversation_id doesn't exist (OpenAI specific)" + }, + ) + + # Configuration - stores either {config_id, config_version} or {config_blob} + config: dict[str, Any] | None = Field( + default=None, + sa_column=sa.Column( + JSONB, + nullable=True, + comment="Configuration: {config_id, config_version} for stored config OR {config_blob} for ad-hoc config", + ), + ) + + # Timestamps + created_at: datetime = Field( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the LLM call was created"}, + ) + + updated_at: datetime = Field( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the LLM call was last updated"}, + ) + + deleted_at: datetime | None = Field( + default=None, + nullable=True, + sa_column_kwargs={"comment": "Timestamp when the record was soft-deleted"}, + ) diff --git a/backend/app/models/llm/response.py b/backend/app/models/llm/response.py index 34c9b9d9b..b2170a8aa 100644 --- a/backend/app/models/llm/response.py +++ b/backend/app/models/llm/response.py @@ -50,3 +50,10 @@ class LLMCallResponse(SQLModel): default=None, description="Unmodified raw response from the LLM provider.", ) + + +class LLMChainResponse(SQLModel): + response: LLMCallResponse = Field( + ..., description="Full response from the last block in the chain" + ) + # blocks_executed: int = Field(..., description="Total number of blocks executed") diff --git a/backend/app/services/llm/chain_executor.py b/backend/app/services/llm/chain_executor.py new file mode 100644 index 000000000..54a1b3e6a --- /dev/null +++ b/backend/app/services/llm/chain_executor.py @@ -0,0 +1,344 @@ +import logging +from uuid import UUID + +from asgi_correlation_id import correlation_id +from fastapi import HTTPException +from sqlmodel import Session + +from app.celery.utils import start_high_priority_job +from app.core.db import engine +from app.core.langfuse.langfuse import observe_llm_execution +from app.crud.config import ConfigVersionCrud +from app.crud.credentials import get_provider_credential +from app.crud.jobs import JobCrud +from app.crud.llm import create_llm_call, update_llm_call_response +from app.models import JobStatus, JobType, JobUpdate, LLMCallRequest, LLMChainResponse +from app.models.llm.request import ( + ConfigBlob, + KaapiCompletionConfig, + LLMChainRequest, + QueryParams, + TextInput, +) +from app.services.llm.input_resolver import cleanup_temp_file, resolve_input +from app.services.llm.jobs import handle_job_error, resolve_config_blob +from app.services.llm.mappers import transform_kaapi_config_to_native +from app.services.llm.providers.registry import get_llm_provider +from app.utils import APIResponse, send_callback + + +logger = logging.getLogger(__name__) + + +def start_chain_job( + db: Session, + request: LLMChainRequest, + project_id: int, + organization_id: int, +) -> UUID: + """Create an LLM chain job and schedule Celery task.""" + trace_id = correlation_id.get() or "N/A" + job_crud = JobCrud(session=db) + job = job_crud.create(job_type=JobType.LLM_CHAIN, trace_id=trace_id) + + db.flush() + db.commit() + + logger.info( + f"[start_chain_job] Created chain job | job_id={job.id}, " + f"blocks={len(request.blocks)}, project_id={project_id}" + ) + + try: + task_id = start_high_priority_job( + function_path="app.services.llm.chain_executor.execute_chain_job", + project_id=project_id, + job_id=str(job.id), + trace_id=trace_id, + request_data=request.model_dump(mode="json"), + organization_id=organization_id, + ) + except Exception as e: + logger.error( + f"[start_chain_job] Error starting Celery task: {str(e)} | job_id={job.id}", + exc_info=True, + ) + job_crud.update( + job_id=job.id, + job_update=JobUpdate(status=JobStatus.FAILED, error_message=str(e)), + ) + raise HTTPException( + status_code=500, + detail="Internal server error while starting chain execution", + ) + + logger.info( + f"[start_chain_job] Chain job scheduled | job_id={job.id}, task_id={task_id}" + ) + return job.id + + +def _interpolate_template(text: str, config_blob: ConfigBlob) -> str: + if config_blob.prompt_template: + return config_blob.prompt_template.template.replace("{{input}}", text) + return text + + +def execute_chain_job( + request_data: dict, + project_id: int, + organization_id: int, + job_id: str, + task_id: str, + task_instance, +) -> dict: + """Celery task to process an LLM chain request asynchronously.""" + request = LLMChainRequest(**request_data) + job_id: UUID = UUID(job_id) + callback_url = str(request.callback_url) if request.callback_url else None + + logger.info( + f"[execute_chain_job] Starting chain | job_id={job_id}, " + f"blocks={len(request.blocks)}, task_id={task_id}" + ) + + try: + with Session(engine) as session: + JobCrud(session=session).update( + job_id=job_id, + job_update=JobUpdate(status=JobStatus.PROCESSING), + ) + + previous_output: str | None = None + last_response = None + + for block_idx, block in enumerate(request.blocks): + is_last = block_idx == len(request.blocks) - 1 + + block_config = block.config + + logger.info(f"[BLOCK CONFIG] ===> {block_config}") + + logger.info( + f"[execute_chain_job] Executing block {block_idx}/{len(request.blocks) - 1} " + f"| job_id={job_id}" + ) + + if block_idx > 0 and not previous_output: + callback_response = APIResponse.failure_response( + error=f"Block {block_idx - 1} returned empty output, cannot continue chain", + metadata=block.request_metadata, + ) + return handle_job_error(job_id, callback_url, callback_response) + + with Session(engine) as session: + config_blob: ConfigBlob | None = None + + if block_config.is_stored_config: + config_crud = ConfigVersionCrud( + session=session, + project_id=project_id, + config_id=block_config.id, + ) + config_blob, error = resolve_config_blob(config_crud, block_config) + if error: + callback_response = APIResponse.failure_response( + error=f"Block {block_idx}: {error}", + metadata=block.request_metadata, + ) + return handle_job_error(job_id, callback_url, callback_response) + else: + config_blob = block_config.blob + + completion_config = config_blob.completion + original_provider = completion_config.provider + + if isinstance(completion_config, KaapiCompletionConfig): + try: + completion_config, warnings = transform_kaapi_config_to_native( + completion_config + ) + except Exception as e: + callback_response = APIResponse.failure_response( + error=f"Block {block_idx}: Config transformation error: {str(e)}", + metadata=block.request_metadata, + ) + return handle_job_error(job_id, callback_url, callback_response) + + if block_idx == 0: + block_query = request.query + if config_blob.prompt_template and isinstance( + request.query.input, TextInput + ): + interpolated = _interpolate_template( + request.query.input.content, config_blob + ) + block_query = QueryParams(input=interpolated) + else: + block_input = _interpolate_template(previous_output, config_blob) + block_query = QueryParams(input=block_input) + + resolved_config_blob = ConfigBlob(completion=completion_config) + synthetic_request = LLMCallRequest( + query=block_query, + config=block_config, + request_metadata=block.request_metadata, + ) + + try: + llm_call = create_llm_call( + session, + request=synthetic_request, + job_id=job_id, + project_id=project_id, + organization_id=organization_id, + resolved_config=resolved_config_blob, + original_provider=original_provider, + ) + llm_call_id = llm_call.id + except Exception as e: + logger.error( + f"[execute_chain_job] Failed to create LLM call record " + f"for block {block_idx}: {str(e)}", + exc_info=True, + ) + callback_response = APIResponse.failure_response( + error=f"Block {block_idx}: Failed to create LLM call record", + metadata=block.request_metadata, + ) + return handle_job_error(job_id, callback_url, callback_response) + + try: + provider_instance = get_llm_provider( + session=session, + provider_type=completion_config.provider, + project_id=project_id, + organization_id=organization_id, + ) + except ValueError as ve: + callback_response = APIResponse.failure_response( + error=f"Block {block_idx}: {str(ve)}", + metadata=block.request_metadata, + ) + return handle_job_error(job_id, callback_url, callback_response) + + langfuse_credentials = get_provider_credential( + session=session, + org_id=organization_id, + project_id=project_id, + provider="langfuse", + ) + + conversation_id = None + if block_query.conversation and block_query.conversation.id: + conversation_id = block_query.conversation.id + + resolved_input, resolve_error = resolve_input(block_query.input) + if resolve_error: + callback_response = APIResponse.failure_response( + error=f"Block {block_idx}: {resolve_error}", + metadata=block.request_metadata, + ) + return handle_job_error(job_id, callback_url, callback_response) + + decorated_execute = observe_llm_execution( + credentials=langfuse_credentials, + session_id=conversation_id, + )(provider_instance.execute) + + try: + response, error = decorated_execute( + completion_config=completion_config, + query=block_query, + resolved_input=resolved_input, + include_provider_raw_response=block.include_provider_raw_response, + ) + finally: + if resolved_input and resolved_input != block_query.input: + cleanup_temp_file(resolved_input) + + if not response: + callback_response = APIResponse.failure_response( + error=f"Block {block_idx}: {error or 'Unknown error'}", + metadata=block.request_metadata, + ) + return handle_job_error(job_id, callback_url, callback_response) + + with Session(engine) as session: + try: + update_llm_call_response( + session, + llm_call_id=llm_call_id, + provider_response_id=response.response.provider_response_id, + content=response.response.output.model_dump(), + usage=response.usage.model_dump(), + conversation_id=response.response.conversation_id, + ) + except Exception as e: + logger.error( + f"[execute_chain_job] Failed to update LLM call record " + f"for block {block_idx}: {str(e)}", + exc_info=True, + ) + + previous_output = response.response.output.text + last_response = response + + logger.info( + f"[execute_chain_job] Block {block_idx} completed | job_id={job_id}, " + f"provider={response.response.provider}, model={response.response.model}, " + f"tokens={response.usage.total_tokens}" + ) + + if not is_last and block.intermediate_callback and callback_url: + send_callback( + callback_url=callback_url, + data=APIResponse.success_response( + data={ + "type": "intermediary", + "block_index": block_idx + 1, + "blocks_total": len(request.blocks), + "response": response.model_dump(), + }, + metadata=block.request_metadata, + ).model_dump(), + ) + + chain_response = LLMChainResponse( + response=last_response, + # blocks_executed=len(request.blocks), + ) + + callback_response = APIResponse.success_response( + data=chain_response, metadata=block.request_metadata + ) + + if callback_url: + send_callback( + callback_url=callback_url, + data=callback_response.model_dump(), + ) + + with Session(engine) as session: + JobCrud(session=session).update( + job_id=job_id, + job_update=JobUpdate(status=JobStatus.SUCCESS), + ) + + logger.info( + f"[execute_chain_job] Chain completed | job_id={job_id}, " + f"blocks_executed={len(request.blocks)}" + ) + + return callback_response.model_dump() + + except Exception as e: + logger.error( + f"[execute_chain_job] Unexpected error: {str(e)} | job_id={job_id}, task_id={task_id}", + exc_info=True, + ) + callback_response = APIResponse.failure_response( + error="Unexpected error occurred during chain execution", + metadata=block.request_metadata, + ) + return handle_job_error(job_id, callback_url, callback_response)