|
| 1 | +""" |
| 2 | +Llama.cpp Provider Implementation |
| 3 | +
|
| 4 | +Implements the LLMProvider interface using a local model. |
| 5 | +
|
| 6 | +This provider uses llama-cpp-python to run inference |
| 7 | +on quantized models (GGUF format). |
| 8 | +""" |
| 9 | + |
| 10 | +from threading import Lock |
| 11 | +from llama_cpp import Llama |
| 12 | +from api.config.loader import CONFIG |
| 13 | +from api.models.llm_provider import LLMProvider |
| 14 | +from utils import LoggerFactory |
| 15 | + |
| 16 | +llm_config = CONFIG["llm"] |
| 17 | +logger = LoggerFactory.instance().get_logger("llm") |
| 18 | + |
| 19 | +# pylint: disable=too-few-public-methods |
| 20 | +class LlamaCppProvider(LLMProvider): |
| 21 | + """ |
| 22 | + LLMProvider implementation for local llama.cpp models. |
| 23 | + """ |
| 24 | + def __init__(self): |
| 25 | + """ |
| 26 | + Initializes the Llama model with configuration from config.yml. |
| 27 | + Sets up a lock to ensure thread-safe usage. |
| 28 | + """ |
| 29 | + self.llm = Llama( |
| 30 | + model_path=llm_config["model_path"], |
| 31 | + n_ctx=llm_config["context_length"], |
| 32 | + n_threads=llm_config["threads"], |
| 33 | + n_gpu_layers=llm_config["gpu_layers"], |
| 34 | + verbose=llm_config["verbose"] |
| 35 | + ) |
| 36 | + self.lock = Lock() |
| 37 | + |
| 38 | + def generate(self, prompt: str, max_tokens: int) -> str: |
| 39 | + """ |
| 40 | + Generate a response from the model given a prompt. |
| 41 | +
|
| 42 | + Args: |
| 43 | + prompt (str): Prompt to feed into the model. |
| 44 | + max_tokens (int): Maximum number of tokens to generate. |
| 45 | +
|
| 46 | + Returns: |
| 47 | + str: The generated text response. |
| 48 | + """ |
| 49 | + try: |
| 50 | + with self.lock: |
| 51 | + output = self.llm( |
| 52 | + prompt=prompt, |
| 53 | + max_tokens=max_tokens, |
| 54 | + echo=False |
| 55 | + ) |
| 56 | + return output["choices"][0]["text"].strip() |
| 57 | + except ValueError as e: |
| 58 | + logger.error("Invalid model configuration: %s", e) |
| 59 | + raise RuntimeError("LLM model could not be initialized. Check the model path.") from e |
| 60 | + except Exception as e: # pylint: disable=broad-exception-caught |
| 61 | + logger.error("Unexpected error during LLM generation: %s", e) |
| 62 | + return "Sorry, something went wrong during generation." |
| 63 | + |
| 64 | +llm_provider = LlamaCppProvider() |
0 commit comments