402 lines
12 KiB
Rust
402 lines
12 KiB
Rust
use rusqlite::{params, Connection, Result};
|
|
use serde::{Deserialize, Serialize};
|
|
use uuid::Uuid;
|
|
|
|
pub const DEFAULT_ANALYST_AGENT_ID: &str = "default-analyst-agent";
|
|
pub const DEFAULT_DEVELOPER_AGENT_ID: &str = "default-developer-agent";
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
|
#[serde(rename_all = "snake_case")]
|
|
pub enum AgentRole {
|
|
Analyst,
|
|
Developer,
|
|
}
|
|
|
|
impl AgentRole {
|
|
pub fn as_str(&self) -> &'static str {
|
|
match self {
|
|
AgentRole::Analyst => "analyst",
|
|
AgentRole::Developer => "developer",
|
|
}
|
|
}
|
|
|
|
pub fn from_str(value: &str) -> Result<Self> {
|
|
match value {
|
|
"analyst" => Ok(AgentRole::Analyst),
|
|
"developer" => Ok(AgentRole::Developer),
|
|
_ => Err(rusqlite::Error::InvalidParameterName(format!(
|
|
"Invalid agent role: {}",
|
|
value
|
|
))),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
|
#[serde(rename_all = "snake_case")]
|
|
pub enum AgentTool {
|
|
Codex,
|
|
ClaudeCode,
|
|
}
|
|
|
|
impl AgentTool {
|
|
pub fn as_str(&self) -> &'static str {
|
|
match self {
|
|
AgentTool::Codex => "codex",
|
|
AgentTool::ClaudeCode => "claude_code",
|
|
}
|
|
}
|
|
|
|
pub fn from_str(value: &str) -> Result<Self> {
|
|
match value {
|
|
"codex" => Ok(AgentTool::Codex),
|
|
"claude_code" => Ok(AgentTool::ClaudeCode),
|
|
_ => Err(rusqlite::Error::InvalidParameterName(format!(
|
|
"Invalid agent tool: {}",
|
|
value
|
|
))),
|
|
}
|
|
}
|
|
|
|
pub fn to_command(&self) -> &'static str {
|
|
match self {
|
|
AgentTool::Codex => "codex",
|
|
AgentTool::ClaudeCode => "claude",
|
|
}
|
|
}
|
|
|
|
pub fn to_non_interactive_args(&self) -> Vec<String> {
|
|
match self {
|
|
AgentTool::Codex => vec![
|
|
"exec".to_string(),
|
|
"--ephemeral".to_string(),
|
|
"-c".to_string(),
|
|
"mcp_servers.tuleap.enabled=false".to_string(),
|
|
"-".to_string(),
|
|
],
|
|
AgentTool::ClaudeCode => vec!["-p".to_string()],
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct Agent {
|
|
pub id: String,
|
|
pub name: String,
|
|
pub role: AgentRole,
|
|
pub tool: AgentTool,
|
|
pub custom_prompt: String,
|
|
pub is_default: bool,
|
|
pub created_at: String,
|
|
pub updated_at: String,
|
|
}
|
|
|
|
fn from_row(row: &rusqlite::Row) -> rusqlite::Result<Agent> {
|
|
let role_raw: String = row.get(2)?;
|
|
let tool_raw: String = row.get(3)?;
|
|
let is_default_int: i32 = row.get(5)?;
|
|
|
|
Ok(Agent {
|
|
id: row.get(0)?,
|
|
name: row.get(1)?,
|
|
role: AgentRole::from_str(&role_raw)?,
|
|
tool: AgentTool::from_str(&tool_raw)?,
|
|
custom_prompt: row.get(4)?,
|
|
is_default: is_default_int != 0,
|
|
created_at: row.get(6)?,
|
|
updated_at: row.get(7)?,
|
|
})
|
|
}
|
|
|
|
impl Agent {
|
|
pub fn insert(
|
|
conn: &Connection,
|
|
name: &str,
|
|
role: AgentRole,
|
|
tool: AgentTool,
|
|
custom_prompt: &str,
|
|
) -> Result<Agent> {
|
|
let id = Uuid::new_v4().to_string();
|
|
let now = chrono::Utc::now().to_rfc3339();
|
|
|
|
conn.execute(
|
|
"INSERT INTO agents (id, name, role, tool, custom_prompt, is_default, created_at, updated_at) VALUES (?1, ?2, ?3, ?4, ?5, 0, ?6, ?7)",
|
|
params![id, name, role.as_str(), tool.as_str(), custom_prompt, now, now],
|
|
)?;
|
|
|
|
Ok(Agent {
|
|
id,
|
|
name: name.to_string(),
|
|
role,
|
|
tool,
|
|
custom_prompt: custom_prompt.to_string(),
|
|
is_default: false,
|
|
created_at: now.clone(),
|
|
updated_at: now,
|
|
})
|
|
}
|
|
|
|
pub fn list(conn: &Connection) -> Result<Vec<Agent>> {
|
|
let mut stmt = conn.prepare(
|
|
"SELECT id, name, role, tool, custom_prompt, is_default, created_at, updated_at
|
|
FROM agents
|
|
ORDER BY role ASC, is_default DESC, created_at DESC",
|
|
)?;
|
|
let rows = stmt.query_map([], from_row)?;
|
|
rows.collect()
|
|
}
|
|
|
|
pub fn get_by_id(conn: &Connection, id: &str) -> Result<Agent> {
|
|
conn.query_row(
|
|
"SELECT id, name, role, tool, custom_prompt, is_default, created_at, updated_at FROM agents WHERE id = ?1",
|
|
params![id],
|
|
from_row,
|
|
)
|
|
}
|
|
|
|
pub fn get_default_by_role(conn: &Connection, role: AgentRole) -> Result<Agent> {
|
|
let default_id = match role {
|
|
AgentRole::Analyst => DEFAULT_ANALYST_AGENT_ID,
|
|
AgentRole::Developer => DEFAULT_DEVELOPER_AGENT_ID,
|
|
};
|
|
|
|
conn.query_row(
|
|
"SELECT id, name, role, tool, custom_prompt, is_default, created_at, updated_at
|
|
FROM agents
|
|
WHERE id = ?1 AND role = ?2 AND is_default = 1
|
|
LIMIT 1",
|
|
params![default_id, role.as_str()],
|
|
from_row,
|
|
)
|
|
}
|
|
|
|
pub fn update(
|
|
conn: &Connection,
|
|
id: &str,
|
|
name: &str,
|
|
role: AgentRole,
|
|
tool: AgentTool,
|
|
custom_prompt: &str,
|
|
) -> Result<()> {
|
|
let existing = Self::get_by_id(conn, id)?;
|
|
let now = chrono::Utc::now().to_rfc3339();
|
|
|
|
if existing.is_default {
|
|
if existing.name != name || existing.role != role {
|
|
return Err(rusqlite::Error::InvalidParameterName(
|
|
"Default agents cannot change name or role".to_string(),
|
|
));
|
|
}
|
|
|
|
conn.execute(
|
|
"UPDATE agents SET tool = ?1, custom_prompt = ?2, updated_at = ?3 WHERE id = ?4",
|
|
params![tool.as_str(), custom_prompt, now, id],
|
|
)?;
|
|
return Ok(());
|
|
}
|
|
|
|
let affected = conn.execute(
|
|
"UPDATE agents SET name = ?1, role = ?2, tool = ?3, custom_prompt = ?4, updated_at = ?5 WHERE id = ?6",
|
|
params![name, role.as_str(), tool.as_str(), custom_prompt, now, id],
|
|
)?;
|
|
|
|
if affected == 0 {
|
|
return Err(rusqlite::Error::QueryReturnedNoRows);
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub fn delete(conn: &Connection, id: &str) -> Result<()> {
|
|
let agent = Self::get_by_id(conn, id)?;
|
|
|
|
if agent.is_default {
|
|
return Err(rusqlite::Error::InvalidParameterName(
|
|
"Default agents cannot be deleted".to_string(),
|
|
));
|
|
}
|
|
|
|
let default_agent = Self::get_default_by_role(conn, agent.role.clone()).map_err(|_| {
|
|
rusqlite::Error::InvalidParameterName(format!(
|
|
"No default agent found for role '{}'",
|
|
agent.role.as_str()
|
|
))
|
|
})?;
|
|
|
|
match agent.role {
|
|
AgentRole::Analyst => {
|
|
conn.execute(
|
|
"UPDATE watched_trackers SET analyst_agent_id = ?1 WHERE analyst_agent_id = ?2",
|
|
params![default_agent.id, id],
|
|
)?;
|
|
}
|
|
AgentRole::Developer => {
|
|
conn.execute(
|
|
"UPDATE watched_trackers SET developer_agent_id = ?1 WHERE developer_agent_id = ?2",
|
|
params![default_agent.id, id],
|
|
)?;
|
|
}
|
|
}
|
|
|
|
conn.execute(
|
|
"UPDATE watched_trackers
|
|
SET status = CASE
|
|
WHEN analyst_agent_id IS NULL OR developer_agent_id IS NULL THEN 'invalid'
|
|
ELSE 'valid'
|
|
END",
|
|
[],
|
|
)?;
|
|
|
|
let affected = conn.execute("DELETE FROM agents WHERE id = ?1", params![id])?;
|
|
if affected == 0 {
|
|
return Err(rusqlite::Error::QueryReturnedNoRows);
|
|
}
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::db;
|
|
use crate::models::project::Project;
|
|
use crate::models::tracker::{NewWatchedTracker, WatchedTracker};
|
|
|
|
fn setup() -> Connection {
|
|
db::init_in_memory().expect("db init should succeed")
|
|
}
|
|
|
|
#[test]
|
|
fn test_defaults_exist() {
|
|
let conn = setup();
|
|
|
|
let analyst = Agent::get_default_by_role(&conn, AgentRole::Analyst).unwrap();
|
|
let developer = Agent::get_default_by_role(&conn, AgentRole::Developer).unwrap();
|
|
|
|
assert_eq!(analyst.id, DEFAULT_ANALYST_AGENT_ID);
|
|
assert_eq!(developer.id, DEFAULT_DEVELOPER_AGENT_ID);
|
|
assert!(analyst.is_default);
|
|
assert!(developer.is_default);
|
|
}
|
|
|
|
#[test]
|
|
fn test_insert_and_get_agent() {
|
|
let conn = setup();
|
|
|
|
let created = Agent::insert(
|
|
&conn,
|
|
"Analyst Codex",
|
|
AgentRole::Analyst,
|
|
AgentTool::Codex,
|
|
"Focus on root cause.",
|
|
)
|
|
.expect("insert should succeed");
|
|
|
|
let found = Agent::get_by_id(&conn, &created.id).expect("get_by_id should succeed");
|
|
|
|
assert_eq!(found.name, "Analyst Codex");
|
|
assert_eq!(found.role, AgentRole::Analyst);
|
|
assert_eq!(found.tool, AgentTool::Codex);
|
|
assert_eq!(found.custom_prompt, "Focus on root cause.");
|
|
assert!(!found.is_default);
|
|
}
|
|
|
|
#[test]
|
|
fn test_non_interactive_args_match_cli_expectations() {
|
|
assert_eq!(
|
|
AgentTool::Codex.to_non_interactive_args(),
|
|
vec![
|
|
"exec".to_string(),
|
|
"--ephemeral".to_string(),
|
|
"-c".to_string(),
|
|
"mcp_servers.tuleap.enabled=false".to_string(),
|
|
"-".to_string()
|
|
]
|
|
);
|
|
assert_eq!(
|
|
AgentTool::ClaudeCode.to_non_interactive_args(),
|
|
vec!["-p".to_string()]
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_update_default_agent_allows_tool_and_prompt_only() {
|
|
let conn = setup();
|
|
let analyst = Agent::get_default_by_role(&conn, AgentRole::Analyst).unwrap();
|
|
|
|
let err = Agent::update(
|
|
&conn,
|
|
&analyst.id,
|
|
"Renamed",
|
|
AgentRole::Developer,
|
|
AgentTool::ClaudeCode,
|
|
"new script",
|
|
)
|
|
.unwrap_err();
|
|
assert!(err
|
|
.to_string()
|
|
.contains("Default agents cannot change name or role"));
|
|
|
|
Agent::update(
|
|
&conn,
|
|
&analyst.id,
|
|
&analyst.name,
|
|
analyst.role.clone(),
|
|
AgentTool::ClaudeCode,
|
|
"Prompt override",
|
|
)
|
|
.unwrap();
|
|
|
|
let updated = Agent::get_by_id(&conn, &analyst.id).unwrap();
|
|
assert_eq!(updated.tool, AgentTool::ClaudeCode);
|
|
assert_eq!(updated.custom_prompt, "Prompt override");
|
|
}
|
|
|
|
#[test]
|
|
fn test_delete_default_agent_is_rejected() {
|
|
let conn = setup();
|
|
let analyst = Agent::get_default_by_role(&conn, AgentRole::Analyst).unwrap();
|
|
|
|
let err = Agent::delete(&conn, &analyst.id).unwrap_err();
|
|
assert!(err.to_string().contains("Default agents cannot be deleted"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_delete_agent_reassigns_trackers_to_default() {
|
|
let conn = setup();
|
|
let project = Project::insert(&conn, "P", "/tmp/p", None, "main").unwrap();
|
|
|
|
let analyst_default = Agent::get_default_by_role(&conn, AgentRole::Analyst).unwrap();
|
|
let developer_default = Agent::get_default_by_role(&conn, AgentRole::Developer).unwrap();
|
|
|
|
let analyst =
|
|
Agent::insert(&conn, "Analyst", AgentRole::Analyst, AgentTool::Codex, "").unwrap();
|
|
|
|
let tracker = WatchedTracker::insert(
|
|
&conn,
|
|
NewWatchedTracker {
|
|
project_id: project.id.clone(),
|
|
tracker_id: 100,
|
|
tracker_label: "Bugs".to_string(),
|
|
polling_interval: 10,
|
|
analyst_agent_id: analyst.id.clone(),
|
|
developer_agent_id: developer_default.id.clone(),
|
|
filters: vec![],
|
|
},
|
|
)
|
|
.unwrap();
|
|
|
|
Agent::delete(&conn, &analyst.id).unwrap();
|
|
|
|
let reloaded = WatchedTracker::get_by_id(&conn, &tracker.id).unwrap();
|
|
assert_eq!(reloaded.status, "valid");
|
|
assert_eq!(
|
|
reloaded.analyst_agent_id.as_deref(),
|
|
Some(analyst_default.id.as_str())
|
|
);
|
|
assert_eq!(
|
|
reloaded.developer_agent_id.as_deref(),
|
|
Some(developer_default.id.as_str())
|
|
);
|
|
}
|
|
}
|