diff --git a/README.md b/README.md index 482c34b8..b5eb6a33 100644 --- a/README.md +++ b/README.md @@ -250,7 +250,6 @@ ollama.embed(model='gemma3', input=['The sky is blue because of rayleigh scatter ollama.ps() ``` - ## Errors Errors are raised if requests return an error status or if an error is detected while streaming. diff --git a/examples/README.md b/examples/README.md index 1d8c9bdc..1df713ea 100644 --- a/examples/README.md +++ b/examples/README.md @@ -78,6 +78,12 @@ Configuration to use with an MCP client: - [multimodal-chat.py](multimodal-chat.py) - [multimodal-generate.py](multimodal-generate.py) +### Image Generation (Experimental) - Generate images with a model + +> **Note:** Image generation is experimental and currently only available on macOS. + +- [generate-image.py](generate-image.py) + ### Structured Outputs - Generate structured outputs with a model - [structured-outputs.py](structured-outputs.py) diff --git a/examples/generate-image.py b/examples/generate-image.py new file mode 100644 index 00000000..f27dc5a3 --- /dev/null +++ b/examples/generate-image.py @@ -0,0 +1,18 @@ +# Image generation is experimental and currently only available on macOS + +import base64 + +from ollama import generate + +prompt = 'a sunset over mountains' +print(f'Prompt: {prompt}') + +for response in generate(model='x/z-image-turbo', prompt=prompt, stream=True): + if response.image: + # Final response contains the image + with open('output.png', 'wb') as f: + f.write(base64.b64decode(response.image)) + print('\nImage saved to output.png') + elif response.total: + # Progress update + print(f'Progress: {response.completed or 0}/{response.total}', end='\r') diff --git a/ollama/_client.py b/ollama/_client.py index ae755eb2..18cb0fb4 100644 --- a/ollama/_client.py +++ b/ollama/_client.py @@ -217,6 +217,9 @@ def generate( images: Optional[Sequence[Union[str, bytes, Image]]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, + width: Optional[int] = None, + height: Optional[int] = None, + steps: Optional[int] = None, ) -> GenerateResponse: ... @overload @@ -238,6 +241,9 @@ def generate( images: Optional[Sequence[Union[str, bytes, Image]]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, + width: Optional[int] = None, + height: Optional[int] = None, + steps: Optional[int] = None, ) -> Iterator[GenerateResponse]: ... def generate( @@ -258,6 +264,9 @@ def generate( images: Optional[Sequence[Union[str, bytes, Image]]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, + width: Optional[int] = None, + height: Optional[int] = None, + steps: Optional[int] = None, ) -> Union[GenerateResponse, Iterator[GenerateResponse]]: """ Create a response using the requested model. @@ -289,6 +298,9 @@ def generate( images=list(_copy_images(images)) if images else None, options=options, keep_alive=keep_alive, + width=width, + height=height, + steps=steps, ).model_dump(exclude_none=True), stream=stream, ) @@ -838,6 +850,9 @@ async def generate( images: Optional[Sequence[Union[str, bytes, Image]]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, + width: Optional[int] = None, + height: Optional[int] = None, + steps: Optional[int] = None, ) -> GenerateResponse: ... @overload @@ -859,6 +874,9 @@ async def generate( images: Optional[Sequence[Union[str, bytes, Image]]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, + width: Optional[int] = None, + height: Optional[int] = None, + steps: Optional[int] = None, ) -> AsyncIterator[GenerateResponse]: ... async def generate( @@ -879,6 +897,9 @@ async def generate( images: Optional[Sequence[Union[str, bytes, Image]]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, + width: Optional[int] = None, + height: Optional[int] = None, + steps: Optional[int] = None, ) -> Union[GenerateResponse, AsyncIterator[GenerateResponse]]: """ Create a response using the requested model. @@ -909,6 +930,9 @@ async def generate( images=list(_copy_images(images)) if images else None, options=options, keep_alive=keep_alive, + width=width, + height=height, + steps=steps, ).model_dump(exclude_none=True), stream=stream, ) diff --git a/ollama/_types.py b/ollama/_types.py index 8931ceac..96529d63 100644 --- a/ollama/_types.py +++ b/ollama/_types.py @@ -216,6 +216,16 @@ class GenerateRequest(BaseGenerateRequest): top_logprobs: Optional[int] = None 'Number of alternative tokens and log probabilities to include per position (0-20).' + # Experimental image generation parameters + width: Optional[int] = None + 'Width of the generated image in pixels (for image generation models).' + + height: Optional[int] = None + 'Height of the generated image in pixels (for image generation models).' + + steps: Optional[int] = None + 'Number of diffusion steps (for image generation models).' + class BaseGenerateResponse(SubscriptableBaseModel): model: Optional[str] = None @@ -267,7 +277,7 @@ class GenerateResponse(BaseGenerateResponse): Response returned by generate requests. """ - response: str + response: Optional[str] = None 'Response content. When streaming, this contains a fragment of the response.' thinking: Optional[str] = None @@ -279,6 +289,17 @@ class GenerateResponse(BaseGenerateResponse): logprobs: Optional[Sequence[Logprob]] = None 'Log probabilities for generated tokens.' + # Image generation response fields + image: Optional[str] = None + 'Base64-encoded generated image data (for image generation models).' + + # Streaming progress fields (for image generation) + completed: Optional[int] = None + 'Number of completed steps (for image generation streaming).' + + total: Optional[int] = None + 'Total number of steps (for image generation streaming).' + class Message(SubscriptableBaseModel): """ diff --git a/tests/test_client.py b/tests/test_client.py index 24a8af34..34657513 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -568,6 +568,115 @@ class ResponseFormat(BaseModel): assert response['response'] == '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}' +def test_client_generate_image(httpserver: HTTPServer): + httpserver.expect_ordered_request( + '/api/generate', + method='POST', + json={ + 'model': 'dummy-image', + 'prompt': 'a sunset over mountains', + 'stream': False, + 'width': 1024, + 'height': 768, + 'steps': 20, + }, + ).respond_with_json( + { + 'model': 'dummy-image', + 'image': PNG_BASE64, + 'done': True, + 'done_reason': 'stop', + } + ) + + client = Client(httpserver.url_for('/')) + response = client.generate('dummy-image', 'a sunset over mountains', width=1024, height=768, steps=20) + assert response['model'] == 'dummy-image' + assert response['image'] == PNG_BASE64 + assert response['done'] is True + + +def test_client_generate_image_stream(httpserver: HTTPServer): + def stream_handler(_: Request): + def generate(): + # Progress updates + for i in range(1, 4): + yield ( + json.dumps( + { + 'model': 'dummy-image', + 'completed': i, + 'total': 3, + 'done': False, + } + ) + + '\n' + ) + # Final response with image + yield ( + json.dumps( + { + 'model': 'dummy-image', + 'image': PNG_BASE64, + 'done': True, + 'done_reason': 'stop', + } + ) + + '\n' + ) + + return Response(generate()) + + httpserver.expect_ordered_request( + '/api/generate', + method='POST', + json={ + 'model': 'dummy-image', + 'prompt': 'a sunset over mountains', + 'stream': True, + 'width': 512, + 'height': 512, + }, + ).respond_with_handler(stream_handler) + + client = Client(httpserver.url_for('/')) + response = client.generate('dummy-image', 'a sunset over mountains', stream=True, width=512, height=512) + + parts = list(response) + # Check progress updates + assert parts[0]['completed'] == 1 + assert parts[0]['total'] == 3 + assert parts[0]['done'] is False + # Check final response + assert parts[-1]['image'] == PNG_BASE64 + assert parts[-1]['done'] is True + + +async def test_async_client_generate_image(httpserver: HTTPServer): + httpserver.expect_ordered_request( + '/api/generate', + method='POST', + json={ + 'model': 'dummy-image', + 'prompt': 'a robot painting', + 'stream': False, + 'width': 1024, + 'height': 1024, + }, + ).respond_with_json( + { + 'model': 'dummy-image', + 'image': PNG_BASE64, + 'done': True, + } + ) + + client = AsyncClient(httpserver.url_for('/')) + response = await client.generate('dummy-image', 'a robot painting', width=1024, height=1024) + assert response['model'] == 'dummy-image' + assert response['image'] == PNG_BASE64 + + def test_client_pull(httpserver: HTTPServer): httpserver.expect_ordered_request( '/api/pull',