diff options
author | Ben Sima <ben@bsima.me> | 2024-12-01 20:33:09 -0500 |
---|---|---|
committer | Ben Sima <ben@bsima.me> | 2024-12-21 10:08:02 -0500 |
commit | f87c5f05444b53f2fee033b99a400888538941a4 (patch) | |
tree | c8663cd03f59a0cbdf68ae7fc7ed8675f0571aa2 /Biz/Storybook.py | |
parent | aafa73c47325185ed733da387c9649d934f6529c (diff) |
Implement storybook prototype
This paritally used gptme to create a storybook generator. The problem I ran
into is that gptme doesn't do any architecting or considerations for
maintainable code, or even readable code, so it just wrote a long script. I
couldn't test it. Also, it didn't actually generate a 10-page story, it
generated 10 separate stories. So, I ended up writing it myself and using gptme
to fixup TODOs that I wrote along the way.
Diffstat (limited to 'Biz/Storybook.py')
-rw-r--r-- | Biz/Storybook.py | 291 |
1 files changed, 291 insertions, 0 deletions
diff --git a/Biz/Storybook.py b/Biz/Storybook.py new file mode 100644 index 0000000..2f362c4 --- /dev/null +++ b/Biz/Storybook.py @@ -0,0 +1,291 @@ +"""Storybook Generator Application. + +This application generates a children's storybook using the OpenAI API. + +The user can select a theme, specify the main character's name, and choose a +setting. The app generates a 10-page storybook with images. + +The tech stack is: Python, Flask, HTMX, and bootstrap. All of the code is in +this single file. + +""" + +# : out storybook +# : dep flask +# : dep openai +import flask +import json +import openai +import os +import pydantic +import sys +import unittest + +app = flask.Flask(__name__) +app.secret_key = os.urandom(24) + + +def main() -> None: + """Run the Flask application.""" + if sys.argv[1] == "test": + test() + else: + move() + + +def move() -> None: + """Run the application.""" + app.run() + + +def test() -> None: + """Run the unittest suite manually.""" + suite = unittest.TestSuite() + tests = [StorybookTestCase, StoryTestCase] + suite.addTests([ + unittest.defaultTestLoader.loadTestsFromTestCase(tc) for tc in tests + ]) + unittest.TextTestRunner(verbosity=2).run(suite) + + +@app.route("/") +def index() -> str: + """Render the main page.""" + return flask.render_template_string(f""" +<!DOCTYPE html> +<html lang="en"> +<head> + <meta charset="UTF-8"> + <meta name="viewport" content="width=device-width, initial-scale=1.0"> + <title>Storybook Generator</title> + <link href="https://stackpath.bootstrapcdn.com/bootstrap/4.5.2/css/bootstrap.min.css" + rel="stylesheet"> + <script src="https://unpkg.com/htmx.org@1.3.3"></script> +</head> +<body> + <div class="container mt-5"> + <h1 class="text-center">Storybook Generator</h1> + <form hx-post="{flask.url_for("generate_story")}" hx-target="#story" + class="mt-4"> + <div class="form-group"> + <label for="theme">Select Theme:</label> + <select class="form-control" id="theme" name="theme"> + <option value="Christian">Christian</option> + <option value="Secular">Secular</option> + </select> + </div> + <div class="form-group"> + <label for="character">Main Character's Name:</label> + <input type="text" class="form-control" id="character" + name="character" required> + </div> + <div class="form-group"> + <label for="setting">Select Setting:</label> + <select class="form-control" id="setting" name="setting"> + <option value="rural">Rural</option> + <option value="urban">Urban</option> + <option value="beach">Beach</option> + <option value="forest">Forest</option> + </select> + </div> + <button type="submit" class="btn btn-primary"> + Generate Story + </button> + </form> + <div id="story" class="mt-5"></div> + </div> +</body> +</html> +""") + + +class Page(pydantic.BaseModel): + """Represents a single page in the storybook.""" + + model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) + text: str + image_prompt: str + image_url: str | None + + +def load_image(page: Page) -> Page: + """Load an image for a given page using the OpenAI API.""" + if page.image_url is not None: + return page + prompt = page.image_prompt + client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY")) + image_response = client.images.generate( + prompt=prompt, + n=1, + size="256x256", + ) + page.image_url = image_response.data[0].url + # Handle if image is None + return page + + +class Story(pydantic.BaseModel): + """Represents a story with multiple pages.""" + + model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) + theme: str + character: str + setting: str + moral: str + pages: list[Page] | None = None + + +def system_prompt() -> str: + """Generate the system prompt for the OpenAI API.""" + return ( + "You are an author and illustrator of childrens books. " + "Each book is 10 pages long. " + "All output must be in valid JSON. " + "Don't add explanation or markdown formatting beyond the JSON. " + "In your output, include the text on the page and a description of the " + "image to be generated with an AI image generator." + ) + + +def user_prompt(story: Story) -> str: + """Generate the user prompt based on the story details.""" + return " ".join([ + "Write a story with the following details.", + "Output must be in valid JSON where each page is an array of text and" + "image like the following example:", + """[{"text": "<text of the story>",""", + """"image": "<description of illustration>"}...],""", + f"Character: {story.character}\n", + f"Setting: {story.setting}\n", + f"Theme: {story.theme}\n", + f"Moral: {story.moral}\n", + ]) + + +def _openai_generate_text(story: Story) -> openai.types.chat.ChatCompletion: + """Generate story text using the OpenAI API.""" + messages: list[ + openai.types.chat.ChatCompletionUserMessageParam + | openai.types.chat.ChatCompletionSystemMessageParam + ] = [ + {"role": "system", "content": system_prompt()}, + {"role": "user", "content": user_prompt(story)}, + ] + client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY")) + return client.chat.completions.create( + model="gpt-4o-mini", + messages=messages, + max_tokens=1500, + ) + + +def generate_text(story: Story) -> Story: + """Generate the text for a story and update its pages. + + Raises: + ValueError: If the content is None or JSON parsing fails. + """ + response = _openai_generate_text(story) + content = response.choices[0].message.content + if content is None: + error_message = "No content in response" + raise ValueError(error_message) + try: + response_messages = json.loads(content) + except (json.JSONDecodeError, ValueError) as e: + error_message = f"Failed to parse story JSON: {e}" + raise ValueError(error_message) from e + story.pages = [ + Page(text=msg["text"], image_prompt=msg["image"], image_url=None) + for msg in response_messages + ] + return story + + +@app.route("/generate/story", methods=["POST"]) +def generate_story() -> str: + """Generate a story based on user input.""" + story = Story( + theme=flask.request.form["theme"], + character=flask.request.form["character"], + setting=flask.request.form["setting"], + moral="Honor thy mother and father", # request.form["moral"], + ) + story = generate_text(story) + if story.pages is None: + return "<p>error: no story pages found</p>" + flask.session["story"] = story.model_dump_json() + return "".join( + f"<div class='card mb-3'>" + f"<img src='/static/placeholder.png' data-src='" + f"{flask.url_for('generate_image', n=i)}'" + f"class='card-img-top' hx-trigger='load' hx-swap='outerHTML' " + f"""hx-get='{flask.url_for("generate_image", n=i)}' alt='Loading...'>""" + f"<div class='card-body'>" + f"<p class='card-text'>{page.text}</p></div></div>" + for i, page in enumerate(story.pages) + ) + + +@app.route("/generate/image/<int:n>", methods=["GET"]) +def generate_image(n: int) -> str: + """Generate an image for a specific page in the story.""" + story_data = flask.session.get("story") + if story_data is None: + return "<p>error: no story data found</p>" + + try: + story = Story.model_validate_json(story_data) + except pydantic.ValidationError as e: + return f"<p>error: story validation failed: {e}</p>" + if story.pages is not None and 0 <= n < len(story.pages): + page = load_image(story.pages[n]) + return f"<img src='{page.image_url}' class='card-img-top'>" + return "<p>Image not found</p>" + + +class StoryTestCase(unittest.TestCase): + """Unit test case for the Story class and related functions.""" + + def test_story_creation(self) -> None: + """Test the creation of a story and its text generation.""" + s = Story( + theme="Christian", + character="Lia and her pet bunny", + setting="A suburban park", + moral="Honor thy mother and father", + ) + s = generate_text(s) + self.assertIsNotNone(s.pages) + self.assertEqual(len(s.pages), 10) # type: ignore[arg-type] + + +class StorybookTestCase(unittest.TestCase): + """Unit test case for the Storybook application.""" + + def setUp(self) -> None: + """Set up the test client.""" + self.app = app.test_client() + + def test_index_page(self) -> None: + """Test that the index page loads successfully.""" + response = self.app.get("/") + self.assertEqual(response.status_code, 200) + self.assertIn(b"Storybook Generator", response.data) + + def test_generate_story(self) -> None: + """Test the story generation endpoint.""" + response = self.app.post( + "/generate/story", + data={ + "theme": "Christian", + "character": "Alice", + "setting": "forest", + }, + ) + self.assertEqual(response.status_code, 200) + self.assertIn(b"<div class='card mb-3'>", response.data) + + +if __name__ == "__main__": + main() |