Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,23 @@ text_gen = llm(provider="openai", api_base="http://localhost:8000/v1", api_key=
lida = Manager(text_gen = text_gen)
```

## Using LIDA with Custom Text Generation

LIDA uses the [custom_textgen] as its interface for cutom text generation. cutom_textgen supports any models. You can use your custom text generators directly.

```python
from lida import Manager, CustomTextGenerator

def generate_text(prompt: str) -> str:
# Your cutom method for text generation
pass

text_gen = CustomTextGenerator(text_generation_function=generate_text)
lida = Manager(text_gen=text_gen)
# now you can call lida methods as above e.g.
sumamry = lida.summarize("data/cars.csv") # ....
```

## Important Notes / Caveats / FAQs

- LIDA generates and executes code based on provided input. Ensure that you run LIDA in a secure environment with appropriate permissions.
Expand Down
3 changes: 2 additions & 1 deletion lida/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from llmx import TextGenerationConfig, llm, TextGenerator
from .components.manager import Manager
from .components.cutom_textgen import CustomTextGenerator


__all__ = ["TextGenerationConfig", "llm", "TextGenerator", "Manager"]
__all__ = ["TextGenerationConfig", "llm", "TextGenerator", "Manager", "CustomTextGenerator"]
59 changes: 59 additions & 0 deletions lida/components/cutom_textgen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import Union, List, Dict, Callable
from dataclasses import asdict
from llmx import TextGenerator
from lida.datamodel import TextGenerationConfig
from llmx import TextGenerationResponse, Message
from llmx.utils import cache_request, num_tokens_from_messages


class CustomTextGenerator(TextGenerator):
def __init__(
self,
text_generation_function: Callable[[str], str],
provider: str = "custom",
**kwargs
):
super().__init__(provider=provider, **kwargs)
self.text_generation_function = text_generation_function

def generate(
self,
messages: Union[List[Dict], str],
config: TextGenerationConfig = TextGenerationConfig(),
**kwargs
) -> TextGenerationResponse:
use_cache = config.use_cache
messages = self.format_messages(messages)

if use_cache:
response = cache_request(cache=self.cache, params={"messages": messages})
if response:
return TextGenerationResponse(**response)

generation_response = self.text_generation_function(messages)
response = TextGenerationResponse(
text=[Message(role="system", content=generation_response)],
logprobs=[], # You may need to extract log probabilities from the response if needed
usage={},
config={},
)

if use_cache:
cache_request(
cache=self.cache, params={"messages": messages}, values=asdict(response)
)

return response

def format_messages(self, messages) -> str:
prompt = ""
for message in messages:
if message["role"] == "system":
prompt += message["content"] + "\n"
else:
prompt += message["role"] + ": " + message["content"] + "\n"

return prompt

def count_tokens(self, text) -> int:
return num_tokens_from_messages(text)