Commit b0bae89f authored by Vũ Hoàng Anh's avatar Vũ Hoàng Anh

feat: AI SQL Trace - HITL + Chart.js + StarRocks-optimized prompt

- HITL: Accept/Reject SQL before execution
- Chart.js: Auto-detect data shape -> render number cards/bar/doughnut/line
- Prompt: Verified StarRocks partition pruning rules (no DATE() on column)
- Summary: Agent provides 2-3 sentence Vietnamese explanation of results
- Fix: renderTraceDocument -> renderDataResults for history loading
- Auth: Login/admin pages, middleware, JWT auth
- Report Agent: Full report generation system
- New pages: stress-test, regression-test, user-simulator, competitor-research
parent eba01fb3
Pipeline #3393 failed with stage
......@@ -56,7 +56,7 @@ class CANIFAGraph:
self._cached_prompt_hash: str | None = None
def _build_chain(self, system_prompt_template: str):
"""Build chain with dynamic system prompt (fetched from Langfuse per request).
"""Build chain with dynamic system prompt (from local file).
Caches the chain and only rebuilds when prompt content changes.
"""
# Check if prompt changed via hash comparison
......@@ -120,7 +120,7 @@ class CANIFAGraph:
or "⚠️ TRẠNG THÁI KHỞI TẠO: Chưa có User Insight từ lịch sử. Hãy bắt đầu thu thập thông tin mới (Nếu thiếu thông tin thì ghi 'Chưa rõ')."
)
# Fetch prompt from Langfuse (cached 5min by SDK) + build chain (cached by hash)
# Fetch prompt from file (cached) + build chain (cached by hash)
current_date_str = datetime.now().strftime("%d/%m/%Y")
system_prompt_template = get_system_prompt_template()
chain = self._build_chain(system_prompt_template)
......@@ -176,7 +176,7 @@ class CANIFAGraph:
workflow.add_edge("collect_tools", "agent")
self._compiled_graph = workflow.compile(cache=self.cache) # No Checkpointer
logger.info("✅ Graph compiled (Langfuse callback will be per-run)")
logger.info("✅ Graph compiled")
return self._compiled_graph
......@@ -201,13 +201,13 @@ def get_graph_manager(
) -> CANIFAGraph:
"""Get CANIFAGraph instance (Auto-rebuild if model config changes).
Prompt is now fetched dynamically per request from Langfuse,
Prompt is now fetched dynamically per request from local file,
so no need to rebuild graph when prompt changes.
"""
# 1. New Instance if Empty
if _instance[0] is None:
_instance[0] = CANIFAGraph(config, llm, tools)
logger.info(f"✨ Graph Created: {_instance[0].config.model_name} (prompts from Langfuse)")
logger.info(f"✨ Graph Created: {_instance[0].config.model_name} (prompts from local files)")
return _instance[0]
# 2. Check for Model Config Changes only
......
"""
Prompt Utilities — ALL prompts from Langfuse.
Prompt Utilities — ALL prompts from local files.
System prompt + Tool prompts, single source of truth.
Cache strategy:
- SDK cache TTL = 300s (5 min) — giảm HTTP calls tới Langfuse
- Langfuse server có cache riêng phía nó
- Gọi force_refresh_prompts() khi cần update tức thì
- Graph hash-check giữ chain cache ổn định
Reads from:
- agent/system_prompt.txt (system prompt)
- agent/tool_prompts/*.txt (tool prompts)
"""
import logging
import os
from datetime import datetime
from langfuse import Langfuse
logger = logging.getLogger(__name__)
LANGFUSE_SYSTEM_PROMPT_NAME = "canifa-stylist-system-prompt"
# Cache 5 phút — balance giữa update nhanh vs performance
# Gọi force_refresh_prompts() nếu cần update ngay lập tức
CACHE_TTL = 300
# Paths
_BASE_DIR = os.path.dirname(os.path.abspath(__file__))
_SYSTEM_PROMPT_PATH = os.path.join(_BASE_DIR, "system_prompt.txt")
_TOOL_PROMPTS_DIR = os.path.join(_BASE_DIR, "tool_prompts")
LANGFUSE_TOOL_PROMPT_MAP = {
"brand_knowledge_tool": "canifa-tool-brand-knowledge",
"check_is_stock": "canifa-tool-check-stock",
"data_retrieval_tool": "canifa-tool-data-retrieval",
"promotion_canifa_tool": "canifa-tool-promotion",
"store_search_tool": "canifa-tool-store-search",
}
# In-memory cache
_system_prompt_cache: str | None = None
_tool_prompt_cache: dict[str, str] = {}
_langfuse_client: Langfuse | None = None
def _get_langfuse() -> Langfuse:
global _langfuse_client
if _langfuse_client is None:
_langfuse_client = Langfuse()
return _langfuse_client
def _read_file(path: str) -> str:
"""Read a text file and return its content."""
with open(path, "r", encoding="utf-8") as f:
return f.read()
def get_system_prompt() -> str:
"""System prompt với ngày hiện tại đã inject."""
lf = _get_langfuse()
prompt = lf.get_prompt(LANGFUSE_SYSTEM_PROMPT_NAME, label="production", cache_ttl_seconds=CACHE_TTL)
return prompt.compile(date_str=datetime.now().strftime("%d/%m/%Y"))
template = _get_system_prompt_raw()
return template.replace("{{date_str}}", datetime.now().strftime("%d/%m/%Y"))
def _get_system_prompt_raw() -> str:
"""Read raw system prompt from file (cached)."""
global _system_prompt_cache
if _system_prompt_cache is None:
_system_prompt_cache = _read_file(_SYSTEM_PROMPT_PATH)
logger.info(f"📄 System prompt loaded from {_SYSTEM_PROMPT_PATH} ({len(_system_prompt_cache):,} chars)")
return _system_prompt_cache
def get_system_prompt_template() -> str:
"""Template chưa replace date_str — dùng cho ChatPromptTemplate.
Langfuse SDK `.prompt` un-escapes {{ → { for non-variable content,
but LangChain ChatPromptTemplate needs {{ for literal braces.
LangChain ChatPromptTemplate needs {{ for literal braces.
So we re-escape ALL { } first, then convert only {{date_str}} → {date_str}.
"""
lf = _get_langfuse()
prompt = lf.get_prompt(LANGFUSE_SYSTEM_PROMPT_NAME, label="production", cache_ttl_seconds=CACHE_TTL)
raw = _get_system_prompt_raw()
# 1) Re-escape all curly braces for LangChain (literal { → {{, } → }})
raw = prompt.prompt.replace("{", "{{").replace("}", "}}")
escaped = raw.replace("{", "{{").replace("}", "}}")
# 2) Convert only the date_str variable back to LangChain format
# After step 1, {{date_str}} became {{{{date_str}}}} → convert to {date_str}
return raw.replace("{{{{date_str}}}}", "{date_str}")
return escaped.replace("{{{{date_str}}}}", "{date_str}")
def read_tool_prompt(filename: str, default_prompt: str = "") -> str:
"""Read tool prompt from Langfuse."""
"""Read tool prompt from local file."""
name_key = filename.replace(".txt", "")
langfuse_name = LANGFUSE_TOOL_PROMPT_MAP.get(name_key)
if not langfuse_name:
logger.warning(f"⚠️ No Langfuse mapping for tool prompt '{name_key}'")
if name_key in _tool_prompt_cache:
return _tool_prompt_cache[name_key]
txt_path = os.path.join(_TOOL_PROMPTS_DIR, f"{name_key}.txt")
if not os.path.isfile(txt_path):
logger.warning(f"⚠️ Tool prompt file not found: {txt_path}")
return default_prompt
lf = _get_langfuse()
prompt = lf.get_prompt(langfuse_name, label="production", cache_ttl_seconds=CACHE_TTL)
return prompt.prompt
content = _read_file(txt_path)
_tool_prompt_cache[name_key] = content
logger.info(f"📄 Tool prompt loaded: {name_key} ({len(content):,} chars)")
return content
def write_tool_prompt(filename: str, content: str) -> bool:
"""Push tool prompt to Langfuse as new version."""
"""Write tool prompt to local file."""
name_key = filename.replace(".txt", "")
langfuse_name = LANGFUSE_TOOL_PROMPT_MAP.get(name_key)
if not langfuse_name:
logger.error(f"No Langfuse mapping for '{name_key}'")
txt_path = os.path.join(_TOOL_PROMPTS_DIR, f"{name_key}.txt")
try:
with open(txt_path, "w", encoding="utf-8") as f:
f.write(content)
# Invalidate cache
_tool_prompt_cache.pop(name_key, None)
logger.info(f"✅ Tool prompt written: {txt_path}")
return True
except Exception as e:
logger.error(f"❌ Failed to write tool prompt '{name_key}': {e}")
return False
lf = _get_langfuse()
lf.create_prompt(
name=langfuse_name,
prompt=content,
labels=["production"],
tags=["canifa", "tool-prompt"],
type="text",
)
logger.info(f"✅ Tool prompt '{name_key}' pushed to Langfuse")
return True
def list_tool_prompts() -> list[str]:
"""List available tool prompt names."""
return sorted(LANGFUSE_TOOL_PROMPT_MAP.keys())
if not os.path.isdir(_TOOL_PROMPTS_DIR):
return []
return sorted(
f.replace(".txt", "")
for f in os.listdir(_TOOL_PROMPTS_DIR)
if f.endswith(".txt")
)
def force_refresh_prompts() -> str:
"""Force refresh ALL prompt caches by fetching with cache_ttl=0.
Call this after updating prompts on Langfuse to take effect immediately.
"""Force refresh ALL prompt caches by re-reading files.
Returns the new system prompt template (LangChain-ready).
"""
lf = _get_langfuse()
# 1) Force refresh system prompt (bypasses SDK cache)
prompt = lf.get_prompt(LANGFUSE_SYSTEM_PROMPT_NAME, label="production", cache_ttl_seconds=0)
logger.info(f"🔄 Force refreshed system prompt: {LANGFUSE_SYSTEM_PROMPT_NAME} (v{prompt.version}, {len(prompt.prompt):,} chars)")
raw = prompt.prompt.replace("{", "{{").replace("}", "}}")
new_template = raw.replace("{{{{date_str}}}}", "{date_str}")
# 2) Force refresh all tool prompts
for name_key, langfuse_name in LANGFUSE_TOOL_PROMPT_MAP.items():
try:
lf.get_prompt(langfuse_name, label="production", cache_ttl_seconds=0)
logger.info(f"🔄 Force refreshed tool prompt: {name_key}")
except Exception as e:
logger.warning(f"⚠️ Failed to refresh tool prompt '{name_key}': {e}")
logger.info(f"✅ All prompts force refreshed (version: {prompt.version})")
return new_template
global _system_prompt_cache
_system_prompt_cache = None
_tool_prompt_cache.clear()
# Re-read system prompt
template = get_system_prompt_template()
logger.info(f"🔄 All prompts force refreshed from local files ({len(template):,} chars)")
# Pre-load all tool prompts
for name in list_tool_prompts():
read_tool_prompt(name)
return template
# Report Agent package — JSON Report + HTML Report
This diff is collapsed.
This diff is collapsed.
"""
Inline Edit Agent Graph — LangGraph StateGraph for report section editing.
┌── simple_edit ──→ rewrite → END (rewrite/shorten/fix — no SQL needed)
think ──┤
└── agent_edit ──→ query_data → rewrite_with_data → END (enrich with real data)
Used by api/report_html_route.py via `run_inline_agent()`.
"""
import json
import logging
import re
from typing import Any, TypedDict
from langgraph.graph import StateGraph, START, END
from agent.report_agent.core import (
call_llm,
parse_json,
execute_tools_parallel,
summarize_results,
)
from agent.report_agent.prompts.inline_prompt import (
INLINE_EDIT_PROMPT,
AGENT_SECTION_PROMPT,
AGENT_WRITER_PROMPT,
)
logger = logging.getLogger(__name__)
# ─── State ───────────────────────────────────────────────────────────
class InlineState(TypedDict):
# Input
selected_text: str
action: str # rewrite | enrich | shorten | fix | agent_rewrite
context: str
model: str
codex_token: str | None
openai_key: str | None
# Internal
needs_data: bool
tools_to_run: list[dict]
data_summary: str
thinking: str
# Output
new_text: str
explanation: str
error: str | None
# ─── Nodes ───────────────────────────────────────────────────────────
async def think_node(state: InlineState) -> dict:
"""Analyze the selected text and decide: simple edit or agent-powered rewrite."""
action = state["action"]
# Simple edits → no SQL needed
if action != "agent_rewrite":
return {"needs_data": False, "tools_to_run": [], "thinking": ""}
# Agent rewrite → analyze what data is needed
think_input = (
f"Section text: \"{state['selected_text']}\"\n"
f"Surrounding context: {state['context'][:500]}\n\n"
f"Generate SQL queries to fetch data for enriching this section.\n"
f"Return JSON only."
)
think_raw = await call_llm(
AGENT_SECTION_PROMPT, think_input, state["model"],
codex_token=state.get("codex_token"),
openai_key=state.get("openai_key"),
json_mode=True,
)
think_response = parse_json(think_raw)
tools = think_response.get("tools", [])
skip = think_response.get("action") == "skip"
return {
"needs_data": bool(tools) and not skip,
"tools_to_run": tools,
"thinking": think_response.get("thinking", ""),
}
async def query_node(state: InlineState) -> dict:
"""Execute SQL tools to fetch real data for enriching the section."""
tools_to_run = state.get("tools_to_run", [])
if not tools_to_run:
return {"data_summary": ""}
results = await execute_tools_parallel(tools_to_run)
all_results: dict[str, Any] = {}
for i, (tool_spec, result) in enumerate(zip(tools_to_run, results)):
if isinstance(result, Exception):
result = {"error": str(result)[:200], "data": []}
all_results[f"{tool_spec.get('name', 'q')}_{i}"] = result
data_summary = summarize_results(all_results)
return {"data_summary": data_summary}
async def simple_rewrite_node(state: InlineState) -> dict:
"""Simple rewrite without data: rewrite/shorten/fix."""
user_input = (
f"Selected text: \"{state['selected_text']}\"\n"
f"Action: {state['action']}\n"
f"Surrounding context: {state['context'][:500]}\n\n"
f"Return JSON only."
)
raw = await call_llm(
INLINE_EDIT_PROMPT, user_input, state["model"],
codex_token=state.get("codex_token"),
openai_key=state.get("openai_key"),
)
json_match = re.search(r'\{[\s\S]*\}', raw)
if json_match:
parsed = json.loads(json_match.group())
return {
"new_text": parsed.get("new_text", raw.strip()),
"explanation": parsed.get("explanation", "AI đã chỉnh sửa văn bản"),
}
return {"new_text": raw.strip(), "explanation": "AI đã chỉnh sửa văn bản"}
async def rewrite_with_data_node(state: InlineState) -> dict:
"""Rewrite the section using real data from SQL queries."""
data_summary = state.get("data_summary", "")
if not data_summary.strip() or "no data" in data_summary.lower():
return {
"new_text": state["selected_text"],
"explanation": "Không có dữ liệu mới để bổ sung",
}
write_input = (
f"Original section:\n\"{state['selected_text']}\"\n\n"
f"New data from queries:\n{data_summary}\n\n"
f"Rewrite this section incorporating the new data. Return JSON only."
)
write_raw = await call_llm(
AGENT_WRITER_PROMPT, write_input, state["model"],
codex_token=state.get("codex_token"),
openai_key=state.get("openai_key"),
)
json_match = re.search(r'\{[\s\S]*\}', write_raw)
if json_match:
parsed = json.loads(json_match.group())
return {
"new_text": parsed.get("new_text", write_raw.strip()),
"explanation": parsed.get("explanation", "AI đã bổ sung dữ liệu mới"),
}
return {"new_text": write_raw.strip(), "explanation": "AI đã bổ sung dữ liệu mới"}
# ─── Routing Functions ──────────────────────────────────────────────
def route_after_think(state: InlineState) -> str:
"""Think decides: simple edit → simple_rewrite, data needed → query."""
if state.get("needs_data"):
return "query"
return "simple_rewrite"
# ─── Build Graph ─────────────────────────────────────────────────────
def build_inline_graph() -> StateGraph:
"""Build and compile the inline edit agent graph."""
graph = StateGraph(InlineState)
# Add nodes
graph.add_node("think", think_node)
graph.add_node("query", query_node)
graph.add_node("simple_rewrite", simple_rewrite_node)
graph.add_node("rewrite_with_data", rewrite_with_data_node)
# Entry point
graph.add_edge(START, "think")
# Think → either simple_rewrite or query
graph.add_conditional_edges("think", route_after_think, ["query", "simple_rewrite"])
# query → rewrite_with_data → END
graph.add_edge("query", "rewrite_with_data")
graph.add_edge("rewrite_with_data", END)
# simple_rewrite → END
graph.add_edge("simple_rewrite", END)
return graph.compile()
# Compiled graph instance
inline_graph = build_inline_graph()
# ─── Public API ──────────────────────────────────────────────────────
async def run_inline_agent(
*,
selected_text: str,
action: str = "rewrite",
context: str = "",
model: str = "codex/gpt-5.3-codex",
codex_token: str | None = None,
openai_key: str | None = None,
) -> dict:
"""
Run the inline edit agent and return the result.
Returns: {"new_text": str, "explanation": str} or {"error": str}
"""
initial_state: InlineState = {
"selected_text": selected_text,
"action": action,
"context": context,
"model": model,
"codex_token": codex_token,
"openai_key": openai_key,
"needs_data": False,
"tools_to_run": [],
"data_summary": "",
"thinking": "",
"new_text": "",
"explanation": "",
"error": None,
}
try:
result = await inline_graph.ainvoke(initial_state)
return {
"new_text": result.get("new_text", selected_text),
"explanation": result.get("explanation", ""),
}
except Exception as e:
logger.error("Inline agent error: %s", e)
return {"error": str(e)}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
"""
Follow-up Prompts — Prompts for the follow-up report section agent.
Contains prompts for:
- Follow-up context loading and analysis
- Section writing for follow-up queries ("Phụ lục A, B, C...")
Imported by: agent/report_agent/follow_up_graph.py
"""
from agent.report_agent.prompts.db_schemas import (
LANGFUSE_TABLE, LANGFUSE_SCHEMA,
CLEVERTAP_TABLE, CLEVERTAP_SCHEMA,
)
# ═══════════════════════════════════════════════════════════════════
# FOLLOW-UP AGENT PROMPT — Analyzes follow-up question + decides tools
# ═══════════════════════════════════════════════════════════════════
FOLLOWUP_AGENT_PROMPT = """You are the Canifa AI Report Follow-up Agent.
You receive a follow-up question about an EXISTING report. Your job is to:
1. Understand what the previous report already covered (from the outline provided)
2. Determine what NEW data is needed to answer the follow-up question
3. Generate SQL queries to fetch that data
## CONTEXT YOU RECEIVE:
- Previous report outline (headings, KPIs, key numbers)
- The follow-up question from the user
- Previously used tools
## RESPONSE FORMAT (JSON only):
{{
"action": "execute",
"thinking": "What the user wants + what's already in the report + what new data I need",
"tools": [
{{"name": "sql_starrocks", "params": {{"sql": "SELECT ..."}}, "purpose": "why this query"}}
]
}}
If the question can be answered from the outline alone (no new data needed):
{{
"action": "direct_answer",
"thinking": "This can be answered from existing report data",
"answer": "Your Vietnamese answer based on the outline data"
}}
## AVAILABLE TOOLS:
1. **sql_langfuse** — Query `{LANGFUSE_TABLE}` on StarRocks: chatbot traces, costs, latency, models, errors
{LANGFUSE_SCHEMA}
2. **sql_clevertap** — Query `{CLEVERTAP_TABLE}` on StarRocks: frontend user behavior, chatbot events
{CLEVERTAP_SCHEMA}
3. **sql_postgres** — Query PostgreSQL: reports, feedback, user data, chat history
4. **read_report_section** — Read a specific section from the previous report
params: {{"report_id": int, "section": "section heading text"}}
## SQL RULES:
- SELECT only
- Always use LIMIT 20
- Use aggregate functions (SUM, COUNT, AVG, GROUP BY)
## KEY PRINCIPLES:
1. NEVER repeat analysis that the previous report already covers
2. Focus on NEW angles, deeper drill-downs, or comparisons
3. If comparing with previous data, query the contrasting period/dimension
4. Keep it focused — 1-3 queries max
"""
# ═══════════════════════════════════════════════════════════════════
# FOLLOW-UP WRITER PROMPT — Writes the follow-up section HTML
# ═══════════════════════════════════════════════════════════════════
FOLLOWUP_WRITER_PROMPT = """You are the Canifa AI Report Section Writer.
You write a SINGLE follow-up section (Phụ lục) to append to an existing report.
You receive real data from SQL queries and must create a complete, beautiful section.
## OUTPUT FORMAT:
Generate a `<div class="followup-section">` containing:
- Section header with the appendix label (provided)
- Analysis text with specific numbers from the data
- Charts (Chart.js) if data is chart-worthy
- Tables if data has multiple rows
- Key insights / takeaways
## HTML STRUCTURE:
```html
<div class="followup-section">
<div class="section">
<h2>📎 {{appendix_label}} — {{section_title}}</h2>
<p class="followup-question">Yêu cầu: "{{user_question}}"</p>
<hr>
<!-- Your analysis content here -->
</div>
</div>
```
## RULES:
1. Vietnamese only
2. Use the same CSS classes as the main report (.section, .kpi-grid, .kpi-card, etc.)
3. Every claim must have a specific number from the data
4. Write like a senior analyst — concise, data-driven
5. Include Chart.js charts when data has 3+ data points
6. If comparing with previous report data, highlight differences with ↑↓ arrows
7. Output RAW HTML ONLY — no markdown fences
## DATA FORMAT:
Data is tab-separated tables. Parse columns for Chart.js labels[] and data[] arrays.
"""
"""
Inline Edit Prompts — Prompts for the inline report editing agent.
Contains prompts for:
- Simple edits (rewrite, enrich, shorten, fix)
- Agent-powered section rewrite (think → query → rewrite with data)
Imported by: agent/report_agent/inline_graph.py
"""
from agent.report_agent.prompts.db_schemas import (
LANGFUSE_TABLE, LANGFUSE_SCHEMA,
CLEVERTAP_TABLE, CLEVERTAP_SCHEMA,
)
# ═══════════════════════════════════════════════════════════════════
# SIMPLE INLINE EDIT PROMPT
# ═══════════════════════════════════════════════════════════════════
INLINE_EDIT_PROMPT = """You are an expert Vietnamese business report editor.
You receive a text selection from an A4 business report and an edit `action`.
## ACTIONS:
1. **rewrite**: Improve clarity, flow, and professionalism. Write like a McKinsey analyst.
2. **enrich**: Make the tone more analytical. Highlight the "so what?" behind the numbers. Add logical transitions ("Do đó", "Dẫn đến", "Ngược lại").
3. **shorten**: Compress to essential insights only. Remove fluff, adjectives, and filler words. Keep all hard numbers intact. (Aim for 50-70% original length).
4. **fix**: Fix spelling, grammar, punctuation, and formatting inconsistencies ONLY. Do not change the tone or meaning.
## STRICT RULES:
1. Output in Vietnamese ONLY.
2. KEEP ALL ORIGINAL NUMBERS/DATA. Do NOT hallucinate new numbers.
3. Keep the exact same HTML structure/tags (e.g. `<strong>`, `<li>`, `<ul>`, `<p>`) if they exist in the selection.
4. REMOVE ALL EMOJIS from the final text. Business reports must be strictly professional.
5. Do NOT use chatbot phrases like "Dưới đây là...", "Chắc chắn rồi...", "Tôi đã sửa...".
6. Return ONLY valid JSON block.
## RESPONSE FORMAT:
```json
{
"new_text": "the rewritten text maintaining HTML structure",
"explanation": "A very short (1 sentence) Vietnamese explanation of what you changed"
}
```
"""
# ═══════════════════════════════════════════════════════════════════
# AGENT-POWERED SECTION ANALYSIS PROMPT (agent_rewrite mode)
# ═══════════════════════════════════════════════════════════════════
AGENT_SECTION_PROMPT = f"""You are a data analyst mini-agent.
You receive a report SECTION that the user wants to ENRICH with real data.
## YOUR TASK:
1. Analyze the section text to understand what data is being discussed
2. Generate 1-3 SQL queries to fetch additional data from these databases:
### Available Tools:
1. **sql_langfuse** — Query `{LANGFUSE_TABLE}` on StarRocks
{LANGFUSE_SCHEMA}
2. **sql_clevertap** — Query `{CLEVERTAP_TABLE}` on StarRocks
{CLEVERTAP_SCHEMA}
3. **sql_postgres** — Query PostgreSQL: ai_reports, feedback data, chat history
## RESPONSE FORMAT (JSON only):
```json
{{
"action": "query",
"thinking": "What data this section needs to verify its claims or enrich its analysis",
"tools": [
{{"name": "sql_langfuse", "params": {{"sql": "SELECT ..."}}, "purpose": "Why this query?"}}
]
}}
```
If no SQL queries are needed (text is purely conceptual and requires no data):
```json
{{
"action": "skip",
"thinking": "Why no queries are needed"
}}
```
## SQL RULES:
- SELECT only (no mutations)
- Always use LIMIT 20
- Use aggregate functions (SUM, COUNT, AVG) when summarizing
- Never query embedding/vector columns
"""
# ═══════════════════════════════════════════════════════════════════
# AGENT-POWERED SECTION REWRITE PROMPT
# ═══════════════════════════════════════════════════════════════════
AGENT_WRITER_PROMPT = """You are an expert Vietnamese business report editor.
You have the ORIGINAL TEXT and NEW DATA from database queries.
Your task is to rewrite the section to incorporate the REAL data.
## RULES:
1. Output in Vietnamese ONLY.
2. Keep the EXACT SAME HTML structure/tags (div, p, ul, strong) as the original section.
3. Replace placeholder or incorrect numbers in the text with the REAL numbers from the data.
4. If the data contradicts the original text, change the narrative to reflect the truth of the data.
5. REMOVE ALL EMOJIS from the final text.
6. Write with the authoritative tone of a senior analyst.
7. Return ONLY valid JSON, no markdown fences.
## RESPONSE FORMAT:
```json
{
"new_text": "the rewritten HTML section",
"explanation": "What data points were updated/added in this rewrite"
}
```
"""
"""
Report HTML Prompt — Re-export shim for backward compatibility.
The actual prompts have been split into:
- agent_prompt.py → HTML_AGENT_PROMPT, LANGFUSE_TABLE, LANGFUSE_SCHEMA
- writer_prompt.py → HTML_WRITER_PROMPT
This file re-exports all symbols so existing imports keep working.
"""
from agent.report_agent.prompts.agent_prompt import ( # noqa: F401
HTML_AGENT_PROMPT,
LANGFUSE_TABLE,
LANGFUSE_SCHEMA,
)
from agent.report_agent.prompts.writer_prompt import HTML_WRITER_PROMPT # noqa: F401
This diff is collapsed.
This diff is collapsed.
"""
SQL Agent — AI-powered SQL analysis with ReAct agent pattern.
Modules:
- trace_agent: Core ReAct loop (THINK→PLAN→ACT→OBSERVE→REFLECT)
- tools: SQL execution, calculator, result serialization
- session_manager: In-memory session state management
- persistence: PostgreSQL save/load for trace sessions
- prompts: LLM system prompts
- core: Legacy shared utilities (validate_sql, call_llm, etc.)
"""
This diff is collapsed.
"""
SQL Trace Persistence — Save/load trace sessions to/from PostgreSQL.
"""
import json
import logging
from typing import Any
from common.pool_wrapper import get_pooled_connection_compat
from agent.sql_agent.session_manager import public_session
logger = logging.getLogger(__name__)
def ensure_sql_tables() -> None:
"""Create sql_trace_sessions table if it doesn't exist."""
conn = None
try:
conn = get_pooled_connection_compat()
cur = conn.cursor()
cur.execute("""
CREATE TABLE IF NOT EXISTS dashboard_canifa.sql_trace_sessions (
id SERIAL PRIMARY KEY,
conversation_id UUID NOT NULL,
session_id INT NOT NULL,
question TEXT NOT NULL,
model VARCHAR(100),
status VARCHAR(20) DEFAULT 'running',
session_data JSONB,
created_at TIMESTAMPTZ DEFAULT NOW(),
completed_at TIMESTAMPTZ
);
CREATE INDEX IF NOT EXISTS idx_sql_trace_conv_id
ON dashboard_canifa.sql_trace_sessions(conversation_id);
""")
cur.close()
except Exception as e:
logger.error("Error creating sql_trace_sessions table: %s", e)
finally:
if conn:
conn.close()
def persist_session_to_db(session: dict[str, Any]) -> None:
"""Save completed session to PostgreSQL."""
conn = None
try:
conn = get_pooled_connection_compat()
cur = conn.cursor()
public = public_session(session)
cur.execute(
"""
INSERT INTO dashboard_canifa.sql_trace_sessions
(conversation_id, session_id, question, model, status, session_data, created_at, completed_at)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
ON CONFLICT DO NOTHING
""",
(
session["conversation_id"],
session["id"],
session["question"],
session.get("model", ""),
session["status"],
json.dumps(public, ensure_ascii=False, default=str),
session["created_at"],
session.get("completed_at"),
),
)
cur.close()
logger.info(
"Persisted SQL trace session %s (conv=%s)",
session["id"],
session["conversation_id"],
)
except Exception as e:
logger.error("Error persisting SQL trace session: %s", e)
finally:
if conn:
conn.close()
def load_conversations(limit: int = 50) -> list[dict[str, Any]]:
"""Get list of SQL trace conversations from DB, newest first."""
conn = None
try:
conn = get_pooled_connection_compat()
cur = conn.cursor()
cur.execute(
"""
SELECT conversation_id,
MIN(question) AS title,
MAX(status) AS status,
MIN(created_at) AS created_at,
MAX(completed_at) AS updated_at,
COUNT(*) AS session_count
FROM dashboard_canifa.sql_trace_sessions
GROUP BY conversation_id
ORDER BY MAX(created_at) DESC
LIMIT %s
""",
(limit,),
)
conversations = []
for r in cur.fetchall():
title = r[1] or "Untitled"
conversations.append(
{
"conversation_id": str(r[0]),
"title": title[:80] + "..." if len(title) > 80 else title,
"status": r[2],
"created_at": r[3].isoformat() if r[3] else None,
"updated_at": r[4].isoformat() if r[4] else None,
"session_count": r[5],
}
)
cur.close()
return conversations
except Exception as e:
logger.error("Error listing SQL conversations: %s", e)
return []
finally:
if conn:
conn.close()
def load_conversation_detail(conv_id: str) -> list[dict[str, Any]]:
"""Get all sessions for a specific conversation from DB."""
conn = None
try:
conn = get_pooled_connection_compat()
cur = conn.cursor()
cur.execute(
"""
SELECT session_id, question, model, status, session_data, created_at, completed_at
FROM dashboard_canifa.sql_trace_sessions
WHERE conversation_id = %s
ORDER BY created_at ASC
""",
(conv_id,),
)
sessions = []
for r in cur.fetchall():
session_data = (
r[4]
if isinstance(r[4], dict)
else json.loads(r[4])
if r[4]
else {}
)
sessions.append(
{
"session_id": r[0],
"question": r[1],
"model": r[2],
"status": r[3],
"session_data": session_data,
"created_at": r[5].isoformat() if r[5] else None,
"completed_at": r[6].isoformat() if r[6] else None,
}
)
cur.close()
return sessions
except Exception as e:
logger.error("Error fetching SQL conversation %s: %s", conv_id, e)
return []
finally:
if conn:
conn.close()
# Auto-init tables on import
try:
ensure_sql_tables()
except Exception:
pass
"""
DashboardAI Prompt — Generates diverse dashboard layouts from natural language.
Supports 2 databases: StarRocks (product catalog) + PostgreSQL (chat history).
Imported by sql_chat_route.py.
"""
STARROCKS_TABLE = "shared_source.magento_product_dimension_with_text_embedding"
POSTGRES_TABLE = "public.langgraph_chat_histories"
DB_SCHEMA = f"""
## Database 1: StarRocks (db: "starrocks")
## Table: {STARROCKS_TABLE}
### Columns (SELECT only, KHÔNG query cột 'vector'):
- internal_ref_code (VARCHAR) — Mã sản phẩm nội bộ, VD: "8TP25A005"
- magento_ref_code (VARCHAR) — Mã ref = internal_ref_code + color, VD: "8TP25A005-SW011"
- product_color_code (VARCHAR) — Mã màu sản phẩm
- product_name (VARCHAR) — Tên sản phẩm, VD: "ÁO THUN NAM"
- color_code (VARCHAR) — Mã màu
- master_color (VARCHAR) — Nhóm màu chính: "Trắng", "Đen", "Đỏ", etc.
- product_color_name (VARCHAR) — Tên màu đầy đủ
- season_sale (VARCHAR) — Mùa sale
- season (VARCHAR) — Mùa: "Xuân Hè", "Thu Đông", etc.
- style (VARCHAR) — Phong cách sản phẩm
- fitting (VARCHAR) — Kiểu form: "Regular Fit", "Slim Fit", etc.
- size_scale (VARCHAR) — Thang size: "XS-XL", "1-6", etc.
- graphic (VARCHAR) — Họa tiết
- pattern (VARCHAR) — Hoa văn
- weaving (VARCHAR) — Kiểu dệt
- shape_detail (VARCHAR) — Chi tiết kiểu dáng
- form_neckline (VARCHAR) — Kiểu cổ: "Cổ tròn", "Cổ V", etc.
- form_sleeve (VARCHAR) — Kiểu tay: "Ngắn tay", "Dài tay", etc.
- form_length (VARCHAR) — Kiểu dài
- form_waistline (VARCHAR) — Kiểu eo
- form_shoulderline (VARCHAR) — Kiểu vai
- form_pants (VARCHAR) — Kiểu quần
- material (VARCHAR) — Chất liệu
- specific_material (VARCHAR) — Chất liệu chi tiết
- sale_price (DECIMAL) — Giá bán (VND)
- original_price (DECIMAL) — Giá gốc (VND)
- discount_amount (DECIMAL) — Giá giảm (VND)
- quantity_sold (INT) — Số lượng đã bán
- is_new_product (TINYINT) — Sản phẩm mới (1=new)
- gender_by_product (VARCHAR) — Giới tính: "men", "women", "boy", "girl", "unisex"
- product_line_vn (VARCHAR) — Dòng sản phẩm: "Áo phông", "Quần short", etc.
- product_web_url (VARCHAR) — URL web
---
## Database 2: PostgreSQL (db: "postgres")
## Table: {POSTGRES_TABLE}
This table stores ALL chatbot conversations (user messages + AI responses).
### Columns:
- id (SERIAL) — Auto-increment ID
- identity_key (VARCHAR 255) — User/session identifier
- message (TEXT) — Content: plain text for user messages, JSON string for AI responses
- AI response JSON format: {{"ai_response": "...", "product_ids": [...]}}
- User message: plain text, e.g. "tôi muốn mua áo polo"
- is_human (BOOLEAN) — true = user message, false = AI response
- timestamp (TIMESTAMPTZ) — Message timestamp
### Important Notes for querying chat history:
- To search user questions: WHERE is_human = true AND message ILIKE '%keyword%'
- To count unique users: COUNT(DISTINCT identity_key)
- To count conversations about a topic: COUNT(*) WHERE is_human = true AND message ILIKE '%topic%'
- message for AI responses is JSON, use message::json->>'ai_response' to extract AI text
- Common topics: "áo polo", "áo phông", "quần", "váy", "đầm", "áo khoác", "áo lót"
- Use timestamp for date filtering: WHERE timestamp >= '2026-01-01'
"""
DASHBOARD_PROMPT = f"""You are DashboardAI, an intelligent analytics dashboard generator for Canifa (Vietnamese fashion brand).
When a user describes a report or dashboard, respond ONLY with a raw JSON object (no markdown, no explanation).
{DB_SCHEMA}
## ⚡ CRITICAL: Each widget MUST include "db" field ⚡
- "db": "starrocks" — for product catalog queries (table: {STARROCKS_TABLE})
- "db": "postgres" — for chat history queries (table: {POSTGRES_TABLE})
## RULES:
- ONLY SELECT queries. NEVER INSERT/UPDATE/DELETE/DROP/ALTER/CREATE.
- Always LIMIT (max 200 rows per widget).
- NEVER query 'embedding' or 'vector' columns.
- Each widget MUST have a "db" field indicating which database to query.
- For product data: use StarRocks table {STARROCKS_TABLE}
- For chat/conversation data: use PostgreSQL table {POSTGRES_TABLE}
- You can combine both databases in one dashboard (e.g., product KPIs + chat analytics)
## Widget JSON Schema:
Each widget has: id, type, title, size, color, sql, x_key, y_key, db
## Available Types:
- **kpi**: Single metric (COUNT/SUM/AVG/MAX/MIN). SQL returns 1 row 1 col.
- **bar**: Vertical bar chart. SQL returns x_key + y_key columns.
- **horizontal-bar**: Horizontal bar chart. Same data as bar.
- **stacked-bar**: Stacked bar chart. SQL returns x_key + multiple numeric cols. Use y_keys (array).
- **line**: Line chart. SQL returns x_key + y_key.
- **area**: Filled area chart. Same as line.
- **donut**: Donut/pie chart. SQL returns label + value.
- **scatter**: Scatter plot. SQL returns x_key + y_key (both numeric).
- **table**: Data table. SQL returns any columns.
- **progress**: Progress bars showing completion. SQL returns label + current + max columns. y_keys: ["current_col","max_col"].
- **number-row**: Row of 2-3 inline metrics. SQL returns 1 row with 2-3 numeric cols.
## Available Sizes (12-column grid):
- "xs": 2 columns — tiny metric
- "sm": 3 columns — compact KPI
- "md": 4 columns — standard card
- "half": 6 columns — half width
- "lg": 8 columns — wide chart
- "full": 12 columns — full width
## Available Colors:
indigo, emerald, amber, red, purple, cyan, pink, orange, teal, blue
## ⚡ CRITICAL: Layout MUST BE DIVERSE! ⚡
You MUST vary the layout. NEVER produce the same layout twice. Here are 6 patterns — randomly pick one or create your own mix:
### Pattern A: "KPI Row + Mixed Charts"
Row 1: 4× kpi (sm) | Row 2: 1× bar (half) + 1× donut (half) | Row 3: 1× line (full) | Row 4: 1× table (full)
### Pattern B: "Hero KPI + Detail Grid"
Row 1: 1× kpi (md) + 1× kpi (md) + 1× kpi (md) | Row 2: 1× area (lg) + 1× number-row (md) | Row 3: 1× horizontal-bar (half) + 1× donut (half) | Row 4: 1× table (full)
### Pattern C: "Wide Charts Focus"
Row 1: 3× kpi (md) | Row 2: 1× bar (full) | Row 3: 1× donut (md) + 1× line (lg) | Row 4: 1× table (full)
### Pattern D: "Compact Dashboard"
Row 1: 2× kpi (half) | Row 2: 1× horizontal-bar (half) + 1× progress (half) | Row 3: 1× area (half) + 1× donut (half) | Row 4: 1× table (full)
### Pattern E: "Analytics Deep Dive"
Row 1: 4× kpi (sm) | Row 2: 1× scatter (half) + 1× bar (half) | Row 3: 1× stacked-bar (full) | Row 4: 1× table (full)
### Pattern F: "Executive Summary"
Row 1: 1× kpi (half) + 1× number-row (half) | Row 2: 1× area (full) | Row 3: 1× donut (md) + 1× horizontal-bar (lg) | Row 4: 1× table (full)
## Tips:
- Mix sizes! Don't make all KPIs the same size.
- Use horizontal-bar for categorical comparisons (top 10 products, etc.)
- Use scatter when comparing two numeric dimensions (price vs quantity)
- Use progress for goal-tracking or top-N with % share
- Use stacked-bar to show composition across categories
- number-row is great for showing 2-3 metrics in one compact widget
- Always produce 5-10 widgets total
- RANDOMIZE which pattern you use. Be creative!
- Respond ONLY with the raw JSON. No other text.
"""
"""
AI SQL Trace Prompt.
Dedicated prompt for a standalone analyst agent whose primary output is
the reasoning trace and SQL execution plan, not a business report.
"""
from agent.sql_agent.prompts.dashboard_prompt import POSTGRES_TABLE, STARROCKS_TABLE
from agent.report_agent.prompts.db_schemas import (
LANGFUSE_TABLE, LANGFUSE_SCHEMA,
CLEVERTAP_TABLE, CLEVERTAP_SCHEMA,
)
SQL_TRACE_AGENT_PROMPT = f"""You are an AI SQL Analyst for Canifa.
Your job: read the user's question, generate SQL, run it, check results.
If SQL errors → fix and retry. If data is enough → stop.
## LOOP
1. THINK: What does the user want?
2. EXECUTE: Generate SQL tools
3. OBSERVE: Check results
4. If error → fix SQL and retry (same cycle)
5. If 0 rows → try different date range
6. If data sufficient → DONE
Maximum 4 cycles. Aim for 1 cycle.
## ⚡ QUERY STRATEGY (CRITICAL)
### Rule 1: ONE comprehensive query, NOT multiple small ones
BAD: 3 queries → count users, count sessions, count traces (separately)
GOOD: 1 query → all metrics in ONE SELECT with multiple aggregates
### Rule 2: Date filtering (CRITICAL for partition pruning)
StarRocks partitions data by date. To enable partition pruning, ALWAYS use range
filters directly on the DATETIME column. NEVER apply functions on the column itself:
```sql
-- ✅ GOOD: Direct range filter — enables partition pruning
WHERE traced_at >= CURDATE()
AND traced_at < DATE_ADD(CURDATE(), INTERVAL 1 DAY)
-- ✅ GOOD: Last 7 days
WHERE traced_at >= DATE_SUB(CURDATE(), INTERVAL 7 DAY)
-- ❌ BAD: Function on column — disables partition pruning, full scan!
WHERE DATE(traced_at) = CURDATE()
```
### Rule 3: "Today" data may be empty (T-1 batch)
If query returns 0 rows:
- Cycle 2: Try yesterday → `WHERE traced_at >= DATE_SUB(CURDATE(), INTERVAL 1 DAY) AND traced_at < CURDATE()`
- Cycle 3: Try last 7 days → `WHERE traced_at >= DATE_SUB(CURDATE(), INTERVAL 7 DAY)`
### Rule 4: Include DETAIL and BREAKDOWNS
Don't just return bare counts. Include breakdowns so humans can evaluate:
Example — "bao nhiêu user vào chatbot hôm nay":
```sql
SELECT DATE(traced_at) AS ngay,
COUNT(DISTINCT device_id) AS unique_devices,
COUNT(DISTINCT session_id) AS unique_sessions,
COUNT(*) AS total_traces,
SUM(CASE WHEN is_guest = TRUE THEN 1 ELSE 0 END) AS guest_traces,
SUM(CASE WHEN is_user = TRUE THEN 1 ELSE 0 END) AS user_traces,
COUNT(DISTINCT CASE WHEN is_user = TRUE THEN customer_id END) AS logged_in_users,
ROUND(AVG(trace_latency), 2) AS avg_latency_s,
ROUND(SUM(total_cost), 4) AS total_cost_usd
FROM {LANGFUSE_TABLE}
WHERE traced_at >= CURDATE()
AND traced_at < DATE_ADD(CURDATE(), INTERVAL 1 DAY)
GROUP BY ngay LIMIT 20;
```
### Rule 5: For CleverTap queries
CleverTap uses `event_date_group` (DATE type) as partition key — direct comparison is fine:
```sql
SELECT event_date_group AS ngay,
COUNT(DISTINCT dwh_clevertap_profile_id) AS unique_users,
COUNT(DISTINCT session_id) AS unique_sessions,
SUM(nb_event) AS total_chatbot_opens,
ROUND(AVG(surfing_seconds), 0) AS avg_browsing_secs
FROM {CLEVERTAP_TABLE}
WHERE event_name = 'Chatbot Viewed'
AND event_date_group >= CURDATE()
GROUP BY ngay LIMIT 20;
```
## DATA SOURCES
### 1. sql_langfuse — Chatbot traces on StarRocks
Table: `{LANGFUSE_TABLE}`
{LANGFUSE_SCHEMA}
### 2. sql_langfuse — CleverTap events on StarRocks
Table: `{CLEVERTAP_TABLE}`
{CLEVERTAP_SCHEMA}
### 3. sql_starrocks — Product catalog on StarRocks
Table: `{STARROCKS_TABLE}`
Product data: name, color, size, price, quantity_sold, gender, season, etc.
NEVER query embedding/vector columns.
### 4. sql_postgres — Chat history on PostgreSQL
Table: `{POSTGRES_TABLE}`
Columns: identity_key, message, is_human, timestamp.
### 5. calculator — Pure arithmetic only.
## STARROCKS SQL OPTIMIZATION RULES (from official docs)
### Partition Pruning (MOST IMPORTANT)
StarRocks uses CBO + partition pruning to skip irrelevant data blocks.
For this to work, WHERE must use DIRECT comparisons on partition columns:
- ✅ `WHERE traced_at >= '2026-03-01'` — partition pruning works
- ✅ `WHERE traced_at >= CURDATE()` — partition pruning works
- ❌ `WHERE DATE(traced_at) = CURDATE()` — function on column = FULL SCAN
- ❌ `WHERE CAST(traced_at AS VARCHAR) > '2026'` — CAST = FULL SCAN
### General Rules
1. **SELECT only**. Never INSERT/UPDATE/DELETE.
2. **Always LIMIT 20** at the end.
3. **Prefer aggregates** over raw rows. Use COUNT, SUM, AVG, MIN, MAX.
4. **Filter first, aggregate second**. Always use WHERE before GROUP BY.
5. **Avoid SELECT ***. Only select needed columns.
6. **BOOLEAN columns**: Use `WHERE is_guest = TRUE`, NOT `is_guest = 1`.
In CASE: `CASE WHEN is_guest = TRUE THEN ...`
7. **ARRAY columns** (CleverTap): Use `array_contains(column, value)`
8. **Use NULLIF to avoid division by zero**: `SUM(x) / NULLIF(SUM(y), 0)`
9. **Subqueries are OK**: StarRocks CBO rewrites scalar subqueries to efficient joins.
You CAN use `WHERE event_date_group = (SELECT MAX(event_date_group) FROM ...) `
10. **Avoid excessive DISTINCT** on high-cardinality string columns if not needed.
## RESPONSE FORMAT
Always respond with raw JSON only.
### action = "execute" — Run SQL queries:
{{{{
"action": "execute",
"thinking": "Ngắn gọn lý do (Vietnamese)",
"tools": [
{{{{
"name": "sql_langfuse",
"params": {{{{"sql": "SELECT ... LIMIT 20"}}}},
"purpose": "Mô tả ngắn query này lấy gì"
}}}}
]
}}}}
### action = "reflect" — Evaluate results:
{{{{
"action": "reflect",
"thinking": "Đánh giá: data đủ chưa (Vietnamese)",
"data_sufficient": true,
"summary": "Tóm tắt ngắn gọn 2-3 câu cho người dùng hiểu kết quả. Ví dụ: 'Ngày 25/03, có 27 thiết bị mở chatbot, 159 lượt chat. Trong đó 108 lượt từ user đăng nhập, 51 lượt khách vãng lai. Chi phí AI trung bình 0.015$/lượt.'",
"missing": [],
"next_tools": []
}}}}
If data_sufficient = false, include next_tools with corrected queries.
## QUALITY
- **1 query per tool call. 1 tool call per cycle is ideal.**
- If SQL errors → fix and retry.
- If 0 rows → try broader date range (yesterday, then last 7 days).
- Include breakdowns and detail, not just bare counts.
- **summary field is REQUIRED** when data_sufficient = true. Write in Vietnamese, 2-3 sentences explaining key findings.
- Respond in Vietnamese. Raw JSON only.
"""
"""
SQL Trace Session Manager — In-memory session state management.
Manages trace session lifecycle: create, update cycles, record events, SSE formatting.
Mirrors the pattern used in report_agent/report_queue.py.
"""
import asyncio
import json
import time
import uuid
from datetime import datetime, timezone
from itertools import count
from typing import Any
_TRACE_SESSION_SEQ = count(1)
_TRACE_SESSIONS: dict[int, dict[str, Any]] = {}
_MAX_STORED_SESSIONS = 80
# Human-in-the-Loop: approval events keyed by session_id
_APPROVAL_EVENTS: dict[int, asyncio.Event] = {}
_APPROVAL_ACTIONS: dict[int, str] = {} # "accept" or "reject"
def utc_now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
def prune_sessions() -> None:
"""Remove oldest sessions when exceeding max capacity."""
if len(_TRACE_SESSIONS) <= _MAX_STORED_SESSIONS:
return
ordered_ids = sorted(
_TRACE_SESSIONS.keys(),
key=lambda sid: _TRACE_SESSIONS[sid].get("_created_ts", 0.0),
)
while len(ordered_ids) > _MAX_STORED_SESSIONS:
sid = ordered_ids.pop(0)
_TRACE_SESSIONS.pop(sid, None)
def new_cycle(cycle_no: int) -> dict[str, Any]:
"""Create a fresh cycle object."""
return {
"cycle": cycle_no,
"thinking": "",
"tool_calls": [],
"tool_results": [],
"missing": [],
"data_sufficient": False,
"reflect_thinking": "",
"status": "pending",
}
def create_session(
question: str,
model: str,
max_cycles: int,
conversation_id: str | None = None,
) -> dict[str, Any]:
"""Create a new trace session and register it."""
session_id = next(_TRACE_SESSION_SEQ)
now = utc_now_iso()
conv_id = conversation_id or str(uuid.uuid4())
session = {
"id": session_id,
"conversation_id": conv_id,
"question": question,
"model": model,
"max_cycles": max_cycles,
"status": "running",
"created_at": now,
"updated_at": now,
"completed_at": None,
"events": [],
"cycles": [],
"summary": {},
"total_queries": 0,
"successful_queries": 0,
"failed_queries": 0,
"_created_ts": time.time(),
}
_TRACE_SESSIONS[session_id] = session
prune_sessions()
return session
# ─── Human-in-the-Loop Approval ─────────────────────────────────────
def create_approval_event(session_id: int) -> asyncio.Event:
"""Create new approval event for a session (resets if exists)."""
evt = asyncio.Event()
_APPROVAL_EVENTS[session_id] = evt
_APPROVAL_ACTIONS[session_id] = ""
return evt
def approve_session(session_id: int) -> bool:
"""Accept the pending SQL. Returns False if no event found."""
evt = _APPROVAL_EVENTS.get(session_id)
if not evt:
return False
_APPROVAL_ACTIONS[session_id] = "accept"
evt.set()
return True
def reject_session(session_id: int) -> bool:
"""Reject the pending SQL. Returns False if no event found."""
evt = _APPROVAL_EVENTS.get(session_id)
if not evt:
return False
_APPROVAL_ACTIONS[session_id] = "reject"
evt.set()
return True
def get_approval_action(session_id: int) -> str:
"""Get approval action (accept/reject) for a session."""
return _APPROVAL_ACTIONS.get(session_id, "")
def cleanup_approval(session_id: int) -> None:
"""Remove approval state for a session."""
_APPROVAL_EVENTS.pop(session_id, None)
_APPROVAL_ACTIONS.pop(session_id, None)
def public_session(session: dict[str, Any]) -> dict[str, Any]:
"""Return session dict without internal keys."""
return {k: v for k, v in session.items() if not k.startswith("_")}
def ensure_cycle(session: dict[str, Any], cycle_no: int) -> dict[str, Any]:
"""Ensure cycle_no exists in session, creating missing cycles if needed."""
while len(session["cycles"]) < cycle_no:
session["cycles"].append(new_cycle(len(session["cycles"]) + 1))
return session["cycles"][cycle_no - 1]
def record_event(session: dict[str, Any], payload: dict[str, Any]) -> None:
"""Record an event in the session event log."""
session["updated_at"] = utc_now_iso()
session["events"].append({"ts": session["updated_at"], **payload})
if len(session["events"]) > 800:
session["events"] = session["events"][-800:]
def sse(session: dict[str, Any], payload: dict[str, Any]) -> str:
"""Format a Server-Sent Event string and record it."""
data = {"session_id": session["id"], **payload}
record_event(session, data)
return f"data: {json.dumps(data, ensure_ascii=False, default=str)}\n\n"
def get_session(session_id: int) -> dict[str, Any] | None:
"""Get session by ID."""
return _TRACE_SESSIONS.get(session_id)
def delete_session(session_id: int) -> bool:
"""Delete session by ID. Returns True if found."""
if session_id in _TRACE_SESSIONS:
del _TRACE_SESSIONS[session_id]
return True
return False
def list_sessions() -> list[dict[str, Any]]:
"""List all sessions (summary info only)."""
sessions = []
for session in _TRACE_SESSIONS.values():
sessions.append(
{
"id": session["id"],
"question": session["question"],
"model": session["model"],
"status": session["status"],
"created_at": session["created_at"],
"updated_at": session["updated_at"],
"completed_at": session["completed_at"],
"total_queries": session["total_queries"],
"successful_queries": session["successful_queries"],
"failed_queries": session["failed_queries"],
"cycles_completed": session.get("summary", {}).get(
"cycles_completed", len(session["cycles"])
),
"data_sufficient": session.get("summary", {}).get("data_sufficient"),
"final_task": session.get("summary", {}).get("final_task", ""),
}
)
sessions.sort(key=lambda x: x.get("created_at") or "", reverse=True)
return sessions
"""
SQL Agent Tools — Execute SQL queries and calculator operations.
Extracted from api/ai_sql_trace_route.py for clean separation.
"""
import asyncio
import re
from typing import Any
from common.starrocks_connection import StarRocksConnection
from common.postgres_readonly import PostgresReadonly
# ─── SQL Safety ──────────────────────────────────────────────────────
_FORBIDDEN = re.compile(
r"\b(INSERT|UPDATE|DELETE|DROP|ALTER|CREATE|TRUNCATE|REPLACE|MERGE|GRANT|REVOKE)\b",
re.IGNORECASE,
)
def validate_sql(sql: str, *, require_limit: bool = True) -> str | None:
"""Validate SQL is safe (SELECT only). Returns error message or None."""
sql_clean = sql.strip().rstrip(";")
upper_sql = sql_clean.upper()
if not any(upper_sql.startswith(k) for k in ("SELECT", "WITH", "SHOW")):
return "Only SELECT, WITH, or SHOW queries allowed"
if _FORBIDDEN.search(sql_clean):
return "Forbidden SQL keyword"
check = sql_clean.lower().replace("_text_embedding", "").replace("_embedding", "")
if re.search(r"\b(embedding|vector)\b", check):
return "Cannot query embedding/vector column"
if require_limit and "limit" not in sql_clean.lower():
return "Query must include LIMIT"
return None
# ─── Row Serialization ──────────────────────────────────────────────
def serialize_row(row: dict[str, Any]) -> dict[str, Any]:
"""Serialize a row for JSON output (handle Decimal, bytes, datetime, NaN)."""
for key, val in row.items():
if isinstance(val, (bytes, bytearray)):
row[key] = val.hex()
elif hasattr(val, "isoformat"):
row[key] = val.isoformat()
elif isinstance(val, float) and val != val:
row[key] = 0
elif hasattr(val, "as_tuple"): # Decimal
row[key] = float(val)
return row
# ─── Tool Executors ──────────────────────────────────────────────────
async def exec_tool(tool_name: str, params: dict[str, Any]) -> dict[str, Any]:
"""Execute a single tool and return results."""
if tool_name in ("sql_langfuse", "sql_starrocks"):
sql = params.get("sql", "")
err = validate_sql(sql)
if err:
return {"error": err, "data": []}
db = StarRocksConnection()
try:
rows = await db.execute_query_async(sql)
data = [serialize_row(dict(r)) for r in rows]
return {"data": data, "columns": list(rows[0].keys()) if rows else [], "row_count": len(data)}
except Exception as e:
return {"error": str(e)[:300], "data": []}
if tool_name == "sql_postgres":
sql = params.get("sql", "")
err = validate_sql(sql)
if err:
return {"error": err, "data": []}
try:
rows = await PostgresReadonly.execute_query_async(sql)
data = [serialize_row(dict(r)) for r in rows]
return {"data": data, "columns": list(rows[0].keys()) if rows else [], "row_count": len(data)}
except Exception as e:
return {"error": str(e)[:300], "data": []}
if tool_name == "calculator":
expr = params.get("expression", "")
description = params.get("description", "")
try:
allowed = set("0123456789.+-*/() ,")
if all(c in allowed for c in expr.replace(" ", "")):
result = eval(expr) # noqa: S307
return {"expression": expr, "result": result, "description": description}
return {"error": "Invalid expression", "expression": expr}
except Exception as e:
return {"error": str(e), "expression": expr}
return {"error": f"Unknown tool: {tool_name}", "data": []}
async def execute_tools_parallel(tools: list[dict[str, Any]]) -> list[dict[str, Any] | Exception]:
"""Execute multiple tools in parallel, return results."""
tasks = [exec_tool(t.get("name", ""), t.get("params", {})) for t in tools]
return await asyncio.gather(*tasks, return_exceptions=True)
# ─── Result Preview & Summary ────────────────────────────────────────
def make_result_preview(result: dict[str, Any] | Exception) -> dict[str, Any]:
"""Create a compact preview of a tool result for SSE streaming."""
preview: dict[str, Any] = {}
if isinstance(result, Exception):
return {"error": str(result)[:200]}
if "data" in result:
data = result["data"]
if isinstance(data, list):
preview["row_count"] = len(data)
preview["columns"] = result.get("columns", [])
preview["preview_rows"] = data[:5] if data else []
if "error" in result:
preview["error"] = result["error"]
if "result" in result:
preview["result"] = result["result"]
return preview
def summarize_results(all_results: dict[str, Any]) -> str:
"""Compress tool results into concise format optimized for token efficiency."""
parts = []
for key, result in all_results.items():
if isinstance(result, Exception):
parts.append(f"\n## {key}\nERROR: {str(result)[:200]}")
continue
result_dict = dict(result) if isinstance(result, dict) else {"error": str(result)}
if "error" in result_dict:
parts.append(f"\n## {key}\nERROR: {result_dict['error']}")
continue
if "result" in result_dict:
parts.append(f"\n## {key}\nResult: {result_dict['result']}")
continue
data = result_dict.get("data", [])
if not isinstance(data, list) or not data:
parts.append(f"\n## {key}\n(no data)")
continue
columns = result_dict.get("columns", list(data[0].keys()) if data else [])
row_count = len(data)
if row_count <= 20:
lines = [f"\n## {key} ({row_count} rows)"]
lines.append("\t".join(str(c) for c in columns))
for row in data:
cells = []
for col in columns:
val = row.get(col, "")
s = str(val) if val is not None else ""
if len(s) > 60:
s = s[:57] + "..."
cells.append(s)
lines.append("\t".join(cells))
parts.append("\n".join(lines))
else:
lines = [f"\n## {key} ({row_count} rows → auto-summarized)"]
lines.append("### Sample (first 10):")
lines.append("\t".join(str(c) for c in columns))
for row in data[:10]:
cells = []
for col in columns:
val = row.get(col, "")
s = str(val) if val is not None else ""
if len(s) > 40:
s = s[:37] + "..."
cells.append(s)
lines.append("\t".join(cells))
# Compute stats for numeric columns
lines.append("### Statistics:")
for col in columns:
values = [row.get(col) for row in data if row.get(col) is not None]
if not values:
continue
try:
nums = [float(v) for v in values if v != "" and v is not None]
if nums and len(nums) > len(values) * 0.5:
lines.append(
f" {col}: min={min(nums):.2f}, max={max(nums):.2f}, "
f"avg={sum(nums)/len(nums):.2f}, sum={sum(nums):.2f}, count={len(nums)}"
)
except (ValueError, TypeError):
unique = list(set(str(v) for v in values[:100]))
top = unique[:5]
lines.append(f" {col}: {len(unique)} unique values, top: {', '.join(top)}")
parts.append("\n".join(lines))
return "\n".join(parts)
This diff is collapsed.
......@@ -4,7 +4,6 @@ import logging
from langchain_core.tools import tool
from pydantic import BaseModel, Field
from agent.prompt_utils import read_tool_prompt
from common.embedding_service import create_embedding_async
from common.starrocks_connection import get_db_connection
......@@ -66,4 +65,3 @@ async def canifa_knowledge_search(query: str) -> str:
return "Tôi đang gặp khó khăn khi truy cập kho kiến thức. Bạn muốn hỏi về sản phẩm gì khác không?"
canifa_knowledge_search.__doc__ = read_tool_prompt("brand_knowledge_tool") or canifa_knowledge_search.__doc__
......@@ -9,7 +9,6 @@ import httpx
from langchain_core.tools import tool
from pydantic import BaseModel, Field
from agent.prompt_utils import read_tool_prompt
logger = logging.getLogger(__name__)
......@@ -85,8 +84,3 @@ async def check_is_stock(skus: str) -> str:
return json.dumps(stock_data, ensure_ascii=False)
# Load dynamic docstring from file
dynamic_prompt = read_tool_prompt("check_is_stock")
if dynamic_prompt:
check_is_stock.__doc__ = dynamic_prompt
check_is_stock.description = dynamic_prompt
......@@ -27,7 +27,7 @@ from common.starrocks_connection import get_db_connection
# Setup Logger
logger = logging.getLogger(__name__)
from agent.prompt_utils import read_tool_prompt
class SearchItem(BaseModel):
......@@ -302,8 +302,3 @@ async def data_retrieval_tool(searches: list[SearchItem]) -> str:
return json.dumps(output, ensure_ascii=False, default=str)
# Load dynamic docstring
dynamic_prompt = read_tool_prompt("data_retrieval_tool")
if dynamic_prompt:
data_retrieval_tool.__doc__ = dynamic_prompt
data_retrieval_tool.description = dynamic_prompt
......@@ -5,7 +5,6 @@ from langchain_core.tools import tool
from pydantic import BaseModel, Field
from common.starrocks_connection import get_db_connection
from agent.prompt_utils import read_tool_prompt
logger = logging.getLogger(__name__)
......@@ -111,5 +110,3 @@ async def canifa_get_promotions(check_date: str = None) -> str:
logger.error(f"❌ Error in canifa_get_promotions: {e}")
return "Xin lỗi, tôi không thể lấy danh sách khuyến mãi lúc này."
# Load dynamic docstring
canifa_get_promotions.__doc__ = read_tool_prompt("promotion_canifa_tool") or canifa_get_promotions.__doc__
......@@ -3,7 +3,6 @@ import logging
from langchain_core.tools import tool
from pydantic import BaseModel, Field
from agent.prompt_utils import read_tool_prompt
from common.starrocks_connection import get_db_connection
logger = logging.getLogger(__name__)
......@@ -86,4 +85,3 @@ async def canifa_store_search(location: str) -> str:
return "Tôi đang gặp khó khăn khi tìm kiếm cửa hàng. Bạn có thể liên hệ hotline 1800 6061 để được hỗ trợ."
canifa_store_search.__doc__ = read_tool_prompt("store_search_tool") or canifa_store_search.__doc__
"""
AI SQL Trace API — Thin FastAPI route layer.
All business logic lives in agent/sql_agent/.
This file only handles HTTP routing, auth, and SSE response formatting.
"""
import logging
from typing import Any
from fastapi import APIRouter, Depends
from pydantic import BaseModel, Field
from starlette.responses import StreamingResponse
from auth.middleware import get_current_user
from agent.sql_agent.trace_agent import run_trace_agent
from agent.sql_agent.session_manager import (
get_session,
delete_session,
list_sessions,
public_session,
approve_session,
reject_session,
)
from agent.sql_agent.persistence import (
load_conversations,
load_conversation_detail,
)
logger = logging.getLogger(__name__)
router = APIRouter()
class AISQLTraceRequest(BaseModel):
question: str
model: str = "codex/gpt-5.3-codex"
max_cycles: int = Field(default=4, ge=1, le=4)
conversation_id: str | None = None
class ApprovalRequest(BaseModel):
action: str = Field(..., pattern="^(accept|reject)$")
# ─── Main Trace Endpoint ─────────────────────────────────────────────
@router.post("/api/ai-sql/trace", summary="AI SQL Trace Agent (SSE Stream)")
async def ai_sql_trace_stream(req: AISQLTraceRequest, user: dict = Depends(get_current_user)):
user_codex_token = user.get("settings", {}).get("codex_token")
user_openai_key = user.get("settings", {}).get("openai_key")
return StreamingResponse(
run_trace_agent(
question=req.question,
model=req.model,
max_cycles=req.max_cycles,
conversation_id=req.conversation_id,
codex_token=user_codex_token,
openai_key=user_openai_key,
),
media_type="text/event-stream",
)
# ─── Approval Endpoint (Human-in-the-Loop) ───────────────────────────
@router.post("/api/ai-sql/approve/{session_id}", summary="Approve or reject pending SQL")
async def approve_ai_sql(session_id: int, req: ApprovalRequest, user: dict = Depends(get_current_user)):
if req.action == "accept":
ok = approve_session(session_id)
else:
ok = reject_session(session_id)
if not ok:
return {"error": "No pending approval for this session", "success": False}
return {"success": True, "action": req.action}
# ─── Session Management Endpoints ────────────────────────────────────
@router.get("/api/ai-sql/sessions", summary="List AI SQL trace sessions")
async def list_ai_sql_sessions(user: dict = Depends(get_current_user)):
return {"sessions": list_sessions()}
@router.get("/api/ai-sql/sessions/{session_id}", summary="Get AI SQL trace session detail")
async def get_ai_sql_session(session_id: int, user: dict = Depends(get_current_user)):
session = get_session(session_id)
if not session:
return {"error": "Session not found"}
return {"session": public_session(session)}
@router.delete("/api/ai-sql/sessions/{session_id}", summary="Delete AI SQL trace session")
async def delete_ai_sql_session(session_id: int, user: dict = Depends(get_current_user)):
if not delete_session(session_id):
return {"error": "Session not found"}
return {"success": True}
# ─── Conversation History Endpoints ──────────────────────────────────
@router.get("/api/ai-sql/conversations", summary="List SQL trace conversations")
async def list_sql_conversations(user: dict = Depends(get_current_user)):
conversations = load_conversations()
return {"conversations": conversations}
@router.get("/api/ai-sql/conversations/{conv_id}", summary="Get SQL trace conversation detail")
async def get_sql_conversation(conv_id: str, user: dict = Depends(get_current_user)):
sessions = load_conversation_detail(conv_id)
return {"conversation_id": conv_id, "sessions": sessions}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment