summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBen Sima <ben@bsima.me>2024-12-04 21:04:57 -0500
committerBen Sima <ben@bsima.me>2024-12-21 10:08:07 -0500
commit2f2c0eaa0e2615d433bad5aa583e687629f2371f (patch)
tree2a86215545e057e272d6ba00d29bbd04ecd08a7d
parent17ff0b65feedd890391e67319e2a6f127dc93f33 (diff)
Manage Storybook Images
This adds the Images endpoint and related functions for loading and saving images to the filesystem. In the view layer, it also loads the images asynchronously using HTMX, so the images get lazy-loaded only when they are done generating.
-rw-r--r--Biz/Storybook.py344
-rw-r--r--Omni/Bild/Deps/Python.nix2
2 files changed, 266 insertions, 80 deletions
diff --git a/Biz/Storybook.py b/Biz/Storybook.py
index 8727b57..3659c37 100644
--- a/Biz/Storybook.py
+++ b/Biz/Storybook.py
@@ -16,6 +16,8 @@ this single file.
# : dep uvicorn
# : dep starlette
# : dep sqids
+# : dep requests
+# : dep types-requests
import json
import logging
import ludic
@@ -28,17 +30,22 @@ import ludic.catalog.typography as typography
import ludic.web
import Omni.Log as Log
import openai
+import pathlib
+import requests
import sqids
import starlette.testclient
import sys
+import time
import typing
import unittest
import uvicorn
MOCK = True
DEBUG = False
+DATA_DIR = pathlib.Path("_/var/storybook/")
app = ludic.web.LudicApp(debug=DEBUG)
+Sqids = sqids.Sqids()
def main() -> None:
@@ -52,14 +59,15 @@ def main() -> None:
def move() -> None:
"""Run the application."""
Log.setup(logging.DEBUG if DEBUG else logging.ERROR)
- uvicorn.run(app, host="100.127.197.132")
+ local = "127.0.0.1"
+ uvicorn.run(app, host=local)
def test() -> None:
"""Run the unittest suite manually."""
Log.setup(logging.DEBUG if DEBUG else logging.ERROR)
suite = unittest.TestSuite()
- tests = [StorybookTest, IndexTest, StoryTest]
+ tests = [StorybookTest, IndexTest, StoryTest, ImagesTest]
suite.addTests([
unittest.defaultTestLoader.loadTestsFromTestCase(tc) for tc in tests
])
@@ -71,31 +79,35 @@ def const(s: str) -> str:
return s
-class StoryPage(ludic.attrs.Attrs):
+class PageContent(ludic.attrs.Attrs):
+ """Represents the content of a single page in the storybook."""
+
+ image_prompt: str
+ text: str
+
+
+class Page(ludic.attrs.Attrs):
"""Represents a single page in the storybook."""
- text: typing.Annotated[str, const]
- image_prompt: typing.Annotated[str, const]
- image_url: typing.Annotated[str, const]
+ story_id: str
+ page_number: int
+ content: PageContent
-def load_image(prompt: str) -> str:
- """Load an image for a given page using the OpenAI API.
+class Image(ludic.attrs.Attrs):
+ """Represents an image associated with a story page."""
- Raises:
- ValueError: when OpenAI response is bad
- """
- client = openai.OpenAI()
- image_response = client.images.generate(
- prompt=prompt,
- n=1,
- size="256x256",
- )
- url = image_response.data[0].url
- if url is not None:
- return url
- msg = "error with load_image"
- raise ValueError(msg)
+ story_id: str
+ page: int
+ prompt: typing.Annotated[str, const]
+ original_url: typing.Annotated[str, const]
+ path: pathlib.Path
+
+
+class Prompt(ludic.attrs.Attrs):
+ """Represents a prompt for generating an image."""
+
+ text: typing.Annotated[str, const]
class StoryInputs(ludic.attrs.Attrs):
@@ -119,7 +131,7 @@ class Story(ludic.attrs.Attrs):
"""Represents a full generated story."""
id: typing.Annotated[str, const]
- pages: typing.Annotated[list[StoryPage], const]
+ pages: typing.Annotated[list[Page], const]
system_prompt: str = (
@@ -140,10 +152,10 @@ def user_prompt(story: StoryInputs) -> str:
"image like the following example:",
"""[{"text": "<text of the story>",""",
""""image": "<description of illustration>"}...],""",
- f"Character: {story["character"]}\n",
- f"Setting: {story["setting"]}\n",
- f"Theme: {story["theme"]}\n",
- f"Moral: {story["moral"]}\n",
+ f"Character: {story['character']}\n",
+ f"Setting: {story['setting']}\n",
+ f"Theme: {story['theme']}\n",
+ f"Moral: {story['moral']}\n",
])
@@ -166,7 +178,7 @@ def _openai_generate_text(
)
-def generate_pages(inputs: StoryInputs) -> list[StoryPage]:
+def generate_pages(inputs: StoryInputs) -> list[PageContent]:
"""Generate the text for a story and update its pages.
Raises:
@@ -176,12 +188,11 @@ def generate_pages(inputs: StoryInputs) -> list[StoryPage]:
if MOCK:
name = inputs["character"]
return [
- StoryPage(
+ PageContent(
text=f"A story about {name}...",
- image_prompt="lorem ipsum",
- image_url="//placehold.co/256x256",
+ image_prompt="Lorem ipsum..",
)
- for _ in range(10)
+ for n in range(10)
]
response = _openai_generate_text(inputs)
content = response.choices[0].message.content
@@ -190,10 +201,9 @@ def generate_pages(inputs: StoryInputs) -> list[StoryPage]:
raise ValueError(msg)
response_messages = json.loads(content)
return [
- StoryPage(
+ PageContent(
text=msg["text"],
image_prompt=msg["image"],
- image_url=load_image(msg["image"]),
)
for msg in response_messages
]
@@ -260,37 +270,226 @@ class IndexTest(unittest.TestCase):
self.assertIn("Storybook Generator", response.text)
-db_last_id: str = "bM" # sqid.encode([0])
-db: dict[str, Story] = {}
+db_last_story_id: str = "bM" # sqid.encode([0])
+
+class Database(ludic.attrs.Attrs):
+ """Represents a simple in-memory database for storing stories and images."""
-@app.endpoint("/stories/{sqid:str}")
+ stories: dict[str, Story]
+ images: dict[str, Image]
+
+
+db: Database = Database(stories={}, images={})
+
+
+@app.endpoint("/pages/{story_id:str}/{page:int}")
+class Pages(ludic.web.Endpoint[Page]):
+ """Resource for retrieving individual pages in a story."""
+
+ @classmethod
+ def get(cls, story_id: str, page: int) -> typing.Self:
+ """Get a single page."""
+ story = Stories.get(story_id)
+ story_page = Page(**story.attrs["pages"][page])
+ return cls(**story_page)
+
+ @typing.override
+ def render(self) -> ludic.base.BaseElement:
+ """Render a single page as HTML."""
+ return layouts.Box(
+ layouts.Stack(
+ ludic.html.img(
+ src="//placehold.co/256/000000/FFFFFF",
+ hx_post=app.url_path_for(
+ "Images",
+ story_id=self.attrs["story_id"],
+ page=self.attrs["page_number"],
+ ),
+ hx_trigger="load",
+ hx_swap="outerHTML:beforeend",
+ hx_vals=json.dumps(
+ Prompt(text=self.attrs["content"]["image_prompt"]),
+ ),
+ width=256,
+ height=256,
+ ),
+ typography.Paragraph(self.attrs["content"]["text"]),
+ ),
+ )
+
+
+@app.endpoint("/images/{story_id:str}/{page:int}")
+class Images(ludic.web.Endpoint[Image]):
+ """Endpoint for handling image-related operations."""
+
+ @classmethod
+ def get(cls, story_id: str, page: int) -> ludic.web.responses.Response:
+ """Load the image from the database, if not found return 404.
+
+ Raises:
+ NotFoundError: If the image is not found.
+ """
+ if image := Images.load_by_id(story_id, page):
+ return ludic.web.responses.FileResponse(image["path"])
+ msg = "no image found"
+ logging.error(msg)
+ raise ludic.web.exceptions.NotFoundError(msg)
+
+ @classmethod
+ def post(
+ cls,
+ story_id: str,
+ page: int,
+ data: ludic.web.parsers.Parser[Prompt],
+ ) -> ludic.web.responses.Response:
+ """Create a new image, or retrieve an existing one."""
+ Prompt(**data.validate())
+ path = cls.gen_path(story_id, page)
+ if path.exists():
+ return cls.get(story_id, page)
+ return cls.put(story_id, page, data)
+
+ @classmethod
+ def put(
+ cls,
+ story_id: str,
+ page: int,
+ data: ludic.web.parsers.Parser[Prompt],
+ ) -> ludic.web.responses.Response:
+ """Create a new image, overwriting if one exists.
+
+ Raises:
+ InternalServerError: If there is an error getting the image from the
+ OpenAI API.
+ """
+ if MOCK:
+ # Simulate slow image generation
+ time.sleep(3)
+ return ludic.web.responses.FileResponse(
+ DATA_DIR / "images" / "placeholder.jpg",
+ )
+ client = openai.OpenAI()
+ prompt = Prompt(**data.validate())
+ image_response = client.images.generate(
+ prompt=prompt["text"],
+ n=1,
+ size="256x256",
+ )
+ url = image_response.data[0].url
+ if url is None:
+ msg = "error getting image from OpenAI"
+ logging.error(msg)
+ raise ludic.web.exceptions.InternalServerError(msg)
+ image = Image(
+ story_id=story_id,
+ page=page,
+ prompt=prompt["text"],
+ original_url=url,
+ path=cls.gen_path(story_id, page),
+ )
+ cls.save(image)
+ return ludic.web.responses.FileResponse(image["path"])
+
+ @classmethod
+ def gen_image_id(cls, story_id: str, page: int) -> str:
+ """Generate a unique image ID based on the story ID and page number."""
+ story_id_num = Sqids.decode(story_id)[0]
+ return Sqids.encode([story_id_num, page])
+
+ @classmethod
+ def load_by_id(cls, story_id: str, page: int) -> Image | None:
+ """Load an image by its story ID and page number."""
+ cls.gen_image_id(story_id, page)
+ path = cls.gen_path(story_id, page)
+ if path.exists():
+ return Image(
+ story_id=story_id,
+ page=page,
+ path=path,
+ # Consider storing prompt and original_url in sqlite
+ prompt="",
+ original_url="",
+ )
+ return None
+
+ @classmethod
+ def gen_path(cls, story_id: str, page: int) -> pathlib.Path:
+ """Generate the file path for an image."""
+ image_id = cls.gen_image_id(story_id, page)
+ return pathlib.Path(
+ DATA_DIR / "images" / story_id / image_id,
+ ).with_suffix(".jpg")
+
+ @classmethod
+ def save(cls, image: Image) -> None:
+ """Save an image to the file system."""
+ response = requests.get(image["original_url"], timeout=10)
+ pathlib.Path(image["path"]).write_bytes(response.content)
+
+ @classmethod
+ def read(cls, image: Image) -> bytes:
+ """Read an image from the file system."""
+ return pathlib.Path(image["path"]).read_bytes()
+
+ @typing.override
+ def render(self) -> ludic.base.BaseElement:
+ return ludic.html.img(
+ src=app.url_path_for(
+ "Images",
+ story_id=self.attrs["story_id"],
+ page=self.attrs["page"],
+ ),
+ )
+
+
+class ImagesTest(unittest.TestCase):
+ """Test the Images endpoint."""
+
+ def setUp(self) -> None:
+ """Create test client."""
+ self.client = starlette.testclient.TestClient(app)
+
+ def test_image_post(self) -> None:
+ """Can POST an Image successfully."""
+ response = self.client.post(
+ app.url_path_for(
+ "Images",
+ story_id="Uk",
+ page=1,
+ ),
+ data={"text": "lorem ipsum"},
+ )
+ self.assertEqual(response.status_code, 200)
+
+
+@app.endpoint("/stories/{story_id:str}")
class Stories(ludic.web.Endpoint[Story]):
"""Resource for accessing a Story."""
@classmethod
- def get(cls, sqid: str) -> typing.Self:
+ def get(cls, story_id: str) -> typing.Self:
"""Get a single story.
Raises:
NotFoundError: if the story doesn't exist.
"""
- story = db.get(sqid)
+ story = db["stories"].get(story_id)
if story is None:
- msg = f"story {sqid} not found"
+ msg = f"story {story_id} not found"
raise ludic.web.exceptions.NotFoundError(msg)
return cls(**story)
@classmethod
- def put(cls, sqid: str, data: list[StoryPage]) -> typing.Self:
+ def put(cls, story_id: str, data: list[PageContent]) -> typing.Self:
"""Upsert a new story."""
- pages = data # .validate()
-
- story = Story(id=sqid, pages=pages)
- story_id = story["id"]
-
+ pages = [
+ Page(story_id=story_id, page_number=n, content=page_content)
+ for n, page_content in enumerate(data)
+ ]
+ story = Story(id=story_id, pages=pages)
# save to the 'database'
- db[story_id] = story
+ db["stories"][story_id] = story
return cls(**story)
@typing.override
@@ -301,30 +500,6 @@ class Stories(ludic.web.Endpoint[Story]):
)
-@app.endpoint("/stories/{sqid:str}/{page:int}")
-class Pages(ludic.web.Endpoint[StoryPage]):
- """Resource for retrieving individual pages in a story."""
-
- @classmethod
- def get(cls, sqid: str, page: int) -> typing.Self:
- """Get a single page."""
- story = Stories.get(sqid)
- story_page = StoryPage(**story.attrs["pages"][page])
- return cls(**story_page)
-
- @typing.override
- def render(self) -> ludic.base.BaseElement:
- """Render a single page as HTML."""
- return layouts.Box(
- layouts.Stack(
- ludic.html.img(
- src=self.attrs["image_url"],
- ),
- typography.Paragraph(self.attrs["text"]),
- ),
- )
-
-
@app.endpoint("/stories")
class StoriesForm(ludic.web.Endpoint[StoryInputs]):
"""Form for generating new stories."""
@@ -334,11 +509,11 @@ class StoriesForm(ludic.web.Endpoint[StoryInputs]):
"""Upsert a new story."""
inputs = StoryInputs(**data.validate())
# generate story pages
+ # Consider calling Pages.put for each one after generating the text
pages = generate_pages(inputs)
# calculate sqid
- sqid = sqids.Sqids()
- next_id_num = 1 + sqid.decode(db_last_id)[0]
- next_id = sqid.encode([next_id_num])
+ next_id_num = 1 + Sqids.decode(db_last_story_id)[0]
+ next_id = Sqids.encode([next_id_num])
return Stories.put(next_id, pages)
@typing.override
@@ -402,7 +577,7 @@ class StorybookTest(unittest.TestCase):
self.character = "Alice"
self.data = example_story | {"character": self.character}
self.client.post("/stories/", data=self.data)
- self.story = next(iter(db.values()))
+ self.story = next(iter(db["stories"].values()))
self.story_id = self.story["id"]
def test_stories_post(self) -> None:
@@ -413,25 +588,34 @@ class StorybookTest(unittest.TestCase):
def test_stories_post_invalid_data(self) -> None:
"""Invalid POST data."""
- response = self.client.post("/stories/", data={"bad": "data"})
+ response = self.client.post(
+ app.url_path_for("StoriesForm"),
+ data={"bad": "data"},
+ )
self.assertNotEqual(response.status_code, 200)
def test_stories_get(self) -> None:
"""User can access the story directly."""
- response = self.client.get(f"/stories/{self.story_id}")
+ response = self.client.get(
+ app.url_path_for("Stories", story_id=self.story_id),
+ )
self.assertEqual(response.status_code, 200)
self.assertIn(self.character, response.text)
def test_stories_get_nonexistent(self) -> None:
"""Returns 404 when a story is not found."""
- response = self.client.get("/stories/nonexistent")
+ response = self.client.get(
+ app.url_path_for("Stories", story_id="nonexistent"),
+ )
self.assertEqual(response.status_code, 404)
def test_pages_get(self) -> None:
"""User can access one page at a time."""
page_num = 1
- self.story["pages"][page_num]
- response = self.client.get(f"/stories/{self.story_id}/{page_num}")
+ _story = self.story["pages"][page_num]
+ response = self.client.get(
+ app.url_path_for("Pages", story_id=self.story_id, page=page_num),
+ )
self.assertEqual(response.status_code, 200)
self.assertIn(self.character, response.text)
diff --git a/Omni/Bild/Deps/Python.nix b/Omni/Bild/Deps/Python.nix
index 9af4630..bb01139 100644
--- a/Omni/Bild/Deps/Python.nix
+++ b/Omni/Bild/Deps/Python.nix
@@ -6,8 +6,10 @@
"mypy"
"nltk"
"openai"
+ "requests"
"slixmpp"
"sqids"
"starlette"
+ "types-requests"
"uvicorn"
]