diff --git a/private_gpt/utils/ollama.py b/private_gpt/utils/ollama.py index 41c7ecc..9c75a87 100644 --- a/private_gpt/utils/ollama.py +++ b/private_gpt/utils/ollama.py @@ -1,4 +1,9 @@ import logging +from collections import deque +from collections.abc import Iterator, Mapping +from typing import Any + +from tqdm import tqdm # type: ignore try: from ollama import Client # type: ignore @@ -19,12 +24,55 @@ def check_connection(client: Client) -> bool: return False +def process_streaming(generator: Iterator[Mapping[str, Any]]) -> None: + progress_bars = {} + queue = deque() # type: ignore + + def create_progress_bar(dgt: str, total: int) -> Any: + return tqdm( + total=total, desc=f"Pulling model {dgt[7:17]}...", unit="B", unit_scale=True + ) + + current_digest = None + + for chunk in generator: + digest = chunk.get("digest") + completed_size = chunk.get("completed", 0) + total_size = chunk.get("total") + + if digest and total_size is not None: + if digest not in progress_bars and completed_size > 0: + progress_bars[digest] = create_progress_bar(digest, total=total_size) + if current_digest is None: + current_digest = digest + else: + queue.append(digest) + + if digest in progress_bars: + progress_bar = progress_bars[digest] + progress = completed_size - progress_bar.n + if completed_size > 0 and total_size >= progress != progress_bar.n: + if digest == current_digest: + progress_bar.update(progress) + if progress_bar.n >= total_size: + progress_bar.close() + current_digest = queue.popleft() if queue else None + else: + # Store progress for later update + progress_bars[digest].total = total_size + progress_bars[digest].n = completed_size + + # Close any remaining progress bars at the end + for progress_bar in progress_bars.values(): + progress_bar.close() + + def pull_model(client: Client, model_name: str, raise_error: bool = True) -> None: try: installed_models = [model["name"] for model in client.list().get("models", {})] if model_name not in installed_models: logger.info(f"Pulling model {model_name}. Please wait...") - client.pull(model_name) + process_streaming(client.pull(model_name, stream=True)) logger.info(f"Model {model_name} pulled successfully") except Exception as e: logger.error(f"Failed to pull model {model_name}: {e!s}")