From e195dc9450c6366d0b777906fa2c3a7a0af45d72 Mon Sep 17 00:00:00 2001 From: rootflo-hardik Date: Wed, 25 Feb 2026 16:00:18 +0530 Subject: [PATCH 1/4] flo_ai aws bedrock integration --- flo_ai/flo_ai/llm/__init__.py | 2 + flo_ai/flo_ai/llm/aws_bedrock_llm.py | 265 +++++++++++++++++++++++++++ 2 files changed, 267 insertions(+) create mode 100644 flo_ai/flo_ai/llm/aws_bedrock_llm.py diff --git a/flo_ai/flo_ai/llm/__init__.py b/flo_ai/flo_ai/llm/__init__.py index 934e4da4..1b8d2c42 100644 --- a/flo_ai/flo_ai/llm/__init__.py +++ b/flo_ai/flo_ai/llm/__init__.py @@ -6,6 +6,7 @@ from .openai_vllm import OpenAIVLLM from .vertexai_llm import VertexAI from .rootflo_llm import RootFloLLM +from .aws_bedrock_llm import AWSBedrock __all__ = [ 'BaseLLM', @@ -16,4 +17,5 @@ 'OpenAIVLLM', 'VertexAI', 'RootFloLLM', + 'AWSBedrock', ] diff --git a/flo_ai/flo_ai/llm/aws_bedrock_llm.py b/flo_ai/flo_ai/llm/aws_bedrock_llm.py new file mode 100644 index 00000000..99cfa040 --- /dev/null +++ b/flo_ai/flo_ai/llm/aws_bedrock_llm.py @@ -0,0 +1,265 @@ +import json +import re +from typing import Dict, Any, List, AsyncIterator, Optional +import boto3 +import asyncio +from .base_llm import BaseLLM +from flo_ai.models.chat_message import ImageMessageContent +from flo_ai.tool.base_tool import Tool +from flo_ai.telemetry.instrumentation import ( + trace_llm_call, + trace_llm_stream, + llm_metrics, + add_span_attributes, +) +from flo_ai.telemetry import get_tracer +from opentelemetry import trace + + +class AWSBedrock(BaseLLM): + def __init__( + self, + model: str = 'openai.gpt-oss-20b-1:0', + temperature: float = 0.7, + **kwargs, + ): + super().__init__(model=model, temperature=temperature, **kwargs) + self.boto_client = boto3.client('bedrock-runtime') + self.model = model + self.kwargs = kwargs + + @staticmethod + def _strip_reasoning(text: str) -> str: + return re.sub(r'.*?', '', text, flags=re.DOTALL).strip() + + def _convert_messages( + self, messages: list[dict], output_schema: dict = None + ) -> list[dict]: + result = [] + + if output_schema: + result.append( + { + 'role': 'system', + 'content': f'Provide output in the following JSON schema:\n{json.dumps(output_schema, indent=2)}', + } + ) + + for msg in messages: + if msg['role'] == 'function': + result.append( + { + 'role': 'tool', + 'tool_call_id': msg.get('tool_use_id', 'unknown'), + 'content': msg['content'], + 'name': msg.get('name', ''), + } + ) + else: + result.append(msg) + + return result + + @trace_llm_call(provider='bedrock') + async def generate( + self, + messages: list[dict], + functions: Optional[List[Dict[str, Any]]] = None, + output_schema: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> Any: + converted = self._convert_messages(messages, output_schema) + + request_body: Dict[str, Any] = { + 'model': self.model, + 'messages': converted, + 'temperature': self.temperature, + } + if 'max_tokens' in self.kwargs: + request_body['max_completion_tokens'] = self.kwargs['max_tokens'] + if functions: + request_body['tools'] = functions + + response = await asyncio.to_thread( + self.boto_client.invoke_model, + modelId=self.model, + body=json.dumps(request_body), + ) + response_body = json.loads(response['body'].read().decode('utf-8')) + + usage = response_body.get('usage', {}) + if usage: + llm_metrics.record_tokens( + total_tokens=usage.get('total_tokens', 0), + prompt_tokens=usage.get('prompt_tokens', 0), + completion_tokens=usage.get('completion_tokens', 0), + model=self.model, + provider='bedrock', + ) + tracer = get_tracer() + if tracer: + add_span_attributes( + trace.get_current_span(), + { + 'llm.tokens.prompt': usage.get('prompt_tokens', 0), + 'llm.tokens.completion': usage.get('completion_tokens', 0), + 'llm.tokens.total': usage.get('total_tokens', 0), + }, + ) + + choices = response_body.get('choices', []) + if not choices: + return {'content': '', 'raw_message': response_body} + + message = choices[0].get('message', {}) + if 'content' in message and message['content']: + message['content'] = self._strip_reasoning(message['content']) + text_content = message.get('content', '') or '' + tool_call = None + + tool_calls = message.get('tool_calls', []) + if tool_calls: + tc = tool_calls[0] + tool_call = { + 'name': tc['function']['name'], + 'arguments': tc['function']['arguments'], + 'id': tc['id'], + } + + if tool_call: + return { + 'content': text_content, + 'function_call': tool_call, + 'raw_message': message, + } + return {'content': text_content, 'raw_message': message} + + @trace_llm_stream(provider='bedrock') + async def stream( + self, + messages: List[Dict[str, Any]], + functions: Optional[List[Dict[str, Any]]] = None, + **kwargs: Any, + ) -> AsyncIterator[Dict[str, Any]]: + converted = self._convert_messages(messages) + + request_body: Dict[str, Any] = { + 'model': self.model, + 'messages': converted, + 'temperature': self.temperature, + 'stream': True, + } + if 'max_tokens' in self.kwargs: + request_body['max_completion_tokens'] = self.kwargs['max_tokens'] + if functions: + request_body['tools'] = functions + + response = await asyncio.to_thread( + self.boto_client.invoke_model_with_response_stream, + modelId=self.model, + body=json.dumps(request_body), + ) + + buffer = '' + + for event in response['body']: + chunk_bytes = event.get('chunk', {}).get('bytes', b'') + if not chunk_bytes: + continue + text = chunk_bytes.decode('utf-8').strip() + # Try direct JSON first (some Bedrock models skip SSE envelope) + try: + data = json.loads(text) + content = data.get('choices', [{}])[0].get('delta', {}).get('content') + if content: + buffer += content + continue + except json.JSONDecodeError: + pass + # Fall back to SSE format: "data: {...}" + for line in text.split('\n'): + line = line.strip() + if line.startswith('data: ') and line != 'data: [DONE]': + try: + data = json.loads(line[6:]) + content = ( + data.get('choices', [{}])[0].get('delta', {}).get('content') + ) + if content: + buffer += content + except json.JSONDecodeError: + pass + + clean = self._strip_reasoning(buffer) + if clean: + yield {'content': clean} + + def get_message_content(self, response: Dict[str, Any]) -> str: + content = ( + response.get('content', '') if isinstance(response, dict) else str(response) + ) + return self._strip_reasoning(content) + + def format_tool_for_llm(self, tool: 'Tool') -> Dict[str, Any]: + return { + 'type': 'function', + 'function': { + 'name': tool.name, + 'description': tool.description, + 'parameters': { + 'type': 'object', + 'properties': { + name: { + 'type': info.get('type', 'string'), + 'description': info.get('description', ''), + **( + {'items': info['items']} + if info.get('type') == 'array' and 'items' in info + else {} + ), + } + for name, info in tool.parameters.items() + }, + 'required': [ + name + for name, info in tool.parameters.items() + if info.get('required', True) + ], + }, + }, + } + + def format_tools_for_llm(self, tools: List['Tool']) -> List[Dict[str, Any]]: + return [self.format_tool_for_llm(tool) for tool in tools] + + def format_image_in_message(self, image: ImageMessageContent) -> dict: + if image.base64: + return { + 'type': 'image_url', + 'image_url': { + 'url': f'data:{image.mime_type or "image/jpeg"};base64,{image.base64}' + }, + } + raise NotImplementedError( + 'AWS Bedrock LLM requires image base64 data to format image content.' + ) + + def get_assistant_message_for_tool_call( + self, response: Dict[str, Any] + ) -> Optional[Any]: + if isinstance(response, dict) and 'raw_message' in response: + return response['raw_message'] + return None + + def get_tool_use_id(self, function_call: Dict[str, Any]) -> Optional[str]: + return function_call.get('id') + + def format_function_result_message( + self, function_name: str, content: str, tool_use_id: Optional[str] = None + ) -> Dict[str, Any]: + return { + 'role': 'tool', + 'tool_call_id': tool_use_id or 'unknown', + 'content': content, + 'name': function_name, + } From 9c583b615d4574047a99748a60ee12a9191baddb Mon Sep 17 00:00:00 2001 From: rootflo-hardik Date: Fri, 27 Feb 2026 16:26:15 +0530 Subject: [PATCH 2/4] resolved review comments --- flo_ai/flo_ai/llm/aws_bedrock_llm.py | 64 ++++++++++++++++------------ flo_ai/pyproject.toml | 3 +- flo_ai/uv.lock | 6 +-- 3 files changed, 39 insertions(+), 34 deletions(-) diff --git a/flo_ai/flo_ai/llm/aws_bedrock_llm.py b/flo_ai/flo_ai/llm/aws_bedrock_llm.py index 99cfa040..1dcee58e 100644 --- a/flo_ai/flo_ai/llm/aws_bedrock_llm.py +++ b/flo_ai/flo_ai/llm/aws_bedrock_llm.py @@ -13,10 +13,11 @@ add_span_attributes, ) from flo_ai.telemetry import get_tracer +from flo_ai.utils.logger import logger from opentelemetry import trace -class AWSBedrock(BaseLLM): +class AWSBedrock(BaseLLM): # Only openai compatible for now def __init__( self, model: str = 'openai.gpt-oss-20b-1:0', @@ -160,39 +161,46 @@ async def stream( body=json.dumps(request_body), ) - buffer = '' + queue: asyncio.Queue = asyncio.Queue() + loop = asyncio.get_event_loop() - for event in response['body']: - chunk_bytes = event.get('chunk', {}).get('bytes', b'') - if not chunk_bytes: - continue + def _iter_events(): + try: + for event in response['body']: + chunk_bytes = event.get('chunk', {}).get('bytes', b'') + if chunk_bytes: + loop.call_soon_threadsafe(queue.put_nowait, chunk_bytes) + finally: + loop.call_soon_threadsafe(queue.put_nowait, None) # sentinel + + loop.run_in_executor(None, _iter_events) + + while True: + chunk_bytes = await queue.get() + if chunk_bytes is None: + break text = chunk_bytes.decode('utf-8').strip() - # Try direct JSON first (some Bedrock models skip SSE envelope) + content = None try: data = json.loads(text) content = data.get('choices', [{}])[0].get('delta', {}).get('content') - if content: - buffer += content - continue except json.JSONDecodeError: - pass - # Fall back to SSE format: "data: {...}" - for line in text.split('\n'): - line = line.strip() - if line.startswith('data: ') and line != 'data: [DONE]': - try: - data = json.loads(line[6:]) - content = ( - data.get('choices', [{}])[0].get('delta', {}).get('content') - ) - if content: - buffer += content - except json.JSONDecodeError: - pass - - clean = self._strip_reasoning(buffer) - if clean: - yield {'content': clean} + for line in text.split('\n'): + line = line.strip() + if line.startswith('data: ') and line != 'data: [DONE]': + try: + data = json.loads(line[6:]) + content = ( + data.get('choices', [{}])[0] + .get('delta', {}) + .get('content') + ) + except json.JSONDecodeError: + logger.debug('Skipping malformed SSE line: %s', line) + if content: + clean = self._strip_reasoning(content) + if clean: + yield {'content': clean} def get_message_content(self, response: Dict[str, Any]) -> str: content = ( diff --git a/flo_ai/pyproject.toml b/flo_ai/pyproject.toml index 2b976988..7e2ff661 100644 --- a/flo_ai/pyproject.toml +++ b/flo_ai/pyproject.toml @@ -10,6 +10,7 @@ license = "MIT" dependencies = [ "aiohttp>=3.12.14,<4", "anthropic>=0.57.1,<0.58", + "boto3>=1.36.1,<2", "chardet>=3.dev0,<4.dev0", "cryptography>=46.0.3", "google-cloud-aiplatform>=1.109.0,<2", @@ -35,8 +36,6 @@ vizualize = [ [dependency-groups] dev = [ - "boto3>=1.36.1,<2", - "botocore>=1.36.1,<2", "db-sqlite3>=0.0.1,<0.0.2", "ipykernel>=6.29.5,<7", "peewee>=3.17.6,<4", diff --git a/flo_ai/uv.lock b/flo_ai/uv.lock index f9b039b9..36b047ce 100644 --- a/flo_ai/uv.lock +++ b/flo_ai/uv.lock @@ -902,6 +902,7 @@ source = { editable = "." } dependencies = [ { name = "aiohttp" }, { name = "anthropic" }, + { name = "boto3" }, { name = "chardet" }, { name = "cryptography" }, { name = "google-cloud-aiplatform" }, @@ -928,8 +929,6 @@ vizualize = [ [package.dev-dependencies] dev = [ - { name = "boto3" }, - { name = "botocore" }, { name = "db-sqlite3" }, { name = "ipykernel" }, { name = "peewee" }, @@ -947,6 +946,7 @@ dev = [ requires-dist = [ { name = "aiohttp", specifier = ">=3.12.14,<4" }, { name = "anthropic", specifier = ">=0.57.1,<0.58" }, + { name = "boto3", specifier = ">=1.36.1,<2" }, { name = "chardet", specifier = ">=3.dev0,<4.dev0" }, { name = "cryptography", specifier = ">=46.0.3" }, { name = "google-cloud-aiplatform", specifier = ">=1.109.0,<2" }, @@ -969,8 +969,6 @@ provides-extras = ["vizualize"] [package.metadata.requires-dev] dev = [ - { name = "boto3", specifier = ">=1.36.1,<2" }, - { name = "botocore", specifier = ">=1.36.1,<2" }, { name = "db-sqlite3", specifier = ">=0.0.1,<0.0.2" }, { name = "ipykernel", specifier = ">=6.29.5,<7" }, { name = "peewee", specifier = ">=3.17.6,<4" }, From 41903a5ef96835f55d224a3576ab7421258b0ae5 Mon Sep 17 00:00:00 2001 From: rootflo-hardik Date: Fri, 27 Feb 2026 16:36:12 +0530 Subject: [PATCH 3/4] fixed minor review comments --- flo_ai/flo_ai/llm/aws_bedrock_llm.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/flo_ai/flo_ai/llm/aws_bedrock_llm.py b/flo_ai/flo_ai/llm/aws_bedrock_llm.py index 1dcee58e..78c17d9c 100644 --- a/flo_ai/flo_ai/llm/aws_bedrock_llm.py +++ b/flo_ai/flo_ai/llm/aws_bedrock_llm.py @@ -34,7 +34,7 @@ def _strip_reasoning(text: str) -> str: return re.sub(r'.*?', '', text, flags=re.DOTALL).strip() def _convert_messages( - self, messages: list[dict], output_schema: dict = None + self, messages: list[dict], output_schema: dict | None = None ) -> list[dict]: result = [] @@ -162,7 +162,7 @@ async def stream( ) queue: asyncio.Queue = asyncio.Queue() - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() def _iter_events(): try: @@ -185,6 +185,7 @@ def _iter_events(): data = json.loads(text) content = data.get('choices', [{}])[0].get('delta', {}).get('content') except json.JSONDecodeError: + # Not valid JSON, try SSE format below for line in text.split('\n'): line = line.strip() if line.startswith('data: ') and line != 'data: [DONE]': From dfd19af9329ddcd7ad3ac71f169c887da6dfca17 Mon Sep 17 00:00:00 2001 From: rootflo-hardik Date: Fri, 27 Feb 2026 16:50:48 +0530 Subject: [PATCH 4/4] resolved review comments --- flo_ai/flo_ai/llm/aws_bedrock_llm.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/flo_ai/flo_ai/llm/aws_bedrock_llm.py b/flo_ai/flo_ai/llm/aws_bedrock_llm.py index 78c17d9c..741ec9c3 100644 --- a/flo_ai/flo_ai/llm/aws_bedrock_llm.py +++ b/flo_ai/flo_ai/llm/aws_bedrock_llm.py @@ -71,13 +71,14 @@ async def generate( ) -> Any: converted = self._convert_messages(messages, output_schema) + request_options = {**self.kwargs, **kwargs} request_body: Dict[str, Any] = { 'model': self.model, 'messages': converted, 'temperature': self.temperature, } - if 'max_tokens' in self.kwargs: - request_body['max_completion_tokens'] = self.kwargs['max_tokens'] + if 'max_tokens' in request_options: + request_body['max_completion_tokens'] = request_options['max_tokens'] if functions: request_body['tools'] = functions @@ -144,14 +145,15 @@ async def stream( ) -> AsyncIterator[Dict[str, Any]]: converted = self._convert_messages(messages) + request_options = {**self.kwargs, **kwargs} request_body: Dict[str, Any] = { 'model': self.model, 'messages': converted, 'temperature': self.temperature, 'stream': True, } - if 'max_tokens' in self.kwargs: - request_body['max_completion_tokens'] = self.kwargs['max_tokens'] + if 'max_tokens' in request_options: + request_body['max_completion_tokens'] = request_options['max_tokens'] if functions: request_body['tools'] = functions @@ -170,6 +172,8 @@ def _iter_events(): chunk_bytes = event.get('chunk', {}).get('bytes', b'') if chunk_bytes: loop.call_soon_threadsafe(queue.put_nowait, chunk_bytes) + except Exception as exc: + loop.call_soon_threadsafe(queue.put_nowait, exc) finally: loop.call_soon_threadsafe(queue.put_nowait, None) # sentinel @@ -177,6 +181,8 @@ def _iter_events(): while True: chunk_bytes = await queue.get() + if isinstance(chunk_bytes, Exception): + raise chunk_bytes if chunk_bytes is None: break text = chunk_bytes.decode('utf-8').strip() @@ -196,6 +202,11 @@ def _iter_events(): .get('delta', {}) .get('content') ) + if content: + clean = self._strip_reasoning(content) + if clean: + yield {'content': clean} + content = None except json.JSONDecodeError: logger.debug('Skipping malformed SSE line: %s', line) if content: