summaryrefslogtreecommitdiff
path: root/Biz/Storybook.py
diff options
context:
space:
mode:
Diffstat (limited to 'Biz/Storybook.py')
-rw-r--r--Biz/Storybook.py344
1 files changed, 264 insertions, 80 deletions
diff --git a/Biz/Storybook.py b/Biz/Storybook.py
index 8727b57..3659c37 100644
--- a/Biz/Storybook.py
+++ b/Biz/Storybook.py
@@ -16,6 +16,8 @@ this single file.
# : dep uvicorn
# : dep starlette
# : dep sqids
+# : dep requests
+# : dep types-requests
import json
import logging
import ludic
@@ -28,17 +30,22 @@ import ludic.catalog.typography as typography
import ludic.web
import Omni.Log as Log
import openai
+import pathlib
+import requests
import sqids
import starlette.testclient
import sys
+import time
import typing
import unittest
import uvicorn
MOCK = True
DEBUG = False
+DATA_DIR = pathlib.Path("_/var/storybook/")
app = ludic.web.LudicApp(debug=DEBUG)
+Sqids = sqids.Sqids()
def main() -> None:
@@ -52,14 +59,15 @@ def main() -> None:
def move() -> None:
"""Run the application."""
Log.setup(logging.DEBUG if DEBUG else logging.ERROR)
- uvicorn.run(app, host="100.127.197.132")
+ local = "127.0.0.1"
+ uvicorn.run(app, host=local)
def test() -> None:
"""Run the unittest suite manually."""
Log.setup(logging.DEBUG if DEBUG else logging.ERROR)
suite = unittest.TestSuite()
- tests = [StorybookTest, IndexTest, StoryTest]
+ tests = [StorybookTest, IndexTest, StoryTest, ImagesTest]
suite.addTests([
unittest.defaultTestLoader.loadTestsFromTestCase(tc) for tc in tests
])
@@ -71,31 +79,35 @@ def const(s: str) -> str:
return s
-class StoryPage(ludic.attrs.Attrs):
+class PageContent(ludic.attrs.Attrs):
+ """Represents the content of a single page in the storybook."""
+
+ image_prompt: str
+ text: str
+
+
+class Page(ludic.attrs.Attrs):
"""Represents a single page in the storybook."""
- text: typing.Annotated[str, const]
- image_prompt: typing.Annotated[str, const]
- image_url: typing.Annotated[str, const]
+ story_id: str
+ page_number: int
+ content: PageContent
-def load_image(prompt: str) -> str:
- """Load an image for a given page using the OpenAI API.
+class Image(ludic.attrs.Attrs):
+ """Represents an image associated with a story page."""
- Raises:
- ValueError: when OpenAI response is bad
- """
- client = openai.OpenAI()
- image_response = client.images.generate(
- prompt=prompt,
- n=1,
- size="256x256",
- )
- url = image_response.data[0].url
- if url is not None:
- return url
- msg = "error with load_image"
- raise ValueError(msg)
+ story_id: str
+ page: int
+ prompt: typing.Annotated[str, const]
+ original_url: typing.Annotated[str, const]
+ path: pathlib.Path
+
+
+class Prompt(ludic.attrs.Attrs):
+ """Represents a prompt for generating an image."""
+
+ text: typing.Annotated[str, const]
class StoryInputs(ludic.attrs.Attrs):
@@ -119,7 +131,7 @@ class Story(ludic.attrs.Attrs):
"""Represents a full generated story."""
id: typing.Annotated[str, const]
- pages: typing.Annotated[list[StoryPage], const]
+ pages: typing.Annotated[list[Page], const]
system_prompt: str = (
@@ -140,10 +152,10 @@ def user_prompt(story: StoryInputs) -> str:
"image like the following example:",
"""[{"text": "<text of the story>",""",
""""image": "<description of illustration>"}...],""",
- f"Character: {story["character"]}\n",
- f"Setting: {story["setting"]}\n",
- f"Theme: {story["theme"]}\n",
- f"Moral: {story["moral"]}\n",
+ f"Character: {story['character']}\n",
+ f"Setting: {story['setting']}\n",
+ f"Theme: {story['theme']}\n",
+ f"Moral: {story['moral']}\n",
])
@@ -166,7 +178,7 @@ def _openai_generate_text(
)
-def generate_pages(inputs: StoryInputs) -> list[StoryPage]:
+def generate_pages(inputs: StoryInputs) -> list[PageContent]:
"""Generate the text for a story and update its pages.
Raises:
@@ -176,12 +188,11 @@ def generate_pages(inputs: StoryInputs) -> list[StoryPage]:
if MOCK:
name = inputs["character"]
return [
- StoryPage(
+ PageContent(
text=f"A story about {name}...",
- image_prompt="lorem ipsum",
- image_url="//placehold.co/256x256",
+ image_prompt="Lorem ipsum..",
)
- for _ in range(10)
+ for n in range(10)
]
response = _openai_generate_text(inputs)
content = response.choices[0].message.content
@@ -190,10 +201,9 @@ def generate_pages(inputs: StoryInputs) -> list[StoryPage]:
raise ValueError(msg)
response_messages = json.loads(content)
return [
- StoryPage(
+ PageContent(
text=msg["text"],
image_prompt=msg["image"],
- image_url=load_image(msg["image"]),
)
for msg in response_messages
]
@@ -260,37 +270,226 @@ class IndexTest(unittest.TestCase):
self.assertIn("Storybook Generator", response.text)
-db_last_id: str = "bM" # sqid.encode([0])
-db: dict[str, Story] = {}
+db_last_story_id: str = "bM" # sqid.encode([0])
+
+class Database(ludic.attrs.Attrs):
+ """Represents a simple in-memory database for storing stories and images."""
-@app.endpoint("/stories/{sqid:str}")
+ stories: dict[str, Story]
+ images: dict[str, Image]
+
+
+db: Database = Database(stories={}, images={})
+
+
+@app.endpoint("/pages/{story_id:str}/{page:int}")
+class Pages(ludic.web.Endpoint[Page]):
+ """Resource for retrieving individual pages in a story."""
+
+ @classmethod
+ def get(cls, story_id: str, page: int) -> typing.Self:
+ """Get a single page."""
+ story = Stories.get(story_id)
+ story_page = Page(**story.attrs["pages"][page])
+ return cls(**story_page)
+
+ @typing.override
+ def render(self) -> ludic.base.BaseElement:
+ """Render a single page as HTML."""
+ return layouts.Box(
+ layouts.Stack(
+ ludic.html.img(
+ src="//placehold.co/256/000000/FFFFFF",
+ hx_post=app.url_path_for(
+ "Images",
+ story_id=self.attrs["story_id"],
+ page=self.attrs["page_number"],
+ ),
+ hx_trigger="load",
+ hx_swap="outerHTML:beforeend",
+ hx_vals=json.dumps(
+ Prompt(text=self.attrs["content"]["image_prompt"]),
+ ),
+ width=256,
+ height=256,
+ ),
+ typography.Paragraph(self.attrs["content"]["text"]),
+ ),
+ )
+
+
+@app.endpoint("/images/{story_id:str}/{page:int}")
+class Images(ludic.web.Endpoint[Image]):
+ """Endpoint for handling image-related operations."""
+
+ @classmethod
+ def get(cls, story_id: str, page: int) -> ludic.web.responses.Response:
+ """Load the image from the database, if not found return 404.
+
+ Raises:
+ NotFoundError: If the image is not found.
+ """
+ if image := Images.load_by_id(story_id, page):
+ return ludic.web.responses.FileResponse(image["path"])
+ msg = "no image found"
+ logging.error(msg)
+ raise ludic.web.exceptions.NotFoundError(msg)
+
+ @classmethod
+ def post(
+ cls,
+ story_id: str,
+ page: int,
+ data: ludic.web.parsers.Parser[Prompt],
+ ) -> ludic.web.responses.Response:
+ """Create a new image, or retrieve an existing one."""
+ Prompt(**data.validate())
+ path = cls.gen_path(story_id, page)
+ if path.exists():
+ return cls.get(story_id, page)
+ return cls.put(story_id, page, data)
+
+ @classmethod
+ def put(
+ cls,
+ story_id: str,
+ page: int,
+ data: ludic.web.parsers.Parser[Prompt],
+ ) -> ludic.web.responses.Response:
+ """Create a new image, overwriting if one exists.
+
+ Raises:
+ InternalServerError: If there is an error getting the image from the
+ OpenAI API.
+ """
+ if MOCK:
+ # Simulate slow image generation
+ time.sleep(3)
+ return ludic.web.responses.FileResponse(
+ DATA_DIR / "images" / "placeholder.jpg",
+ )
+ client = openai.OpenAI()
+ prompt = Prompt(**data.validate())
+ image_response = client.images.generate(
+ prompt=prompt["text"],
+ n=1,
+ size="256x256",
+ )
+ url = image_response.data[0].url
+ if url is None:
+ msg = "error getting image from OpenAI"
+ logging.error(msg)
+ raise ludic.web.exceptions.InternalServerError(msg)
+ image = Image(
+ story_id=story_id,
+ page=page,
+ prompt=prompt["text"],
+ original_url=url,
+ path=cls.gen_path(story_id, page),
+ )
+ cls.save(image)
+ return ludic.web.responses.FileResponse(image["path"])
+
+ @classmethod
+ def gen_image_id(cls, story_id: str, page: int) -> str:
+ """Generate a unique image ID based on the story ID and page number."""
+ story_id_num = Sqids.decode(story_id)[0]
+ return Sqids.encode([story_id_num, page])
+
+ @classmethod
+ def load_by_id(cls, story_id: str, page: int) -> Image | None:
+ """Load an image by its story ID and page number."""
+ cls.gen_image_id(story_id, page)
+ path = cls.gen_path(story_id, page)
+ if path.exists():
+ return Image(
+ story_id=story_id,
+ page=page,
+ path=path,
+ # Consider storing prompt and original_url in sqlite
+ prompt="",
+ original_url="",
+ )
+ return None
+
+ @classmethod
+ def gen_path(cls, story_id: str, page: int) -> pathlib.Path:
+ """Generate the file path for an image."""
+ image_id = cls.gen_image_id(story_id, page)
+ return pathlib.Path(
+ DATA_DIR / "images" / story_id / image_id,
+ ).with_suffix(".jpg")
+
+ @classmethod
+ def save(cls, image: Image) -> None:
+ """Save an image to the file system."""
+ response = requests.get(image["original_url"], timeout=10)
+ pathlib.Path(image["path"]).write_bytes(response.content)
+
+ @classmethod
+ def read(cls, image: Image) -> bytes:
+ """Read an image from the file system."""
+ return pathlib.Path(image["path"]).read_bytes()
+
+ @typing.override
+ def render(self) -> ludic.base.BaseElement:
+ return ludic.html.img(
+ src=app.url_path_for(
+ "Images",
+ story_id=self.attrs["story_id"],
+ page=self.attrs["page"],
+ ),
+ )
+
+
+class ImagesTest(unittest.TestCase):
+ """Test the Images endpoint."""
+
+ def setUp(self) -> None:
+ """Create test client."""
+ self.client = starlette.testclient.TestClient(app)
+
+ def test_image_post(self) -> None:
+ """Can POST an Image successfully."""
+ response = self.client.post(
+ app.url_path_for(
+ "Images",
+ story_id="Uk",
+ page=1,
+ ),
+ data={"text": "lorem ipsum"},
+ )
+ self.assertEqual(response.status_code, 200)
+
+
+@app.endpoint("/stories/{story_id:str}")
class Stories(ludic.web.Endpoint[Story]):
"""Resource for accessing a Story."""
@classmethod
- def get(cls, sqid: str) -> typing.Self:
+ def get(cls, story_id: str) -> typing.Self:
"""Get a single story.
Raises:
NotFoundError: if the story doesn't exist.
"""
- story = db.get(sqid)
+ story = db["stories"].get(story_id)
if story is None:
- msg = f"story {sqid} not found"
+ msg = f"story {story_id} not found"
raise ludic.web.exceptions.NotFoundError(msg)
return cls(**story)
@classmethod
- def put(cls, sqid: str, data: list[StoryPage]) -> typing.Self:
+ def put(cls, story_id: str, data: list[PageContent]) -> typing.Self:
"""Upsert a new story."""
- pages = data # .validate()
-
- story = Story(id=sqid, pages=pages)
- story_id = story["id"]
-
+ pages = [
+ Page(story_id=story_id, page_number=n, content=page_content)
+ for n, page_content in enumerate(data)
+ ]
+ story = Story(id=story_id, pages=pages)
# save to the 'database'
- db[story_id] = story
+ db["stories"][story_id] = story
return cls(**story)
@typing.override
@@ -301,30 +500,6 @@ class Stories(ludic.web.Endpoint[Story]):
)
-@app.endpoint("/stories/{sqid:str}/{page:int}")
-class Pages(ludic.web.Endpoint[StoryPage]):
- """Resource for retrieving individual pages in a story."""
-
- @classmethod
- def get(cls, sqid: str, page: int) -> typing.Self:
- """Get a single page."""
- story = Stories.get(sqid)
- story_page = StoryPage(**story.attrs["pages"][page])
- return cls(**story_page)
-
- @typing.override
- def render(self) -> ludic.base.BaseElement:
- """Render a single page as HTML."""
- return layouts.Box(
- layouts.Stack(
- ludic.html.img(
- src=self.attrs["image_url"],
- ),
- typography.Paragraph(self.attrs["text"]),
- ),
- )
-
-
@app.endpoint("/stories")
class StoriesForm(ludic.web.Endpoint[StoryInputs]):
"""Form for generating new stories."""
@@ -334,11 +509,11 @@ class StoriesForm(ludic.web.Endpoint[StoryInputs]):
"""Upsert a new story."""
inputs = StoryInputs(**data.validate())
# generate story pages
+ # Consider calling Pages.put for each one after generating the text
pages = generate_pages(inputs)
# calculate sqid
- sqid = sqids.Sqids()
- next_id_num = 1 + sqid.decode(db_last_id)[0]
- next_id = sqid.encode([next_id_num])
+ next_id_num = 1 + Sqids.decode(db_last_story_id)[0]
+ next_id = Sqids.encode([next_id_num])
return Stories.put(next_id, pages)
@typing.override
@@ -402,7 +577,7 @@ class StorybookTest(unittest.TestCase):
self.character = "Alice"
self.data = example_story | {"character": self.character}
self.client.post("/stories/", data=self.data)
- self.story = next(iter(db.values()))
+ self.story = next(iter(db["stories"].values()))
self.story_id = self.story["id"]
def test_stories_post(self) -> None:
@@ -413,25 +588,34 @@ class StorybookTest(unittest.TestCase):
def test_stories_post_invalid_data(self) -> None:
"""Invalid POST data."""
- response = self.client.post("/stories/", data={"bad": "data"})
+ response = self.client.post(
+ app.url_path_for("StoriesForm"),
+ data={"bad": "data"},
+ )
self.assertNotEqual(response.status_code, 200)
def test_stories_get(self) -> None:
"""User can access the story directly."""
- response = self.client.get(f"/stories/{self.story_id}")
+ response = self.client.get(
+ app.url_path_for("Stories", story_id=self.story_id),
+ )
self.assertEqual(response.status_code, 200)
self.assertIn(self.character, response.text)
def test_stories_get_nonexistent(self) -> None:
"""Returns 404 when a story is not found."""
- response = self.client.get("/stories/nonexistent")
+ response = self.client.get(
+ app.url_path_for("Stories", story_id="nonexistent"),
+ )
self.assertEqual(response.status_code, 404)
def test_pages_get(self) -> None:
"""User can access one page at a time."""
page_num = 1
- self.story["pages"][page_num]
- response = self.client.get(f"/stories/{self.story_id}/{page_num}")
+ _story = self.story["pages"][page_num]
+ response = self.client.get(
+ app.url_path_for("Pages", story_id=self.story_id, page=page_num),
+ )
self.assertEqual(response.status_code, 200)
self.assertIn(self.character, response.text)