diff --git a/README.md b/README.md index c7f49b77a..b74659955 100644 --- a/README.md +++ b/README.md @@ -189,6 +189,23 @@ text_gen = llm(provider="openai", api_base="http://localhost:8000/v1", api_key= lida = Manager(text_gen = text_gen) ``` +## Using LIDA with Custom Text Generation + +LIDA uses the [custom_textgen] as its interface for cutom text generation. cutom_textgen supports any models. You can use your custom text generators directly. + +```python +from lida import Manager, CustomTextGenerator + +def generate_text(prompt: str) -> str: + # Your cutom method for text generation + pass + +text_gen = CustomTextGenerator(text_generation_function=generate_text) +lida = Manager(text_gen=text_gen) +# now you can call lida methods as above e.g. +sumamry = lida.summarize("data/cars.csv") # .... +``` + ## Important Notes / Caveats / FAQs - LIDA generates and executes code based on provided input. Ensure that you run LIDA in a secure environment with appropriate permissions. diff --git a/lida/__init__.py b/lida/__init__.py index b74cbaa26..b51d2c089 100644 --- a/lida/__init__.py +++ b/lida/__init__.py @@ -1,5 +1,6 @@ from llmx import TextGenerationConfig, llm, TextGenerator from .components.manager import Manager +from .components.cutom_textgen import CustomTextGenerator -__all__ = ["TextGenerationConfig", "llm", "TextGenerator", "Manager"] +__all__ = ["TextGenerationConfig", "llm", "TextGenerator", "Manager", "CustomTextGenerator"] diff --git a/lida/components/cutom_textgen.py b/lida/components/cutom_textgen.py new file mode 100644 index 000000000..7f8521bb2 --- /dev/null +++ b/lida/components/cutom_textgen.py @@ -0,0 +1,59 @@ +from typing import Union, List, Dict, Callable +from dataclasses import asdict +from llmx import TextGenerator +from lida.datamodel import TextGenerationConfig +from llmx import TextGenerationResponse, Message +from llmx.utils import cache_request, num_tokens_from_messages + + +class CustomTextGenerator(TextGenerator): + def __init__( + self, + text_generation_function: Callable[[str], str], + provider: str = "custom", + **kwargs + ): + super().__init__(provider=provider, **kwargs) + self.text_generation_function = text_generation_function + + def generate( + self, + messages: Union[List[Dict], str], + config: TextGenerationConfig = TextGenerationConfig(), + **kwargs + ) -> TextGenerationResponse: + use_cache = config.use_cache + messages = self.format_messages(messages) + + if use_cache: + response = cache_request(cache=self.cache, params={"messages": messages}) + if response: + return TextGenerationResponse(**response) + + generation_response = self.text_generation_function(messages) + response = TextGenerationResponse( + text=[Message(role="system", content=generation_response)], + logprobs=[], # You may need to extract log probabilities from the response if needed + usage={}, + config={}, + ) + + if use_cache: + cache_request( + cache=self.cache, params={"messages": messages}, values=asdict(response) + ) + + return response + + def format_messages(self, messages) -> str: + prompt = "" + for message in messages: + if message["role"] == "system": + prompt += message["content"] + "\n" + else: + prompt += message["role"] + ": " + message["content"] + "\n" + + return prompt + + def count_tokens(self, text) -> int: + return num_tokens_from_messages(text) \ No newline at end of file