feat(mcp): support prompts (#1155)

* feat(mcp): support prompts

* fmt

* fix

* Revert "fix"

This reverts commit d370e9be86.

* fix get

* not allow array

* add test case

* fix clippy

* fix connect with rmcp client

* fix
This commit is contained in:
Yiyu Lin
2026-01-21 10:14:00 +08:00
committed by GitHub
parent 54c42e3274
commit d2c6f986fa
18 changed files with 1464 additions and 120 deletions

View File

@@ -0,0 +1,20 @@
[package]
name = "prompts-streamable-http"
version = "0.1.0"
edition = "2021"
default-run = "prompts-streamable-http"
[[bin]]
name = "prompts-streamable-http-client"
path = "src/client.rs"
[dependencies]
anyhow = "1"
poem-mcpserver = { workspace = true, features = ["streamable-http"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = { workspace = true }
schemars = "1.0"
poem = { workspace = true, features = ["sse"] }
rmcp = { version = "0.13", features = ["client", "transport-streamable-http-client-reqwest"] }
tokio = { workspace = true, features = ["macros", "rt-multi-thread", "sync"] }
tracing-subscriber.workspace = true

View File

@@ -0,0 +1,27 @@
# prompts-streamable-http
This example runs a streamable HTTP MCP server with tools and prompts, plus a small rmcp client that connects to it.
## Run the server
1. From the repo root, in one terminal:
cargo run --manifest-path ./examples/mcpserver/prompts-streamable-http/Cargo.toml
Or from the example directory:
cargo run
The server listens on http://127.0.0.1:8000/.
## Run the client (rmcp)
2. From the repo root, in another terminal:
cargo run --manifest-path ./examples/mcpserver/prompts-streamable-http/Cargo.toml --bin prompts-streamable-http-client
Or from the example directory:
cargo run --bin prompts-streamable-http-client
The client lists tools and invokes the get_review_count tool.
## Tip
Because this example is not a workspace member, run commands from this directory or use --manifest-path from the repo root.

View File

@@ -0,0 +1,27 @@
use anyhow::Result;
use rmcp::{
model::CallToolRequestParam, service::ServiceExt, transport::StreamableHttpClientTransport,
};
#[tokio::main]
async fn main() -> Result<()> {
let transport = StreamableHttpClientTransport::from_uri("http://127.0.0.1:8000/");
let service = ().serve(transport).await?;
let tools = service.list_tools(Default::default()).await?;
println!("Available tools: {tools:#?}");
// Sending empty arguments object for get_review_count tool to satisfy schema
let response = service
.call_tool(CallToolRequestParam {
name: "get_review_count".to_string().into(),
arguments: Some(serde_json::Map::new()),
task: None,
})
.await?;
println!("get_review_count response: {response:#?}");
service.cancel().await?;
Ok(())
}

View File

@@ -0,0 +1,183 @@
use poem::{listener::TcpListener, middleware::Cors, EndpointExt, Route, Server};
use poem_mcpserver::{
content::Text, prompts::PromptMessages, streamable_http, McpServer, Prompts, Tools,
};
/// A collection of development assistant tools.
struct DevTools {
/// History of reviewed code snippets
review_count: u32,
}
/// This server provides development assistant tools for code analysis.
#[Tools]
impl DevTools {
/// Analyze code complexity and return metrics.
async fn analyze_complexity(
&mut self,
/// The code to analyze
code: String,
) -> Text<String> {
let lines = code.lines().count();
let chars = code.len();
self.review_count += 1;
Text(format!(
"Code Analysis #{}\n- Lines: {}\n- Characters: {}\n- Estimated complexity: {}",
self.review_count,
lines,
chars,
if lines > 50 {
"High"
} else if lines > 20 {
"Medium"
} else {
"Low"
}
))
}
/// Count occurrences of a pattern in code.
async fn count_pattern(
&self,
/// The code to search in
code: String,
/// The pattern to search for
pattern: String,
) -> Text<String> {
let count = code.matches(&pattern).count();
Text(format!("Found {} occurrences of '{}'", count, pattern))
}
/// Get the total number of code reviews performed.
async fn get_review_count(&self) -> Text<u32> {
Text(self.review_count)
}
}
/// A collection of development assistant prompts.
struct DevPrompts {
/// The assistant's persona name
assistant_name: String,
}
/// This server provides development assistant prompts for code review,
/// documentation generation, and debugging help.
///
/// Use the 'code_review' prompt for reviewing code snippets.
/// Use the 'generate_docs' prompt for generating documentation.
/// Use the 'debug_help' prompt for debugging assistance.
#[Prompts]
impl DevPrompts {
/// Review code for potential issues, style, and best practices.
async fn code_review(
&self,
/// The code snippet to review
#[mcp(required)]
code: Option<String>,
/// The programming language of the code
language: Option<String>,
/// Focus area: "security", "performance", "style", or "all"
focus: Option<String>,
) -> PromptMessages {
let lang = language.unwrap_or_else(|| "unknown".to_string());
let focus_area = focus.unwrap_or_else(|| "all".to_string());
PromptMessages::new()
.user(Text(format!(
"Please review the following {} code. Focus on: {}\n\n```{}\n{}\n```",
lang,
focus_area,
lang,
code.unwrap()
)))
.assistant(Text(format!(
"I'm {}, and I'll review this {} code focusing on {}. Let me analyze it...",
self.assistant_name, lang, focus_area
)))
}
/// Generate documentation for a code snippet.
async fn generate_docs(
&self,
/// The code to document
#[mcp(required)]
code: Option<String>,
/// Documentation style: "markdown", "jsdoc", "rustdoc", etc.
style: Option<String>,
) -> PromptMessages {
let doc_style = style.unwrap_or_else(|| "markdown".to_string());
PromptMessages::new().user(Text(format!(
"Generate {} documentation for the following code:\n\n```\n{}\n```",
doc_style,
code.unwrap()
)))
}
/// Get help debugging an issue.
async fn debug_help(
&self,
/// Description of the problem
#[mcp(required)]
problem: Option<String>,
/// The error message, if any
error_message: Option<String>,
/// Relevant code snippet
code: Option<String>,
) -> PromptMessages {
let mut prompt = format!(
"I need help debugging an issue.\n\nProblem: {}",
problem.unwrap()
);
if let Some(err) = error_message {
prompt.push_str(&format!("\n\nError message:\n```\n{}\n```", err));
}
if let Some(code_snippet) = code {
prompt.push_str(&format!("\n\nRelevant code:\n```\n{}\n```", code_snippet));
}
PromptMessages::new()
.user(Text(prompt))
.assistant(Text(format!(
"I'm {} and I'll help you debug this issue. Let me analyze the problem...",
self.assistant_name
)))
}
/// Get a simple greeting from the assistant.
async fn greet(&self) -> String {
format!(
"Hello! I'm {}, your development assistant. I can help you with:\n\
- Code reviews (use 'code_review' prompt)\n\
- Documentation generation (use 'generate_docs' prompt)\n\
- Debugging help (use 'debug_help' prompt)\n\n\
How can I assist you today?",
self.assistant_name
)
}
}
#[tokio::main]
async fn main() -> std::io::Result<()> {
if std::env::var_os("RUST_LOG").is_none() {
std::env::set_var("RUST_LOG", "poem=debug");
}
tracing_subscriber::fmt::init();
let listener = TcpListener::bind("127.0.0.1:8000");
let app = Route::new()
.at(
"/",
streamable_http::endpoint(|_| {
let tools = DevTools { review_count: 0 };
let prompts = DevPrompts {
assistant_name: "CodeBot".to_string(),
};
McpServer::new().tools(tools).prompts(prompts)
}),
)
.with(Cors::new());
Server::new(listener).run(app).await
}

View File

@@ -0,0 +1,10 @@
[package]
name = "prompts-example"
version = "0.1.0"
edition = "2021"
[dependencies]
poem-mcpserver.workspace = true
serde = { version = "1.0.219", features = ["derive"] }
schemars = "1.0"
tokio = { workspace = true, features = ["macros", "rt-multi-thread", "sync"] }

View File

@@ -0,0 +1,169 @@
use poem_mcpserver::{
content::Text, prompts::PromptMessages, stdio::stdio, McpServer, Prompts, Tools,
};
/// A collection of development assistant tools.
struct DevTools {
/// History of reviewed code snippets
review_count: u32,
}
/// This server provides development assistant tools for code analysis.
#[Tools]
impl DevTools {
/// Analyze code complexity and return metrics.
async fn analyze_complexity(
&mut self,
/// The code to analyze
code: String,
) -> Text<String> {
let lines = code.lines().count();
let chars = code.len();
self.review_count += 1;
Text(format!(
"Code Analysis #{}\n- Lines: {}\n- Characters: {}\n- Estimated complexity: {}",
self.review_count,
lines,
chars,
if lines > 50 {
"High"
} else if lines > 20 {
"Medium"
} else {
"Low"
}
))
}
/// Count occurrences of a pattern in code.
async fn count_pattern(
&self,
/// The code to search in
code: String,
/// The pattern to search for
pattern: String,
) -> Text<String> {
let count = code.matches(&pattern).count();
Text(format!("Found {} occurrences of '{}'", count, pattern))
}
/// Get the total number of code reviews performed.
async fn get_review_count(&self) -> Text<u32> {
Text(self.review_count)
}
}
/// A collection of development assistant prompts.
struct DevPrompts {
/// The assistant's persona name
assistant_name: String,
}
/// This server provides development assistant prompts for code review,
/// documentation generation, and debugging help.
///
/// Use the 'code_review' prompt for reviewing code snippets.
/// Use the 'generate_docs' prompt for generating documentation.
/// Use the 'debug_help' prompt for debugging assistance.
#[Prompts]
impl DevPrompts {
/// Review code for potential issues, style, and best practices.
async fn code_review(
&self,
/// The code snippet to review
#[mcp(required)]
code: Option<String>,
/// The programming language of the code
language: Option<String>,
/// Focus area: "security", "performance", "style", or "all"
focus: Option<String>,
) -> PromptMessages {
let lang = language.unwrap_or_else(|| "unknown".to_string());
let focus_area = focus.unwrap_or_else(|| "all".to_string());
PromptMessages::new()
.user(Text(format!(
"Please review the following {} code. Focus on: {}\n\n```{}\n{}\n```",
lang,
focus_area,
lang,
code.unwrap()
)))
.assistant(Text(format!(
"I'm {}, and I'll review this {} code focusing on {}. Let me analyze it...",
self.assistant_name, lang, focus_area
)))
}
/// Generate documentation for a code snippet.
async fn generate_docs(
&self,
/// The code to document
#[mcp(required)]
code: Option<String>,
/// Documentation style: "markdown", "jsdoc", "rustdoc", etc.
style: Option<String>,
) -> PromptMessages {
let doc_style = style.unwrap_or_else(|| "markdown".to_string());
PromptMessages::new().user(Text(format!(
"Generate {} documentation for the following code:\n\n```\n{}\n```",
doc_style,
code.unwrap()
)))
}
/// Get help debugging an issue.
async fn debug_help(
&self,
/// Description of the problem
#[mcp(required)]
problem: Option<String>,
/// The error message, if any
error_message: Option<String>,
/// Relevant code snippet
code: Option<String>,
) -> PromptMessages {
let mut prompt = format!(
"I need help debugging an issue.\n\nProblem: {}",
problem.unwrap()
);
if let Some(err) = error_message {
prompt.push_str(&format!("\n\nError message:\n```\n{}\n```", err));
}
if let Some(code_snippet) = code {
prompt.push_str(&format!("\n\nRelevant code:\n```\n{}\n```", code_snippet));
}
PromptMessages::new()
.user(Text(prompt))
.assistant(Text(format!(
"I'm {} and I'll help you debug this issue. Let me analyze the problem...",
self.assistant_name
)))
}
/// Get a simple greeting from the assistant.
async fn greet(&self) -> String {
format!(
"Hello! I'm {}, your development assistant. I can help you with:\n\
- Code reviews (use 'code_review' prompt)\n\
- Documentation generation (use 'generate_docs' prompt)\n\
- Debugging help (use 'debug_help' prompt)\n\n\
How can I assist you today?",
self.assistant_name
)
}
}
#[tokio::main]
async fn main() -> std::io::Result<()> {
let tools = DevTools { review_count: 0 };
let prompts = DevPrompts {
assistant_name: "CodeBot".to_string(),
};
stdio(McpServer::new().tools(tools).prompts(prompts)).await
}

View File

@@ -1,3 +1,4 @@
mod prompts;
mod tools;
mod utils;
@@ -33,3 +34,14 @@ pub fn Tools(args: TokenStream, input: TokenStream) -> TokenStream {
Err(err) => err.write_errors().into(),
}
}
#[proc_macro_attribute]
#[allow(non_snake_case)]
pub fn Prompts(args: TokenStream, input: TokenStream) -> TokenStream {
let prompt_args = parse_nested_meta!(prompts::PromptsArgs, args);
let item_impl = parse_macro_input!(input as ItemImpl);
match prompts::generate(prompt_args, item_impl) {
Ok(stream) => stream.into(),
Err(err) => err.write_errors().into(),
}
}

View File

@@ -0,0 +1,149 @@
use darling::{Error, FromMeta, Result};
use proc_macro2::TokenStream;
use quote::quote;
use syn::{FnArg, ImplItem, ItemImpl, Pat};
use crate::utils::*;
#[derive(FromMeta, Default)]
pub(crate) struct PromptsArgs {}
#[derive(FromMeta, Default)]
pub(crate) struct PromptArgs {
name: Option<String>,
}
#[derive(FromMeta, Default)]
pub(crate) struct PromptParamArgs {
name: Option<String>,
#[darling(default)]
required: bool,
}
pub(crate) fn generate(_args: PromptsArgs, mut item_impl: ItemImpl) -> Result<TokenStream> {
let crate_name = get_crate_name();
let ident = item_impl.self_ty.clone();
let mut prompts_descriptions = vec![];
let mut get_branches = vec![];
for item in &mut item_impl.items {
if let ImplItem::Fn(method) = item {
let prompt_args = parse_mcp_attrs::<PromptArgs>(&method.attrs)?;
remove_mcp_attrs(&mut method.attrs);
let prompt_name = match &prompt_args.name {
Some(name) => name.clone(),
None => method.sig.ident.to_string(),
};
let prompt_description = get_description(&method.attrs).unwrap_or_default();
if method.sig.asyncness.is_none() {
return Err(Error::custom("must be asynchronous").with_span(&method.sig.ident));
}
if method.sig.inputs.is_empty() {
return Err(Error::custom("at least one `&self` receiver is required.")
.with_span(&method.sig.ident));
}
if !matches!(&method.sig.inputs[0], FnArg::Receiver(_)) {
return Err(
Error::custom("the first parameter must be a `&self` receiver.")
.with_span(&method.sig.inputs[0]),
);
}
let mut prompt_arguments = vec![];
let mut arg_extractions = vec![];
let mut arg_names = vec![];
let mut required_checks = vec![];
for arg in method.sig.inputs.iter_mut().skip(1) {
let FnArg::Typed(pat) = arg else {
unreachable!()
};
let Pat::Ident(ident) = &mut *pat.pat else {
return Err(Error::custom("expected ident").with_span(&pat.pat));
};
let param_args = parse_mcp_attrs::<PromptParamArgs>(&pat.attrs)?;
remove_mcp_attrs(&mut pat.attrs);
let param_name = match &param_args.name {
Some(name) => name.clone(),
None => ident.ident.to_string(),
};
let param_desc = get_description(&pat.attrs).unwrap_or_default();
remove_description(&mut pat.attrs);
let is_required = param_args.required;
let arg_ident = &ident.ident;
prompt_arguments.push(quote! {
#crate_name::protocol::prompts::PromptArgument {
name: #param_name,
description: #param_desc,
required: #is_required,
},
});
if is_required {
required_checks.push(quote! {
if !arguments.contains_key(#param_name) {
return ::std::result::Result::Err(
#crate_name::protocol::rpc::RpcError::invalid_params(
format!("missing required argument: {}", #param_name)
)
);
}
});
}
arg_extractions.push(quote! {
let #arg_ident: ::std::option::Option<::std::string::String> = arguments.get(#param_name).cloned();
});
arg_names.push(quote! { #arg_ident });
}
let method_ident = &method.sig.ident;
get_branches.push(quote! {
#prompt_name => {
#(#required_checks)*
#(#arg_extractions)*
let response = self.#method_ident(#(#arg_names),*).await;
::std::result::Result::Ok(#crate_name::prompts::IntoPromptResponse::into_prompt_response(response))
}
});
prompts_descriptions.push(quote! {
#crate_name::protocol::prompts::Prompt {
name: #prompt_name,
description: #prompt_description,
arguments: &[#(#prompt_arguments)*],
},
});
}
}
Ok(quote! {
#item_impl
impl #crate_name::prompts::Prompts for #ident {
fn list() -> ::std::vec::Vec<#crate_name::protocol::prompts::Prompt> {
::std::vec![#(#prompts_descriptions)*]
}
async fn get(
&self,
name: &::std::primitive::str,
arguments: ::std::collections::HashMap<::std::string::String, ::std::string::String>,
) -> ::std::result::Result<#crate_name::protocol::prompts::PromptGetResponse, #crate_name::protocol::rpc::RpcError> {
match name {
#(#get_branches)*
_ => ::std::result::Result::Err(#crate_name::protocol::rpc::RpcError::method_not_found(format!("prompt not found: {}", name))),
}
}
}
})
}

View File

@@ -7,6 +7,7 @@
#![warn(missing_docs)]
pub mod content;
pub mod prompts;
pub mod protocol;
mod server;
pub mod stdio;
@@ -14,7 +15,7 @@ pub mod stdio;
#[cfg_attr(docsrs, doc(cfg(feature = "streamable-http")))]
pub mod streamable_http;
pub mod tool;
pub use poem_mcpserver_macros::Tools;
pub use poem_mcpserver_macros::{Prompts, Tools};
pub use schemars::JsonSchema;
pub use server::McpServer;
@@ -22,5 +23,5 @@ pub use server::McpServer;
pub mod private {
pub use serde_json;
pub use crate::tool::IntoToolResponse;
pub use crate::{prompts::IntoPromptResponse, tool::IntoToolResponse};
}

View File

@@ -0,0 +1,180 @@
//! Types for prompts.
use std::future::Future;
use crate::{
content::IntoContent,
protocol::{
content::Content,
prompts::{Prompt, PromptGetResponse, PromptMessage, Role},
rpc::RpcError,
},
};
/// Represents a type that can be converted into a prompt response.
pub trait IntoPromptResponse {
/// Consumes the object and converts it into a prompt response.
fn into_prompt_response(self) -> PromptGetResponse;
}
impl IntoPromptResponse for PromptGetResponse {
#[inline]
fn into_prompt_response(self) -> PromptGetResponse {
self
}
}
impl IntoPromptResponse for PromptMessage {
#[inline]
fn into_prompt_response(self) -> PromptGetResponse {
PromptGetResponse {
description: "",
messages: vec![self],
}
}
}
impl IntoPromptResponse for Vec<PromptMessage> {
#[inline]
fn into_prompt_response(self) -> PromptGetResponse {
PromptGetResponse {
description: "",
messages: self,
}
}
}
impl IntoPromptResponse for String {
#[inline]
fn into_prompt_response(self) -> PromptGetResponse {
PromptGetResponse {
description: "",
messages: vec![PromptMessage {
role: Role::User,
content: Content::Text { text: self },
}],
}
}
}
impl IntoPromptResponse for &str {
#[inline]
fn into_prompt_response(self) -> PromptGetResponse {
PromptGetResponse {
description: "",
messages: vec![PromptMessage {
role: Role::User,
content: Content::Text {
text: self.to_string(),
},
}],
}
}
}
impl<T> IntoPromptResponse for (Role, T)
where
T: IntoContent,
{
#[inline]
fn into_prompt_response(self) -> PromptGetResponse {
PromptGetResponse {
description: "",
messages: vec![PromptMessage {
role: self.0,
content: self.1.into_content(),
}],
}
}
}
/// A builder for creating prompt responses with multiple messages.
#[derive(Debug, Default)]
pub struct PromptMessages {
messages: Vec<PromptMessage>,
}
impl PromptMessages {
/// Creates a new empty prompt messages builder.
#[inline]
pub fn new() -> Self {
Self {
messages: Vec::new(),
}
}
/// Adds a user message to the prompt.
#[inline]
pub fn user(mut self, content: impl IntoContent) -> Self {
self.messages.push(PromptMessage {
role: Role::User,
content: content.into_content(),
});
self
}
/// Adds an assistant message to the prompt.
#[inline]
pub fn assistant(mut self, content: impl IntoContent) -> Self {
self.messages.push(PromptMessage {
role: Role::Assistant,
content: content.into_content(),
});
self
}
/// Adds a message with a specific role to the prompt.
#[inline]
pub fn message(mut self, role: Role, content: impl IntoContent) -> Self {
self.messages.push(PromptMessage {
role,
content: content.into_content(),
});
self
}
}
impl IntoPromptResponse for PromptMessages {
#[inline]
fn into_prompt_response(self) -> PromptGetResponse {
PromptGetResponse {
description: "",
messages: self.messages,
}
}
}
/// Represents a prompts collection.
pub trait Prompts {
/// Returns a list of prompts.
fn list() -> Vec<Prompt>;
/// Gets a prompt by name with the given arguments.
fn get(
&self,
name: &str,
arguments: std::collections::HashMap<String, String>,
) -> impl Future<Output = Result<PromptGetResponse, RpcError>> + Send;
}
/// Empty prompts collection.
#[derive(Debug, Clone, Copy)]
pub struct NoPrompts;
impl Prompts for NoPrompts {
#[inline]
fn list() -> Vec<Prompt> {
vec![]
}
#[inline]
async fn get(
&self,
name: &str,
_arguments: std::collections::HashMap<String, String>,
) -> Result<PromptGetResponse, RpcError> {
Err(RpcError::method_not_found(format!(
"prompt '{name}' not found"
)))
}
}

View File

@@ -12,6 +12,17 @@ pub struct PromptsListRequest {
pub cursor: Option<String>,
}
/// A request to get a prompt.
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PromptsGetRequest {
/// The name of the prompt to get.
pub name: String,
/// Arguments to pass to the prompt.
#[serde(default)]
pub arguments: std::collections::HashMap<String, String>,
}
/// Prompt argument.
#[derive(Debug, Serialize)]
pub struct PromptArgument {

View File

@@ -6,7 +6,7 @@ use serde_json::Value;
use crate::protocol::{
initialize::InitializeRequest,
prompts::PromptsListRequest,
prompts::{PromptsGetRequest, PromptsListRequest},
tool::{ToolsCallRequest, ToolsListRequest},
};
@@ -62,6 +62,12 @@ pub enum Requests {
#[serde(default)]
params: PromptsListRequest,
},
/// Get a prompt.
#[serde(rename = "prompts/get")]
PromptsGet {
/// Prompts get request parameters.
params: PromptsGetRequest,
},
/// Resources list.
#[serde(rename = "resources/list")]
ResourcesList {

View File

@@ -3,13 +3,14 @@ use std::collections::HashSet;
use serde_json::Value;
use crate::{
prompts::{NoPrompts, Prompts},
protocol::{
JSON_RPC_VERSION,
initialize::{
InitializeRequest, InitializeResponse, PromptsCapability, ResourcesCapability,
ServerCapabilities, ServerInfo, ToolsCapability,
},
prompts::PromptsListResponse,
prompts::{PromptsGetRequest, PromptsListResponse},
resources::ResourcesListResponse,
rpc::{Request, RequestId, Requests, Response},
tool::{ToolsCallRequest, ToolsListResponse},
@@ -18,25 +19,27 @@ use crate::{
};
/// A server that can be used to handle MCP requests.
pub struct McpServer<ToolsType = NoTools> {
pub struct McpServer<ToolsType = NoTools, PromptsType = NoPrompts> {
tools: ToolsType,
prompts: PromptsType,
disabled_tools: HashSet<String>,
server_info: ServerInfo,
}
impl Default for McpServer<NoTools> {
impl Default for McpServer<NoTools, NoPrompts> {
#[inline]
fn default() -> Self {
Self::new()
}
}
impl McpServer<NoTools> {
impl McpServer<NoTools, NoPrompts> {
/// Creates a new MCP server.
#[inline]
pub fn new() -> Self {
Self {
tools: NoTools,
prompts: NoPrompts,
disabled_tools: HashSet::new(),
server_info: ServerInfo {
name: "poem-mcpserver".to_string(),
@@ -46,18 +49,34 @@ impl McpServer<NoTools> {
}
}
impl<ToolsType> McpServer<ToolsType>
impl<ToolsType, PromptsType> McpServer<ToolsType, PromptsType>
where
ToolsType: Tools,
PromptsType: Prompts,
{
/// Sets the tools that the server will use.
#[inline]
pub fn tools<T>(self, tools: T) -> McpServer<T>
pub fn tools<T>(self, tools: T) -> McpServer<T, PromptsType>
where
T: Tools,
{
McpServer {
tools,
prompts: self.prompts,
disabled_tools: self.disabled_tools,
server_info: self.server_info,
}
}
/// Sets the prompts that the server will use.
#[inline]
pub fn prompts<P>(self, prompts: P) -> McpServer<ToolsType, P>
where
P: Prompts,
{
McpServer {
tools: self.tools,
prompts,
disabled_tools: self.disabled_tools,
server_info: self.server_info,
}
@@ -172,6 +191,41 @@ where
}
}
fn handle_prompts_list(&self, id: Option<RequestId>) -> Response<Value> {
Response {
jsonrpc: JSON_RPC_VERSION.to_string(),
id,
result: Some(PromptsListResponse {
prompts: PromptsType::list(),
}),
error: None,
}
.map_result_to_value()
}
async fn handle_prompts_get(
&self,
request: PromptsGetRequest,
id: Option<RequestId>,
) -> Response<Value> {
match self.prompts.get(&request.name, request.arguments).await {
Ok(response) => Response {
jsonrpc: JSON_RPC_VERSION.to_string(),
id,
result: Some(response),
error: None,
}
.map_result_to_value(),
Err(err) => Response::<()> {
jsonrpc: JSON_RPC_VERSION.to_string(),
id,
result: None,
error: Some(err),
}
.map_result_to_value(),
}
}
/// Handles a request and returns a response.
pub async fn handle_request(&mut self, request: Request) -> Option<Response<Value>> {
match request.body {
@@ -183,15 +237,10 @@ where
Requests::ToolsCall { params } => {
Some(self.handle_tools_call(params, request.id).await)
}
Requests::PromptsList { .. } => Some(
Response {
jsonrpc: JSON_RPC_VERSION.to_string(),
id: request.id,
result: Some(PromptsListResponse { prompts: vec![] }),
error: None,
Requests::PromptsList { .. } => Some(self.handle_prompts_list(request.id)),
Requests::PromptsGet { params } => {
Some(self.handle_prompts_get(params, request.id).await)
}
.map_result_to_value(),
),
Requests::ResourcesList { .. } => Some(
Response {
jsonrpc: JSON_RPC_VERSION.to_string(),

View File

@@ -5,6 +5,7 @@ use tokio::io::{AsyncBufReadExt, BufReader};
use crate::{
McpServer,
prompts::Prompts,
protocol::{
JSON_RPC_VERSION,
rpc::{BatchRequest, Response, RpcError},
@@ -17,9 +18,12 @@ fn print_response(response: impl Serialize) {
}
/// Run the server using standard input and output.
pub async fn stdio<ToolsType>(server: McpServer<ToolsType>) -> std::io::Result<()>
pub async fn stdio<ToolsType, PromptsType>(
server: McpServer<ToolsType, PromptsType>,
) -> std::io::Result<()>
where
ToolsType: Tools,
PromptsType: Prompts,
{
let mut server = server;
let mut input = BufReader::new(tokio::io::stdin()).lines();

View File

@@ -6,100 +6,119 @@ use std::{
time::Duration,
};
use mime::Mime;
use poem::{
EndpointExt, IntoEndpoint, IntoResponse, Request, handler,
http::{HeaderMap, StatusCode},
http::StatusCode,
post,
web::{
Accept, Data, Json,
Accept, Data, Json, Query,
sse::{Event, SSE},
},
};
use serde_json::Value;
use tokio::time::Instant;
use crate::{
McpServer,
protocol::rpc::{BatchRequest as McpBatchRequest, Request as McpRequest},
prompts::Prompts,
protocol::rpc::{BatchRequest as McpBatchRequest, Request as McpRequest, Requests},
tool::Tools,
};
const SESSION_TIMEOUT: Duration = Duration::from_secs(60 * 5);
type ServerFactoryFn<ToolsType> = Box<dyn Fn(&Request) -> McpServer<ToolsType> + Send + Sync>;
type ServerFactoryFn<ToolsType, PromptsType> =
Box<dyn Fn(&Request) -> McpServer<ToolsType, PromptsType> + Send + Sync>;
struct Session<ToolsType> {
server: Arc<tokio::sync::Mutex<McpServer<ToolsType>>>,
struct Session<ToolsType, PromptsType> {
server: Arc<tokio::sync::Mutex<McpServer<ToolsType, PromptsType>>>,
sender: Option<tokio::sync::mpsc::UnboundedSender<String>>,
last_active: Instant,
}
struct State<ToolsType> {
server_factory: ServerFactoryFn<ToolsType>,
sessions: Mutex<HashMap<String, Session<ToolsType>>>,
struct State<ToolsType, PromptsType> {
server_factory: ServerFactoryFn<ToolsType, PromptsType>,
sessions: Mutex<HashMap<String, Session<ToolsType, PromptsType>>>,
}
async fn handle_request<ToolsType>(
server: Arc<tokio::sync::Mutex<McpServer<ToolsType>>>,
session_id: &str,
accept: &Mime,
requests: impl Iterator<Item = McpRequest> + Send + 'static,
) -> impl IntoResponse
async fn process_request<ToolsType, PromptsType>(
server: Arc<tokio::sync::Mutex<McpServer<ToolsType, PromptsType>>>,
request: McpRequest,
) -> Option<crate::protocol::rpc::Response<Value>>
where
ToolsType: Tools + Send + Sync + 'static,
PromptsType: Prompts + Send + Sync + 'static,
{
tracing::info!(
session_id = session_id,
accept = accept.essence_str(),
"handling requests"
);
match accept.essence_str() {
"application/json" => {
let mut resps = vec![];
for request in requests {
tracing::info!(session_id = session_id, request = ?request, "received request");
let resp = server.lock().await.handle_request(request).await;
tracing::info!(session_id = session_id, response = ?resp, "sending response");
resps.extend(resp);
}
Json(resps)
.with_content_type("application/json")
.into_response()
}
"text/event-stream" => {
let session_id = session_id.to_string();
SSE::new(async_stream::stream! {
for request in requests {
tracing::info!(session_id = session_id, request = ?request, "received request");
let resp = server.lock().await.handle_request(request).await;
tracing::info!(session_id = session_id, response = ?resp, "sending response");
yield Event::message(serde_json::to_string(&resp).unwrap()).event_type("message");
}
})
.into_response()
}
_ => StatusCode::BAD_REQUEST.into_response(),
}
server.lock().await.handle_request(request).await
}
#[handler]
async fn post_handler<ToolsType>(
data: Data<&Arc<State<ToolsType>>>,
async fn get_handler<ToolsType, PromptsType>(
data: Data<&Arc<State<ToolsType, PromptsType>>>,
request: &Request,
batch_request: Json<McpBatchRequest>,
accept: Accept,
) -> impl IntoResponse
where
ToolsType: Tools + Send + Sync + 'static,
PromptsType: Prompts + Send + Sync + 'static,
{
let Some(accept) = accept.0.first() else {
let session_id = session_id();
let server = (data.0.server_factory)(request);
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
{
let mut sessions = data.0.sessions.lock().unwrap();
sessions.insert(
session_id.clone(),
Session {
server: Arc::new(tokio::sync::Mutex::new(server)),
sender: Some(tx),
last_active: Instant::now(),
},
);
}
tracing::info!(
session_id = session_id,
"created new standard session (SSE)"
);
SSE::new(async_stream::stream! {
let endpoint_uri = format!("?session_id={}", session_id);
yield Event::message(endpoint_uri).event_type("endpoint");
while let Some(msg) = rx.recv().await {
yield Event::message(msg).event_type("message");
}
})
.into_response()
}
#[handler]
async fn post_handler<ToolsType, PromptsType>(
data: Data<&Arc<State<ToolsType, PromptsType>>>,
request: &Request,
batch_request: Json<McpBatchRequest>,
accept: Accept,
query: Query<HashMap<String, String>>,
) -> impl IntoResponse
where
ToolsType: Tools + Send + Sync + 'static,
PromptsType: Prompts + Send + Sync + 'static,
{
let session_id_param = request
.headers()
.get("Mcp-Session-Id")
.and_then(|value| value.to_str().ok())
.map(String::from)
.or_else(|| query.get("session_id").cloned());
if session_id_param.is_none() {
let Some(_accept) = accept.0.first() else {
return StatusCode::BAD_REQUEST.into_response();
};
if batch_request.len() == 1
&& batch_request.requests()[0].is_initialize()
&& !request.headers().contains_key("Mcp-Session-Id")
{
if batch_request.len() == 1 && batch_request.requests()[0].is_initialize() {
let session_id = session_id();
let mut server = (data.0.server_factory)(request);
let initialize_request = batch_request.0.into_iter().next().unwrap();
@@ -112,54 +131,127 @@ where
session_id.clone(),
Session {
server: Arc::new(tokio::sync::Mutex::new(server)),
sender: None,
last_active: Instant::now(),
},
);
tracing::info!(session_id = session_id, "created new session");
tracing::info!(session_id = session_id, "created new legacy session");
return Json(resp)
.with_header("Mcp-Session-Id", session_id)
.into_response();
}
let Some(session_id) = request
.headers()
.get("Mcp-Session-Id")
.and_then(|value| value.to_str().ok())
else {
return StatusCode::BAD_REQUEST.into_response();
};
}
let server = {
let session_id = session_id_param.unwrap();
let (server, sender) = {
let mut sessions = data.0.sessions.lock().unwrap();
let Some(session) = sessions.get_mut(session_id) else {
let Some(session) = sessions.get_mut(&session_id) else {
return StatusCode::NOT_FOUND.into_response();
};
session.last_active = Instant::now();
session.server.clone()
(session.server.clone(), session.sender.clone())
};
handle_request(server, session_id, accept, batch_request.0.into_iter())
.await
if let Some(tx) = sender {
for request in batch_request.0 {
tracing::info!(session_id = session_id, request = ?request, "received request (std)");
let resp = process_request(server.clone(), request).await;
if let Some(resp) = resp {
tracing::info!(session_id = session_id, response = ?resp, "pushing to SSE");
if tx.send(serde_json::to_string(&resp).unwrap()).is_err() {
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
}
}
return StatusCode::ACCEPTED.into_response();
}
let all_notifications = batch_request.requests().iter().all(|request| {
matches!(
request.body,
Requests::Initialized | Requests::Cancelled { .. }
)
});
let requests = batch_request.0.into_iter();
let accept = accept
.0
.first()
.map(|value| value.essence_str())
.unwrap_or("application/json");
match accept {
"text/event-stream" => {
if all_notifications {
return StatusCode::ACCEPTED.into_response();
}
let session_id = session_id.clone();
SSE::new(async_stream::stream! {
for request in requests {
tracing::info!(session_id = session_id, request = ?request, "received request");
let resp = process_request(server.clone(), request).await;
if let Some(resp) = resp {
tracing::info!(session_id = session_id, response = ?resp, "sending response");
yield Event::message(serde_json::to_string(&resp).unwrap()).event_type("message");
}
}
})
.into_response()
}
_ => {
let mut resps = vec![];
for request in requests {
tracing::info!(session_id = session_id, request = ?request, "received request");
let resp = process_request(server.clone(), request).await;
if let Some(resp) = resp {
tracing::info!(session_id = session_id, response = ?resp, "sending response");
resps.push(resp);
}
}
if resps.is_empty() {
return StatusCode::ACCEPTED.into_response();
}
Json(resps)
.with_content_type("application/json")
.into_response()
}
}
}
#[handler]
async fn delete_handler<ToolsType>(
data: Data<&Arc<State<ToolsType>>>,
headers: &HeaderMap,
async fn delete_handler<ToolsType, PromptsType>(
data: Data<&Arc<State<ToolsType, PromptsType>>>,
req: &Request,
query: Query<HashMap<String, String>>,
) -> impl IntoResponse
where
ToolsType: Tools + Send + Sync + 'static,
PromptsType: Prompts + Send + Sync + 'static,
{
let Some(session_id) = headers
let session_id = req
.headers()
.get("Mcp-Session-Id")
.and_then(|value| value.to_str().ok())
else {
.map(String::from)
.or_else(|| query.get("session_id").cloned());
let Some(session_id) = session_id else {
return StatusCode::BAD_REQUEST;
};
if data.sessions.lock().unwrap().remove(session_id).is_none() {
if data
.0
.sessions
.lock()
.unwrap()
.remove(&session_id)
.is_none()
{
return StatusCode::NOT_FOUND;
}
@@ -168,10 +260,11 @@ where
}
/// A streamable http endpoint that can be used to handle MCP requests.
pub fn endpoint<F, ToolsType>(server_factory: F) -> impl IntoEndpoint
pub fn endpoint<F, ToolsType, PromptsType>(server_factory: F) -> impl IntoEndpoint
where
F: Fn(&Request) -> McpServer<ToolsType> + Send + Sync + 'static,
F: Fn(&Request) -> McpServer<ToolsType, PromptsType> + Send + Sync + 'static,
ToolsType: Tools + Send + Sync + 'static,
PromptsType: Prompts + Send + Sync + 'static,
{
let state = Arc::new(State {
server_factory: Box::new(server_factory),
@@ -190,8 +283,9 @@ where
}
});
post(post_handler::<ToolsType>::default())
.delete(delete_handler::<ToolsType>::default())
post(post_handler::<ToolsType, PromptsType>::default())
.get(get_handler::<ToolsType, PromptsType>::default())
.delete(delete_handler::<ToolsType, PromptsType>::default())
.data(state)
}

View File

@@ -117,7 +117,15 @@ where
T: Serialize + JsonSchema,
{
fn output_schema() -> Option<Schema> {
Some(schemars::SchemaGenerator::default().into_root_schema_for::<T>())
let schema = schemars::SchemaGenerator::default().into_root_schema_for::<T>();
if let Ok(value) = serde_json::to_value(&schema) {
if value.get("type") == Some(&serde_json::Value::String("array".to_string())) {
panic!(
"Tool return type must be an object, but found array. Please wrap the return value in a struct."
);
}
}
Some(schema)
}
fn into_tool_response(self) -> ToolsCallResponse {
@@ -137,7 +145,15 @@ where
E: Display,
{
fn output_schema() -> Option<Schema> {
Some(schemars::SchemaGenerator::default().into_root_schema_for::<T>())
let schema = schemars::SchemaGenerator::default().into_root_schema_for::<T>();
if let Ok(value) = serde_json::to_value(&schema) {
if value.get("type") == Some(&serde_json::Value::String("array".to_string())) {
panic!(
"Tool return type must be an object, but found array. Please wrap the return value in a struct."
);
}
}
Some(schema)
}
fn into_tool_response(self) -> ToolsCallResponse {

View File

@@ -0,0 +1,315 @@
use std::collections::HashMap;
use poem_mcpserver::{
McpServer, Prompts,
content::Text,
prompts::PromptMessages,
protocol::{
JSON_RPC_VERSION,
prompts::{PromptsGetRequest, PromptsListRequest},
rpc::{Request, RequestId, Requests},
},
};
struct TestPrompts {
system_name: String,
}
impl TestPrompts {
fn new() -> Self {
Self {
system_name: "TestSystem".to_string(),
}
}
}
#[Prompts]
impl TestPrompts {
/// A simple greeting prompt.
async fn greet(
&self,
/// The name to greet
#[mcp(required)]
name: Option<String>,
) -> String {
format!("Hello, {}! Welcome to {}.", name.unwrap(), self.system_name)
}
/// A code review prompt with optional language parameter.
async fn code_review(
&self,
/// The code to review
#[mcp(required)]
code: Option<String>,
/// The programming language
language: Option<String>,
) -> PromptMessages {
let lang = language.unwrap_or_else(|| "unknown".to_string());
PromptMessages::new()
.user(Text(format!(
"Please review the following {} code:\n\n```{}\n{}\n```",
lang,
lang,
code.unwrap()
)))
.assistant(Text(
"I'll review this code for you. Let me analyze it...".to_string(),
))
}
/// A simple prompt without required arguments.
async fn help(&self) -> String {
"How can I help you today?".to_string()
}
}
#[tokio::test]
async fn prompts_list() {
let prompts = TestPrompts::new();
let mut server = McpServer::new().prompts(prompts);
let resp = server
.handle_request(Request {
jsonrpc: JSON_RPC_VERSION.to_string(),
id: Some(RequestId::Int(1)),
body: Requests::PromptsList {
params: PromptsListRequest { cursor: None },
},
})
.await;
assert_eq!(
serde_json::to_value(&resp).unwrap(),
serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"result": {
"prompts": [
{
"name": "greet",
"description": "A simple greeting prompt.",
"arguments": [
{
"name": "name",
"description": "The name to greet",
"required": true
}
]
},
{
"name": "code_review",
"description": "A code review prompt with optional language parameter.",
"arguments": [
{
"name": "code",
"description": "The code to review",
"required": true
},
{
"name": "language",
"description": "The programming language",
"required": false
}
]
},
{
"name": "help",
"description": "A simple prompt without required arguments.",
"arguments": []
}
]
}
})
);
}
#[tokio::test]
async fn prompts_get_simple() {
let prompts = TestPrompts::new();
let mut server = McpServer::new().prompts(prompts);
let mut arguments = HashMap::new();
arguments.insert("name".to_string(), "Alice".to_string());
let resp = server
.handle_request(Request {
jsonrpc: JSON_RPC_VERSION.to_string(),
id: Some(RequestId::Int(2)),
body: Requests::PromptsGet {
params: PromptsGetRequest {
name: "greet".to_string(),
arguments,
},
},
})
.await;
assert_eq!(
serde_json::to_value(&resp).unwrap(),
serde_json::json!({
"jsonrpc": "2.0",
"id": 2,
"result": {
"description": "",
"messages": [
{
"role": "user",
"content": {
"type": "text",
"text": "Hello, Alice! Welcome to TestSystem."
}
}
]
}
})
);
}
#[tokio::test]
async fn prompts_get_with_multiple_messages() {
let prompts = TestPrompts::new();
let mut server = McpServer::new().prompts(prompts);
let mut arguments = HashMap::new();
arguments.insert("code".to_string(), "fn main() {}".to_string());
arguments.insert("language".to_string(), "rust".to_string());
let resp = server
.handle_request(Request {
jsonrpc: JSON_RPC_VERSION.to_string(),
id: Some(RequestId::Int(3)),
body: Requests::PromptsGet {
params: PromptsGetRequest {
name: "code_review".to_string(),
arguments,
},
},
})
.await;
assert_eq!(
serde_json::to_value(&resp).unwrap(),
serde_json::json!({
"jsonrpc": "2.0",
"id": 3,
"result": {
"description": "",
"messages": [
{
"role": "user",
"content": {
"type": "text",
"text": "Please review the following rust code:\n\n```rust\nfn main() {}\n```"
}
},
{
"role": "assistant",
"content": {
"type": "text",
"text": "I'll review this code for you. Let me analyze it..."
}
}
]
}
})
);
}
#[tokio::test]
async fn prompts_get_missing_required_argument() {
let prompts = TestPrompts::new();
let mut server = McpServer::new().prompts(prompts);
let resp = server
.handle_request(Request {
jsonrpc: JSON_RPC_VERSION.to_string(),
id: Some(RequestId::Int(4)),
body: Requests::PromptsGet {
params: PromptsGetRequest {
name: "greet".to_string(),
arguments: HashMap::new(),
},
},
})
.await;
let resp_value = serde_json::to_value(&resp).unwrap();
assert_eq!(resp_value["jsonrpc"], "2.0");
assert_eq!(resp_value["id"], 4);
assert!(resp_value["error"]["code"].as_i64().is_some());
assert!(
resp_value["error"]["message"]
.as_str()
.unwrap()
.contains("missing required argument: name")
);
}
#[tokio::test]
async fn prompts_get_unknown_prompt() {
let prompts = TestPrompts::new();
let mut server = McpServer::new().prompts(prompts);
let resp = server
.handle_request(Request {
jsonrpc: JSON_RPC_VERSION.to_string(),
id: Some(RequestId::Int(5)),
body: Requests::PromptsGet {
params: PromptsGetRequest {
name: "unknown_prompt".to_string(),
arguments: HashMap::new(),
},
},
})
.await;
let resp_value = serde_json::to_value(&resp).unwrap();
assert_eq!(resp_value["jsonrpc"], "2.0");
assert_eq!(resp_value["id"], 5);
assert!(resp_value["error"]["code"].as_i64().is_some());
assert!(
resp_value["error"]["message"]
.as_str()
.unwrap()
.contains("prompt not found")
);
}
#[tokio::test]
async fn prompts_get_no_arguments_needed() {
let prompts = TestPrompts::new();
let mut server = McpServer::new().prompts(prompts);
let resp = server
.handle_request(Request {
jsonrpc: JSON_RPC_VERSION.to_string(),
id: Some(RequestId::Int(6)),
body: Requests::PromptsGet {
params: PromptsGetRequest {
name: "help".to_string(),
arguments: HashMap::new(),
},
},
})
.await;
assert_eq!(
serde_json::to_value(&resp).unwrap(),
serde_json::json!({
"jsonrpc": "2.0",
"id": 6,
"result": {
"description": "",
"messages": [
{
"role": "user",
"content": {
"type": "text",
"text": "How can I help you today?"
}
}
]
}
})
);
}

View File

@@ -278,3 +278,74 @@ async fn disable_tools() {
})
);
}
#[derive(JsonSchema, Serialize)]
struct StringList {
items: Vec<String>,
}
struct CollectionTools;
#[Tools]
impl CollectionTools {
async fn get_list(&self) -> StructuredContent<StringList> {
StructuredContent(StringList {
items: vec!["a".to_string(), "b".to_string()],
})
}
}
#[tokio::test]
async fn collection_schema() {
let tools = CollectionTools;
let mut server = McpServer::new().tools(tools);
let resp = server
.handle_request(Request {
jsonrpc: JSON_RPC_VERSION.to_string(),
id: Some(RequestId::Int(1)),
body: Requests::ToolsList {
params: ToolsListRequest { cursor: None },
},
})
.await;
let resp_json = serde_json::to_value(&resp).unwrap();
let tools = resp_json["result"]["tools"].as_array().unwrap();
let tool = &tools[0];
assert_eq!(tool["name"], "get_list");
let output_schema = &tool["outputSchema"];
assert_eq!(output_schema["type"], "object");
assert_eq!(output_schema["title"], "StringList");
assert!(output_schema["properties"]["items"]["type"] == "array");
}
struct ArrayTools;
#[Tools]
impl ArrayTools {
async fn array_ret(&self) -> StructuredContent<Vec<String>> {
StructuredContent(vec![])
}
}
#[tokio::test]
#[should_panic(
expected = "Tool return type must be an object, but found array. Please wrap the return value in a struct."
)]
async fn test_array_panic() {
let tools = ArrayTools;
let mut server = McpServer::new().tools(tools);
let _ = server
.handle_request(Request {
jsonrpc: JSON_RPC_VERSION.to_string(),
id: Some(RequestId::Int(1)),
body: Requests::ToolsList {
params: ToolsListRequest { cursor: None },
},
})
.await;
}