orchai/src-tauri/src/models/agent.rs

402 lines
12 KiB
Rust

use rusqlite::{params, Connection, Result};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
pub const DEFAULT_ANALYST_AGENT_ID: &str = "default-analyst-agent";
pub const DEFAULT_DEVELOPER_AGENT_ID: &str = "default-developer-agent";
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum AgentRole {
Analyst,
Developer,
}
impl AgentRole {
pub fn as_str(&self) -> &'static str {
match self {
AgentRole::Analyst => "analyst",
AgentRole::Developer => "developer",
}
}
pub fn from_str(value: &str) -> Result<Self> {
match value {
"analyst" => Ok(AgentRole::Analyst),
"developer" => Ok(AgentRole::Developer),
_ => Err(rusqlite::Error::InvalidParameterName(format!(
"Invalid agent role: {}",
value
))),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum AgentTool {
Codex,
ClaudeCode,
}
impl AgentTool {
pub fn as_str(&self) -> &'static str {
match self {
AgentTool::Codex => "codex",
AgentTool::ClaudeCode => "claude_code",
}
}
pub fn from_str(value: &str) -> Result<Self> {
match value {
"codex" => Ok(AgentTool::Codex),
"claude_code" => Ok(AgentTool::ClaudeCode),
_ => Err(rusqlite::Error::InvalidParameterName(format!(
"Invalid agent tool: {}",
value
))),
}
}
pub fn to_command(&self) -> &'static str {
match self {
AgentTool::Codex => "codex",
AgentTool::ClaudeCode => "claude",
}
}
pub fn to_non_interactive_args(&self) -> Vec<String> {
match self {
AgentTool::Codex => vec![
"exec".to_string(),
"--ephemeral".to_string(),
"-c".to_string(),
"mcp_servers.tuleap.enabled=false".to_string(),
"-".to_string(),
],
AgentTool::ClaudeCode => vec!["-p".to_string()],
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Agent {
pub id: String,
pub name: String,
pub role: AgentRole,
pub tool: AgentTool,
pub custom_prompt: String,
pub is_default: bool,
pub created_at: String,
pub updated_at: String,
}
fn from_row(row: &rusqlite::Row) -> rusqlite::Result<Agent> {
let role_raw: String = row.get(2)?;
let tool_raw: String = row.get(3)?;
let is_default_int: i32 = row.get(5)?;
Ok(Agent {
id: row.get(0)?,
name: row.get(1)?,
role: AgentRole::from_str(&role_raw)?,
tool: AgentTool::from_str(&tool_raw)?,
custom_prompt: row.get(4)?,
is_default: is_default_int != 0,
created_at: row.get(6)?,
updated_at: row.get(7)?,
})
}
impl Agent {
pub fn insert(
conn: &Connection,
name: &str,
role: AgentRole,
tool: AgentTool,
custom_prompt: &str,
) -> Result<Agent> {
let id = Uuid::new_v4().to_string();
let now = chrono::Utc::now().to_rfc3339();
conn.execute(
"INSERT INTO agents (id, name, role, tool, custom_prompt, is_default, created_at, updated_at) VALUES (?1, ?2, ?3, ?4, ?5, 0, ?6, ?7)",
params![id, name, role.as_str(), tool.as_str(), custom_prompt, now, now],
)?;
Ok(Agent {
id,
name: name.to_string(),
role,
tool,
custom_prompt: custom_prompt.to_string(),
is_default: false,
created_at: now.clone(),
updated_at: now,
})
}
pub fn list(conn: &Connection) -> Result<Vec<Agent>> {
let mut stmt = conn.prepare(
"SELECT id, name, role, tool, custom_prompt, is_default, created_at, updated_at
FROM agents
ORDER BY role ASC, is_default DESC, created_at DESC",
)?;
let rows = stmt.query_map([], from_row)?;
rows.collect()
}
pub fn get_by_id(conn: &Connection, id: &str) -> Result<Agent> {
conn.query_row(
"SELECT id, name, role, tool, custom_prompt, is_default, created_at, updated_at FROM agents WHERE id = ?1",
params![id],
from_row,
)
}
pub fn get_default_by_role(conn: &Connection, role: AgentRole) -> Result<Agent> {
let default_id = match role {
AgentRole::Analyst => DEFAULT_ANALYST_AGENT_ID,
AgentRole::Developer => DEFAULT_DEVELOPER_AGENT_ID,
};
conn.query_row(
"SELECT id, name, role, tool, custom_prompt, is_default, created_at, updated_at
FROM agents
WHERE id = ?1 AND role = ?2 AND is_default = 1
LIMIT 1",
params![default_id, role.as_str()],
from_row,
)
}
pub fn update(
conn: &Connection,
id: &str,
name: &str,
role: AgentRole,
tool: AgentTool,
custom_prompt: &str,
) -> Result<()> {
let existing = Self::get_by_id(conn, id)?;
let now = chrono::Utc::now().to_rfc3339();
if existing.is_default {
if existing.name != name || existing.role != role {
return Err(rusqlite::Error::InvalidParameterName(
"Default agents cannot change name or role".to_string(),
));
}
conn.execute(
"UPDATE agents SET tool = ?1, custom_prompt = ?2, updated_at = ?3 WHERE id = ?4",
params![tool.as_str(), custom_prompt, now, id],
)?;
return Ok(());
}
let affected = conn.execute(
"UPDATE agents SET name = ?1, role = ?2, tool = ?3, custom_prompt = ?4, updated_at = ?5 WHERE id = ?6",
params![name, role.as_str(), tool.as_str(), custom_prompt, now, id],
)?;
if affected == 0 {
return Err(rusqlite::Error::QueryReturnedNoRows);
}
Ok(())
}
pub fn delete(conn: &Connection, id: &str) -> Result<()> {
let agent = Self::get_by_id(conn, id)?;
if agent.is_default {
return Err(rusqlite::Error::InvalidParameterName(
"Default agents cannot be deleted".to_string(),
));
}
let default_agent = Self::get_default_by_role(conn, agent.role.clone()).map_err(|_| {
rusqlite::Error::InvalidParameterName(format!(
"No default agent found for role '{}'",
agent.role.as_str()
))
})?;
match agent.role {
AgentRole::Analyst => {
conn.execute(
"UPDATE watched_trackers SET analyst_agent_id = ?1 WHERE analyst_agent_id = ?2",
params![default_agent.id, id],
)?;
}
AgentRole::Developer => {
conn.execute(
"UPDATE watched_trackers SET developer_agent_id = ?1 WHERE developer_agent_id = ?2",
params![default_agent.id, id],
)?;
}
}
conn.execute(
"UPDATE watched_trackers
SET status = CASE
WHEN analyst_agent_id IS NULL OR developer_agent_id IS NULL THEN 'invalid'
ELSE 'valid'
END",
[],
)?;
let affected = conn.execute("DELETE FROM agents WHERE id = ?1", params![id])?;
if affected == 0 {
return Err(rusqlite::Error::QueryReturnedNoRows);
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::db;
use crate::models::project::Project;
use crate::models::tracker::{NewWatchedTracker, WatchedTracker};
fn setup() -> Connection {
db::init_in_memory().expect("db init should succeed")
}
#[test]
fn test_defaults_exist() {
let conn = setup();
let analyst = Agent::get_default_by_role(&conn, AgentRole::Analyst).unwrap();
let developer = Agent::get_default_by_role(&conn, AgentRole::Developer).unwrap();
assert_eq!(analyst.id, DEFAULT_ANALYST_AGENT_ID);
assert_eq!(developer.id, DEFAULT_DEVELOPER_AGENT_ID);
assert!(analyst.is_default);
assert!(developer.is_default);
}
#[test]
fn test_insert_and_get_agent() {
let conn = setup();
let created = Agent::insert(
&conn,
"Analyst Codex",
AgentRole::Analyst,
AgentTool::Codex,
"Focus on root cause.",
)
.expect("insert should succeed");
let found = Agent::get_by_id(&conn, &created.id).expect("get_by_id should succeed");
assert_eq!(found.name, "Analyst Codex");
assert_eq!(found.role, AgentRole::Analyst);
assert_eq!(found.tool, AgentTool::Codex);
assert_eq!(found.custom_prompt, "Focus on root cause.");
assert!(!found.is_default);
}
#[test]
fn test_non_interactive_args_match_cli_expectations() {
assert_eq!(
AgentTool::Codex.to_non_interactive_args(),
vec![
"exec".to_string(),
"--ephemeral".to_string(),
"-c".to_string(),
"mcp_servers.tuleap.enabled=false".to_string(),
"-".to_string()
]
);
assert_eq!(
AgentTool::ClaudeCode.to_non_interactive_args(),
vec!["-p".to_string()]
);
}
#[test]
fn test_update_default_agent_allows_tool_and_prompt_only() {
let conn = setup();
let analyst = Agent::get_default_by_role(&conn, AgentRole::Analyst).unwrap();
let err = Agent::update(
&conn,
&analyst.id,
"Renamed",
AgentRole::Developer,
AgentTool::ClaudeCode,
"new script",
)
.unwrap_err();
assert!(err
.to_string()
.contains("Default agents cannot change name or role"));
Agent::update(
&conn,
&analyst.id,
&analyst.name,
analyst.role.clone(),
AgentTool::ClaudeCode,
"Prompt override",
)
.unwrap();
let updated = Agent::get_by_id(&conn, &analyst.id).unwrap();
assert_eq!(updated.tool, AgentTool::ClaudeCode);
assert_eq!(updated.custom_prompt, "Prompt override");
}
#[test]
fn test_delete_default_agent_is_rejected() {
let conn = setup();
let analyst = Agent::get_default_by_role(&conn, AgentRole::Analyst).unwrap();
let err = Agent::delete(&conn, &analyst.id).unwrap_err();
assert!(err.to_string().contains("Default agents cannot be deleted"));
}
#[test]
fn test_delete_agent_reassigns_trackers_to_default() {
let conn = setup();
let project = Project::insert(&conn, "P", "/tmp/p", None, "main").unwrap();
let analyst_default = Agent::get_default_by_role(&conn, AgentRole::Analyst).unwrap();
let developer_default = Agent::get_default_by_role(&conn, AgentRole::Developer).unwrap();
let analyst =
Agent::insert(&conn, "Analyst", AgentRole::Analyst, AgentTool::Codex, "").unwrap();
let tracker = WatchedTracker::insert(
&conn,
NewWatchedTracker {
project_id: project.id.clone(),
tracker_id: 100,
tracker_label: "Bugs".to_string(),
polling_interval: 10,
analyst_agent_id: analyst.id.clone(),
developer_agent_id: developer_default.id.clone(),
filters: vec![],
},
)
.unwrap();
Agent::delete(&conn, &analyst.id).unwrap();
let reloaded = WatchedTracker::get_by_id(&conn, &tracker.id).unwrap();
assert_eq!(reloaded.status, "valid");
assert_eq!(
reloaded.analyst_agent_id.as_deref(),
Some(analyst_default.id.as_str())
);
assert_eq!(
reloaded.developer_agent_id.as_deref(),
Some(developer_default.id.as_str())
);
}
}