Skip to content

Commit

Permalink
fix: improve ollama docs, s/ollama_dir/ollama_home/g (#619)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexanderankin authored Jun 27, 2024
1 parent 5442d05 commit 27f2a6b
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 7 deletions.
50 changes: 45 additions & 5 deletions modules/ollama/testcontainers/ollama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,27 +34,67 @@ class OllamaContainer(DockerContainer):
"""
Ollama Container
Example:
:param: image - the ollama image to use (default: :code:`ollama/ollama:0.1.44`)
:param: ollama_home - the directory to mount for model data (default: None)
you may pass :code:`pathlib.Path.home() / ".ollama"` to re-use models
that have already been pulled with ollama running on this host outside the container.
Examples:
.. doctest::
>>> from testcontainers.ollama import OllamaContainer
>>> with OllamaContainer() as ollama:
... ollama.list_models()
[]
.. code-block:: python
>>> from json import loads
>>> from pathlib import Path
>>> from requests import post
>>> from testcontainers.ollama import OllamaContainer
>>> def split_by_line(generator):
... data = b''
... for each_item in generator:
... for line in each_item.splitlines(True):
... data += line
... if data.endswith((b'\\r\\r', b'\\n\\n', b'\\r\\n\\r\\n', b'\\n')):
... yield from data.splitlines()
... data = b''
... if data:
... yield from data.splitlines()
>>> with OllamaContainer(ollama_home=Path.home() / ".ollama") as ollama:
... if "llama3:latest" not in [e["name"] for e in ollama.list_models()]:
... print("did not find 'llama3:latest', pulling")
... ollama.pull_model("llama3:latest")
... endpoint = ollama.get_endpoint()
... for chunk in split_by_line(
... post(url=f"{endpoint}/api/chat", stream=True, json={
... "model": "llama3:latest",
... "messages": [{
... "role": "user",
... "content": "what color is the sky? MAX ONE WORD"
... }]
... })
... ):
... print(loads(chunk)["message"]["content"], end="")
Blue.
"""

OLLAMA_PORT = 11434

def __init__(
self,
image: str = "ollama/ollama:0.1.44",
ollama_dir: Optional[Union[str, PathLike]] = None,
ollama_home: Optional[Union[str, PathLike]] = None,
**kwargs,
#
):
super().__init__(image=image, **kwargs)
self.ollama_dir = ollama_dir
self.ollama_home = ollama_home
self.with_exposed_ports(OllamaContainer.OLLAMA_PORT)
self._check_and_add_gpu_capabilities()

Expand All @@ -67,8 +107,8 @@ def start(self) -> "OllamaContainer":
"""
Start the Ollama server
"""
if self.ollama_dir:
self.with_volume_mapping(self.ollama_dir, "/root/.ollama", "rw")
if self.ollama_home:
self.with_volume_mapping(self.ollama_home, "/root/.ollama", "rw")
super().start()
wait_for_logs(self, "Listening on ", timeout=30)

Expand Down
4 changes: 2 additions & 2 deletions modules/ollama/tests/test_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ def test_download_model_and_commit_to_image():


def test_models_saved_in_folder(tmp_path: Path):
with OllamaContainer("ollama/ollama:0.1.26", ollama_dir=tmp_path) as ollama:
with OllamaContainer("ollama/ollama:0.1.26", ollama_home=tmp_path) as ollama:
assert len(ollama.list_models()) == 0
ollama.pull_model("all-minilm")
assert len(ollama.list_models()) == 1
assert "all-minilm" in ollama.list_models()[0].get("name")

with OllamaContainer("ollama/ollama:0.1.26", ollama_dir=tmp_path) as ollama:
with OllamaContainer("ollama/ollama:0.1.26", ollama_home=tmp_path) as ollama:
assert len(ollama.list_models()) == 1
assert "all-minilm" in ollama.list_models()[0].get("name")

0 comments on commit 27f2a6b

Please sign in to comment.