task_mcp/
lib.rs

1//! MCP server for managing background tasks.
2//!
3//! Provides tools to start, stop, list, and inspect background processes.
4
5#![forbid(unsafe_code)]
6
7mod error;
8mod process;
9mod state;
10
11use rmcp::handler::server::router::tool::ToolRouter;
12use rmcp::handler::server::tool::ToolCallContext;
13use rmcp::handler::server::wrapper::Parameters;
14use rmcp::model::{
15    CallToolRequestParam, CallToolResult, Content, ListToolsResult, PaginatedRequestParam,
16    ServerCapabilities, ServerInfo,
17};
18use rmcp::schemars::JsonSchema;
19use rmcp::service::{RequestContext, RoleServer};
20use rmcp::{tool, tool_router, ErrorData as McpError, ServerHandler};
21use serde::{Deserialize, Serialize};
22
23pub use error::{Error, Result};
24pub use state::{TaskInfo, TaskManager};
25
26// === Tool argument schemas ===
27
28/// Arguments for the `task_ensure` tool.
29#[derive(Debug, Deserialize, JsonSchema)]
30pub struct TaskEnsureArgs {
31    /// Unique name for the task.
32    #[schemars(description = "Unique name for the task")]
33    pub name: String,
34    /// Shell command to execute.
35    #[schemars(description = "Shell command to execute")]
36    pub command: String,
37    /// Working directory (optional).
38    #[schemars(description = "Working directory (optional)")]
39    pub cwd: Option<String>,
40}
41
42/// Arguments for the `task_stop` tool.
43#[derive(Debug, Deserialize, JsonSchema)]
44pub struct TaskStopArgs {
45    /// Name of the task to stop.
46    #[schemars(description = "Name of the task to stop")]
47    pub name: String,
48}
49
50/// Arguments for the `task_logs` tool.
51#[derive(Debug, Deserialize, JsonSchema)]
52pub struct TaskLogsArgs {
53    /// Name of the task.
54    #[schemars(description = "Name of the task")]
55    pub name: String,
56    /// Number of lines to return (default: 50).
57    #[schemars(description = "Number of lines to return (default: 50)")]
58    pub tail: Option<usize>,
59}
60
61// === Tool response schemas ===
62
63/// Status information for a task.
64#[derive(Debug, Serialize, JsonSchema)]
65pub struct TaskStatus {
66    /// Task name.
67    pub name: String,
68    /// Process ID.
69    pub pid: u32,
70    /// The command being run.
71    pub command: String,
72    /// Working directory.
73    pub cwd: Option<String>,
74    /// Whether the process is still alive.
75    pub alive: bool,
76    /// Seconds since the task was started.
77    pub uptime_secs: u64,
78}
79
80/// Result of `task_ensure`.
81#[derive(Debug, Serialize, JsonSchema)]
82pub struct TaskEnsureResult {
83    /// `"started"` or `"already_running"`.
84    pub status: String,
85    /// Task status information.
86    pub task: TaskStatus,
87}
88
89/// Result of `task_stop`.
90#[derive(Debug, Serialize, JsonSchema)]
91pub struct TaskStopResult {
92    /// Always "stopped".
93    pub status: String,
94    /// Name of the stopped task.
95    pub name: String,
96}
97
98/// Result of `task_list`.
99#[derive(Debug, Serialize, JsonSchema)]
100pub struct TaskListResult {
101    /// List of all tracked tasks.
102    pub tasks: Vec<TaskStatus>,
103}
104
105/// Result of `task_logs`.
106#[derive(Debug, Serialize, JsonSchema)]
107pub struct TaskLogsResult {
108    /// Task name.
109    pub name: String,
110    /// Recent stdout output.
111    pub stdout: String,
112    /// Recent stderr output.
113    pub stderr: String,
114}
115
116// === MCP Server ===
117
118/// Convert `TaskInfo` to `TaskStatus` for API responses.
119fn task_to_status(info: &TaskInfo) -> TaskStatus {
120    TaskStatus {
121        name: info.name.clone(),
122        pid: info.pid,
123        command: info.command.clone(),
124        cwd: info.cwd.as_ref().map(|p| p.display().to_string()),
125        alive: process::is_alive(info.pid),
126        uptime_secs: info.started_at.elapsed().as_secs(),
127    }
128}
129
130/// MCP server for managing background tasks.
131#[derive(Clone)]
132pub struct TaskMcpServer {
133    manager: TaskManager,
134    tool_router: ToolRouter<Self>,
135}
136
137impl TaskMcpServer {
138    /// Create a new task MCP server.
139    #[must_use]
140    pub fn new() -> Self {
141        Self {
142            manager: TaskManager::new(),
143            tool_router: Self::tool_router(),
144        }
145    }
146}
147
148impl Default for TaskMcpServer {
149    fn default() -> Self {
150        Self::new()
151    }
152}
153
154#[tool_router]
155impl TaskMcpServer {
156    /// Ensure a background task is running.
157    ///
158    /// Idempotent: succeeds whether the task was started fresh or was already running.
159    /// If the task exists but the process is dead, it will be restarted.
160    #[tool(description = "Ensure a background task is running. Idempotent: succeeds whether task was started fresh or was already running.")]
161    async fn task_ensure(
162        &self,
163        Parameters(args): Parameters<TaskEnsureArgs>,
164    ) -> std::result::Result<CallToolResult, McpError> {
165        let TaskEnsureArgs { name, command, cwd } = args;
166
167        // Check if task already exists and is alive
168        if let Some(existing) = self.manager.get(&name).await {
169            if process::is_alive(existing.pid) {
170                let result = TaskEnsureResult {
171                    status: "already_running".to_string(),
172                    task: task_to_status(&existing),
173                };
174                let json = serde_json::to_string_pretty(&result)
175                    .map_err(|e| McpError::internal_error(e.to_string(), None))?;
176                return Ok(CallToolResult::success(vec![Content::text(json)]));
177            }
178            // Task exists but process is dead - clean up and restart
179            if let Some(old) = self.manager.remove(&name).await {
180                process::cleanup_logs(&old.stdout_path, &old.stderr_path).await;
181            }
182        }
183
184        // Spawn new task
185        let cwd_path = cwd.as_ref().map(std::path::Path::new);
186        let info = process::spawn_task(&name, &command, cwd_path)
187            .map_err(|e| McpError::internal_error(e.to_string(), None))?;
188
189        let status = task_to_status(&info);
190        self.manager.insert(info).await;
191
192        let result = TaskEnsureResult {
193            status: "started".to_string(),
194            task: status,
195        };
196        let json = serde_json::to_string_pretty(&result)
197            .map_err(|e| McpError::internal_error(e.to_string(), None))?;
198        Ok(CallToolResult::success(vec![Content::text(json)]))
199    }
200
201    /// Stop a background task and clean up its log files.
202    #[tool(description = "Stop a background task and clean up its log files.")]
203    async fn task_stop(
204        &self,
205        Parameters(args): Parameters<TaskStopArgs>,
206    ) -> std::result::Result<CallToolResult, McpError> {
207        let TaskStopArgs { name } = args;
208
209        let info = self.manager.remove(&name).await.ok_or_else(|| {
210            McpError::invalid_params(format!("task not found: {name}"), None)
211        })?;
212
213        // Terminate process if alive
214        if process::is_alive(info.pid) {
215            let _ = process::terminate(info.pid);
216        }
217
218        // Clean up log files
219        process::cleanup_logs(&info.stdout_path, &info.stderr_path).await;
220
221        let result = TaskStopResult {
222            status: "stopped".to_string(),
223            name,
224        };
225        let json = serde_json::to_string_pretty(&result)
226            .map_err(|e| McpError::internal_error(e.to_string(), None))?;
227        Ok(CallToolResult::success(vec![Content::text(json)]))
228    }
229
230    /// List all background tasks with their current status.
231    #[tool(description = "List all background tasks with their current status.")]
232    async fn task_list(&self) -> std::result::Result<CallToolResult, McpError> {
233        let tasks = self.manager.list().await;
234        let statuses: Vec<TaskStatus> = tasks.iter().map(task_to_status).collect();
235
236        let result = TaskListResult { tasks: statuses };
237        let json = serde_json::to_string_pretty(&result)
238            .map_err(|e| McpError::internal_error(e.to_string(), None))?;
239        Ok(CallToolResult::success(vec![Content::text(json)]))
240    }
241
242    /// Get the stdout and stderr logs from a background task.
243    #[tool(description = "Get the stdout and stderr logs from a background task.")]
244    async fn task_logs(
245        &self,
246        Parameters(args): Parameters<TaskLogsArgs>,
247    ) -> std::result::Result<CallToolResult, McpError> {
248        let TaskLogsArgs { name, tail } = args;
249
250        let info = self.manager.get(&name).await.ok_or_else(|| {
251            McpError::invalid_params(format!("task not found: {name}"), None)
252        })?;
253
254        let tail = tail.unwrap_or(50);
255
256        let stdout = process::read_log_tail(&info.stdout_path, tail)
257            .await
258            .unwrap_or_default();
259        let stderr = process::read_log_tail(&info.stderr_path, tail)
260            .await
261            .unwrap_or_default();
262
263        let result = TaskLogsResult {
264            name,
265            stdout,
266            stderr,
267        };
268        let json = serde_json::to_string_pretty(&result)
269            .map_err(|e| McpError::internal_error(e.to_string(), None))?;
270        Ok(CallToolResult::success(vec![Content::text(json)]))
271    }
272}
273
274impl ServerHandler for TaskMcpServer {
275    fn get_info(&self) -> ServerInfo {
276        ServerInfo {
277            capabilities: ServerCapabilities::builder().enable_tools().build(),
278            instructions: Some(
279                "Background task manager. Use task_ensure to start tasks, \
280                 task_stop to terminate them, task_list to see all tasks, \
281                 and task_logs to view output."
282                    .to_string(),
283            ),
284            ..Default::default()
285        }
286    }
287
288    fn call_tool(
289        &self,
290        request: CallToolRequestParam,
291        context: RequestContext<RoleServer>,
292    ) -> impl std::future::Future<Output = std::result::Result<CallToolResult, McpError>> + Send + '_
293    {
294        let tool_context = ToolCallContext::new(self, request, context);
295        async move { self.tool_router.call(tool_context).await }
296    }
297
298    fn list_tools(
299        &self,
300        _request: Option<PaginatedRequestParam>,
301        _context: RequestContext<RoleServer>,
302    ) -> impl std::future::Future<Output = std::result::Result<ListToolsResult, McpError>> + Send + '_
303    {
304        std::future::ready(Ok(ListToolsResult {
305            tools: self.tool_router.list_all(),
306            next_cursor: None,
307            meta: None,
308        }))
309    }
310}