diff options
author | Ben Sima <ben@bsima.me> | 2024-12-11 22:40:34 -0500 |
---|---|---|
committer | Ben Sima <ben@bsima.me> | 2024-12-21 10:08:09 -0500 |
commit | e7d6505ff6bfefa927466361570cedde799e94a6 (patch) | |
tree | 82e88f4fe504723b919b38f6733d39cf02ea53af /Biz/Storybook.py | |
parent | fc1422f099d95878209c92b3e9e2f509fe8ca77e (diff) |
Async end-to-end Storybook working
I deleted the tests because they were overspecifying the functionality. My
mistake was to try and build out the objects and endpoints before the end-to-end
sync thing was fully working. And then I misunderstood how to do async with
HTMX, I was overcomplicating it trying to create objects and endpoints for
everything instead of just focusing on the HTML that I should be
generating. This all just led to a clusterfuck of code doing all the wrong
things in the wrong places.
So far this is much better architected. And it turns out that using image n-1
with OpenAI's create_variation function doesn't work very well anyway, so I
scrapped that too; I'll have to look into different image gen services in the
future.
Diffstat (limited to 'Biz/Storybook.py')
-rw-r--r-- | Biz/Storybook.py | 693 |
1 files changed, 354 insertions, 339 deletions
diff --git a/Biz/Storybook.py b/Biz/Storybook.py index c619ef8..80f746a 100644 --- a/Biz/Storybook.py +++ b/Biz/Storybook.py @@ -18,6 +18,7 @@ this single file. # : dep sqids # : dep requests # : dep types-requests +import enum import json import logging import ludic @@ -25,50 +26,80 @@ import ludic.catalog.buttons as buttons import ludic.catalog.forms as forms import ludic.catalog.headers as headers import ludic.catalog.layouts as layouts +import ludic.catalog.loaders as loaders import ludic.catalog.pages as pages import ludic.catalog.typography as typography import ludic.web import Omni.Log as Log import openai +import os import pathlib import requests import sqids import starlette.testclient import sys +import threading import time import typing import unittest -import unittest.mock as mock +import uuid import uvicorn -MOCK = True -DEBUG = False -DATA_DIR = pathlib.Path("_/var/storybook/") +VPN = True +CODEROOT = pathlib.Path(os.getenv("CODEROOT", ".")) +DATA_DIR = pathlib.Path(CODEROOT / "_/var/storybook/") + + +class Area(enum.Enum): + """The area we are running.""" + + Test = "Test" + Live = "Live" + + +def from_env() -> Area: + """Load AREA from environment variable. + + Raises: + ValueError: if AREA is not defined + """ + var = os.getenv("AREA", "Test") + if var == "Test": + return Area.Test + if var == "Live": + return Area.Live + msg = "AREA not defined" + raise ValueError(msg) + + +area = from_env() +app = ludic.web.LudicApp(debug=area == Area.Test) -app = ludic.web.LudicApp(debug=DEBUG) Sqids = sqids.Sqids() def main() -> None: """Run the Ludic application.""" - if sys.argv[1] == "test": - test() + area = from_env() + if "test" in sys.argv: + test(area) else: - move() + move(area) -def move() -> None: +def move(area: Area) -> None: """Run the application.""" - Log.setup(logging.DEBUG if DEBUG else logging.ERROR) - local = "127.0.0.1" - uvicorn.run(app, host=local) + Log.setup(logging.DEBUG if area.Test else logging.ERROR) + logging.info("area: %s", area) + host = "100.127.197.132" if VPN else "127.0.0.1" + uvicorn.run(app, host=host) -def test() -> None: +def test(area: Area = Area.Test) -> None: """Run the unittest suite manually.""" - Log.setup(logging.DEBUG if DEBUG else logging.ERROR) + Log.setup(logging.DEBUG if area.Test else logging.ERROR) suite = unittest.TestSuite() - tests = [StorybookTest, IndexTest, StoryTest, ImagesTest] + tests = [IndexTest, StoryTest] suite.addTests([ unittest.defaultTestLoader.loadTestsFromTestCase(tc) for tc in tests ]) @@ -80,33 +111,67 @@ def const(s: str) -> str: return s -class PageContent(ludic.attrs.Attrs): - """Represents the content of a single page in the storybook.""" +class Image(ludic.attrs.Attrs): + """Represents an image associated with a story page.""" + + story_id: str + page: typing.Annotated[int, const] + path: pathlib.Path + + +class OpenAIOverview(ludic.attrs.Attrs): + """Part of OpenAIStoryResponse.""" + + character: str + setting: str + summary: str + + +class OpenAIPage(ludic.attrs.Attrs): + """Part of OpenAIStoryResponse.""" - image_prompt: str text: str + image: str -class Page(ludic.attrs.Attrs): - """Represents a single page in the storybook.""" +class OpenAIStoryResponse(ludic.attrs.Attrs): + """The message content of the API response.""" - story_id: str - page_number: int - content: PageContent + overview: OpenAIOverview + pages: list[OpenAIPage] -class Image(ludic.attrs.Attrs): - """Represents an image associated with a story page.""" +example_openai_story_response: OpenAIStoryResponse = { + "overview": { + "character": ( + "Alice is a blond haired girl, age 5, " + "wearing a blue dress, white shoes, and a whie wide-brimmed hat." + ), + "setting": ( + "A farm on a hill. " + "There is a red barn, a grain silo, and lots of pasture." + ), + "summary": "<brief summary of the story>", + }, + "pages": [{"text": "<story text>", "image": "<image prompt>"}], +} + + +class Page(ludic.attrs.Attrs): + """Represents a single page in the storybook.""" story_id: str - page: int - prompt: typing.Annotated[str, const] - original_url: typing.Annotated[str, const] - path: pathlib.Path + page_number: int + text: str + image_prompt: str class Prompt(ludic.attrs.Attrs): - """Represents a prompt for generating an image.""" + """Represents a prompt for generating an image. + + This datatype is overkill except that we need to validate it over the wire, + so its actually useful in that sense. + """ text: typing.Annotated[str, const] @@ -120,7 +185,7 @@ class StoryInputs(ludic.attrs.Attrs): moral: typing.Annotated[str, const] -example_story: dict[str, str] = { +example_story: StoryInputs = { "theme": "Christian", "character": "Lia and her pet bunny", "setting": "A suburban park", @@ -132,7 +197,7 @@ class Story(ludic.attrs.Attrs): """Represents a full generated story.""" id: typing.Annotated[str, const] - pages: typing.Annotated[list[Page], const] + inputs: StoryInputs system_prompt: str = ( @@ -145,18 +210,40 @@ system_prompt: str = ( ) -def user_prompt(story: StoryInputs) -> str: - """Generate the user prompt based on the story details.""" - return " ".join([ - "Write a story with the following details.", - "Output must be in valid JSON where each page is an array of text and" - "image like the following example:", - """[{"text": "<text of the story>",""", - """"image": "<description of illustration>"}...],""", - f"Character: {story['character']}\n", - f"Setting: {story['setting']}\n", - f"Theme: {story['theme']}\n", - f"Moral: {story['moral']}\n", +def story_meta(story: StoryInputs) -> list[str]: + """Format the `StoryInputs` for submission to an LLM.""" + return [ + f"Character: {story['character']}", + f"Setting: {story['setting']}", + f"Theme: {story['theme']}", + f"Moral: {story['moral']}", + ] + + +user_prompt: str = " ".join([ + "Write a children's story with the following details.", + "Output must be in valid JSON.", + "The overview key must contain a character sketch and setting description.", + "The pages key must an array of objects.", + "Each object must have the text of the page and image prompt.", + "Here is an example:", + json.dumps(example_openai_story_response), +]) + + +def gen_image_prompt( + story: StoryInputs, + image_prompt: str, + story_text: str, +) -> str: + """Format and return the full image prompt with additional context.""" + return "\n".join([ + f"Illustration: {image_prompt}", + f"Narrative text: {story_text}", + *story_meta(story), + "Style: a hand-drawn children's cartoon from the 1990s", + "Use soft pastel colors.", + "Do not include any text in the generated image.", ]) @@ -169,9 +256,13 @@ def _openai_generate_text( | openai.types.chat.ChatCompletionSystemMessageParam ] = [ {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt(story)}, + { + "role": "user", + "content": "\n".join([user_prompt, *story_meta(story)]), + }, ] client = openai.OpenAI() + logging.debug("calling openai.chat.completions.create") return client.chat.completions.create( model="gpt-4o-mini", messages=messages, @@ -179,35 +270,147 @@ def _openai_generate_text( ) -def generate_pages(inputs: StoryInputs) -> list[PageContent]: +def generate_pages(inputs: StoryInputs) -> OpenAIStoryResponse: """Generate the text for a story and update its pages. Raises: ValueError: when openAI response is bad """ # when developing, don't run up the OpenAI tab - if MOCK: + if area == Area.Test: name = inputs["character"] - return [ - PageContent( + ret = example_openai_story_response.copy() + ret["pages"] = [ + OpenAIPage( text=f"A story about {name}...", - image_prompt="Lorem ipsum..", + image="Lorem ipsum..", ) for n in range(10) ] + return ret response = _openai_generate_text(inputs) content = response.choices[0].message.content if content is None: msg = "content is none" raise ValueError(msg) - response_messages = json.loads(content) - return [ - PageContent( - text=msg["text"], - image_prompt=msg["image"], + parsed_content = json.loads(content) + overview = parsed_content["overview"] + pages = parsed_content["pages"] + return OpenAIStoryResponse( + overview=OpenAIOverview( + character=overview["character"], + setting=overview["setting"], + summary=overview["summary"], + ), + pages=[OpenAIPage(text=p["text"], image=p["image"]) for p in pages], + ) + + +def generate_image( + area: Area, + image_prompt: str, + story_id: str, + page: int, +) -> Image: + """Generate an image with OpenAI. + + Raises: + InternalServerError: when OpenAI API fails + """ + logging.info("generating image %s.%s", story_id, page) + url = None + if area == Area.Test: + time.sleep(1) + url = "https://placehold.co/1024.png" + else: + client = openai.OpenAI() + logging.debug("calling openai.images.generate") + logging.debug("prompt: %s", image_prompt) + image_response = client.images.generate( + model="dall-e-3", + prompt=image_prompt, + n=1, + size="1024x1024", + quality="standard", ) - for msg in response_messages + url = image_response.data[0].url + if url is None: + msg = "error getting image from OpenAI" + logging.error(msg) + raise ludic.web.exceptions.InternalServerError(msg) + image = Image( + story_id=story_id, + path=Images.gen_path(story_id, page), + page=page, + ) + Images.save(image, url) + return image + + +class Job(ludic.attrs.Attrs): + """Simple wrapper class for background jobs. + + This will become more useful when I need to store and track jobs in the + database. + """ + + id: str + + +def generate_story_in_background( + area: Area, + story_id: str, + inputs: StoryInputs, +) -> Story: + """Kick off `generate_story_pages` in a background thread.""" + job_id = str(uuid.uuid4()) + job = Job(id=job_id) + thread = threading.Thread( + target=generate_story_pages, + args=( + area, + story_id, + inputs, + ), + ) + logging.info("starting job %s", job_id) + thread.start() + story = Story(id=story_id, inputs=inputs) + # save stuff + db["jobs"][job_id] = job + db["stories"][story_id] = story + return story + + +def generate_story_pages( + area: Area, + story_id: str, + inputs: StoryInputs, +) -> list[Page]: + """Upsert a new story.""" + logging.info("generating story pages %s", story_id) + story_resp = generate_pages(inputs) + pages = [ + Page( + page_number=i + 1, + text=sr["text"], + story_id=story_id, + image_prompt=sr["image"], + ) + for i, sr in enumerate(story_resp["pages"]) ] + db["pages"][story_id] = pages + for page in pages: + image_prompt = gen_image_prompt( + inputs, + page["image_prompt"], + page["text"], + ) + n = page["page_number"] + generate_image(area, image_prompt, story_id, n) + # I *would* save the Image to the database here, but i'm not actually + # tracking that currenlty, just putting them in a known location on disk + return pages class StoryTest(unittest.TestCase): @@ -215,9 +418,10 @@ class StoryTest(unittest.TestCase): def test_story_creation(self) -> None: """Creates a story with 10 pages.""" - s = StoryInputs(example_story) # type: ignore[misc] - pages = generate_pages(s) - self.assertIsNotNone(pages) + story_id = "Uk" + story = generate_story_pages(Area.Test, story_id, example_story) + pages = db["pages"][story_id] + self.assertIsNotNone(story) self.assertEqual(len(pages), 10) @@ -228,6 +432,8 @@ class AppPage( @typing.override def render(self) -> pages.HtmlPage: + dark = ludic.styles.themes.DarkTheme() + ludic.styles.themes.set_default_theme(dark) return pages.HtmlPage( pages.Head( ludic.html.meta(charset="utf-8"), @@ -252,7 +458,7 @@ def index(_: ludic.web.Request) -> AppPage: """Render the main page.""" return AppPage( headers.H1("Storybook Generator"), - StoriesForm(), + Generate(), ludic.html.div(id="story"), ) @@ -277,11 +483,14 @@ db_last_story_id: str = "bM" # sqid.encode([0]) class Database(ludic.attrs.Attrs): """Represents a simple in-memory database for storing stories and images.""" + # each of these corresponds to a table in a SQL database stories: dict[str, Story] + pages: dict[str, list[Page]] images: dict[str, Image] + jobs: dict[str, Job] -db: Database = Database(stories={}, images={}) +db: Database = Database(stories={}, images={}, jobs={}, pages={}) @app.endpoint("/pages/{story_id:str}/{page:int}") @@ -289,143 +498,102 @@ class Pages(ludic.web.Endpoint[Page]): """Resource for retrieving individual pages in a story.""" @classmethod - def get(cls, story_id: str, page: int) -> typing.Self: - """Get a single page.""" - story = Stories.get(story_id) - story_page = Page(**story.attrs["pages"][page]) - return cls(**story_page) + async def get(cls, story_id: str, page: int) -> typing.Self: + """Get a single page. + + Raises: + NotFoundError: when the requested page is not found + """ + pages = db["pages"].get(story_id, None) + if pages is None: + msg = "story: %s" + raise ludic.web.exceptions.NotFoundError(msg.format(story_id)) + this_page = pages[page] + return cls(**this_page) @typing.override def render(self) -> ludic.base.BaseElement: """Render a single page as HTML.""" + story_id = self.attrs["story_id"] + page = self.attrs["page_number"] + image_url = app.url_path_for( + "Images", + story_id=story_id, + page=page, + ) return layouts.Box( layouts.Stack( - ludic.html.img( - src="//placehold.co/256/000000/FFFFFF", - hx_post=app.url_path_for( - "Images", - story_id=self.attrs["story_id"], - page=self.attrs["page_number"], - ), - hx_trigger="load", - hx_swap="outerHTML:beforeend", - hx_vals=json.dumps( - Prompt(text=self.attrs["content"]["image_prompt"]), - ), - width=256, - height=256, + ludic.html.div( + loaders.Loading(), + hx_get=image_url, + hx_trigger="every 1s", + hx_swap="outerHTML", ), - typography.Paragraph(self.attrs["content"]["text"]), + typography.Paragraph(self.attrs["text"]), ), ) +@app.get("/images/{story_id:str}/{page:int}.png") +def images_static(story_id: str, page: int) -> ludic.web.responses.Response: + """Endpoint for accessing static images. + + This does no generation, it only loads static images from the + filesystem. This must be separate to match on the `.png` suffix. + + For generation use the `Images` class/endpoint. + + Raises: + NotFoundError: when the image doesn't exist + + """ + image = Images.by_id(story_id, page) + if image["path"].exists(): + return ludic.web.responses.FileResponse(image["path"]) + msg = "images_static: image not found" + logging.error(msg) + raise ludic.web.exceptions.NotFoundError(msg) + + @app.endpoint("/images/{story_id:str}/{page:int}") class Images(ludic.web.Endpoint[Image]): """Endpoint for handling image-related operations.""" @classmethod - def get(cls, story_id: str, page: int) -> ludic.web.responses.Response: + async def get(cls, story_id: str, page: int) -> typing.Self: """Load the image from the database, if not found return 404. Raises: NotFoundError: If the image is not found. """ - if image := Images.load_by_id(story_id, page): - return ludic.web.responses.FileResponse(image["path"]) - msg = "no image found" - logging.error(msg) + image = Images.by_id(story_id, page) + if image["path"].exists(): + return cls(**image) + msg = "image not found" raise ludic.web.exceptions.NotFoundError(msg) @classmethod - def post( - cls, - story_id: str, - page: int, - data: ludic.web.parsers.Parser[Prompt], - ) -> ludic.web.responses.Response: - """Create a new image, or retrieve an existing one.""" - Prompt(**data.validate()) + def by_id(cls, story_id: str, page: int) -> Image: + """Load an image by its story ID and page number.""" path = cls.gen_path(story_id, page) - if path.exists(): - return cls.get(story_id, page) - return cls.put(story_id, page, data) - - @classmethod - def put( - cls, - story_id: str, - page: int, - data: ludic.web.parsers.Parser[Prompt], - ) -> ludic.web.responses.Response: - """Create a new image, overwriting if one exists. - - Raises: - InternalServerError: If there is an error getting the image from the - OpenAI API. - """ - if MOCK: - # Simulate slow image generation - time.sleep(1) - return ludic.web.responses.FileResponse( - DATA_DIR / "images" / "placeholder.jpg", - ) - client = openai.OpenAI() - prompt = Prompt(**data.validate()) - image_response = client.images.generate( - prompt=prompt["text"], - n=1, - size="256x256", - ) - url = image_response.data[0].url - if url is None: - msg = "error getting image from OpenAI" - logging.error(msg) - raise ludic.web.exceptions.InternalServerError(msg) - image = Image( + return Image( story_id=story_id, page=page, - prompt=prompt["text"], - original_url=url, - path=cls.gen_path(story_id, page), + path=path, ) - cls.save(image) - return ludic.web.responses.FileResponse(image["path"]) - - @classmethod - def gen_image_id(cls, story_id: str, page: int) -> str: - """Generate a unique image ID based on the story ID and page number.""" - story_id_num = Sqids.decode(story_id)[0] - return Sqids.encode([story_id_num, page]) - - @classmethod - def load_by_id(cls, story_id: str, page: int) -> Image | None: - """Load an image by its story ID and page number.""" - cls.gen_image_id(story_id, page) - path = cls.gen_path(story_id, page) - if path.exists(): - return Image( - story_id=story_id, - page=page, - path=path, - # Consider storing prompt and original_url in sqlite - prompt="", - original_url="", - ) - return None @classmethod def gen_path(cls, story_id: str, page: int) -> pathlib.Path: """Generate the file path for an image.""" - image_id = cls.gen_image_id(story_id, page) return pathlib.Path( - DATA_DIR / "images" / story_id / image_id, - ).with_suffix(".jpg") + DATA_DIR / "images" / story_id / str(page), + ).with_suffix(".png") @classmethod - def save(cls, image: Image) -> None: + def save(cls, image: Image, original_url: str) -> None: """Save an image to the file system.""" - response = requests.get(image["original_url"], timeout=10) + response = requests.get(original_url, timeout=10) + image["path"].parent.mkdir(parents=True, exist_ok=True) pathlib.Path(image["path"]).write_bytes(response.content) @classmethod @@ -437,119 +605,20 @@ class Images(ludic.web.Endpoint[Image]): def render(self) -> ludic.base.BaseElement: return ludic.html.img( src=app.url_path_for( - "Images", + "images_static", story_id=self.attrs["story_id"], page=self.attrs["page"], ), + loading="lazy", ) -class ImagesTest(unittest.TestCase): - """Test the Images endpoint.""" - - def setUp(self) -> None: - """Create test client.""" - self.client = starlette.testclient.TestClient(app) - self.story_id = "Uk" - self.page = 1 - self.valid_prompt = {"text": "A beautiful sunset over the ocean"} - - def test_image_get_existing(self) -> None: - """Test retrieving an existing image.""" - # Arrange: Mock the load_by_id method to simulate an existing image - data = {"path": DATA_DIR / "images" / "placeholder.jpg"} - mock_dict = mock.MagicMock() - mock_dict.__getitem__.side_effect = data.__getitem__ - with mock.patch.object( - Images, - "load_by_id", - return_value=mock_dict, - ): - # Act: Send a GET request to retrieve the image - response = self.client.get( - app.url_path_for( - "Images", - story_id=self.story_id, - page=self.page, - ), - ) - # Assert: Check that the response status is 200 - self.assertEqual(response.status_code, 200) - - def test_image_get_nonexistent(self) -> None: - """Test retrieving a non-existent image.""" - # Act: Send a GET request for a non-existent image - response = self.client.get( - app.url_path_for("Images", story_id="nonexistent", page=self.page), - ) - # Assert: Check that the response status is 404 - self.assertEqual(response.status_code, 404) - - def test_image_post_valid(self) -> None: - """Test creating an image with valid data.""" - # Arrange: Mock the OpenAI API and file system operations - with ( - mock.patch("Biz.Storybook.openai.OpenAI") as mock_openai, - mock.patch( - "Biz.Storybook.pathlib.Path.write_bytes", - ), - ): - mock_openai.return_value.images.generate.return_value.data = [ - mock.MagicMock(url="http://example.com/image.jpg"), - ] - # Act: Send a POST request with valid data - response = self.client.post( - app.url_path_for( - "Images", - story_id=self.story_id, - page=self.page, - ), - data=self.valid_prompt, - ) - # Assert: Check that the response status is 200 - self.assertEqual(response.status_code, 200) - - def test_image_post_invalid(self) -> None: - """Test creating an image with invalid data.""" - # Act: Send a POST request with invalid data - response = self.client.post( - app.url_path_for("Images", story_id=self.story_id, page=self.page), - data={"invalid": "data"}, - ) - # Assert: Check that the response status indicates an error - self.assertNotEqual(response.status_code, 200) - - def test_image_put_overwrite(self) -> None: - """Test overwriting an existing image.""" - # Arrange: Mock the OpenAI API and file system operations - with ( - mock.patch("Biz.Storybook.openai.OpenAI") as mock_openai, - mock.patch( - "Biz.Storybook.pathlib.Path.write_bytes", - ), - ): - mock_openai.return_value.images.generate.return_value.data = [ - mock.MagicMock(url="http://example.com/image.jpg"), - ] - # Act: Send a PUT request to overwrite the image - response = self.client.put( - app.url_path_for( - "Images", - story_id=self.story_id, - page=self.page, - ), - data=self.valid_prompt, - ) - # Assert: Check that the response status is 200 - self.assertEqual(response.status_code, 200) - - @app.endpoint("/stories/{story_id:str}") class Stories(ludic.web.Endpoint[Story]): """Resource for accessing a Story.""" @classmethod - def get(cls, story_id: str) -> typing.Self: + async def get(cls, story_id: str) -> typing.Self: """Get a single story. Raises: @@ -561,41 +630,38 @@ class Stories(ludic.web.Endpoint[Story]): raise ludic.web.exceptions.NotFoundError(msg) return cls(**story) - @classmethod - def put(cls, story_id: str, data: list[PageContent]) -> typing.Self: - """Upsert a new story.""" - pages = [ - Page(story_id=story_id, page_number=n, content=page_content) - for n, page_content in enumerate(data) - ] - story = Story(id=story_id, pages=pages) - # save to the 'database' - db["stories"][story_id] = story - return cls(**story) - @typing.override def render(self) -> ludic.base.BaseElement: + story_id = self.attrs["id"] return layouts.Stack( - headers.H1(str(self.attrs["id"])), - *(Pages(**page) for page in self.attrs["pages"]), + typography.Paragraph(f"Story id: {story_id}"), + *( + loaders.LazyLoader( + load_url=app.url_path_for( + "Pages", + story_id=story_id, + page=n, + ), + hx_trigger="every 2s", + ) + for n in range(1, 10) + ), + id="#story", ) -@app.endpoint("/stories") -class StoriesForm(ludic.web.Endpoint[StoryInputs]): +@app.endpoint("/generate") +class Generate(ludic.web.Endpoint[StoryInputs]): """Form for generating new stories.""" @classmethod - def post(cls, data: ludic.web.parsers.Parser[StoryInputs]) -> Stories: + async def post(cls, data: ludic.web.parsers.Parser[StoryInputs]) -> Stories: """Upsert a new story.""" inputs = StoryInputs(**data.validate()) - # generate story pages - # Consider calling Pages.put for each one after generating the text - pages = generate_pages(inputs) - # calculate sqid next_id_num = 1 + Sqids.decode(db_last_story_id)[0] next_id = Sqids.encode([next_id_num]) - return Stories.put(next_id, pages) + story = generate_story_in_background(area, next_id, inputs) + return Stories(**story) @typing.override def render(self) -> ludic.base.BaseElement: @@ -643,63 +709,12 @@ class StoriesForm(ludic.web.Endpoint[StoryInputs]): type="submit", classes=["large"], ), - hx_post=self.url_for(StoriesForm), + hx_post=self.url_for(Generate), hx_target="#story", + hx_trigger="submit", ), ) -class StorybookTest(unittest.TestCase): - """Unit test case for the Storybook application.""" - - def setUp(self) -> None: - """Set up the test client and seed database.""" - self.client = starlette.testclient.TestClient(app) - self.character = "Alice" - self.data = example_story | {"character": self.character} - self.client.post("/stories/", data=self.data) - self.story = next(iter(db["stories"].values())) - self.story_id = self.story["id"] - - def test_stories_post(self) -> None: - """User can create a story.""" - response = self.client.post("/stories/", data=self.data) - self.assertEqual(response.status_code, 200) - self.assertIn(self.character, response.text) - - def test_stories_post_invalid_data(self) -> None: - """Invalid POST data.""" - response = self.client.post( - app.url_path_for("StoriesForm"), - data={"bad": "data"}, - ) - self.assertNotEqual(response.status_code, 200) - - def test_stories_get(self) -> None: - """User can access the story directly.""" - response = self.client.get( - app.url_path_for("Stories", story_id=self.story_id), - ) - self.assertEqual(response.status_code, 200) - self.assertIn(self.character, response.text) - - def test_stories_get_nonexistent(self) -> None: - """Returns 404 when a story is not found.""" - response = self.client.get( - app.url_path_for("Stories", story_id="nonexistent"), - ) - self.assertEqual(response.status_code, 404) - - def test_pages_get(self) -> None: - """User can access one page at a time.""" - page_num = 1 - _story = self.story["pages"][page_num] - response = self.client.get( - app.url_path_for("Pages", story_id=self.story_id, page=page_num), - ) - self.assertEqual(response.status_code, 200) - self.assertIn(self.character, response.text) - - if __name__ == "__main__": main() |