From fc1422f099d95878209c92b3e9e2f509fe8ca77e Mon Sep 17 00:00:00 2001 From: Ben Sima Date: Wed, 4 Dec 2024 21:55:02 -0500 Subject: Add some mock tests of the Image endpoint These were contributed in part by gptme, thanks! --- Biz/Storybook.py | 101 +++++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 91 insertions(+), 10 deletions(-) (limited to 'Biz/Storybook.py') diff --git a/Biz/Storybook.py b/Biz/Storybook.py index 3659c37..c619ef8 100644 --- a/Biz/Storybook.py +++ b/Biz/Storybook.py @@ -38,6 +38,7 @@ import sys import time import typing import unittest +import unittest.mock as mock import uvicorn MOCK = True @@ -365,7 +366,7 @@ class Images(ludic.web.Endpoint[Image]): """ if MOCK: # Simulate slow image generation - time.sleep(3) + time.sleep(1) return ludic.web.responses.FileResponse( DATA_DIR / "images" / "placeholder.jpg", ) @@ -449,18 +450,98 @@ class ImagesTest(unittest.TestCase): def setUp(self) -> None: """Create test client.""" self.client = starlette.testclient.TestClient(app) + self.story_id = "Uk" + self.page = 1 + self.valid_prompt = {"text": "A beautiful sunset over the ocean"} + + def test_image_get_existing(self) -> None: + """Test retrieving an existing image.""" + # Arrange: Mock the load_by_id method to simulate an existing image + data = {"path": DATA_DIR / "images" / "placeholder.jpg"} + mock_dict = mock.MagicMock() + mock_dict.__getitem__.side_effect = data.__getitem__ + with mock.patch.object( + Images, + "load_by_id", + return_value=mock_dict, + ): + # Act: Send a GET request to retrieve the image + response = self.client.get( + app.url_path_for( + "Images", + story_id=self.story_id, + page=self.page, + ), + ) + # Assert: Check that the response status is 200 + self.assertEqual(response.status_code, 200) - def test_image_post(self) -> None: - """Can POST an Image successfully.""" - response = self.client.post( - app.url_path_for( - "Images", - story_id="Uk", - page=1, + def test_image_get_nonexistent(self) -> None: + """Test retrieving a non-existent image.""" + # Act: Send a GET request for a non-existent image + response = self.client.get( + app.url_path_for("Images", story_id="nonexistent", page=self.page), + ) + # Assert: Check that the response status is 404 + self.assertEqual(response.status_code, 404) + + def test_image_post_valid(self) -> None: + """Test creating an image with valid data.""" + # Arrange: Mock the OpenAI API and file system operations + with ( + mock.patch("Biz.Storybook.openai.OpenAI") as mock_openai, + mock.patch( + "Biz.Storybook.pathlib.Path.write_bytes", ), - data={"text": "lorem ipsum"}, + ): + mock_openai.return_value.images.generate.return_value.data = [ + mock.MagicMock(url="http://example.com/image.jpg"), + ] + # Act: Send a POST request with valid data + response = self.client.post( + app.url_path_for( + "Images", + story_id=self.story_id, + page=self.page, + ), + data=self.valid_prompt, + ) + # Assert: Check that the response status is 200 + self.assertEqual(response.status_code, 200) + + def test_image_post_invalid(self) -> None: + """Test creating an image with invalid data.""" + # Act: Send a POST request with invalid data + response = self.client.post( + app.url_path_for("Images", story_id=self.story_id, page=self.page), + data={"invalid": "data"}, ) - self.assertEqual(response.status_code, 200) + # Assert: Check that the response status indicates an error + self.assertNotEqual(response.status_code, 200) + + def test_image_put_overwrite(self) -> None: + """Test overwriting an existing image.""" + # Arrange: Mock the OpenAI API and file system operations + with ( + mock.patch("Biz.Storybook.openai.OpenAI") as mock_openai, + mock.patch( + "Biz.Storybook.pathlib.Path.write_bytes", + ), + ): + mock_openai.return_value.images.generate.return_value.data = [ + mock.MagicMock(url="http://example.com/image.jpg"), + ] + # Act: Send a PUT request to overwrite the image + response = self.client.put( + app.url_path_for( + "Images", + story_id=self.story_id, + page=self.page, + ), + data=self.valid_prompt, + ) + # Assert: Check that the response status is 200 + self.assertEqual(response.status_code, 200) @app.endpoint("/stories/{story_id:str}") -- cgit v1.2.3