summaryrefslogtreecommitdiff
path: root/Biz
diff options
context:
space:
mode:
authorBen Sima <ben@bsima.me>2024-12-04 21:55:02 -0500
committerBen Sima <ben@bsima.me>2024-12-21 10:08:08 -0500
commitfc1422f099d95878209c92b3e9e2f509fe8ca77e (patch)
treecb92424ff16b8192117680baea82925511852e1e /Biz
parent2f2c0eaa0e2615d433bad5aa583e687629f2371f (diff)
Add some mock tests of the Image endpoint
These were contributed in part by gptme, thanks!
Diffstat (limited to 'Biz')
-rw-r--r--Biz/Storybook.py101
1 files changed, 91 insertions, 10 deletions
diff --git a/Biz/Storybook.py b/Biz/Storybook.py
index 3659c37..c619ef8 100644
--- a/Biz/Storybook.py
+++ b/Biz/Storybook.py
@@ -38,6 +38,7 @@ import sys
import time
import typing
import unittest
+import unittest.mock as mock
import uvicorn
MOCK = True
@@ -365,7 +366,7 @@ class Images(ludic.web.Endpoint[Image]):
"""
if MOCK:
# Simulate slow image generation
- time.sleep(3)
+ time.sleep(1)
return ludic.web.responses.FileResponse(
DATA_DIR / "images" / "placeholder.jpg",
)
@@ -449,18 +450,98 @@ class ImagesTest(unittest.TestCase):
def setUp(self) -> None:
"""Create test client."""
self.client = starlette.testclient.TestClient(app)
+ self.story_id = "Uk"
+ self.page = 1
+ self.valid_prompt = {"text": "A beautiful sunset over the ocean"}
+
+ def test_image_get_existing(self) -> None:
+ """Test retrieving an existing image."""
+ # Arrange: Mock the load_by_id method to simulate an existing image
+ data = {"path": DATA_DIR / "images" / "placeholder.jpg"}
+ mock_dict = mock.MagicMock()
+ mock_dict.__getitem__.side_effect = data.__getitem__
+ with mock.patch.object(
+ Images,
+ "load_by_id",
+ return_value=mock_dict,
+ ):
+ # Act: Send a GET request to retrieve the image
+ response = self.client.get(
+ app.url_path_for(
+ "Images",
+ story_id=self.story_id,
+ page=self.page,
+ ),
+ )
+ # Assert: Check that the response status is 200
+ self.assertEqual(response.status_code, 200)
- def test_image_post(self) -> None:
- """Can POST an Image successfully."""
- response = self.client.post(
- app.url_path_for(
- "Images",
- story_id="Uk",
- page=1,
+ def test_image_get_nonexistent(self) -> None:
+ """Test retrieving a non-existent image."""
+ # Act: Send a GET request for a non-existent image
+ response = self.client.get(
+ app.url_path_for("Images", story_id="nonexistent", page=self.page),
+ )
+ # Assert: Check that the response status is 404
+ self.assertEqual(response.status_code, 404)
+
+ def test_image_post_valid(self) -> None:
+ """Test creating an image with valid data."""
+ # Arrange: Mock the OpenAI API and file system operations
+ with (
+ mock.patch("Biz.Storybook.openai.OpenAI") as mock_openai,
+ mock.patch(
+ "Biz.Storybook.pathlib.Path.write_bytes",
),
- data={"text": "lorem ipsum"},
+ ):
+ mock_openai.return_value.images.generate.return_value.data = [
+ mock.MagicMock(url="http://example.com/image.jpg"),
+ ]
+ # Act: Send a POST request with valid data
+ response = self.client.post(
+ app.url_path_for(
+ "Images",
+ story_id=self.story_id,
+ page=self.page,
+ ),
+ data=self.valid_prompt,
+ )
+ # Assert: Check that the response status is 200
+ self.assertEqual(response.status_code, 200)
+
+ def test_image_post_invalid(self) -> None:
+ """Test creating an image with invalid data."""
+ # Act: Send a POST request with invalid data
+ response = self.client.post(
+ app.url_path_for("Images", story_id=self.story_id, page=self.page),
+ data={"invalid": "data"},
)
- self.assertEqual(response.status_code, 200)
+ # Assert: Check that the response status indicates an error
+ self.assertNotEqual(response.status_code, 200)
+
+ def test_image_put_overwrite(self) -> None:
+ """Test overwriting an existing image."""
+ # Arrange: Mock the OpenAI API and file system operations
+ with (
+ mock.patch("Biz.Storybook.openai.OpenAI") as mock_openai,
+ mock.patch(
+ "Biz.Storybook.pathlib.Path.write_bytes",
+ ),
+ ):
+ mock_openai.return_value.images.generate.return_value.data = [
+ mock.MagicMock(url="http://example.com/image.jpg"),
+ ]
+ # Act: Send a PUT request to overwrite the image
+ response = self.client.put(
+ app.url_path_for(
+ "Images",
+ story_id=self.story_id,
+ page=self.page,
+ ),
+ data=self.valid_prompt,
+ )
+ # Assert: Check that the response status is 200
+ self.assertEqual(response.status_code, 200)
@app.endpoint("/stories/{story_id:str}")