task_mcp/
lib.rs

1#![doc = include_str!("../README.md")]
2#![forbid(unsafe_code)]
3
4mod error;
5mod process;
6mod state;
7
8use rmcp::handler::server::router::tool::ToolRouter;
9use rmcp::handler::server::tool::ToolCallContext;
10use rmcp::handler::server::wrapper::Parameters;
11use rmcp::model::{
12    CallToolRequestParam, CallToolResult, Content, ListToolsResult, PaginatedRequestParam,
13    ServerCapabilities, ServerInfo,
14};
15use rmcp::schemars::JsonSchema;
16use rmcp::service::{RequestContext, RoleServer};
17use rmcp::{ErrorData as McpError, ServerHandler, tool, tool_router};
18use serde::{Deserialize, Serialize};
19
20pub use error::{Error, Result};
21pub use state::{TaskInfo, TaskManager};
22
23// === Tool argument schemas ===
24
25/// Arguments for the `task_ensure` tool.
26#[derive(Debug, Deserialize, JsonSchema)]
27pub struct TaskEnsureArgs {
28    /// Unique name for the task.
29    #[schemars(description = "Unique name for the task")]
30    pub name: String,
31    /// Shell command to execute.
32    #[schemars(description = "Shell command to execute")]
33    pub command: String,
34    /// Working directory (optional).
35    #[schemars(description = "Working directory (optional)")]
36    pub cwd: Option<String>,
37}
38
39/// Arguments for the `task_stop` tool.
40#[derive(Debug, Deserialize, JsonSchema)]
41pub struct TaskStopArgs {
42    /// Name of the task to stop.
43    #[schemars(description = "Name of the task to stop")]
44    pub name: String,
45}
46
47/// Arguments for the `task_logs` tool.
48#[derive(Debug, Deserialize, JsonSchema)]
49pub struct TaskLogsArgs {
50    /// Name of the task.
51    #[schemars(description = "Name of the task")]
52    pub name: String,
53    /// Number of lines to return (default: 50).
54    #[schemars(description = "Number of lines to return (default: 50)")]
55    pub tail: Option<usize>,
56}
57
58// === Tool response schemas ===
59
60/// Status information for a task.
61#[derive(Debug, Serialize, JsonSchema)]
62pub struct TaskStatus {
63    /// Task name.
64    pub name: String,
65    /// Process ID.
66    pub pid: u32,
67    /// The command being run.
68    pub command: String,
69    /// Working directory.
70    pub cwd: Option<String>,
71    /// Whether the process is still alive.
72    pub alive: bool,
73    /// Seconds since the task was started.
74    pub uptime_secs: u64,
75}
76
77/// Result of `task_ensure`.
78#[derive(Debug, Serialize, JsonSchema)]
79pub struct TaskEnsureResult {
80    /// `"started"` or `"already_running"`.
81    pub status: String,
82    /// Task status information.
83    pub task: TaskStatus,
84}
85
86/// Result of `task_stop`.
87#[derive(Debug, Serialize, JsonSchema)]
88pub struct TaskStopResult {
89    /// Always "stopped".
90    pub status: String,
91    /// Name of the stopped task.
92    pub name: String,
93}
94
95/// Result of `task_list`.
96#[derive(Debug, Serialize, JsonSchema)]
97pub struct TaskListResult {
98    /// List of all tracked tasks.
99    pub tasks: Vec<TaskStatus>,
100}
101
102/// Result of `task_logs`.
103#[derive(Debug, Serialize, JsonSchema)]
104pub struct TaskLogsResult {
105    /// Task name.
106    pub name: String,
107    /// Recent stdout output.
108    pub stdout: String,
109    /// Recent stderr output.
110    pub stderr: String,
111}
112
113// === MCP Server ===
114
115/// Convert `TaskInfo` to `TaskStatus` for API responses.
116fn task_to_status(info: &TaskInfo) -> TaskStatus {
117    TaskStatus {
118        name: info.name.clone(),
119        pid: info.pid,
120        command: info.command.clone(),
121        cwd: info.cwd.as_ref().map(|p| p.display().to_string()),
122        alive: process::is_alive(info.pid),
123        uptime_secs: info.started_at.elapsed().as_secs(),
124    }
125}
126
127/// MCP server for managing background tasks.
128#[derive(Clone)]
129pub struct TaskMcpServer {
130    manager: TaskManager,
131    tool_router: ToolRouter<Self>,
132}
133
134impl TaskMcpServer {
135    /// Create a new task MCP server.
136    #[must_use]
137    pub fn new() -> Self {
138        Self {
139            manager: TaskManager::new(),
140            tool_router: Self::tool_router(),
141        }
142    }
143}
144
145impl Default for TaskMcpServer {
146    fn default() -> Self {
147        Self::new()
148    }
149}
150
151#[tool_router]
152impl TaskMcpServer {
153    /// Ensure a background task is running.
154    ///
155    /// Idempotent: succeeds whether the task was started fresh or was already running.
156    /// If the task exists but the process is dead, it will be restarted.
157    #[tool(
158        description = "Ensure a background task is running. Idempotent: succeeds whether task was started fresh or was already running."
159    )]
160    async fn task_ensure(
161        &self,
162        Parameters(args): Parameters<TaskEnsureArgs>,
163    ) -> std::result::Result<CallToolResult, McpError> {
164        let TaskEnsureArgs { name, command, cwd } = args;
165
166        // Check if task already exists and is alive
167        if let Some(existing) = self.manager.get(&name).await {
168            if process::is_alive(existing.pid) {
169                let result = TaskEnsureResult {
170                    status: "already_running".to_string(),
171                    task: task_to_status(&existing),
172                };
173                let json = serde_json::to_string_pretty(&result)
174                    .map_err(|e| McpError::internal_error(e.to_string(), None))?;
175                return Ok(CallToolResult::success(vec![Content::text(json)]));
176            }
177            // Task exists but process is dead - clean up and restart
178            if let Some(old) = self.manager.remove(&name).await {
179                process::cleanup_logs(&old.stdout_path, &old.stderr_path).await;
180            }
181        }
182
183        // Spawn new task
184        let cwd_path = cwd.as_ref().map(std::path::Path::new);
185        let info = process::spawn_task(&name, &command, cwd_path)
186            .map_err(|e| McpError::internal_error(e.to_string(), None))?;
187
188        let status = task_to_status(&info);
189        self.manager.insert(info).await;
190
191        let result = TaskEnsureResult {
192            status: "started".to_string(),
193            task: status,
194        };
195        let json = serde_json::to_string_pretty(&result)
196            .map_err(|e| McpError::internal_error(e.to_string(), None))?;
197        Ok(CallToolResult::success(vec![Content::text(json)]))
198    }
199
200    /// Stop a background task and clean up its log files.
201    #[tool(description = "Stop a background task and clean up its log files.")]
202    async fn task_stop(
203        &self,
204        Parameters(args): Parameters<TaskStopArgs>,
205    ) -> std::result::Result<CallToolResult, McpError> {
206        let TaskStopArgs { name } = args;
207
208        let info = self
209            .manager
210            .remove(&name)
211            .await
212            .ok_or_else(|| McpError::invalid_params(format!("task not found: {name}"), None))?;
213
214        // Terminate process if alive
215        if process::is_alive(info.pid) {
216            let _ = process::terminate(info.pid);
217        }
218
219        // Clean up log files
220        process::cleanup_logs(&info.stdout_path, &info.stderr_path).await;
221
222        let result = TaskStopResult {
223            status: "stopped".to_string(),
224            name,
225        };
226        let json = serde_json::to_string_pretty(&result)
227            .map_err(|e| McpError::internal_error(e.to_string(), None))?;
228        Ok(CallToolResult::success(vec![Content::text(json)]))
229    }
230
231    /// List all background tasks with their current status.
232    #[tool(description = "List all background tasks with their current status.")]
233    async fn task_list(&self) -> std::result::Result<CallToolResult, McpError> {
234        let tasks = self.manager.list().await;
235        let statuses: Vec<TaskStatus> = tasks.iter().map(task_to_status).collect();
236
237        let result = TaskListResult { tasks: statuses };
238        let json = serde_json::to_string_pretty(&result)
239            .map_err(|e| McpError::internal_error(e.to_string(), None))?;
240        Ok(CallToolResult::success(vec![Content::text(json)]))
241    }
242
243    /// Get the stdout and stderr logs from a background task.
244    #[tool(description = "Get the stdout and stderr logs from a background task.")]
245    async fn task_logs(
246        &self,
247        Parameters(args): Parameters<TaskLogsArgs>,
248    ) -> std::result::Result<CallToolResult, McpError> {
249        let TaskLogsArgs { name, tail } = args;
250
251        let info = self
252            .manager
253            .get(&name)
254            .await
255            .ok_or_else(|| McpError::invalid_params(format!("task not found: {name}"), None))?;
256
257        let tail = tail.unwrap_or(50);
258
259        let stdout = process::read_log_tail(&info.stdout_path, tail)
260            .await
261            .unwrap_or_default();
262        let stderr = process::read_log_tail(&info.stderr_path, tail)
263            .await
264            .unwrap_or_default();
265
266        let result = TaskLogsResult {
267            name,
268            stdout,
269            stderr,
270        };
271        let json = serde_json::to_string_pretty(&result)
272            .map_err(|e| McpError::internal_error(e.to_string(), None))?;
273        Ok(CallToolResult::success(vec![Content::text(json)]))
274    }
275}
276
277impl ServerHandler for TaskMcpServer {
278    fn get_info(&self) -> ServerInfo {
279        ServerInfo {
280            capabilities: ServerCapabilities::builder().enable_tools().build(),
281            instructions: Some(
282                "Background task manager. Use task_ensure to start tasks, \
283                 task_stop to terminate them, task_list to see all tasks, \
284                 and task_logs to view output."
285                    .to_string(),
286            ),
287            ..Default::default()
288        }
289    }
290
291    fn call_tool(
292        &self,
293        request: CallToolRequestParam,
294        context: RequestContext<RoleServer>,
295    ) -> impl std::future::Future<Output = std::result::Result<CallToolResult, McpError>> + Send + '_
296    {
297        let tool_context = ToolCallContext::new(self, request, context);
298        async move { self.tool_router.call(tool_context).await }
299    }
300
301    fn list_tools(
302        &self,
303        _request: Option<PaginatedRequestParam>,
304        _context: RequestContext<RoleServer>,
305    ) -> impl std::future::Future<Output = std::result::Result<ListToolsResult, McpError>> + Send + '_
306    {
307        std::future::ready(Ok(ListToolsResult {
308            tools: self.tool_router.list_all(),
309            next_cursor: None,
310            meta: None,
311        }))
312    }
313}