Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 9 additions & 25 deletions src/scribae/brief.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@

import asyncio
import json
import re
from collections.abc import Callable
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any, cast

from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator
from pydantic_ai import Agent, NativeOutput, UnexpectedModelBehavior
from pydantic_ai.settings import ModelSettings

from .common import current_timestamp, report, slugify
from .idea import Idea, IdeaList
from .io_utils import NoteDetails, Reporter, load_note
from .language import LanguageMismatchError, LanguageResolutionError, ensure_language_output, resolve_output_language
Expand Down Expand Up @@ -174,7 +173,7 @@ def prepare_context(
except OSError as exc: # pragma: no cover - surfaced by CLI
raise BriefFileError(f"Unable to read note: {exc}") from exc

_report(reporter, f"Loaded note '{note.title}' from {note.path}")
report(reporter, f"Loaded note '{note.title}' from {note.path}")

try:
language_resolution = resolve_output_language(
Expand All @@ -187,7 +186,7 @@ def prepare_context(
except LanguageResolutionError as exc:
raise BriefValidationError(str(exc)) from exc

_report(
report(
reporter,
f"Resolved output language: {language_resolution.language} (source: {language_resolution.source})",
)
Expand All @@ -200,7 +199,7 @@ def prepare_context(
idea_selector=idea_selector,
metadata=note.metadata,
)
_report(reporter, f"Selected idea '{selected_idea.title}' (id={selected_idea.id}).")
report(reporter, f"Selected idea '{selected_idea.title}' (id={selected_idea.id}).")

prompts = build_prompt_bundle(
project=project,
Expand All @@ -209,7 +208,7 @@ def prepare_context(
language=language_resolution.language,
idea=selected_idea,
)
_report(reporter, "Prepared structured prompt.")
report(reporter, "Prepared structured prompt.")

return BriefingContext(
note=note,
Expand Down Expand Up @@ -241,7 +240,7 @@ def generate_brief(
else agent
)

_report(
report(
reporter,
f"Calling model '{model_name}' via {resolved_settings.base_url}",
)
Expand Down Expand Up @@ -273,7 +272,7 @@ def generate_brief(
except Exception as exc: # pragma: no cover - surfaced to CLI
raise BriefLLMError(f"LLM request failed: {exc}") from exc

_report(reporter, "LLM call complete, structured brief validated.")
report(reporter, "LLM call complete, structured brief validated.")
return brief


Expand Down Expand Up @@ -311,8 +310,8 @@ def save_prompt_artifacts(
) -> tuple[Path, Path]:
"""Persist the system prompt and truncated note for debugging."""
destination.mkdir(parents=True, exist_ok=True)
stamp = timestamp or _current_timestamp()
slug = _slugify(project_label or "default") or "default"
stamp = timestamp or current_timestamp()
slug = slugify(project_label or "default") or "default"

prompt_path = destination / f"{stamp}-{slug}-note.prompt.txt"
note_path = destination / f"{stamp}-note.txt"
Expand Down Expand Up @@ -361,15 +360,6 @@ async def _call() -> SeoBrief:
return asyncio.run(asyncio.wait_for(_call(), timeout_seconds))


def _current_timestamp() -> str:
return datetime.now().strftime("%Y%m%d-%H%M%S")


def _slugify(value: str) -> str:
lowered = value.lower()
return re.sub(r"[^a-z0-9]+", "-", lowered).strip("-")


def _brief_language_text(brief: SeoBrief) -> str:
faq_text = "\n".join(f"{item.question} {item.answer}" for item in brief.faq)
outline_text = "\n".join(brief.outline)
Expand Down Expand Up @@ -418,9 +408,3 @@ def _metadata_idea_id(metadata: dict[str, Any]) -> str | None:
return None
value = raw.strip() if isinstance(raw, str) else str(raw).strip()
return value or None


def _report(reporter: Reporter, message: str) -> None:
"""Send verbose output when enabled."""
if reporter:
reporter(message)
24 changes: 24 additions & 0 deletions src/scribae/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

import re
from collections.abc import Callable
from datetime import datetime

Reporter = Callable[[str], None] | None


def slugify(value: str) -> str:
lowered = value.lower()
return re.sub(r"[^a-z0-9]+", "-", lowered).strip("-")


def report(reporter: Reporter, message: str) -> None:
if reporter:
reporter(message)


def current_timestamp() -> str:
return datetime.now().strftime("%Y%m%d-%H%M%S")


__all__ = ["Reporter", "current_timestamp", "report", "slugify"]
20 changes: 8 additions & 12 deletions src/scribae/feedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pydantic_ai.settings import ModelSettings

from .brief import SeoBrief
from .common import report
from .io_utils import NoteDetails, Reporter, load_note, truncate
from .language import LanguageMismatchError, LanguageResolutionError, ensure_language_output, resolve_output_language
from .llm import LLM_OUTPUT_RETRIES, LLM_TIMEOUT_SECONDS, OpenAISettings, apply_optional_settings, make_model
Expand Down Expand Up @@ -306,9 +307,9 @@ def prepare_context(
brief = _load_brief(brief_path)
note = _load_note(note_path, max_chars=max_note_chars) if note_path else None

_report(reporter, f"Loaded draft '{body.path.name}' and brief '{brief.title}'.")
report(reporter, f"Loaded draft '{body.path.name}' and brief '{brief.title}'.")
if note is not None:
_report(reporter, f"Loaded source note '{note.title}'.")
report(reporter, f"Loaded source note '{note.title}'.")

try:
language_resolution = resolve_output_language(
Expand All @@ -321,7 +322,7 @@ def prepare_context(
except LanguageResolutionError as exc:
raise FeedbackValidationError(str(exc)) from exc

_report(
report(
reporter,
f"Resolved output language: {language_resolution.language} (source: {language_resolution.source})",
)
Expand Down Expand Up @@ -373,10 +374,10 @@ def generate_feedback_report(
agent if agent is not None else _create_agent(model_name, temperature=temperature, top_p=top_p, seed=seed)
)

_report(reporter, f"Calling model '{model_name}' via {resolved_settings.base_url}")
report(reporter, f"Calling model '{model_name}' via {resolved_settings.base_url}")

try:
report = cast(
result = cast(
FeedbackReport,
ensure_language_output(
prompt=prompts.user_prompt,
Expand All @@ -401,8 +402,8 @@ def generate_feedback_report(
raise FeedbackLLMError(f"LLM request failed: {exc}") from exc

# Remap any out-of-scope categories to "other"
report = _normalize_finding_categories(report, context.focus)
return report
result = _normalize_finding_categories(result, context.focus)
return result


def render_json(report: FeedbackReport) -> str:
Expand Down Expand Up @@ -726,11 +727,6 @@ def _format_location(location: FeedbackLocation | None) -> str:
return f" ({'; '.join(details)})" if details else ""


def _report(reporter: Reporter, message: str) -> None:
if reporter:
reporter(message)


__all__ = [
"BriefAlignment",
"FeedbackFinding",
Expand Down
33 changes: 8 additions & 25 deletions src/scribae/idea.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@

import asyncio
import json
import re
from collections.abc import Callable
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import cast

from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic_ai import Agent, NativeOutput, UnexpectedModelBehavior
from pydantic_ai.settings import ModelSettings

from .common import current_timestamp, report, slugify
from .io_utils import NoteDetails, Reporter, load_note
from .language import LanguageMismatchError, LanguageResolutionError, ensure_language_output, resolve_output_language
from .llm import (
Expand Down Expand Up @@ -108,7 +107,7 @@ def prepare_context(
except OSError as exc: # pragma: no cover - surfaced by CLI
raise IdeaFileError(f"Unable to read note: {exc}") from exc

_report(reporter, f"Loaded note '{note.title}' from {note.path}")
report(reporter, f"Loaded note '{note.title}' from {note.path}")

try:
language_resolution = resolve_output_language(
Expand All @@ -121,7 +120,7 @@ def prepare_context(
except LanguageResolutionError as exc:
raise IdeaValidationError(str(exc)) from exc

_report(
report(
reporter,
f"Resolved output language: {language_resolution.language} (source: {language_resolution.source})",
)
Expand All @@ -132,7 +131,7 @@ def prepare_context(
note_content=note.body,
language=language_resolution.language,
)
_report(reporter, "Prepared idea-generation prompt.")
report(reporter, "Prepared idea-generation prompt.")

return IdeaContext(note=note, project=project, prompts=prompts, language=language_resolution.language)

Expand All @@ -159,7 +158,7 @@ def generate_ideas(
else agent
)

_report(reporter, f"Calling model '{model_name}' via {resolved_settings.base_url}")
report(reporter, f"Calling model '{model_name}' via {resolved_settings.base_url}")

try:
ideas = cast(
Expand Down Expand Up @@ -188,7 +187,7 @@ def generate_ideas(
except Exception as exc: # pragma: no cover - surfaced to CLI
raise IdeaLLMError(f"LLM request failed: {exc}") from exc

_report(reporter, "LLM call complete, ideas validated.")
report(reporter, "LLM call complete, ideas validated.")
return ideas


Expand All @@ -208,8 +207,8 @@ def save_prompt_artifacts(
"""Persist the system prompt and truncated note for debugging."""

destination.mkdir(parents=True, exist_ok=True)
stamp = timestamp or _current_timestamp()
slug = _slugify(project_label or "default") or "default"
stamp = timestamp or current_timestamp()
slug = slugify(project_label or "default") or "default"

prompt_path = destination / f"{stamp}-{slug}-ideas.prompt.txt"
note_path = destination / f"{stamp}-note.txt"
Expand Down Expand Up @@ -264,22 +263,6 @@ async def _call() -> IdeaList:
return asyncio.run(asyncio.wait_for(_call(), timeout_seconds))


def _current_timestamp() -> str:
return datetime.now().strftime("%Y%m%d-%H%M%S")


def _slugify(value: str) -> str:
lowered = value.lower()
return re.sub(r"[^a-z0-9]+", "-", lowered).strip("-")


def _report(reporter: Reporter, message: str) -> None:
"""Send verbose output when enabled."""

if reporter:
reporter(message)


__all__ = [
"Idea",
"IdeaContext",
Expand Down
4 changes: 1 addition & 3 deletions src/scribae/io_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from __future__ import annotations

from collections.abc import Callable
from dataclasses import dataclass
from pathlib import Path
from typing import Any

import frontmatter

# Type alias for verbose output callbacks used across modules.
Reporter = Callable[[str], None] | None
from .common import Reporter


@dataclass(frozen=True)
Expand Down
9 changes: 3 additions & 6 deletions src/scribae/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from dataclasses import dataclass
from typing import Any

from .common import report


class LanguageResolutionError(Exception):
"""Raised when the output language cannot be determined."""
Expand Down Expand Up @@ -96,7 +98,7 @@ def ensure_language_output(
_validate_language(extract_text(first_result), expected_language, language_detector=language_detector)
return first_result
except LanguageMismatchError as first_error:
_report(reporter, str(first_error) + " Retrying with language correction.")
report(reporter, str(first_error) + " Retrying with language correction.")

corrective_prompt = _append_language_correction(prompt, expected_language)
second_result = invoke(corrective_prompt)
Expand Down Expand Up @@ -174,11 +176,6 @@ def _clean_language(value: Any) -> str | None:
return cleaned or None


def _report(reporter: Callable[[str], None] | None, message: str) -> None:
if reporter:
reporter(message)


__all__ = [
"LanguageResolution",
"LanguageResolutionError",
Expand Down
Loading