1#![forbid(unsafe_code)]
6
7mod error;
8mod process;
9mod state;
10
11use rmcp::handler::server::router::tool::ToolRouter;
12use rmcp::handler::server::tool::ToolCallContext;
13use rmcp::handler::server::wrapper::Parameters;
14use rmcp::model::{
15 CallToolRequestParam, CallToolResult, Content, ListToolsResult, PaginatedRequestParam,
16 ServerCapabilities, ServerInfo,
17};
18use rmcp::schemars::JsonSchema;
19use rmcp::service::{RequestContext, RoleServer};
20use rmcp::{tool, tool_router, ErrorData as McpError, ServerHandler};
21use serde::{Deserialize, Serialize};
22
23pub use error::{Error, Result};
24pub use state::{TaskInfo, TaskManager};
25
26#[derive(Debug, Deserialize, JsonSchema)]
30pub struct TaskEnsureArgs {
31 #[schemars(description = "Unique name for the task")]
33 pub name: String,
34 #[schemars(description = "Shell command to execute")]
36 pub command: String,
37 #[schemars(description = "Working directory (optional)")]
39 pub cwd: Option<String>,
40}
41
42#[derive(Debug, Deserialize, JsonSchema)]
44pub struct TaskStopArgs {
45 #[schemars(description = "Name of the task to stop")]
47 pub name: String,
48}
49
50#[derive(Debug, Deserialize, JsonSchema)]
52pub struct TaskLogsArgs {
53 #[schemars(description = "Name of the task")]
55 pub name: String,
56 #[schemars(description = "Number of lines to return (default: 50)")]
58 pub tail: Option<usize>,
59}
60
61#[derive(Debug, Serialize, JsonSchema)]
65pub struct TaskStatus {
66 pub name: String,
68 pub pid: u32,
70 pub command: String,
72 pub cwd: Option<String>,
74 pub alive: bool,
76 pub uptime_secs: u64,
78}
79
80#[derive(Debug, Serialize, JsonSchema)]
82pub struct TaskEnsureResult {
83 pub status: String,
85 pub task: TaskStatus,
87}
88
89#[derive(Debug, Serialize, JsonSchema)]
91pub struct TaskStopResult {
92 pub status: String,
94 pub name: String,
96}
97
98#[derive(Debug, Serialize, JsonSchema)]
100pub struct TaskListResult {
101 pub tasks: Vec<TaskStatus>,
103}
104
105#[derive(Debug, Serialize, JsonSchema)]
107pub struct TaskLogsResult {
108 pub name: String,
110 pub stdout: String,
112 pub stderr: String,
114}
115
116fn task_to_status(info: &TaskInfo) -> TaskStatus {
120 TaskStatus {
121 name: info.name.clone(),
122 pid: info.pid,
123 command: info.command.clone(),
124 cwd: info.cwd.as_ref().map(|p| p.display().to_string()),
125 alive: process::is_alive(info.pid),
126 uptime_secs: info.started_at.elapsed().as_secs(),
127 }
128}
129
130#[derive(Clone)]
132pub struct TaskMcpServer {
133 manager: TaskManager,
134 tool_router: ToolRouter<Self>,
135}
136
137impl TaskMcpServer {
138 #[must_use]
140 pub fn new() -> Self {
141 Self {
142 manager: TaskManager::new(),
143 tool_router: Self::tool_router(),
144 }
145 }
146}
147
148impl Default for TaskMcpServer {
149 fn default() -> Self {
150 Self::new()
151 }
152}
153
154#[tool_router]
155impl TaskMcpServer {
156 #[tool(description = "Ensure a background task is running. Idempotent: succeeds whether task was started fresh or was already running.")]
161 async fn task_ensure(
162 &self,
163 Parameters(args): Parameters<TaskEnsureArgs>,
164 ) -> std::result::Result<CallToolResult, McpError> {
165 let TaskEnsureArgs { name, command, cwd } = args;
166
167 if let Some(existing) = self.manager.get(&name).await {
169 if process::is_alive(existing.pid) {
170 let result = TaskEnsureResult {
171 status: "already_running".to_string(),
172 task: task_to_status(&existing),
173 };
174 let json = serde_json::to_string_pretty(&result)
175 .map_err(|e| McpError::internal_error(e.to_string(), None))?;
176 return Ok(CallToolResult::success(vec![Content::text(json)]));
177 }
178 if let Some(old) = self.manager.remove(&name).await {
180 process::cleanup_logs(&old.stdout_path, &old.stderr_path).await;
181 }
182 }
183
184 let cwd_path = cwd.as_ref().map(std::path::Path::new);
186 let info = process::spawn_task(&name, &command, cwd_path)
187 .map_err(|e| McpError::internal_error(e.to_string(), None))?;
188
189 let status = task_to_status(&info);
190 self.manager.insert(info).await;
191
192 let result = TaskEnsureResult {
193 status: "started".to_string(),
194 task: status,
195 };
196 let json = serde_json::to_string_pretty(&result)
197 .map_err(|e| McpError::internal_error(e.to_string(), None))?;
198 Ok(CallToolResult::success(vec![Content::text(json)]))
199 }
200
201 #[tool(description = "Stop a background task and clean up its log files.")]
203 async fn task_stop(
204 &self,
205 Parameters(args): Parameters<TaskStopArgs>,
206 ) -> std::result::Result<CallToolResult, McpError> {
207 let TaskStopArgs { name } = args;
208
209 let info = self.manager.remove(&name).await.ok_or_else(|| {
210 McpError::invalid_params(format!("task not found: {name}"), None)
211 })?;
212
213 if process::is_alive(info.pid) {
215 let _ = process::terminate(info.pid);
216 }
217
218 process::cleanup_logs(&info.stdout_path, &info.stderr_path).await;
220
221 let result = TaskStopResult {
222 status: "stopped".to_string(),
223 name,
224 };
225 let json = serde_json::to_string_pretty(&result)
226 .map_err(|e| McpError::internal_error(e.to_string(), None))?;
227 Ok(CallToolResult::success(vec![Content::text(json)]))
228 }
229
230 #[tool(description = "List all background tasks with their current status.")]
232 async fn task_list(&self) -> std::result::Result<CallToolResult, McpError> {
233 let tasks = self.manager.list().await;
234 let statuses: Vec<TaskStatus> = tasks.iter().map(task_to_status).collect();
235
236 let result = TaskListResult { tasks: statuses };
237 let json = serde_json::to_string_pretty(&result)
238 .map_err(|e| McpError::internal_error(e.to_string(), None))?;
239 Ok(CallToolResult::success(vec![Content::text(json)]))
240 }
241
242 #[tool(description = "Get the stdout and stderr logs from a background task.")]
244 async fn task_logs(
245 &self,
246 Parameters(args): Parameters<TaskLogsArgs>,
247 ) -> std::result::Result<CallToolResult, McpError> {
248 let TaskLogsArgs { name, tail } = args;
249
250 let info = self.manager.get(&name).await.ok_or_else(|| {
251 McpError::invalid_params(format!("task not found: {name}"), None)
252 })?;
253
254 let tail = tail.unwrap_or(50);
255
256 let stdout = process::read_log_tail(&info.stdout_path, tail)
257 .await
258 .unwrap_or_default();
259 let stderr = process::read_log_tail(&info.stderr_path, tail)
260 .await
261 .unwrap_or_default();
262
263 let result = TaskLogsResult {
264 name,
265 stdout,
266 stderr,
267 };
268 let json = serde_json::to_string_pretty(&result)
269 .map_err(|e| McpError::internal_error(e.to_string(), None))?;
270 Ok(CallToolResult::success(vec![Content::text(json)]))
271 }
272}
273
274impl ServerHandler for TaskMcpServer {
275 fn get_info(&self) -> ServerInfo {
276 ServerInfo {
277 capabilities: ServerCapabilities::builder().enable_tools().build(),
278 instructions: Some(
279 "Background task manager. Use task_ensure to start tasks, \
280 task_stop to terminate them, task_list to see all tasks, \
281 and task_logs to view output."
282 .to_string(),
283 ),
284 ..Default::default()
285 }
286 }
287
288 fn call_tool(
289 &self,
290 request: CallToolRequestParam,
291 context: RequestContext<RoleServer>,
292 ) -> impl std::future::Future<Output = std::result::Result<CallToolResult, McpError>> + Send + '_
293 {
294 let tool_context = ToolCallContext::new(self, request, context);
295 async move { self.tool_router.call(tool_context).await }
296 }
297
298 fn list_tools(
299 &self,
300 _request: Option<PaginatedRequestParam>,
301 _context: RequestContext<RoleServer>,
302 ) -> impl std::future::Future<Output = std::result::Result<ListToolsResult, McpError>> + Send + '_
303 {
304 std::future::ready(Ok(ListToolsResult {
305 tools: self.tool_router.list_all(),
306 next_cursor: None,
307 meta: None,
308 }))
309 }
310}