From 2f2c0eaa0e2615d433bad5aa583e687629f2371f Mon Sep 17 00:00:00 2001 From: Ben Sima Date: Wed, 4 Dec 2024 21:04:57 -0500 Subject: Manage Storybook Images This adds the Images endpoint and related functions for loading and saving images to the filesystem. In the view layer, it also loads the images asynchronously using HTMX, so the images get lazy-loaded only when they are done generating. --- Biz/Storybook.py | 344 ++++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 264 insertions(+), 80 deletions(-) (limited to 'Biz/Storybook.py') diff --git a/Biz/Storybook.py b/Biz/Storybook.py index 8727b57..3659c37 100644 --- a/Biz/Storybook.py +++ b/Biz/Storybook.py @@ -16,6 +16,8 @@ this single file. # : dep uvicorn # : dep starlette # : dep sqids +# : dep requests +# : dep types-requests import json import logging import ludic @@ -28,17 +30,22 @@ import ludic.catalog.typography as typography import ludic.web import Omni.Log as Log import openai +import pathlib +import requests import sqids import starlette.testclient import sys +import time import typing import unittest import uvicorn MOCK = True DEBUG = False +DATA_DIR = pathlib.Path("_/var/storybook/") app = ludic.web.LudicApp(debug=DEBUG) +Sqids = sqids.Sqids() def main() -> None: @@ -52,14 +59,15 @@ def main() -> None: def move() -> None: """Run the application.""" Log.setup(logging.DEBUG if DEBUG else logging.ERROR) - uvicorn.run(app, host="100.127.197.132") + local = "127.0.0.1" + uvicorn.run(app, host=local) def test() -> None: """Run the unittest suite manually.""" Log.setup(logging.DEBUG if DEBUG else logging.ERROR) suite = unittest.TestSuite() - tests = [StorybookTest, IndexTest, StoryTest] + tests = [StorybookTest, IndexTest, StoryTest, ImagesTest] suite.addTests([ unittest.defaultTestLoader.loadTestsFromTestCase(tc) for tc in tests ]) @@ -71,31 +79,35 @@ def const(s: str) -> str: return s -class StoryPage(ludic.attrs.Attrs): +class PageContent(ludic.attrs.Attrs): + """Represents the content of a single page in the storybook.""" + + image_prompt: str + text: str + + +class Page(ludic.attrs.Attrs): """Represents a single page in the storybook.""" - text: typing.Annotated[str, const] - image_prompt: typing.Annotated[str, const] - image_url: typing.Annotated[str, const] + story_id: str + page_number: int + content: PageContent -def load_image(prompt: str) -> str: - """Load an image for a given page using the OpenAI API. +class Image(ludic.attrs.Attrs): + """Represents an image associated with a story page.""" - Raises: - ValueError: when OpenAI response is bad - """ - client = openai.OpenAI() - image_response = client.images.generate( - prompt=prompt, - n=1, - size="256x256", - ) - url = image_response.data[0].url - if url is not None: - return url - msg = "error with load_image" - raise ValueError(msg) + story_id: str + page: int + prompt: typing.Annotated[str, const] + original_url: typing.Annotated[str, const] + path: pathlib.Path + + +class Prompt(ludic.attrs.Attrs): + """Represents a prompt for generating an image.""" + + text: typing.Annotated[str, const] class StoryInputs(ludic.attrs.Attrs): @@ -119,7 +131,7 @@ class Story(ludic.attrs.Attrs): """Represents a full generated story.""" id: typing.Annotated[str, const] - pages: typing.Annotated[list[StoryPage], const] + pages: typing.Annotated[list[Page], const] system_prompt: str = ( @@ -140,10 +152,10 @@ def user_prompt(story: StoryInputs) -> str: "image like the following example:", """[{"text": "",""", """"image": ""}...],""", - f"Character: {story["character"]}\n", - f"Setting: {story["setting"]}\n", - f"Theme: {story["theme"]}\n", - f"Moral: {story["moral"]}\n", + f"Character: {story['character']}\n", + f"Setting: {story['setting']}\n", + f"Theme: {story['theme']}\n", + f"Moral: {story['moral']}\n", ]) @@ -166,7 +178,7 @@ def _openai_generate_text( ) -def generate_pages(inputs: StoryInputs) -> list[StoryPage]: +def generate_pages(inputs: StoryInputs) -> list[PageContent]: """Generate the text for a story and update its pages. Raises: @@ -176,12 +188,11 @@ def generate_pages(inputs: StoryInputs) -> list[StoryPage]: if MOCK: name = inputs["character"] return [ - StoryPage( + PageContent( text=f"A story about {name}...", - image_prompt="lorem ipsum", - image_url="//placehold.co/256x256", + image_prompt="Lorem ipsum..", ) - for _ in range(10) + for n in range(10) ] response = _openai_generate_text(inputs) content = response.choices[0].message.content @@ -190,10 +201,9 @@ def generate_pages(inputs: StoryInputs) -> list[StoryPage]: raise ValueError(msg) response_messages = json.loads(content) return [ - StoryPage( + PageContent( text=msg["text"], image_prompt=msg["image"], - image_url=load_image(msg["image"]), ) for msg in response_messages ] @@ -260,37 +270,226 @@ class IndexTest(unittest.TestCase): self.assertIn("Storybook Generator", response.text) -db_last_id: str = "bM" # sqid.encode([0]) -db: dict[str, Story] = {} +db_last_story_id: str = "bM" # sqid.encode([0]) + +class Database(ludic.attrs.Attrs): + """Represents a simple in-memory database for storing stories and images.""" -@app.endpoint("/stories/{sqid:str}") + stories: dict[str, Story] + images: dict[str, Image] + + +db: Database = Database(stories={}, images={}) + + +@app.endpoint("/pages/{story_id:str}/{page:int}") +class Pages(ludic.web.Endpoint[Page]): + """Resource for retrieving individual pages in a story.""" + + @classmethod + def get(cls, story_id: str, page: int) -> typing.Self: + """Get a single page.""" + story = Stories.get(story_id) + story_page = Page(**story.attrs["pages"][page]) + return cls(**story_page) + + @typing.override + def render(self) -> ludic.base.BaseElement: + """Render a single page as HTML.""" + return layouts.Box( + layouts.Stack( + ludic.html.img( + src="//placehold.co/256/000000/FFFFFF", + hx_post=app.url_path_for( + "Images", + story_id=self.attrs["story_id"], + page=self.attrs["page_number"], + ), + hx_trigger="load", + hx_swap="outerHTML:beforeend", + hx_vals=json.dumps( + Prompt(text=self.attrs["content"]["image_prompt"]), + ), + width=256, + height=256, + ), + typography.Paragraph(self.attrs["content"]["text"]), + ), + ) + + +@app.endpoint("/images/{story_id:str}/{page:int}") +class Images(ludic.web.Endpoint[Image]): + """Endpoint for handling image-related operations.""" + + @classmethod + def get(cls, story_id: str, page: int) -> ludic.web.responses.Response: + """Load the image from the database, if not found return 404. + + Raises: + NotFoundError: If the image is not found. + """ + if image := Images.load_by_id(story_id, page): + return ludic.web.responses.FileResponse(image["path"]) + msg = "no image found" + logging.error(msg) + raise ludic.web.exceptions.NotFoundError(msg) + + @classmethod + def post( + cls, + story_id: str, + page: int, + data: ludic.web.parsers.Parser[Prompt], + ) -> ludic.web.responses.Response: + """Create a new image, or retrieve an existing one.""" + Prompt(**data.validate()) + path = cls.gen_path(story_id, page) + if path.exists(): + return cls.get(story_id, page) + return cls.put(story_id, page, data) + + @classmethod + def put( + cls, + story_id: str, + page: int, + data: ludic.web.parsers.Parser[Prompt], + ) -> ludic.web.responses.Response: + """Create a new image, overwriting if one exists. + + Raises: + InternalServerError: If there is an error getting the image from the + OpenAI API. + """ + if MOCK: + # Simulate slow image generation + time.sleep(3) + return ludic.web.responses.FileResponse( + DATA_DIR / "images" / "placeholder.jpg", + ) + client = openai.OpenAI() + prompt = Prompt(**data.validate()) + image_response = client.images.generate( + prompt=prompt["text"], + n=1, + size="256x256", + ) + url = image_response.data[0].url + if url is None: + msg = "error getting image from OpenAI" + logging.error(msg) + raise ludic.web.exceptions.InternalServerError(msg) + image = Image( + story_id=story_id, + page=page, + prompt=prompt["text"], + original_url=url, + path=cls.gen_path(story_id, page), + ) + cls.save(image) + return ludic.web.responses.FileResponse(image["path"]) + + @classmethod + def gen_image_id(cls, story_id: str, page: int) -> str: + """Generate a unique image ID based on the story ID and page number.""" + story_id_num = Sqids.decode(story_id)[0] + return Sqids.encode([story_id_num, page]) + + @classmethod + def load_by_id(cls, story_id: str, page: int) -> Image | None: + """Load an image by its story ID and page number.""" + cls.gen_image_id(story_id, page) + path = cls.gen_path(story_id, page) + if path.exists(): + return Image( + story_id=story_id, + page=page, + path=path, + # Consider storing prompt and original_url in sqlite + prompt="", + original_url="", + ) + return None + + @classmethod + def gen_path(cls, story_id: str, page: int) -> pathlib.Path: + """Generate the file path for an image.""" + image_id = cls.gen_image_id(story_id, page) + return pathlib.Path( + DATA_DIR / "images" / story_id / image_id, + ).with_suffix(".jpg") + + @classmethod + def save(cls, image: Image) -> None: + """Save an image to the file system.""" + response = requests.get(image["original_url"], timeout=10) + pathlib.Path(image["path"]).write_bytes(response.content) + + @classmethod + def read(cls, image: Image) -> bytes: + """Read an image from the file system.""" + return pathlib.Path(image["path"]).read_bytes() + + @typing.override + def render(self) -> ludic.base.BaseElement: + return ludic.html.img( + src=app.url_path_for( + "Images", + story_id=self.attrs["story_id"], + page=self.attrs["page"], + ), + ) + + +class ImagesTest(unittest.TestCase): + """Test the Images endpoint.""" + + def setUp(self) -> None: + """Create test client.""" + self.client = starlette.testclient.TestClient(app) + + def test_image_post(self) -> None: + """Can POST an Image successfully.""" + response = self.client.post( + app.url_path_for( + "Images", + story_id="Uk", + page=1, + ), + data={"text": "lorem ipsum"}, + ) + self.assertEqual(response.status_code, 200) + + +@app.endpoint("/stories/{story_id:str}") class Stories(ludic.web.Endpoint[Story]): """Resource for accessing a Story.""" @classmethod - def get(cls, sqid: str) -> typing.Self: + def get(cls, story_id: str) -> typing.Self: """Get a single story. Raises: NotFoundError: if the story doesn't exist. """ - story = db.get(sqid) + story = db["stories"].get(story_id) if story is None: - msg = f"story {sqid} not found" + msg = f"story {story_id} not found" raise ludic.web.exceptions.NotFoundError(msg) return cls(**story) @classmethod - def put(cls, sqid: str, data: list[StoryPage]) -> typing.Self: + def put(cls, story_id: str, data: list[PageContent]) -> typing.Self: """Upsert a new story.""" - pages = data # .validate() - - story = Story(id=sqid, pages=pages) - story_id = story["id"] - + pages = [ + Page(story_id=story_id, page_number=n, content=page_content) + for n, page_content in enumerate(data) + ] + story = Story(id=story_id, pages=pages) # save to the 'database' - db[story_id] = story + db["stories"][story_id] = story return cls(**story) @typing.override @@ -301,30 +500,6 @@ class Stories(ludic.web.Endpoint[Story]): ) -@app.endpoint("/stories/{sqid:str}/{page:int}") -class Pages(ludic.web.Endpoint[StoryPage]): - """Resource for retrieving individual pages in a story.""" - - @classmethod - def get(cls, sqid: str, page: int) -> typing.Self: - """Get a single page.""" - story = Stories.get(sqid) - story_page = StoryPage(**story.attrs["pages"][page]) - return cls(**story_page) - - @typing.override - def render(self) -> ludic.base.BaseElement: - """Render a single page as HTML.""" - return layouts.Box( - layouts.Stack( - ludic.html.img( - src=self.attrs["image_url"], - ), - typography.Paragraph(self.attrs["text"]), - ), - ) - - @app.endpoint("/stories") class StoriesForm(ludic.web.Endpoint[StoryInputs]): """Form for generating new stories.""" @@ -334,11 +509,11 @@ class StoriesForm(ludic.web.Endpoint[StoryInputs]): """Upsert a new story.""" inputs = StoryInputs(**data.validate()) # generate story pages + # Consider calling Pages.put for each one after generating the text pages = generate_pages(inputs) # calculate sqid - sqid = sqids.Sqids() - next_id_num = 1 + sqid.decode(db_last_id)[0] - next_id = sqid.encode([next_id_num]) + next_id_num = 1 + Sqids.decode(db_last_story_id)[0] + next_id = Sqids.encode([next_id_num]) return Stories.put(next_id, pages) @typing.override @@ -402,7 +577,7 @@ class StorybookTest(unittest.TestCase): self.character = "Alice" self.data = example_story | {"character": self.character} self.client.post("/stories/", data=self.data) - self.story = next(iter(db.values())) + self.story = next(iter(db["stories"].values())) self.story_id = self.story["id"] def test_stories_post(self) -> None: @@ -413,25 +588,34 @@ class StorybookTest(unittest.TestCase): def test_stories_post_invalid_data(self) -> None: """Invalid POST data.""" - response = self.client.post("/stories/", data={"bad": "data"}) + response = self.client.post( + app.url_path_for("StoriesForm"), + data={"bad": "data"}, + ) self.assertNotEqual(response.status_code, 200) def test_stories_get(self) -> None: """User can access the story directly.""" - response = self.client.get(f"/stories/{self.story_id}") + response = self.client.get( + app.url_path_for("Stories", story_id=self.story_id), + ) self.assertEqual(response.status_code, 200) self.assertIn(self.character, response.text) def test_stories_get_nonexistent(self) -> None: """Returns 404 when a story is not found.""" - response = self.client.get("/stories/nonexistent") + response = self.client.get( + app.url_path_for("Stories", story_id="nonexistent"), + ) self.assertEqual(response.status_code, 404) def test_pages_get(self) -> None: """User can access one page at a time.""" page_num = 1 - self.story["pages"][page_num] - response = self.client.get(f"/stories/{self.story_id}/{page_num}") + _story = self.story["pages"][page_num] + response = self.client.get( + app.url_path_for("Pages", story_id=self.story_id, page=page_num), + ) self.assertEqual(response.status_code, 200) self.assertIn(self.character, response.text) -- cgit v1.2.3