Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions databricks-builder-app/client/src/pages/ProjectPage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -555,17 +555,18 @@ export default function ProjectPage() {
if (isForeground) setStreamingText(fullText);
}
} else if (type === 'tool_use') {
const toolName = event.tool_name as string;
const newItem: ActivityItem = {
id: event.tool_id as string,
type: 'tool_use',
content: '',
toolName: event.tool_name as string,
toolName,
toolInput: event.tool_input as Record<string, unknown>,
timestamp: Date.now(),
};
if (stream) {
stream.activityItems = [...stream.activityItems, newItem];
stream.tools = [...stream.tools, event.tool_name as string];
stream.tools = [...stream.tools, toolName];
}
if (isForeground) setActivityItems(prev => [...prev, newItem]);
} else if (type === 'tool_result') {
Expand Down
42 changes: 26 additions & 16 deletions databricks-builder-app/server/routers/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,26 +557,36 @@ async def get_conversation_executions(
detail=f'Project {project_id} not found'
)

# First check in-memory streams for this conversation (always works)
# Check in-memory streams for this conversation. Prefer a running stream;
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When testing locally, the poller (Lemma) would sometimes miss final events because the endpoint previously ignored completed streams, falling through to the DB before persist_loop had flushed them. This change falls back to the most recently completed in-memory stream (which lingers for 5 minutes before cleanup), giving the poller time to drain all events. It doesn't affect the stream itself, just adds robustness to polling.

This race condition mostly surfaces locally where slower calls widen the timing window, but can still occur in prod under load.

# fall back to the most recently completed one so the poller can read all
# events before the 5-minute cleanup removes them from memory.
stream_manager = get_stream_manager()
in_memory_active = None
async with stream_manager._lock:
running_stream = None
completed_stream = None
for stream in stream_manager._streams.values():
if (
stream.conversation_id == conversation_id
and not stream.is_complete
and not stream.is_cancelled
):
in_memory_active = {
'id': stream.execution_id,
'conversation_id': stream.conversation_id,
'project_id': stream.project_id,
'status': 'running',
'events': [e.data for e in stream.events],
'error': stream.error,
'created_at': None,
}
break
if stream.conversation_id == conversation_id and not stream.is_cancelled:
if not stream.is_complete:
running_stream = stream
break # Running stream always wins
else:
completed_stream = stream
chosen = running_stream or completed_stream
if chosen:
if chosen.is_complete:
status = 'error' if chosen.error else 'completed'
else:
status = 'running'
in_memory_active = {
'id': chosen.execution_id,
'conversation_id': chosen.conversation_id,
'project_id': chosen.project_id,
'status': status,
'events': [e.data for e in chosen.events],
'error': chosen.error,
'created_at': None,
}

# Try to get executions from database (may fail if table doesn't exist yet)
active = None
Expand Down
152 changes: 117 additions & 35 deletions databricks-builder-app/server/services/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
"""

import asyncio
import json
import logging
import os
import queue
Expand All @@ -29,7 +28,7 @@
from pathlib import Path
from typing import AsyncIterator

from claude_agent_sdk import ClaudeAgentOptions, query, HookMatcher
from claude_agent_sdk import ClaudeAgentOptions, HookMatcher, query
from claude_agent_sdk.types import (
AssistantMessage,
ResultMessage,
Expand All @@ -41,10 +40,10 @@
ToolUseBlock,
UserMessage,
)
from databricks_tools_core.auth import set_databricks_auth, clear_databricks_auth
from databricks_tools_core.auth import clear_databricks_auth, set_databricks_auth

from .backup_manager import ensure_project_directory as _ensure_project_directory
from .databricks_tools import load_databricks_tools, create_filtered_databricks_server
from .databricks_tools import create_filtered_databricks_server, load_databricks_tools
from .system_prompt import get_system_prompt

logger = logging.getLogger(__name__)
Expand All @@ -54,9 +53,10 @@
'Read',
'Write',
'Edit',
# 'Bash',
# 'Bash',
'Glob',
'Grep',
'AskUserQuestion',
]

# Cached Databricks tools (loaded once)
Expand Down Expand Up @@ -174,7 +174,9 @@ async def mlflow_stop_hook(input_data: dict, tool_use_id: str | None, context) -
client = mlflow.MlflowClient()
trace_id = trace.info.trace_id
requested_model = os.environ.get('ANTHROPIC_MODEL', 'databricks-claude-opus-4-5')
requested_model_mini = os.environ.get('ANTHROPIC_MODEL_MINI', 'databricks-claude-sonnet-4-5')
requested_model_mini = os.environ.get(
'ANTHROPIC_MODEL_MINI', 'databricks-claude-sonnet-4-5'
)
base_url = os.environ.get('ANTHROPIC_BASE_URL', '')

# Set tags to clarify the Databricks model endpoint used
Expand Down Expand Up @@ -207,7 +209,41 @@ async def mlflow_stop_hook(input_data: dict, tool_use_id: str | None, context) -
return None


def _run_agent_in_fresh_loop(message, options, result_queue, context, is_cancelled_fn, mlflow_experiment=None):
def _create_ask_user_hook():
"""Create a PreToolUse hook that blocks AskUserQuestion.

The AskUserQuestion tool fails silently in headless mode. This hook
denies it and tells the agent to ask the question as regular text
in its response instead, so the user can reply in the next message.
"""

async def ask_user_pretool_hook(input_data: dict, tool_use_id: str | None, context) -> dict:
tool_name = input_data.get('tool_name', '')
if tool_name != 'AskUserQuestion':
return {}

logger.info('AskUserQuestion blocked — redirecting to text-based question')

return {
'systemMessage': (
'The AskUserQuestion tool is not available in this environment. '
'Instead, ask your question directly as text in your response. '
'The user will reply in their next message. '
'Do NOT re-attempt the AskUserQuestion tool.'
),
'hookSpecificOutput': {
'hookEventName': input_data.get('hook_event_name', 'PreToolUse'),
'permissionDecision': 'deny',
'permissionDecisionReason': 'Redirected to text-based question',
},
}

return ask_user_pretool_hook


def _run_agent_in_fresh_loop(
message, options, result_queue, context, is_cancelled_fn, mlflow_experiment=None
):
"""Run agent in a fresh event loop (workaround for issue #462).

This function runs in a separate thread with a fresh event loop to avoid
Expand All @@ -226,11 +262,23 @@ def _run_agent_in_fresh_loop(message, options, result_queue, context, is_cancell

See: https://github.com/anthropics/claude-agent-sdk-python/issues/462
"""

# Run in the copied context to preserve contextvars (like Databricks auth)
def run_with_context():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

# Block AskUserQuestion — it fails silently in headless mode.
# The hook denies the tool and tells the agent to ask via text instead.
ask_user_hook = _create_ask_user_hook()
if options.hooks is None:
options.hooks = {}
if 'PreToolUse' not in options.hooks:
options.hooks['PreToolUse'] = []
options.hooks['PreToolUse'].append(
HookMatcher(matcher='AskUserQuestion', hooks=[ask_user_hook], timeout=30)
)

# Add MLflow Stop hook for tracing if experiment is configured
exp_name = mlflow_experiment or os.environ.get('MLFLOW_EXPERIMENT_NAME')
if exp_name:
Expand All @@ -246,6 +294,7 @@ def run_with_context():

async def run_query():
"""Run agent using query() for proper streaming."""

# Create prompt generator in the fresh event loop context
async def prompt_generator():
yield {'type': 'user', 'message': {'role': 'user', 'content': message}}
Expand All @@ -255,37 +304,54 @@ async def prompt_generator():
async for msg in query(prompt=prompt_generator(), options=options):
msg_count += 1
msg_type = type(msg).__name__
logger.info(f"[AGENT DEBUG] Received message #{msg_count}: {msg_type}")
logger.info(f'[AGENT DEBUG] Received message #{msg_count}: {msg_type}')

# Log more details for specific message types
if hasattr(msg, 'content'):
content = msg.content
if isinstance(content, list):
block_types = [type(b).__name__ for b in content]
logger.info(f"[AGENT DEBUG] Content blocks: {block_types}")
logger.info(f'[AGENT DEBUG] Content blocks: {block_types}')
if hasattr(msg, 'is_error') and msg.is_error:
logger.error(f"[AGENT DEBUG] is_error=True")
logger.error('[AGENT DEBUG] is_error=True')
if hasattr(msg, 'session_id'):
logger.info(f"[AGENT DEBUG] session_id={msg.session_id}")
logger.info(f'[AGENT DEBUG] session_id={msg.session_id}')

# Check for cancellation before processing each message
if is_cancelled_fn():
logger.info("Agent cancelled by user request")
logger.info('Agent cancelled by user request')
result_queue.put(('cancelled', None))
return
result_queue.put(('message', msg))
logger.info(f"[AGENT DEBUG] query() loop completed normally after {msg_count} messages")
logger.info(f'[AGENT DEBUG] query() loop completed normally after {msg_count} messages')
except asyncio.CancelledError:
logger.warning("Agent query was cancelled (asyncio.CancelledError)")
result_queue.put(('error', Exception("Agent query cancelled - likely due to stream timeout or connection issue")))
logger.warning('Agent query was cancelled (asyncio.CancelledError)')
result_queue.put(
(
'error',
Exception('Agent query cancelled - likely due to stream timeout or connection issue'),
)
)
except ConnectionError as e:
logger.error(f"Connection error in agent query: {e}")
result_queue.put(('error', Exception(f"Connection error: {e}. This may occur when tools take longer than the stream timeout (50s).")))
logger.error(f'Connection error in agent query: {e}')
result_queue.put(
(
'error',
Exception(
f'Connection error: {e}. This may occur when tools take longer than the stream timeout (50s).'
),
)
)
except BrokenPipeError as e:
logger.error(f"Broken pipe in agent query: {e}")
result_queue.put(('error', Exception(f"Broken pipe: {e}. The agent subprocess communication was interrupted.")))
logger.error(f'Broken pipe in agent query: {e}')
result_queue.put(
(
'error',
Exception(f'Broken pipe: {e}. The agent subprocess communication was interrupted.'),
)
)
except Exception as e:
logger.exception(f"Unexpected error in agent query: {type(e).__name__}: {e}")
logger.exception(f'Unexpected error in agent query: {type(e).__name__}: {e}')
result_queue.put(('error', e))
finally:
result_queue.put(('done', None))
Expand Down Expand Up @@ -364,20 +430,27 @@ async def stream_agent_response(
allowed_tools = BUILTIN_TOOLS.copy()

# Sync project skills directory before running agent
from .skills_manager import sync_project_skills, get_available_skills, get_allowed_mcp_tools
from .skills_manager import get_allowed_mcp_tools, get_available_skills, sync_project_skills

sync_project_skills(project_dir, enabled_skills=enabled_skills)

# Get Databricks tools and filter based on enabled skills.
# We must create a filtered MCP server (not just filter allowed_tools)
# because bypassPermissions mode exposes all tools in registered MCP servers.
databricks_server, databricks_tool_names = get_databricks_tools()
filtered_tool_names = get_allowed_mcp_tools(databricks_tool_names, enabled_skills=enabled_skills)
filtered_tool_names = get_allowed_mcp_tools(
databricks_tool_names, enabled_skills=enabled_skills
)

if len(filtered_tool_names) < len(databricks_tool_names):
# Some tools are blocked — create a filtered MCP server with only allowed tools
databricks_server, filtered_tool_names = create_filtered_databricks_server(filtered_tool_names)
databricks_server, filtered_tool_names = create_filtered_databricks_server(
filtered_tool_names
)
blocked_count = len(databricks_tool_names) - len(filtered_tool_names)
logger.info(f'Databricks MCP server: {len(filtered_tool_names)} tools allowed, {blocked_count} blocked by disabled skills')
logger.info(
f'Databricks MCP server: {len(filtered_tool_names)} tools allowed, {blocked_count} blocked by disabled skills'
)
else:
logger.info(f'Databricks MCP server configured with {len(filtered_tool_names)} tools')

Expand Down Expand Up @@ -429,11 +502,16 @@ async def stream_agent_response(
# Disable beta headers for Databricks FMAPI compatibility
claude_env['ANTHROPIC_CUSTOM_HEADERS'] = 'x-databricks-disable-beta-headers: true'

logger.info(f'Configured Databricks model serving: {anthropic_base_url} with model {anthropic_model}')
logger.info(f'Claude env vars: BASE_URL={claude_env.get("ANTHROPIC_BASE_URL")}, MODEL={claude_env.get("ANTHROPIC_MODEL")}')
logger.info(
f'Configured Databricks model serving: {anthropic_base_url} with model {anthropic_model}'
)
logger.info(
f'Claude env vars: BASE_URL={claude_env.get("ANTHROPIC_BASE_URL")}, MODEL={claude_env.get("ANTHROPIC_MODEL")}'
)

# Databricks SDK upstream tracking for subprocess user-agent attribution
from databricks_tools_core.identity import PRODUCT_NAME, PRODUCT_VERSION

claude_env['DATABRICKS_SDK_UPSTREAM'] = PRODUCT_NAME
claude_env['DATABRICKS_SDK_UPSTREAM_VERSION'] = PRODUCT_VERSION

Expand All @@ -454,7 +532,7 @@ def stderr_callback(line: str):
resume=session_id, # Resume from previous session if provided
mcp_servers={'databricks': databricks_server}, # In-process SDK tools
system_prompt=system_prompt, # Databricks-focused system prompt
setting_sources=["user", "project"], # Load Skills from filesystem
setting_sources=['user', 'project'], # Load Skills from filesystem
env=claude_env, # Pass Databricks auth settings (ANTHROPIC_AUTH_TOKEN, etc.)
include_partial_messages=True, # Enable token-by-token streaming
stderr=stderr_callback, # Capture stderr for debugging
Expand All @@ -473,7 +551,7 @@ def stderr_callback(line: str):
agent_thread = threading.Thread(
target=_run_agent_in_fresh_loop,
args=(message, options, result_queue, ctx, cancel_check, mlflow_experiment),
daemon=True
daemon=True,
)
agent_thread.start()

Expand All @@ -489,9 +567,7 @@ def get_with_timeout():
except queue.Empty:
return ('keepalive', None)

msg_type, msg = await asyncio.get_event_loop().run_in_executor(
None, get_with_timeout
)
msg_type, msg = await asyncio.get_event_loop().run_in_executor(None, get_with_timeout)

if msg_type == 'keepalive':
# Emit keepalive event to keep the stream active during long tool execution
Expand All @@ -509,7 +585,7 @@ def get_with_timeout():
if msg_type == 'done':
break
elif msg_type == 'cancelled':
logger.info("Agent execution cancelled")
logger.info('Agent execution cancelled')
yield {'type': 'cancelled'}
break
elif msg_type == 'error':
Expand Down Expand Up @@ -640,7 +716,13 @@ def get_with_timeout():
'thinking': thinking,
}
# Pass through other stream events if needed
elif event_type not in ('content_block_start', 'content_block_stop', 'message_start', 'message_delta', 'message_stop'):
elif event_type not in (
'content_block_start',
'content_block_stop',
'message_start',
'message_delta',
'message_stop',
):
yield {
'type': 'stream_event',
'event': event_data,
Expand All @@ -653,7 +735,7 @@ def get_with_timeout():
full_traceback = traceback.format_exc()

# Use print to stderr for immediate visibility
print(f'\n{"="*60}', file=sys.stderr)
print(f'\n{"=" * 60}', file=sys.stderr)
print(f'AGENT ERROR: {error_msg}', file=sys.stderr)
print(full_traceback, file=sys.stderr)

Expand All @@ -670,7 +752,7 @@ def get_with_timeout():
logger.error(f'Sub-exception {i}: {sub_exc}')
logger.error(sub_tb)

print(f'{"="*60}\n', file=sys.stderr)
print(f'{"=" * 60}\n', file=sys.stderr)

yield {
'type': 'error',
Expand Down