summaryrefslogtreecommitdiff
path: root/Biz
diff options
context:
space:
mode:
authorBen Sima <ben@bsima.me>2024-12-11 22:40:34 -0500
committerBen Sima <ben@bsima.me>2024-12-21 10:08:09 -0500
commite7d6505ff6bfefa927466361570cedde799e94a6 (patch)
tree82e88f4fe504723b919b38f6733d39cf02ea53af /Biz
parentfc1422f099d95878209c92b3e9e2f509fe8ca77e (diff)
Async end-to-end Storybook working
I deleted the tests because they were overspecifying the functionality. My mistake was to try and build out the objects and endpoints before the end-to-end sync thing was fully working. And then I misunderstood how to do async with HTMX, I was overcomplicating it trying to create objects and endpoints for everything instead of just focusing on the HTML that I should be generating. This all just led to a clusterfuck of code doing all the wrong things in the wrong places. So far this is much better architected. And it turns out that using image n-1 with OpenAI's create_variation function doesn't work very well anyway, so I scrapped that too; I'll have to look into different image gen services in the future.
Diffstat (limited to 'Biz')
-rw-r--r--Biz/Storybook.py693
1 files changed, 354 insertions, 339 deletions
diff --git a/Biz/Storybook.py b/Biz/Storybook.py
index c619ef8..80f746a 100644
--- a/Biz/Storybook.py
+++ b/Biz/Storybook.py
@@ -18,6 +18,7 @@ this single file.
# : dep sqids
# : dep requests
# : dep types-requests
+import enum
import json
import logging
import ludic
@@ -25,50 +26,80 @@ import ludic.catalog.buttons as buttons
import ludic.catalog.forms as forms
import ludic.catalog.headers as headers
import ludic.catalog.layouts as layouts
+import ludic.catalog.loaders as loaders
import ludic.catalog.pages as pages
import ludic.catalog.typography as typography
import ludic.web
import Omni.Log as Log
import openai
+import os
import pathlib
import requests
import sqids
import starlette.testclient
import sys
+import threading
import time
import typing
import unittest
-import unittest.mock as mock
+import uuid
import uvicorn
-MOCK = True
-DEBUG = False
-DATA_DIR = pathlib.Path("_/var/storybook/")
+VPN = True
+CODEROOT = pathlib.Path(os.getenv("CODEROOT", "."))
+DATA_DIR = pathlib.Path(CODEROOT / "_/var/storybook/")
+
+
+class Area(enum.Enum):
+ """The area we are running."""
+
+ Test = "Test"
+ Live = "Live"
+
+
+def from_env() -> Area:
+ """Load AREA from environment variable.
+
+ Raises:
+ ValueError: if AREA is not defined
+ """
+ var = os.getenv("AREA", "Test")
+ if var == "Test":
+ return Area.Test
+ if var == "Live":
+ return Area.Live
+ msg = "AREA not defined"
+ raise ValueError(msg)
+
+
+area = from_env()
+app = ludic.web.LudicApp(debug=area == Area.Test)
-app = ludic.web.LudicApp(debug=DEBUG)
Sqids = sqids.Sqids()
def main() -> None:
"""Run the Ludic application."""
- if sys.argv[1] == "test":
- test()
+ area = from_env()
+ if "test" in sys.argv:
+ test(area)
else:
- move()
+ move(area)
-def move() -> None:
+def move(area: Area) -> None:
"""Run the application."""
- Log.setup(logging.DEBUG if DEBUG else logging.ERROR)
- local = "127.0.0.1"
- uvicorn.run(app, host=local)
+ Log.setup(logging.DEBUG if area.Test else logging.ERROR)
+ logging.info("area: %s", area)
+ host = "100.127.197.132" if VPN else "127.0.0.1"
+ uvicorn.run(app, host=host)
-def test() -> None:
+def test(area: Area = Area.Test) -> None:
"""Run the unittest suite manually."""
- Log.setup(logging.DEBUG if DEBUG else logging.ERROR)
+ Log.setup(logging.DEBUG if area.Test else logging.ERROR)
suite = unittest.TestSuite()
- tests = [StorybookTest, IndexTest, StoryTest, ImagesTest]
+ tests = [IndexTest, StoryTest]
suite.addTests([
unittest.defaultTestLoader.loadTestsFromTestCase(tc) for tc in tests
])
@@ -80,33 +111,67 @@ def const(s: str) -> str:
return s
-class PageContent(ludic.attrs.Attrs):
- """Represents the content of a single page in the storybook."""
+class Image(ludic.attrs.Attrs):
+ """Represents an image associated with a story page."""
+
+ story_id: str
+ page: typing.Annotated[int, const]
+ path: pathlib.Path
+
+
+class OpenAIOverview(ludic.attrs.Attrs):
+ """Part of OpenAIStoryResponse."""
+
+ character: str
+ setting: str
+ summary: str
+
+
+class OpenAIPage(ludic.attrs.Attrs):
+ """Part of OpenAIStoryResponse."""
- image_prompt: str
text: str
+ image: str
-class Page(ludic.attrs.Attrs):
- """Represents a single page in the storybook."""
+class OpenAIStoryResponse(ludic.attrs.Attrs):
+ """The message content of the API response."""
- story_id: str
- page_number: int
- content: PageContent
+ overview: OpenAIOverview
+ pages: list[OpenAIPage]
-class Image(ludic.attrs.Attrs):
- """Represents an image associated with a story page."""
+example_openai_story_response: OpenAIStoryResponse = {
+ "overview": {
+ "character": (
+ "Alice is a blond haired girl, age 5, "
+ "wearing a blue dress, white shoes, and a whie wide-brimmed hat."
+ ),
+ "setting": (
+ "A farm on a hill. "
+ "There is a red barn, a grain silo, and lots of pasture."
+ ),
+ "summary": "<brief summary of the story>",
+ },
+ "pages": [{"text": "<story text>", "image": "<image prompt>"}],
+}
+
+
+class Page(ludic.attrs.Attrs):
+ """Represents a single page in the storybook."""
story_id: str
- page: int
- prompt: typing.Annotated[str, const]
- original_url: typing.Annotated[str, const]
- path: pathlib.Path
+ page_number: int
+ text: str
+ image_prompt: str
class Prompt(ludic.attrs.Attrs):
- """Represents a prompt for generating an image."""
+ """Represents a prompt for generating an image.
+
+ This datatype is overkill except that we need to validate it over the wire,
+ so its actually useful in that sense.
+ """
text: typing.Annotated[str, const]
@@ -120,7 +185,7 @@ class StoryInputs(ludic.attrs.Attrs):
moral: typing.Annotated[str, const]
-example_story: dict[str, str] = {
+example_story: StoryInputs = {
"theme": "Christian",
"character": "Lia and her pet bunny",
"setting": "A suburban park",
@@ -132,7 +197,7 @@ class Story(ludic.attrs.Attrs):
"""Represents a full generated story."""
id: typing.Annotated[str, const]
- pages: typing.Annotated[list[Page], const]
+ inputs: StoryInputs
system_prompt: str = (
@@ -145,18 +210,40 @@ system_prompt: str = (
)
-def user_prompt(story: StoryInputs) -> str:
- """Generate the user prompt based on the story details."""
- return " ".join([
- "Write a story with the following details.",
- "Output must be in valid JSON where each page is an array of text and"
- "image like the following example:",
- """[{"text": "<text of the story>",""",
- """"image": "<description of illustration>"}...],""",
- f"Character: {story['character']}\n",
- f"Setting: {story['setting']}\n",
- f"Theme: {story['theme']}\n",
- f"Moral: {story['moral']}\n",
+def story_meta(story: StoryInputs) -> list[str]:
+ """Format the `StoryInputs` for submission to an LLM."""
+ return [
+ f"Character: {story['character']}",
+ f"Setting: {story['setting']}",
+ f"Theme: {story['theme']}",
+ f"Moral: {story['moral']}",
+ ]
+
+
+user_prompt: str = " ".join([
+ "Write a children's story with the following details.",
+ "Output must be in valid JSON.",
+ "The overview key must contain a character sketch and setting description.",
+ "The pages key must an array of objects.",
+ "Each object must have the text of the page and image prompt.",
+ "Here is an example:",
+ json.dumps(example_openai_story_response),
+])
+
+
+def gen_image_prompt(
+ story: StoryInputs,
+ image_prompt: str,
+ story_text: str,
+) -> str:
+ """Format and return the full image prompt with additional context."""
+ return "\n".join([
+ f"Illustration: {image_prompt}",
+ f"Narrative text: {story_text}",
+ *story_meta(story),
+ "Style: a hand-drawn children's cartoon from the 1990s",
+ "Use soft pastel colors.",
+ "Do not include any text in the generated image.",
])
@@ -169,9 +256,13 @@ def _openai_generate_text(
| openai.types.chat.ChatCompletionSystemMessageParam
] = [
{"role": "system", "content": system_prompt},
- {"role": "user", "content": user_prompt(story)},
+ {
+ "role": "user",
+ "content": "\n".join([user_prompt, *story_meta(story)]),
+ },
]
client = openai.OpenAI()
+ logging.debug("calling openai.chat.completions.create")
return client.chat.completions.create(
model="gpt-4o-mini",
messages=messages,
@@ -179,35 +270,147 @@ def _openai_generate_text(
)
-def generate_pages(inputs: StoryInputs) -> list[PageContent]:
+def generate_pages(inputs: StoryInputs) -> OpenAIStoryResponse:
"""Generate the text for a story and update its pages.
Raises:
ValueError: when openAI response is bad
"""
# when developing, don't run up the OpenAI tab
- if MOCK:
+ if area == Area.Test:
name = inputs["character"]
- return [
- PageContent(
+ ret = example_openai_story_response.copy()
+ ret["pages"] = [
+ OpenAIPage(
text=f"A story about {name}...",
- image_prompt="Lorem ipsum..",
+ image="Lorem ipsum..",
)
for n in range(10)
]
+ return ret
response = _openai_generate_text(inputs)
content = response.choices[0].message.content
if content is None:
msg = "content is none"
raise ValueError(msg)
- response_messages = json.loads(content)
- return [
- PageContent(
- text=msg["text"],
- image_prompt=msg["image"],
+ parsed_content = json.loads(content)
+ overview = parsed_content["overview"]
+ pages = parsed_content["pages"]
+ return OpenAIStoryResponse(
+ overview=OpenAIOverview(
+ character=overview["character"],
+ setting=overview["setting"],
+ summary=overview["summary"],
+ ),
+ pages=[OpenAIPage(text=p["text"], image=p["image"]) for p in pages],
+ )
+
+
+def generate_image(
+ area: Area,
+ image_prompt: str,
+ story_id: str,
+ page: int,
+) -> Image:
+ """Generate an image with OpenAI.
+
+ Raises:
+ InternalServerError: when OpenAI API fails
+ """
+ logging.info("generating image %s.%s", story_id, page)
+ url = None
+ if area == Area.Test:
+ time.sleep(1)
+ url = "https://placehold.co/1024.png"
+ else:
+ client = openai.OpenAI()
+ logging.debug("calling openai.images.generate")
+ logging.debug("prompt: %s", image_prompt)
+ image_response = client.images.generate(
+ model="dall-e-3",
+ prompt=image_prompt,
+ n=1,
+ size="1024x1024",
+ quality="standard",
)
- for msg in response_messages
+ url = image_response.data[0].url
+ if url is None:
+ msg = "error getting image from OpenAI"
+ logging.error(msg)
+ raise ludic.web.exceptions.InternalServerError(msg)
+ image = Image(
+ story_id=story_id,
+ path=Images.gen_path(story_id, page),
+ page=page,
+ )
+ Images.save(image, url)
+ return image
+
+
+class Job(ludic.attrs.Attrs):
+ """Simple wrapper class for background jobs.
+
+ This will become more useful when I need to store and track jobs in the
+ database.
+ """
+
+ id: str
+
+
+def generate_story_in_background(
+ area: Area,
+ story_id: str,
+ inputs: StoryInputs,
+) -> Story:
+ """Kick off `generate_story_pages` in a background thread."""
+ job_id = str(uuid.uuid4())
+ job = Job(id=job_id)
+ thread = threading.Thread(
+ target=generate_story_pages,
+ args=(
+ area,
+ story_id,
+ inputs,
+ ),
+ )
+ logging.info("starting job %s", job_id)
+ thread.start()
+ story = Story(id=story_id, inputs=inputs)
+ # save stuff
+ db["jobs"][job_id] = job
+ db["stories"][story_id] = story
+ return story
+
+
+def generate_story_pages(
+ area: Area,
+ story_id: str,
+ inputs: StoryInputs,
+) -> list[Page]:
+ """Upsert a new story."""
+ logging.info("generating story pages %s", story_id)
+ story_resp = generate_pages(inputs)
+ pages = [
+ Page(
+ page_number=i + 1,
+ text=sr["text"],
+ story_id=story_id,
+ image_prompt=sr["image"],
+ )
+ for i, sr in enumerate(story_resp["pages"])
]
+ db["pages"][story_id] = pages
+ for page in pages:
+ image_prompt = gen_image_prompt(
+ inputs,
+ page["image_prompt"],
+ page["text"],
+ )
+ n = page["page_number"]
+ generate_image(area, image_prompt, story_id, n)
+ # I *would* save the Image to the database here, but i'm not actually
+ # tracking that currenlty, just putting them in a known location on disk
+ return pages
class StoryTest(unittest.TestCase):
@@ -215,9 +418,10 @@ class StoryTest(unittest.TestCase):
def test_story_creation(self) -> None:
"""Creates a story with 10 pages."""
- s = StoryInputs(example_story) # type: ignore[misc]
- pages = generate_pages(s)
- self.assertIsNotNone(pages)
+ story_id = "Uk"
+ story = generate_story_pages(Area.Test, story_id, example_story)
+ pages = db["pages"][story_id]
+ self.assertIsNotNone(story)
self.assertEqual(len(pages), 10)
@@ -228,6 +432,8 @@ class AppPage(
@typing.override
def render(self) -> pages.HtmlPage:
+ dark = ludic.styles.themes.DarkTheme()
+ ludic.styles.themes.set_default_theme(dark)
return pages.HtmlPage(
pages.Head(
ludic.html.meta(charset="utf-8"),
@@ -252,7 +458,7 @@ def index(_: ludic.web.Request) -> AppPage:
"""Render the main page."""
return AppPage(
headers.H1("Storybook Generator"),
- StoriesForm(),
+ Generate(),
ludic.html.div(id="story"),
)
@@ -277,11 +483,14 @@ db_last_story_id: str = "bM" # sqid.encode([0])
class Database(ludic.attrs.Attrs):
"""Represents a simple in-memory database for storing stories and images."""
+ # each of these corresponds to a table in a SQL database
stories: dict[str, Story]
+ pages: dict[str, list[Page]]
images: dict[str, Image]
+ jobs: dict[str, Job]
-db: Database = Database(stories={}, images={})
+db: Database = Database(stories={}, images={}, jobs={}, pages={})
@app.endpoint("/pages/{story_id:str}/{page:int}")
@@ -289,143 +498,102 @@ class Pages(ludic.web.Endpoint[Page]):
"""Resource for retrieving individual pages in a story."""
@classmethod
- def get(cls, story_id: str, page: int) -> typing.Self:
- """Get a single page."""
- story = Stories.get(story_id)
- story_page = Page(**story.attrs["pages"][page])
- return cls(**story_page)
+ async def get(cls, story_id: str, page: int) -> typing.Self:
+ """Get a single page.
+
+ Raises:
+ NotFoundError: when the requested page is not found
+ """
+ pages = db["pages"].get(story_id, None)
+ if pages is None:
+ msg = "story: %s"
+ raise ludic.web.exceptions.NotFoundError(msg.format(story_id))
+ this_page = pages[page]
+ return cls(**this_page)
@typing.override
def render(self) -> ludic.base.BaseElement:
"""Render a single page as HTML."""
+ story_id = self.attrs["story_id"]
+ page = self.attrs["page_number"]
+ image_url = app.url_path_for(
+ "Images",
+ story_id=story_id,
+ page=page,
+ )
return layouts.Box(
layouts.Stack(
- ludic.html.img(
- src="//placehold.co/256/000000/FFFFFF",
- hx_post=app.url_path_for(
- "Images",
- story_id=self.attrs["story_id"],
- page=self.attrs["page_number"],
- ),
- hx_trigger="load",
- hx_swap="outerHTML:beforeend",
- hx_vals=json.dumps(
- Prompt(text=self.attrs["content"]["image_prompt"]),
- ),
- width=256,
- height=256,
+ ludic.html.div(
+ loaders.Loading(),
+ hx_get=image_url,
+ hx_trigger="every 1s",
+ hx_swap="outerHTML",
),
- typography.Paragraph(self.attrs["content"]["text"]),
+ typography.Paragraph(self.attrs["text"]),
),
)
+@app.get("/images/{story_id:str}/{page:int}.png")
+def images_static(story_id: str, page: int) -> ludic.web.responses.Response:
+ """Endpoint for accessing static images.
+
+ This does no generation, it only loads static images from the
+ filesystem. This must be separate to match on the `.png` suffix.
+
+ For generation use the `Images` class/endpoint.
+
+ Raises:
+ NotFoundError: when the image doesn't exist
+
+ """
+ image = Images.by_id(story_id, page)
+ if image["path"].exists():
+ return ludic.web.responses.FileResponse(image["path"])
+ msg = "images_static: image not found"
+ logging.error(msg)
+ raise ludic.web.exceptions.NotFoundError(msg)
+
+
@app.endpoint("/images/{story_id:str}/{page:int}")
class Images(ludic.web.Endpoint[Image]):
"""Endpoint for handling image-related operations."""
@classmethod
- def get(cls, story_id: str, page: int) -> ludic.web.responses.Response:
+ async def get(cls, story_id: str, page: int) -> typing.Self:
"""Load the image from the database, if not found return 404.
Raises:
NotFoundError: If the image is not found.
"""
- if image := Images.load_by_id(story_id, page):
- return ludic.web.responses.FileResponse(image["path"])
- msg = "no image found"
- logging.error(msg)
+ image = Images.by_id(story_id, page)
+ if image["path"].exists():
+ return cls(**image)
+ msg = "image not found"
raise ludic.web.exceptions.NotFoundError(msg)
@classmethod
- def post(
- cls,
- story_id: str,
- page: int,
- data: ludic.web.parsers.Parser[Prompt],
- ) -> ludic.web.responses.Response:
- """Create a new image, or retrieve an existing one."""
- Prompt(**data.validate())
+ def by_id(cls, story_id: str, page: int) -> Image:
+ """Load an image by its story ID and page number."""
path = cls.gen_path(story_id, page)
- if path.exists():
- return cls.get(story_id, page)
- return cls.put(story_id, page, data)
-
- @classmethod
- def put(
- cls,
- story_id: str,
- page: int,
- data: ludic.web.parsers.Parser[Prompt],
- ) -> ludic.web.responses.Response:
- """Create a new image, overwriting if one exists.
-
- Raises:
- InternalServerError: If there is an error getting the image from the
- OpenAI API.
- """
- if MOCK:
- # Simulate slow image generation
- time.sleep(1)
- return ludic.web.responses.FileResponse(
- DATA_DIR / "images" / "placeholder.jpg",
- )
- client = openai.OpenAI()
- prompt = Prompt(**data.validate())
- image_response = client.images.generate(
- prompt=prompt["text"],
- n=1,
- size="256x256",
- )
- url = image_response.data[0].url
- if url is None:
- msg = "error getting image from OpenAI"
- logging.error(msg)
- raise ludic.web.exceptions.InternalServerError(msg)
- image = Image(
+ return Image(
story_id=story_id,
page=page,
- prompt=prompt["text"],
- original_url=url,
- path=cls.gen_path(story_id, page),
+ path=path,
)
- cls.save(image)
- return ludic.web.responses.FileResponse(image["path"])
-
- @classmethod
- def gen_image_id(cls, story_id: str, page: int) -> str:
- """Generate a unique image ID based on the story ID and page number."""
- story_id_num = Sqids.decode(story_id)[0]
- return Sqids.encode([story_id_num, page])
-
- @classmethod
- def load_by_id(cls, story_id: str, page: int) -> Image | None:
- """Load an image by its story ID and page number."""
- cls.gen_image_id(story_id, page)
- path = cls.gen_path(story_id, page)
- if path.exists():
- return Image(
- story_id=story_id,
- page=page,
- path=path,
- # Consider storing prompt and original_url in sqlite
- prompt="",
- original_url="",
- )
- return None
@classmethod
def gen_path(cls, story_id: str, page: int) -> pathlib.Path:
"""Generate the file path for an image."""
- image_id = cls.gen_image_id(story_id, page)
return pathlib.Path(
- DATA_DIR / "images" / story_id / image_id,
- ).with_suffix(".jpg")
+ DATA_DIR / "images" / story_id / str(page),
+ ).with_suffix(".png")
@classmethod
- def save(cls, image: Image) -> None:
+ def save(cls, image: Image, original_url: str) -> None:
"""Save an image to the file system."""
- response = requests.get(image["original_url"], timeout=10)
+ response = requests.get(original_url, timeout=10)
+ image["path"].parent.mkdir(parents=True, exist_ok=True)
pathlib.Path(image["path"]).write_bytes(response.content)
@classmethod
@@ -437,119 +605,20 @@ class Images(ludic.web.Endpoint[Image]):
def render(self) -> ludic.base.BaseElement:
return ludic.html.img(
src=app.url_path_for(
- "Images",
+ "images_static",
story_id=self.attrs["story_id"],
page=self.attrs["page"],
),
+ loading="lazy",
)
-class ImagesTest(unittest.TestCase):
- """Test the Images endpoint."""
-
- def setUp(self) -> None:
- """Create test client."""
- self.client = starlette.testclient.TestClient(app)
- self.story_id = "Uk"
- self.page = 1
- self.valid_prompt = {"text": "A beautiful sunset over the ocean"}
-
- def test_image_get_existing(self) -> None:
- """Test retrieving an existing image."""
- # Arrange: Mock the load_by_id method to simulate an existing image
- data = {"path": DATA_DIR / "images" / "placeholder.jpg"}
- mock_dict = mock.MagicMock()
- mock_dict.__getitem__.side_effect = data.__getitem__
- with mock.patch.object(
- Images,
- "load_by_id",
- return_value=mock_dict,
- ):
- # Act: Send a GET request to retrieve the image
- response = self.client.get(
- app.url_path_for(
- "Images",
- story_id=self.story_id,
- page=self.page,
- ),
- )
- # Assert: Check that the response status is 200
- self.assertEqual(response.status_code, 200)
-
- def test_image_get_nonexistent(self) -> None:
- """Test retrieving a non-existent image."""
- # Act: Send a GET request for a non-existent image
- response = self.client.get(
- app.url_path_for("Images", story_id="nonexistent", page=self.page),
- )
- # Assert: Check that the response status is 404
- self.assertEqual(response.status_code, 404)
-
- def test_image_post_valid(self) -> None:
- """Test creating an image with valid data."""
- # Arrange: Mock the OpenAI API and file system operations
- with (
- mock.patch("Biz.Storybook.openai.OpenAI") as mock_openai,
- mock.patch(
- "Biz.Storybook.pathlib.Path.write_bytes",
- ),
- ):
- mock_openai.return_value.images.generate.return_value.data = [
- mock.MagicMock(url="http://example.com/image.jpg"),
- ]
- # Act: Send a POST request with valid data
- response = self.client.post(
- app.url_path_for(
- "Images",
- story_id=self.story_id,
- page=self.page,
- ),
- data=self.valid_prompt,
- )
- # Assert: Check that the response status is 200
- self.assertEqual(response.status_code, 200)
-
- def test_image_post_invalid(self) -> None:
- """Test creating an image with invalid data."""
- # Act: Send a POST request with invalid data
- response = self.client.post(
- app.url_path_for("Images", story_id=self.story_id, page=self.page),
- data={"invalid": "data"},
- )
- # Assert: Check that the response status indicates an error
- self.assertNotEqual(response.status_code, 200)
-
- def test_image_put_overwrite(self) -> None:
- """Test overwriting an existing image."""
- # Arrange: Mock the OpenAI API and file system operations
- with (
- mock.patch("Biz.Storybook.openai.OpenAI") as mock_openai,
- mock.patch(
- "Biz.Storybook.pathlib.Path.write_bytes",
- ),
- ):
- mock_openai.return_value.images.generate.return_value.data = [
- mock.MagicMock(url="http://example.com/image.jpg"),
- ]
- # Act: Send a PUT request to overwrite the image
- response = self.client.put(
- app.url_path_for(
- "Images",
- story_id=self.story_id,
- page=self.page,
- ),
- data=self.valid_prompt,
- )
- # Assert: Check that the response status is 200
- self.assertEqual(response.status_code, 200)
-
-
@app.endpoint("/stories/{story_id:str}")
class Stories(ludic.web.Endpoint[Story]):
"""Resource for accessing a Story."""
@classmethod
- def get(cls, story_id: str) -> typing.Self:
+ async def get(cls, story_id: str) -> typing.Self:
"""Get a single story.
Raises:
@@ -561,41 +630,38 @@ class Stories(ludic.web.Endpoint[Story]):
raise ludic.web.exceptions.NotFoundError(msg)
return cls(**story)
- @classmethod
- def put(cls, story_id: str, data: list[PageContent]) -> typing.Self:
- """Upsert a new story."""
- pages = [
- Page(story_id=story_id, page_number=n, content=page_content)
- for n, page_content in enumerate(data)
- ]
- story = Story(id=story_id, pages=pages)
- # save to the 'database'
- db["stories"][story_id] = story
- return cls(**story)
-
@typing.override
def render(self) -> ludic.base.BaseElement:
+ story_id = self.attrs["id"]
return layouts.Stack(
- headers.H1(str(self.attrs["id"])),
- *(Pages(**page) for page in self.attrs["pages"]),
+ typography.Paragraph(f"Story id: {story_id}"),
+ *(
+ loaders.LazyLoader(
+ load_url=app.url_path_for(
+ "Pages",
+ story_id=story_id,
+ page=n,
+ ),
+ hx_trigger="every 2s",
+ )
+ for n in range(1, 10)
+ ),
+ id="#story",
)
-@app.endpoint("/stories")
-class StoriesForm(ludic.web.Endpoint[StoryInputs]):
+@app.endpoint("/generate")
+class Generate(ludic.web.Endpoint[StoryInputs]):
"""Form for generating new stories."""
@classmethod
- def post(cls, data: ludic.web.parsers.Parser[StoryInputs]) -> Stories:
+ async def post(cls, data: ludic.web.parsers.Parser[StoryInputs]) -> Stories:
"""Upsert a new story."""
inputs = StoryInputs(**data.validate())
- # generate story pages
- # Consider calling Pages.put for each one after generating the text
- pages = generate_pages(inputs)
- # calculate sqid
next_id_num = 1 + Sqids.decode(db_last_story_id)[0]
next_id = Sqids.encode([next_id_num])
- return Stories.put(next_id, pages)
+ story = generate_story_in_background(area, next_id, inputs)
+ return Stories(**story)
@typing.override
def render(self) -> ludic.base.BaseElement:
@@ -643,63 +709,12 @@ class StoriesForm(ludic.web.Endpoint[StoryInputs]):
type="submit",
classes=["large"],
),
- hx_post=self.url_for(StoriesForm),
+ hx_post=self.url_for(Generate),
hx_target="#story",
+ hx_trigger="submit",
),
)
-class StorybookTest(unittest.TestCase):
- """Unit test case for the Storybook application."""
-
- def setUp(self) -> None:
- """Set up the test client and seed database."""
- self.client = starlette.testclient.TestClient(app)
- self.character = "Alice"
- self.data = example_story | {"character": self.character}
- self.client.post("/stories/", data=self.data)
- self.story = next(iter(db["stories"].values()))
- self.story_id = self.story["id"]
-
- def test_stories_post(self) -> None:
- """User can create a story."""
- response = self.client.post("/stories/", data=self.data)
- self.assertEqual(response.status_code, 200)
- self.assertIn(self.character, response.text)
-
- def test_stories_post_invalid_data(self) -> None:
- """Invalid POST data."""
- response = self.client.post(
- app.url_path_for("StoriesForm"),
- data={"bad": "data"},
- )
- self.assertNotEqual(response.status_code, 200)
-
- def test_stories_get(self) -> None:
- """User can access the story directly."""
- response = self.client.get(
- app.url_path_for("Stories", story_id=self.story_id),
- )
- self.assertEqual(response.status_code, 200)
- self.assertIn(self.character, response.text)
-
- def test_stories_get_nonexistent(self) -> None:
- """Returns 404 when a story is not found."""
- response = self.client.get(
- app.url_path_for("Stories", story_id="nonexistent"),
- )
- self.assertEqual(response.status_code, 404)
-
- def test_pages_get(self) -> None:
- """User can access one page at a time."""
- page_num = 1
- _story = self.story["pages"][page_num]
- response = self.client.get(
- app.url_path_for("Pages", story_id=self.story_id, page=page_num),
- )
- self.assertEqual(response.status_code, 200)
- self.assertIn(self.character, response.text)
-
-
if __name__ == "__main__":
main()