diff options
Diffstat (limited to 'Biz')
-rw-r--r-- | Biz/Storybook.py | 344 |
1 files changed, 264 insertions, 80 deletions
diff --git a/Biz/Storybook.py b/Biz/Storybook.py index 8727b57..3659c37 100644 --- a/Biz/Storybook.py +++ b/Biz/Storybook.py @@ -16,6 +16,8 @@ this single file. # : dep uvicorn # : dep starlette # : dep sqids +# : dep requests +# : dep types-requests import json import logging import ludic @@ -28,17 +30,22 @@ import ludic.catalog.typography as typography import ludic.web import Omni.Log as Log import openai +import pathlib +import requests import sqids import starlette.testclient import sys +import time import typing import unittest import uvicorn MOCK = True DEBUG = False +DATA_DIR = pathlib.Path("_/var/storybook/") app = ludic.web.LudicApp(debug=DEBUG) +Sqids = sqids.Sqids() def main() -> None: @@ -52,14 +59,15 @@ def main() -> None: def move() -> None: """Run the application.""" Log.setup(logging.DEBUG if DEBUG else logging.ERROR) - uvicorn.run(app, host="100.127.197.132") + local = "127.0.0.1" + uvicorn.run(app, host=local) def test() -> None: """Run the unittest suite manually.""" Log.setup(logging.DEBUG if DEBUG else logging.ERROR) suite = unittest.TestSuite() - tests = [StorybookTest, IndexTest, StoryTest] + tests = [StorybookTest, IndexTest, StoryTest, ImagesTest] suite.addTests([ unittest.defaultTestLoader.loadTestsFromTestCase(tc) for tc in tests ]) @@ -71,31 +79,35 @@ def const(s: str) -> str: return s -class StoryPage(ludic.attrs.Attrs): +class PageContent(ludic.attrs.Attrs): + """Represents the content of a single page in the storybook.""" + + image_prompt: str + text: str + + +class Page(ludic.attrs.Attrs): """Represents a single page in the storybook.""" - text: typing.Annotated[str, const] - image_prompt: typing.Annotated[str, const] - image_url: typing.Annotated[str, const] + story_id: str + page_number: int + content: PageContent -def load_image(prompt: str) -> str: - """Load an image for a given page using the OpenAI API. +class Image(ludic.attrs.Attrs): + """Represents an image associated with a story page.""" - Raises: - ValueError: when OpenAI response is bad - """ - client = openai.OpenAI() - image_response = client.images.generate( - prompt=prompt, - n=1, - size="256x256", - ) - url = image_response.data[0].url - if url is not None: - return url - msg = "error with load_image" - raise ValueError(msg) + story_id: str + page: int + prompt: typing.Annotated[str, const] + original_url: typing.Annotated[str, const] + path: pathlib.Path + + +class Prompt(ludic.attrs.Attrs): + """Represents a prompt for generating an image.""" + + text: typing.Annotated[str, const] class StoryInputs(ludic.attrs.Attrs): @@ -119,7 +131,7 @@ class Story(ludic.attrs.Attrs): """Represents a full generated story.""" id: typing.Annotated[str, const] - pages: typing.Annotated[list[StoryPage], const] + pages: typing.Annotated[list[Page], const] system_prompt: str = ( @@ -140,10 +152,10 @@ def user_prompt(story: StoryInputs) -> str: "image like the following example:", """[{"text": "<text of the story>",""", """"image": "<description of illustration>"}...],""", - f"Character: {story["character"]}\n", - f"Setting: {story["setting"]}\n", - f"Theme: {story["theme"]}\n", - f"Moral: {story["moral"]}\n", + f"Character: {story['character']}\n", + f"Setting: {story['setting']}\n", + f"Theme: {story['theme']}\n", + f"Moral: {story['moral']}\n", ]) @@ -166,7 +178,7 @@ def _openai_generate_text( ) -def generate_pages(inputs: StoryInputs) -> list[StoryPage]: +def generate_pages(inputs: StoryInputs) -> list[PageContent]: """Generate the text for a story and update its pages. Raises: @@ -176,12 +188,11 @@ def generate_pages(inputs: StoryInputs) -> list[StoryPage]: if MOCK: name = inputs["character"] return [ - StoryPage( + PageContent( text=f"A story about {name}...", - image_prompt="lorem ipsum", - image_url="//placehold.co/256x256", + image_prompt="Lorem ipsum..", ) - for _ in range(10) + for n in range(10) ] response = _openai_generate_text(inputs) content = response.choices[0].message.content @@ -190,10 +201,9 @@ def generate_pages(inputs: StoryInputs) -> list[StoryPage]: raise ValueError(msg) response_messages = json.loads(content) return [ - StoryPage( + PageContent( text=msg["text"], image_prompt=msg["image"], - image_url=load_image(msg["image"]), ) for msg in response_messages ] @@ -260,37 +270,226 @@ class IndexTest(unittest.TestCase): self.assertIn("Storybook Generator", response.text) -db_last_id: str = "bM" # sqid.encode([0]) -db: dict[str, Story] = {} +db_last_story_id: str = "bM" # sqid.encode([0]) + +class Database(ludic.attrs.Attrs): + """Represents a simple in-memory database for storing stories and images.""" -@app.endpoint("/stories/{sqid:str}") + stories: dict[str, Story] + images: dict[str, Image] + + +db: Database = Database(stories={}, images={}) + + +@app.endpoint("/pages/{story_id:str}/{page:int}") +class Pages(ludic.web.Endpoint[Page]): + """Resource for retrieving individual pages in a story.""" + + @classmethod + def get(cls, story_id: str, page: int) -> typing.Self: + """Get a single page.""" + story = Stories.get(story_id) + story_page = Page(**story.attrs["pages"][page]) + return cls(**story_page) + + @typing.override + def render(self) -> ludic.base.BaseElement: + """Render a single page as HTML.""" + return layouts.Box( + layouts.Stack( + ludic.html.img( + src="//placehold.co/256/000000/FFFFFF", + hx_post=app.url_path_for( + "Images", + story_id=self.attrs["story_id"], + page=self.attrs["page_number"], + ), + hx_trigger="load", + hx_swap="outerHTML:beforeend", + hx_vals=json.dumps( + Prompt(text=self.attrs["content"]["image_prompt"]), + ), + width=256, + height=256, + ), + typography.Paragraph(self.attrs["content"]["text"]), + ), + ) + + +@app.endpoint("/images/{story_id:str}/{page:int}") +class Images(ludic.web.Endpoint[Image]): + """Endpoint for handling image-related operations.""" + + @classmethod + def get(cls, story_id: str, page: int) -> ludic.web.responses.Response: + """Load the image from the database, if not found return 404. + + Raises: + NotFoundError: If the image is not found. + """ + if image := Images.load_by_id(story_id, page): + return ludic.web.responses.FileResponse(image["path"]) + msg = "no image found" + logging.error(msg) + raise ludic.web.exceptions.NotFoundError(msg) + + @classmethod + def post( + cls, + story_id: str, + page: int, + data: ludic.web.parsers.Parser[Prompt], + ) -> ludic.web.responses.Response: + """Create a new image, or retrieve an existing one.""" + Prompt(**data.validate()) + path = cls.gen_path(story_id, page) + if path.exists(): + return cls.get(story_id, page) + return cls.put(story_id, page, data) + + @classmethod + def put( + cls, + story_id: str, + page: int, + data: ludic.web.parsers.Parser[Prompt], + ) -> ludic.web.responses.Response: + """Create a new image, overwriting if one exists. + + Raises: + InternalServerError: If there is an error getting the image from the + OpenAI API. + """ + if MOCK: + # Simulate slow image generation + time.sleep(3) + return ludic.web.responses.FileResponse( + DATA_DIR / "images" / "placeholder.jpg", + ) + client = openai.OpenAI() + prompt = Prompt(**data.validate()) + image_response = client.images.generate( + prompt=prompt["text"], + n=1, + size="256x256", + ) + url = image_response.data[0].url + if url is None: + msg = "error getting image from OpenAI" + logging.error(msg) + raise ludic.web.exceptions.InternalServerError(msg) + image = Image( + story_id=story_id, + page=page, + prompt=prompt["text"], + original_url=url, + path=cls.gen_path(story_id, page), + ) + cls.save(image) + return ludic.web.responses.FileResponse(image["path"]) + + @classmethod + def gen_image_id(cls, story_id: str, page: int) -> str: + """Generate a unique image ID based on the story ID and page number.""" + story_id_num = Sqids.decode(story_id)[0] + return Sqids.encode([story_id_num, page]) + + @classmethod + def load_by_id(cls, story_id: str, page: int) -> Image | None: + """Load an image by its story ID and page number.""" + cls.gen_image_id(story_id, page) + path = cls.gen_path(story_id, page) + if path.exists(): + return Image( + story_id=story_id, + page=page, + path=path, + # Consider storing prompt and original_url in sqlite + prompt="", + original_url="", + ) + return None + + @classmethod + def gen_path(cls, story_id: str, page: int) -> pathlib.Path: + """Generate the file path for an image.""" + image_id = cls.gen_image_id(story_id, page) + return pathlib.Path( + DATA_DIR / "images" / story_id / image_id, + ).with_suffix(".jpg") + + @classmethod + def save(cls, image: Image) -> None: + """Save an image to the file system.""" + response = requests.get(image["original_url"], timeout=10) + pathlib.Path(image["path"]).write_bytes(response.content) + + @classmethod + def read(cls, image: Image) -> bytes: + """Read an image from the file system.""" + return pathlib.Path(image["path"]).read_bytes() + + @typing.override + def render(self) -> ludic.base.BaseElement: + return ludic.html.img( + src=app.url_path_for( + "Images", + story_id=self.attrs["story_id"], + page=self.attrs["page"], + ), + ) + + +class ImagesTest(unittest.TestCase): + """Test the Images endpoint.""" + + def setUp(self) -> None: + """Create test client.""" + self.client = starlette.testclient.TestClient(app) + + def test_image_post(self) -> None: + """Can POST an Image successfully.""" + response = self.client.post( + app.url_path_for( + "Images", + story_id="Uk", + page=1, + ), + data={"text": "lorem ipsum"}, + ) + self.assertEqual(response.status_code, 200) + + +@app.endpoint("/stories/{story_id:str}") class Stories(ludic.web.Endpoint[Story]): """Resource for accessing a Story.""" @classmethod - def get(cls, sqid: str) -> typing.Self: + def get(cls, story_id: str) -> typing.Self: """Get a single story. Raises: NotFoundError: if the story doesn't exist. """ - story = db.get(sqid) + story = db["stories"].get(story_id) if story is None: - msg = f"story {sqid} not found" + msg = f"story {story_id} not found" raise ludic.web.exceptions.NotFoundError(msg) return cls(**story) @classmethod - def put(cls, sqid: str, data: list[StoryPage]) -> typing.Self: + def put(cls, story_id: str, data: list[PageContent]) -> typing.Self: """Upsert a new story.""" - pages = data # .validate() - - story = Story(id=sqid, pages=pages) - story_id = story["id"] - + pages = [ + Page(story_id=story_id, page_number=n, content=page_content) + for n, page_content in enumerate(data) + ] + story = Story(id=story_id, pages=pages) # save to the 'database' - db[story_id] = story + db["stories"][story_id] = story return cls(**story) @typing.override @@ -301,30 +500,6 @@ class Stories(ludic.web.Endpoint[Story]): ) -@app.endpoint("/stories/{sqid:str}/{page:int}") -class Pages(ludic.web.Endpoint[StoryPage]): - """Resource for retrieving individual pages in a story.""" - - @classmethod - def get(cls, sqid: str, page: int) -> typing.Self: - """Get a single page.""" - story = Stories.get(sqid) - story_page = StoryPage(**story.attrs["pages"][page]) - return cls(**story_page) - - @typing.override - def render(self) -> ludic.base.BaseElement: - """Render a single page as HTML.""" - return layouts.Box( - layouts.Stack( - ludic.html.img( - src=self.attrs["image_url"], - ), - typography.Paragraph(self.attrs["text"]), - ), - ) - - @app.endpoint("/stories") class StoriesForm(ludic.web.Endpoint[StoryInputs]): """Form for generating new stories.""" @@ -334,11 +509,11 @@ class StoriesForm(ludic.web.Endpoint[StoryInputs]): """Upsert a new story.""" inputs = StoryInputs(**data.validate()) # generate story pages + # Consider calling Pages.put for each one after generating the text pages = generate_pages(inputs) # calculate sqid - sqid = sqids.Sqids() - next_id_num = 1 + sqid.decode(db_last_id)[0] - next_id = sqid.encode([next_id_num]) + next_id_num = 1 + Sqids.decode(db_last_story_id)[0] + next_id = Sqids.encode([next_id_num]) return Stories.put(next_id, pages) @typing.override @@ -402,7 +577,7 @@ class StorybookTest(unittest.TestCase): self.character = "Alice" self.data = example_story | {"character": self.character} self.client.post("/stories/", data=self.data) - self.story = next(iter(db.values())) + self.story = next(iter(db["stories"].values())) self.story_id = self.story["id"] def test_stories_post(self) -> None: @@ -413,25 +588,34 @@ class StorybookTest(unittest.TestCase): def test_stories_post_invalid_data(self) -> None: """Invalid POST data.""" - response = self.client.post("/stories/", data={"bad": "data"}) + response = self.client.post( + app.url_path_for("StoriesForm"), + data={"bad": "data"}, + ) self.assertNotEqual(response.status_code, 200) def test_stories_get(self) -> None: """User can access the story directly.""" - response = self.client.get(f"/stories/{self.story_id}") + response = self.client.get( + app.url_path_for("Stories", story_id=self.story_id), + ) self.assertEqual(response.status_code, 200) self.assertIn(self.character, response.text) def test_stories_get_nonexistent(self) -> None: """Returns 404 when a story is not found.""" - response = self.client.get("/stories/nonexistent") + response = self.client.get( + app.url_path_for("Stories", story_id="nonexistent"), + ) self.assertEqual(response.status_code, 404) def test_pages_get(self) -> None: """User can access one page at a time.""" page_num = 1 - self.story["pages"][page_num] - response = self.client.get(f"/stories/{self.story_id}/{page_num}") + _story = self.story["pages"][page_num] + response = self.client.get( + app.url_path_for("Pages", story_id=self.story_id, page=page_num), + ) self.assertEqual(response.status_code, 200) self.assertIn(self.character, response.text) |