mirror of
https://github.com/poem-web/poem.git
synced 2026-01-25 04:18:25 +00:00
feat(mcp): support prompts (#1155)
* feat(mcp): support prompts
* fmt
* fix
* Revert "fix"
This reverts commit d370e9be86.
* fix get
* not allow array
* add test case
* fix clippy
* fix connect with rmcp client
* fix
This commit is contained in:
20
examples/mcpserver/prompts-streamable-http/Cargo.toml
Normal file
20
examples/mcpserver/prompts-streamable-http/Cargo.toml
Normal file
@@ -0,0 +1,20 @@
|
||||
[package]
|
||||
name = "prompts-streamable-http"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
default-run = "prompts-streamable-http"
|
||||
|
||||
[[bin]]
|
||||
name = "prompts-streamable-http-client"
|
||||
path = "src/client.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1"
|
||||
poem-mcpserver = { workspace = true, features = ["streamable-http"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
schemars = "1.0"
|
||||
poem = { workspace = true, features = ["sse"] }
|
||||
rmcp = { version = "0.13", features = ["client", "transport-streamable-http-client-reqwest"] }
|
||||
tokio = { workspace = true, features = ["macros", "rt-multi-thread", "sync"] }
|
||||
tracing-subscriber.workspace = true
|
||||
27
examples/mcpserver/prompts-streamable-http/README.md
Normal file
27
examples/mcpserver/prompts-streamable-http/README.md
Normal file
@@ -0,0 +1,27 @@
|
||||
# prompts-streamable-http
|
||||
|
||||
This example runs a streamable HTTP MCP server with tools and prompts, plus a small rmcp client that connects to it.
|
||||
|
||||
## Run the server
|
||||
|
||||
1. From the repo root, in one terminal:
|
||||
cargo run --manifest-path ./examples/mcpserver/prompts-streamable-http/Cargo.toml
|
||||
|
||||
Or from the example directory:
|
||||
cargo run
|
||||
|
||||
The server listens on http://127.0.0.1:8000/.
|
||||
|
||||
## Run the client (rmcp)
|
||||
|
||||
2. From the repo root, in another terminal:
|
||||
cargo run --manifest-path ./examples/mcpserver/prompts-streamable-http/Cargo.toml --bin prompts-streamable-http-client
|
||||
|
||||
Or from the example directory:
|
||||
cargo run --bin prompts-streamable-http-client
|
||||
|
||||
The client lists tools and invokes the get_review_count tool.
|
||||
|
||||
## Tip
|
||||
|
||||
Because this example is not a workspace member, run commands from this directory or use --manifest-path from the repo root.
|
||||
27
examples/mcpserver/prompts-streamable-http/src/client.rs
Normal file
27
examples/mcpserver/prompts-streamable-http/src/client.rs
Normal file
@@ -0,0 +1,27 @@
|
||||
use anyhow::Result;
|
||||
use rmcp::{
|
||||
model::CallToolRequestParam, service::ServiceExt, transport::StreamableHttpClientTransport,
|
||||
};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
let transport = StreamableHttpClientTransport::from_uri("http://127.0.0.1:8000/");
|
||||
let service = ().serve(transport).await?;
|
||||
|
||||
let tools = service.list_tools(Default::default()).await?;
|
||||
println!("Available tools: {tools:#?}");
|
||||
|
||||
// Sending empty arguments object for get_review_count tool to satisfy schema
|
||||
|
||||
let response = service
|
||||
.call_tool(CallToolRequestParam {
|
||||
name: "get_review_count".to_string().into(),
|
||||
arguments: Some(serde_json::Map::new()),
|
||||
task: None,
|
||||
})
|
||||
.await?;
|
||||
println!("get_review_count response: {response:#?}");
|
||||
|
||||
service.cancel().await?;
|
||||
Ok(())
|
||||
}
|
||||
183
examples/mcpserver/prompts-streamable-http/src/main.rs
Normal file
183
examples/mcpserver/prompts-streamable-http/src/main.rs
Normal file
@@ -0,0 +1,183 @@
|
||||
use poem::{listener::TcpListener, middleware::Cors, EndpointExt, Route, Server};
|
||||
use poem_mcpserver::{
|
||||
content::Text, prompts::PromptMessages, streamable_http, McpServer, Prompts, Tools,
|
||||
};
|
||||
|
||||
/// A collection of development assistant tools.
|
||||
struct DevTools {
|
||||
/// History of reviewed code snippets
|
||||
review_count: u32,
|
||||
}
|
||||
|
||||
/// This server provides development assistant tools for code analysis.
|
||||
#[Tools]
|
||||
impl DevTools {
|
||||
/// Analyze code complexity and return metrics.
|
||||
async fn analyze_complexity(
|
||||
&mut self,
|
||||
/// The code to analyze
|
||||
code: String,
|
||||
) -> Text<String> {
|
||||
let lines = code.lines().count();
|
||||
let chars = code.len();
|
||||
self.review_count += 1;
|
||||
Text(format!(
|
||||
"Code Analysis #{}\n- Lines: {}\n- Characters: {}\n- Estimated complexity: {}",
|
||||
self.review_count,
|
||||
lines,
|
||||
chars,
|
||||
if lines > 50 {
|
||||
"High"
|
||||
} else if lines > 20 {
|
||||
"Medium"
|
||||
} else {
|
||||
"Low"
|
||||
}
|
||||
))
|
||||
}
|
||||
|
||||
/// Count occurrences of a pattern in code.
|
||||
async fn count_pattern(
|
||||
&self,
|
||||
/// The code to search in
|
||||
code: String,
|
||||
/// The pattern to search for
|
||||
pattern: String,
|
||||
) -> Text<String> {
|
||||
let count = code.matches(&pattern).count();
|
||||
Text(format!("Found {} occurrences of '{}'", count, pattern))
|
||||
}
|
||||
|
||||
/// Get the total number of code reviews performed.
|
||||
async fn get_review_count(&self) -> Text<u32> {
|
||||
Text(self.review_count)
|
||||
}
|
||||
}
|
||||
|
||||
/// A collection of development assistant prompts.
|
||||
struct DevPrompts {
|
||||
/// The assistant's persona name
|
||||
assistant_name: String,
|
||||
}
|
||||
|
||||
/// This server provides development assistant prompts for code review,
|
||||
/// documentation generation, and debugging help.
|
||||
///
|
||||
/// Use the 'code_review' prompt for reviewing code snippets.
|
||||
/// Use the 'generate_docs' prompt for generating documentation.
|
||||
/// Use the 'debug_help' prompt for debugging assistance.
|
||||
#[Prompts]
|
||||
impl DevPrompts {
|
||||
/// Review code for potential issues, style, and best practices.
|
||||
async fn code_review(
|
||||
&self,
|
||||
/// The code snippet to review
|
||||
#[mcp(required)]
|
||||
code: Option<String>,
|
||||
/// The programming language of the code
|
||||
language: Option<String>,
|
||||
/// Focus area: "security", "performance", "style", or "all"
|
||||
focus: Option<String>,
|
||||
) -> PromptMessages {
|
||||
let lang = language.unwrap_or_else(|| "unknown".to_string());
|
||||
let focus_area = focus.unwrap_or_else(|| "all".to_string());
|
||||
|
||||
PromptMessages::new()
|
||||
.user(Text(format!(
|
||||
"Please review the following {} code. Focus on: {}\n\n```{}\n{}\n```",
|
||||
lang,
|
||||
focus_area,
|
||||
lang,
|
||||
code.unwrap()
|
||||
)))
|
||||
.assistant(Text(format!(
|
||||
"I'm {}, and I'll review this {} code focusing on {}. Let me analyze it...",
|
||||
self.assistant_name, lang, focus_area
|
||||
)))
|
||||
}
|
||||
|
||||
/// Generate documentation for a code snippet.
|
||||
async fn generate_docs(
|
||||
&self,
|
||||
/// The code to document
|
||||
#[mcp(required)]
|
||||
code: Option<String>,
|
||||
/// Documentation style: "markdown", "jsdoc", "rustdoc", etc.
|
||||
style: Option<String>,
|
||||
) -> PromptMessages {
|
||||
let doc_style = style.unwrap_or_else(|| "markdown".to_string());
|
||||
|
||||
PromptMessages::new().user(Text(format!(
|
||||
"Generate {} documentation for the following code:\n\n```\n{}\n```",
|
||||
doc_style,
|
||||
code.unwrap()
|
||||
)))
|
||||
}
|
||||
|
||||
/// Get help debugging an issue.
|
||||
async fn debug_help(
|
||||
&self,
|
||||
/// Description of the problem
|
||||
#[mcp(required)]
|
||||
problem: Option<String>,
|
||||
/// The error message, if any
|
||||
error_message: Option<String>,
|
||||
/// Relevant code snippet
|
||||
code: Option<String>,
|
||||
) -> PromptMessages {
|
||||
let mut prompt = format!(
|
||||
"I need help debugging an issue.\n\nProblem: {}",
|
||||
problem.unwrap()
|
||||
);
|
||||
|
||||
if let Some(err) = error_message {
|
||||
prompt.push_str(&format!("\n\nError message:\n```\n{}\n```", err));
|
||||
}
|
||||
|
||||
if let Some(code_snippet) = code {
|
||||
prompt.push_str(&format!("\n\nRelevant code:\n```\n{}\n```", code_snippet));
|
||||
}
|
||||
|
||||
PromptMessages::new()
|
||||
.user(Text(prompt))
|
||||
.assistant(Text(format!(
|
||||
"I'm {} and I'll help you debug this issue. Let me analyze the problem...",
|
||||
self.assistant_name
|
||||
)))
|
||||
}
|
||||
|
||||
/// Get a simple greeting from the assistant.
|
||||
async fn greet(&self) -> String {
|
||||
format!(
|
||||
"Hello! I'm {}, your development assistant. I can help you with:\n\
|
||||
- Code reviews (use 'code_review' prompt)\n\
|
||||
- Documentation generation (use 'generate_docs' prompt)\n\
|
||||
- Debugging help (use 'debug_help' prompt)\n\n\
|
||||
How can I assist you today?",
|
||||
self.assistant_name
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> std::io::Result<()> {
|
||||
if std::env::var_os("RUST_LOG").is_none() {
|
||||
std::env::set_var("RUST_LOG", "poem=debug");
|
||||
}
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
let listener = TcpListener::bind("127.0.0.1:8000");
|
||||
let app = Route::new()
|
||||
.at(
|
||||
"/",
|
||||
streamable_http::endpoint(|_| {
|
||||
let tools = DevTools { review_count: 0 };
|
||||
let prompts = DevPrompts {
|
||||
assistant_name: "CodeBot".to_string(),
|
||||
};
|
||||
McpServer::new().tools(tools).prompts(prompts)
|
||||
}),
|
||||
)
|
||||
.with(Cors::new());
|
||||
Server::new(listener).run(app).await
|
||||
}
|
||||
10
examples/mcpserver/prompts/Cargo.toml
Normal file
10
examples/mcpserver/prompts/Cargo.toml
Normal file
@@ -0,0 +1,10 @@
|
||||
[package]
|
||||
name = "prompts-example"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
poem-mcpserver.workspace = true
|
||||
serde = { version = "1.0.219", features = ["derive"] }
|
||||
schemars = "1.0"
|
||||
tokio = { workspace = true, features = ["macros", "rt-multi-thread", "sync"] }
|
||||
169
examples/mcpserver/prompts/src/main.rs
Normal file
169
examples/mcpserver/prompts/src/main.rs
Normal file
@@ -0,0 +1,169 @@
|
||||
use poem_mcpserver::{
|
||||
content::Text, prompts::PromptMessages, stdio::stdio, McpServer, Prompts, Tools,
|
||||
};
|
||||
|
||||
/// A collection of development assistant tools.
|
||||
struct DevTools {
|
||||
/// History of reviewed code snippets
|
||||
review_count: u32,
|
||||
}
|
||||
|
||||
/// This server provides development assistant tools for code analysis.
|
||||
#[Tools]
|
||||
impl DevTools {
|
||||
/// Analyze code complexity and return metrics.
|
||||
async fn analyze_complexity(
|
||||
&mut self,
|
||||
/// The code to analyze
|
||||
code: String,
|
||||
) -> Text<String> {
|
||||
let lines = code.lines().count();
|
||||
let chars = code.len();
|
||||
self.review_count += 1;
|
||||
Text(format!(
|
||||
"Code Analysis #{}\n- Lines: {}\n- Characters: {}\n- Estimated complexity: {}",
|
||||
self.review_count,
|
||||
lines,
|
||||
chars,
|
||||
if lines > 50 {
|
||||
"High"
|
||||
} else if lines > 20 {
|
||||
"Medium"
|
||||
} else {
|
||||
"Low"
|
||||
}
|
||||
))
|
||||
}
|
||||
|
||||
/// Count occurrences of a pattern in code.
|
||||
async fn count_pattern(
|
||||
&self,
|
||||
/// The code to search in
|
||||
code: String,
|
||||
/// The pattern to search for
|
||||
pattern: String,
|
||||
) -> Text<String> {
|
||||
let count = code.matches(&pattern).count();
|
||||
Text(format!("Found {} occurrences of '{}'", count, pattern))
|
||||
}
|
||||
|
||||
/// Get the total number of code reviews performed.
|
||||
async fn get_review_count(&self) -> Text<u32> {
|
||||
Text(self.review_count)
|
||||
}
|
||||
}
|
||||
|
||||
/// A collection of development assistant prompts.
|
||||
struct DevPrompts {
|
||||
/// The assistant's persona name
|
||||
assistant_name: String,
|
||||
}
|
||||
|
||||
/// This server provides development assistant prompts for code review,
|
||||
/// documentation generation, and debugging help.
|
||||
///
|
||||
/// Use the 'code_review' prompt for reviewing code snippets.
|
||||
/// Use the 'generate_docs' prompt for generating documentation.
|
||||
/// Use the 'debug_help' prompt for debugging assistance.
|
||||
#[Prompts]
|
||||
impl DevPrompts {
|
||||
/// Review code for potential issues, style, and best practices.
|
||||
async fn code_review(
|
||||
&self,
|
||||
/// The code snippet to review
|
||||
#[mcp(required)]
|
||||
code: Option<String>,
|
||||
/// The programming language of the code
|
||||
language: Option<String>,
|
||||
/// Focus area: "security", "performance", "style", or "all"
|
||||
focus: Option<String>,
|
||||
) -> PromptMessages {
|
||||
let lang = language.unwrap_or_else(|| "unknown".to_string());
|
||||
let focus_area = focus.unwrap_or_else(|| "all".to_string());
|
||||
|
||||
PromptMessages::new()
|
||||
.user(Text(format!(
|
||||
"Please review the following {} code. Focus on: {}\n\n```{}\n{}\n```",
|
||||
lang,
|
||||
focus_area,
|
||||
lang,
|
||||
code.unwrap()
|
||||
)))
|
||||
.assistant(Text(format!(
|
||||
"I'm {}, and I'll review this {} code focusing on {}. Let me analyze it...",
|
||||
self.assistant_name, lang, focus_area
|
||||
)))
|
||||
}
|
||||
|
||||
/// Generate documentation for a code snippet.
|
||||
async fn generate_docs(
|
||||
&self,
|
||||
/// The code to document
|
||||
#[mcp(required)]
|
||||
code: Option<String>,
|
||||
/// Documentation style: "markdown", "jsdoc", "rustdoc", etc.
|
||||
style: Option<String>,
|
||||
) -> PromptMessages {
|
||||
let doc_style = style.unwrap_or_else(|| "markdown".to_string());
|
||||
|
||||
PromptMessages::new().user(Text(format!(
|
||||
"Generate {} documentation for the following code:\n\n```\n{}\n```",
|
||||
doc_style,
|
||||
code.unwrap()
|
||||
)))
|
||||
}
|
||||
|
||||
/// Get help debugging an issue.
|
||||
async fn debug_help(
|
||||
&self,
|
||||
/// Description of the problem
|
||||
#[mcp(required)]
|
||||
problem: Option<String>,
|
||||
/// The error message, if any
|
||||
error_message: Option<String>,
|
||||
/// Relevant code snippet
|
||||
code: Option<String>,
|
||||
) -> PromptMessages {
|
||||
let mut prompt = format!(
|
||||
"I need help debugging an issue.\n\nProblem: {}",
|
||||
problem.unwrap()
|
||||
);
|
||||
|
||||
if let Some(err) = error_message {
|
||||
prompt.push_str(&format!("\n\nError message:\n```\n{}\n```", err));
|
||||
}
|
||||
|
||||
if let Some(code_snippet) = code {
|
||||
prompt.push_str(&format!("\n\nRelevant code:\n```\n{}\n```", code_snippet));
|
||||
}
|
||||
|
||||
PromptMessages::new()
|
||||
.user(Text(prompt))
|
||||
.assistant(Text(format!(
|
||||
"I'm {} and I'll help you debug this issue. Let me analyze the problem...",
|
||||
self.assistant_name
|
||||
)))
|
||||
}
|
||||
|
||||
/// Get a simple greeting from the assistant.
|
||||
async fn greet(&self) -> String {
|
||||
format!(
|
||||
"Hello! I'm {}, your development assistant. I can help you with:\n\
|
||||
- Code reviews (use 'code_review' prompt)\n\
|
||||
- Documentation generation (use 'generate_docs' prompt)\n\
|
||||
- Debugging help (use 'debug_help' prompt)\n\n\
|
||||
How can I assist you today?",
|
||||
self.assistant_name
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> std::io::Result<()> {
|
||||
let tools = DevTools { review_count: 0 };
|
||||
let prompts = DevPrompts {
|
||||
assistant_name: "CodeBot".to_string(),
|
||||
};
|
||||
|
||||
stdio(McpServer::new().tools(tools).prompts(prompts)).await
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
mod prompts;
|
||||
mod tools;
|
||||
mod utils;
|
||||
|
||||
@@ -33,3 +34,14 @@ pub fn Tools(args: TokenStream, input: TokenStream) -> TokenStream {
|
||||
Err(err) => err.write_errors().into(),
|
||||
}
|
||||
}
|
||||
|
||||
#[proc_macro_attribute]
|
||||
#[allow(non_snake_case)]
|
||||
pub fn Prompts(args: TokenStream, input: TokenStream) -> TokenStream {
|
||||
let prompt_args = parse_nested_meta!(prompts::PromptsArgs, args);
|
||||
let item_impl = parse_macro_input!(input as ItemImpl);
|
||||
match prompts::generate(prompt_args, item_impl) {
|
||||
Ok(stream) => stream.into(),
|
||||
Err(err) => err.write_errors().into(),
|
||||
}
|
||||
}
|
||||
|
||||
149
poem-mcpserver-macros/src/prompts.rs
Normal file
149
poem-mcpserver-macros/src/prompts.rs
Normal file
@@ -0,0 +1,149 @@
|
||||
use darling::{Error, FromMeta, Result};
|
||||
use proc_macro2::TokenStream;
|
||||
use quote::quote;
|
||||
use syn::{FnArg, ImplItem, ItemImpl, Pat};
|
||||
|
||||
use crate::utils::*;
|
||||
|
||||
#[derive(FromMeta, Default)]
|
||||
pub(crate) struct PromptsArgs {}
|
||||
|
||||
#[derive(FromMeta, Default)]
|
||||
pub(crate) struct PromptArgs {
|
||||
name: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(FromMeta, Default)]
|
||||
pub(crate) struct PromptParamArgs {
|
||||
name: Option<String>,
|
||||
#[darling(default)]
|
||||
required: bool,
|
||||
}
|
||||
|
||||
pub(crate) fn generate(_args: PromptsArgs, mut item_impl: ItemImpl) -> Result<TokenStream> {
|
||||
let crate_name = get_crate_name();
|
||||
let ident = item_impl.self_ty.clone();
|
||||
let mut prompts_descriptions = vec![];
|
||||
let mut get_branches = vec![];
|
||||
|
||||
for item in &mut item_impl.items {
|
||||
if let ImplItem::Fn(method) = item {
|
||||
let prompt_args = parse_mcp_attrs::<PromptArgs>(&method.attrs)?;
|
||||
remove_mcp_attrs(&mut method.attrs);
|
||||
|
||||
let prompt_name = match &prompt_args.name {
|
||||
Some(name) => name.clone(),
|
||||
None => method.sig.ident.to_string(),
|
||||
};
|
||||
let prompt_description = get_description(&method.attrs).unwrap_or_default();
|
||||
|
||||
if method.sig.asyncness.is_none() {
|
||||
return Err(Error::custom("must be asynchronous").with_span(&method.sig.ident));
|
||||
}
|
||||
|
||||
if method.sig.inputs.is_empty() {
|
||||
return Err(Error::custom("at least one `&self` receiver is required.")
|
||||
.with_span(&method.sig.ident));
|
||||
}
|
||||
|
||||
if !matches!(&method.sig.inputs[0], FnArg::Receiver(_)) {
|
||||
return Err(
|
||||
Error::custom("the first parameter must be a `&self` receiver.")
|
||||
.with_span(&method.sig.inputs[0]),
|
||||
);
|
||||
}
|
||||
|
||||
let mut prompt_arguments = vec![];
|
||||
let mut arg_extractions = vec![];
|
||||
let mut arg_names = vec![];
|
||||
let mut required_checks = vec![];
|
||||
|
||||
for arg in method.sig.inputs.iter_mut().skip(1) {
|
||||
let FnArg::Typed(pat) = arg else {
|
||||
unreachable!()
|
||||
};
|
||||
let Pat::Ident(ident) = &mut *pat.pat else {
|
||||
return Err(Error::custom("expected ident").with_span(&pat.pat));
|
||||
};
|
||||
|
||||
let param_args = parse_mcp_attrs::<PromptParamArgs>(&pat.attrs)?;
|
||||
remove_mcp_attrs(&mut pat.attrs);
|
||||
|
||||
let param_name = match ¶m_args.name {
|
||||
Some(name) => name.clone(),
|
||||
None => ident.ident.to_string(),
|
||||
};
|
||||
let param_desc = get_description(&pat.attrs).unwrap_or_default();
|
||||
remove_description(&mut pat.attrs);
|
||||
let is_required = param_args.required;
|
||||
|
||||
let arg_ident = &ident.ident;
|
||||
|
||||
prompt_arguments.push(quote! {
|
||||
#crate_name::protocol::prompts::PromptArgument {
|
||||
name: #param_name,
|
||||
description: #param_desc,
|
||||
required: #is_required,
|
||||
},
|
||||
});
|
||||
|
||||
if is_required {
|
||||
required_checks.push(quote! {
|
||||
if !arguments.contains_key(#param_name) {
|
||||
return ::std::result::Result::Err(
|
||||
#crate_name::protocol::rpc::RpcError::invalid_params(
|
||||
format!("missing required argument: {}", #param_name)
|
||||
)
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
arg_extractions.push(quote! {
|
||||
let #arg_ident: ::std::option::Option<::std::string::String> = arguments.get(#param_name).cloned();
|
||||
});
|
||||
arg_names.push(quote! { #arg_ident });
|
||||
}
|
||||
|
||||
let method_ident = &method.sig.ident;
|
||||
|
||||
get_branches.push(quote! {
|
||||
#prompt_name => {
|
||||
#(#required_checks)*
|
||||
#(#arg_extractions)*
|
||||
let response = self.#method_ident(#(#arg_names),*).await;
|
||||
::std::result::Result::Ok(#crate_name::prompts::IntoPromptResponse::into_prompt_response(response))
|
||||
}
|
||||
});
|
||||
|
||||
prompts_descriptions.push(quote! {
|
||||
#crate_name::protocol::prompts::Prompt {
|
||||
name: #prompt_name,
|
||||
description: #prompt_description,
|
||||
arguments: &[#(#prompt_arguments)*],
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(quote! {
|
||||
#item_impl
|
||||
|
||||
impl #crate_name::prompts::Prompts for #ident {
|
||||
fn list() -> ::std::vec::Vec<#crate_name::protocol::prompts::Prompt> {
|
||||
::std::vec![#(#prompts_descriptions)*]
|
||||
}
|
||||
|
||||
async fn get(
|
||||
&self,
|
||||
name: &::std::primitive::str,
|
||||
arguments: ::std::collections::HashMap<::std::string::String, ::std::string::String>,
|
||||
) -> ::std::result::Result<#crate_name::protocol::prompts::PromptGetResponse, #crate_name::protocol::rpc::RpcError> {
|
||||
match name {
|
||||
#(#get_branches)*
|
||||
_ => ::std::result::Result::Err(#crate_name::protocol::rpc::RpcError::method_not_found(format!("prompt not found: {}", name))),
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -7,6 +7,7 @@
|
||||
#![warn(missing_docs)]
|
||||
|
||||
pub mod content;
|
||||
pub mod prompts;
|
||||
pub mod protocol;
|
||||
mod server;
|
||||
pub mod stdio;
|
||||
@@ -14,7 +15,7 @@ pub mod stdio;
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "streamable-http")))]
|
||||
pub mod streamable_http;
|
||||
pub mod tool;
|
||||
pub use poem_mcpserver_macros::Tools;
|
||||
pub use poem_mcpserver_macros::{Prompts, Tools};
|
||||
pub use schemars::JsonSchema;
|
||||
pub use server::McpServer;
|
||||
|
||||
@@ -22,5 +23,5 @@ pub use server::McpServer;
|
||||
pub mod private {
|
||||
pub use serde_json;
|
||||
|
||||
pub use crate::tool::IntoToolResponse;
|
||||
pub use crate::{prompts::IntoPromptResponse, tool::IntoToolResponse};
|
||||
}
|
||||
|
||||
180
poem-mcpserver/src/prompts.rs
Normal file
180
poem-mcpserver/src/prompts.rs
Normal file
@@ -0,0 +1,180 @@
|
||||
//! Types for prompts.
|
||||
|
||||
use std::future::Future;
|
||||
|
||||
use crate::{
|
||||
content::IntoContent,
|
||||
protocol::{
|
||||
content::Content,
|
||||
prompts::{Prompt, PromptGetResponse, PromptMessage, Role},
|
||||
rpc::RpcError,
|
||||
},
|
||||
};
|
||||
|
||||
/// Represents a type that can be converted into a prompt response.
|
||||
pub trait IntoPromptResponse {
|
||||
/// Consumes the object and converts it into a prompt response.
|
||||
fn into_prompt_response(self) -> PromptGetResponse;
|
||||
}
|
||||
|
||||
impl IntoPromptResponse for PromptGetResponse {
|
||||
#[inline]
|
||||
fn into_prompt_response(self) -> PromptGetResponse {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoPromptResponse for PromptMessage {
|
||||
#[inline]
|
||||
fn into_prompt_response(self) -> PromptGetResponse {
|
||||
PromptGetResponse {
|
||||
description: "",
|
||||
messages: vec![self],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoPromptResponse for Vec<PromptMessage> {
|
||||
#[inline]
|
||||
fn into_prompt_response(self) -> PromptGetResponse {
|
||||
PromptGetResponse {
|
||||
description: "",
|
||||
messages: self,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoPromptResponse for String {
|
||||
#[inline]
|
||||
fn into_prompt_response(self) -> PromptGetResponse {
|
||||
PromptGetResponse {
|
||||
description: "",
|
||||
messages: vec![PromptMessage {
|
||||
role: Role::User,
|
||||
content: Content::Text { text: self },
|
||||
}],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoPromptResponse for &str {
|
||||
#[inline]
|
||||
fn into_prompt_response(self) -> PromptGetResponse {
|
||||
PromptGetResponse {
|
||||
description: "",
|
||||
messages: vec![PromptMessage {
|
||||
role: Role::User,
|
||||
content: Content::Text {
|
||||
text: self.to_string(),
|
||||
},
|
||||
}],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> IntoPromptResponse for (Role, T)
|
||||
where
|
||||
T: IntoContent,
|
||||
{
|
||||
#[inline]
|
||||
fn into_prompt_response(self) -> PromptGetResponse {
|
||||
PromptGetResponse {
|
||||
description: "",
|
||||
messages: vec![PromptMessage {
|
||||
role: self.0,
|
||||
content: self.1.into_content(),
|
||||
}],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A builder for creating prompt responses with multiple messages.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct PromptMessages {
|
||||
messages: Vec<PromptMessage>,
|
||||
}
|
||||
|
||||
impl PromptMessages {
|
||||
/// Creates a new empty prompt messages builder.
|
||||
#[inline]
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
messages: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds a user message to the prompt.
|
||||
#[inline]
|
||||
pub fn user(mut self, content: impl IntoContent) -> Self {
|
||||
self.messages.push(PromptMessage {
|
||||
role: Role::User,
|
||||
content: content.into_content(),
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// Adds an assistant message to the prompt.
|
||||
#[inline]
|
||||
pub fn assistant(mut self, content: impl IntoContent) -> Self {
|
||||
self.messages.push(PromptMessage {
|
||||
role: Role::Assistant,
|
||||
content: content.into_content(),
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// Adds a message with a specific role to the prompt.
|
||||
#[inline]
|
||||
pub fn message(mut self, role: Role, content: impl IntoContent) -> Self {
|
||||
self.messages.push(PromptMessage {
|
||||
role,
|
||||
content: content.into_content(),
|
||||
});
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoPromptResponse for PromptMessages {
|
||||
#[inline]
|
||||
fn into_prompt_response(self) -> PromptGetResponse {
|
||||
PromptGetResponse {
|
||||
description: "",
|
||||
messages: self.messages,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a prompts collection.
|
||||
pub trait Prompts {
|
||||
/// Returns a list of prompts.
|
||||
fn list() -> Vec<Prompt>;
|
||||
|
||||
/// Gets a prompt by name with the given arguments.
|
||||
fn get(
|
||||
&self,
|
||||
name: &str,
|
||||
arguments: std::collections::HashMap<String, String>,
|
||||
) -> impl Future<Output = Result<PromptGetResponse, RpcError>> + Send;
|
||||
}
|
||||
|
||||
/// Empty prompts collection.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct NoPrompts;
|
||||
|
||||
impl Prompts for NoPrompts {
|
||||
#[inline]
|
||||
fn list() -> Vec<Prompt> {
|
||||
vec![]
|
||||
}
|
||||
|
||||
#[inline]
|
||||
async fn get(
|
||||
&self,
|
||||
name: &str,
|
||||
_arguments: std::collections::HashMap<String, String>,
|
||||
) -> Result<PromptGetResponse, RpcError> {
|
||||
Err(RpcError::method_not_found(format!(
|
||||
"prompt '{name}' not found"
|
||||
)))
|
||||
}
|
||||
}
|
||||
@@ -12,6 +12,17 @@ pub struct PromptsListRequest {
|
||||
pub cursor: Option<String>,
|
||||
}
|
||||
|
||||
/// A request to get a prompt.
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PromptsGetRequest {
|
||||
/// The name of the prompt to get.
|
||||
pub name: String,
|
||||
/// Arguments to pass to the prompt.
|
||||
#[serde(default)]
|
||||
pub arguments: std::collections::HashMap<String, String>,
|
||||
}
|
||||
|
||||
/// Prompt argument.
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct PromptArgument {
|
||||
|
||||
@@ -6,7 +6,7 @@ use serde_json::Value;
|
||||
|
||||
use crate::protocol::{
|
||||
initialize::InitializeRequest,
|
||||
prompts::PromptsListRequest,
|
||||
prompts::{PromptsGetRequest, PromptsListRequest},
|
||||
tool::{ToolsCallRequest, ToolsListRequest},
|
||||
};
|
||||
|
||||
@@ -62,6 +62,12 @@ pub enum Requests {
|
||||
#[serde(default)]
|
||||
params: PromptsListRequest,
|
||||
},
|
||||
/// Get a prompt.
|
||||
#[serde(rename = "prompts/get")]
|
||||
PromptsGet {
|
||||
/// Prompts get request parameters.
|
||||
params: PromptsGetRequest,
|
||||
},
|
||||
/// Resources list.
|
||||
#[serde(rename = "resources/list")]
|
||||
ResourcesList {
|
||||
|
||||
@@ -3,13 +3,14 @@ use std::collections::HashSet;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::{
|
||||
prompts::{NoPrompts, Prompts},
|
||||
protocol::{
|
||||
JSON_RPC_VERSION,
|
||||
initialize::{
|
||||
InitializeRequest, InitializeResponse, PromptsCapability, ResourcesCapability,
|
||||
ServerCapabilities, ServerInfo, ToolsCapability,
|
||||
},
|
||||
prompts::PromptsListResponse,
|
||||
prompts::{PromptsGetRequest, PromptsListResponse},
|
||||
resources::ResourcesListResponse,
|
||||
rpc::{Request, RequestId, Requests, Response},
|
||||
tool::{ToolsCallRequest, ToolsListResponse},
|
||||
@@ -18,25 +19,27 @@ use crate::{
|
||||
};
|
||||
|
||||
/// A server that can be used to handle MCP requests.
|
||||
pub struct McpServer<ToolsType = NoTools> {
|
||||
pub struct McpServer<ToolsType = NoTools, PromptsType = NoPrompts> {
|
||||
tools: ToolsType,
|
||||
prompts: PromptsType,
|
||||
disabled_tools: HashSet<String>,
|
||||
server_info: ServerInfo,
|
||||
}
|
||||
|
||||
impl Default for McpServer<NoTools> {
|
||||
impl Default for McpServer<NoTools, NoPrompts> {
|
||||
#[inline]
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl McpServer<NoTools> {
|
||||
impl McpServer<NoTools, NoPrompts> {
|
||||
/// Creates a new MCP server.
|
||||
#[inline]
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
tools: NoTools,
|
||||
prompts: NoPrompts,
|
||||
disabled_tools: HashSet::new(),
|
||||
server_info: ServerInfo {
|
||||
name: "poem-mcpserver".to_string(),
|
||||
@@ -46,18 +49,34 @@ impl McpServer<NoTools> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<ToolsType> McpServer<ToolsType>
|
||||
impl<ToolsType, PromptsType> McpServer<ToolsType, PromptsType>
|
||||
where
|
||||
ToolsType: Tools,
|
||||
PromptsType: Prompts,
|
||||
{
|
||||
/// Sets the tools that the server will use.
|
||||
#[inline]
|
||||
pub fn tools<T>(self, tools: T) -> McpServer<T>
|
||||
pub fn tools<T>(self, tools: T) -> McpServer<T, PromptsType>
|
||||
where
|
||||
T: Tools,
|
||||
{
|
||||
McpServer {
|
||||
tools,
|
||||
prompts: self.prompts,
|
||||
disabled_tools: self.disabled_tools,
|
||||
server_info: self.server_info,
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets the prompts that the server will use.
|
||||
#[inline]
|
||||
pub fn prompts<P>(self, prompts: P) -> McpServer<ToolsType, P>
|
||||
where
|
||||
P: Prompts,
|
||||
{
|
||||
McpServer {
|
||||
tools: self.tools,
|
||||
prompts,
|
||||
disabled_tools: self.disabled_tools,
|
||||
server_info: self.server_info,
|
||||
}
|
||||
@@ -172,6 +191,41 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_prompts_list(&self, id: Option<RequestId>) -> Response<Value> {
|
||||
Response {
|
||||
jsonrpc: JSON_RPC_VERSION.to_string(),
|
||||
id,
|
||||
result: Some(PromptsListResponse {
|
||||
prompts: PromptsType::list(),
|
||||
}),
|
||||
error: None,
|
||||
}
|
||||
.map_result_to_value()
|
||||
}
|
||||
|
||||
async fn handle_prompts_get(
|
||||
&self,
|
||||
request: PromptsGetRequest,
|
||||
id: Option<RequestId>,
|
||||
) -> Response<Value> {
|
||||
match self.prompts.get(&request.name, request.arguments).await {
|
||||
Ok(response) => Response {
|
||||
jsonrpc: JSON_RPC_VERSION.to_string(),
|
||||
id,
|
||||
result: Some(response),
|
||||
error: None,
|
||||
}
|
||||
.map_result_to_value(),
|
||||
Err(err) => Response::<()> {
|
||||
jsonrpc: JSON_RPC_VERSION.to_string(),
|
||||
id,
|
||||
result: None,
|
||||
error: Some(err),
|
||||
}
|
||||
.map_result_to_value(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Handles a request and returns a response.
|
||||
pub async fn handle_request(&mut self, request: Request) -> Option<Response<Value>> {
|
||||
match request.body {
|
||||
@@ -183,15 +237,10 @@ where
|
||||
Requests::ToolsCall { params } => {
|
||||
Some(self.handle_tools_call(params, request.id).await)
|
||||
}
|
||||
Requests::PromptsList { .. } => Some(
|
||||
Response {
|
||||
jsonrpc: JSON_RPC_VERSION.to_string(),
|
||||
id: request.id,
|
||||
result: Some(PromptsListResponse { prompts: vec![] }),
|
||||
error: None,
|
||||
Requests::PromptsList { .. } => Some(self.handle_prompts_list(request.id)),
|
||||
Requests::PromptsGet { params } => {
|
||||
Some(self.handle_prompts_get(params, request.id).await)
|
||||
}
|
||||
.map_result_to_value(),
|
||||
),
|
||||
Requests::ResourcesList { .. } => Some(
|
||||
Response {
|
||||
jsonrpc: JSON_RPC_VERSION.to_string(),
|
||||
|
||||
@@ -5,6 +5,7 @@ use tokio::io::{AsyncBufReadExt, BufReader};
|
||||
|
||||
use crate::{
|
||||
McpServer,
|
||||
prompts::Prompts,
|
||||
protocol::{
|
||||
JSON_RPC_VERSION,
|
||||
rpc::{BatchRequest, Response, RpcError},
|
||||
@@ -17,9 +18,12 @@ fn print_response(response: impl Serialize) {
|
||||
}
|
||||
|
||||
/// Run the server using standard input and output.
|
||||
pub async fn stdio<ToolsType>(server: McpServer<ToolsType>) -> std::io::Result<()>
|
||||
pub async fn stdio<ToolsType, PromptsType>(
|
||||
server: McpServer<ToolsType, PromptsType>,
|
||||
) -> std::io::Result<()>
|
||||
where
|
||||
ToolsType: Tools,
|
||||
PromptsType: Prompts,
|
||||
{
|
||||
let mut server = server;
|
||||
let mut input = BufReader::new(tokio::io::stdin()).lines();
|
||||
|
||||
@@ -6,100 +6,119 @@ use std::{
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use mime::Mime;
|
||||
use poem::{
|
||||
EndpointExt, IntoEndpoint, IntoResponse, Request, handler,
|
||||
http::{HeaderMap, StatusCode},
|
||||
http::StatusCode,
|
||||
post,
|
||||
web::{
|
||||
Accept, Data, Json,
|
||||
Accept, Data, Json, Query,
|
||||
sse::{Event, SSE},
|
||||
},
|
||||
};
|
||||
use serde_json::Value;
|
||||
use tokio::time::Instant;
|
||||
|
||||
use crate::{
|
||||
McpServer,
|
||||
protocol::rpc::{BatchRequest as McpBatchRequest, Request as McpRequest},
|
||||
prompts::Prompts,
|
||||
protocol::rpc::{BatchRequest as McpBatchRequest, Request as McpRequest, Requests},
|
||||
tool::Tools,
|
||||
};
|
||||
|
||||
const SESSION_TIMEOUT: Duration = Duration::from_secs(60 * 5);
|
||||
|
||||
type ServerFactoryFn<ToolsType> = Box<dyn Fn(&Request) -> McpServer<ToolsType> + Send + Sync>;
|
||||
type ServerFactoryFn<ToolsType, PromptsType> =
|
||||
Box<dyn Fn(&Request) -> McpServer<ToolsType, PromptsType> + Send + Sync>;
|
||||
|
||||
struct Session<ToolsType> {
|
||||
server: Arc<tokio::sync::Mutex<McpServer<ToolsType>>>,
|
||||
struct Session<ToolsType, PromptsType> {
|
||||
server: Arc<tokio::sync::Mutex<McpServer<ToolsType, PromptsType>>>,
|
||||
sender: Option<tokio::sync::mpsc::UnboundedSender<String>>,
|
||||
last_active: Instant,
|
||||
}
|
||||
|
||||
struct State<ToolsType> {
|
||||
server_factory: ServerFactoryFn<ToolsType>,
|
||||
sessions: Mutex<HashMap<String, Session<ToolsType>>>,
|
||||
struct State<ToolsType, PromptsType> {
|
||||
server_factory: ServerFactoryFn<ToolsType, PromptsType>,
|
||||
sessions: Mutex<HashMap<String, Session<ToolsType, PromptsType>>>,
|
||||
}
|
||||
|
||||
async fn handle_request<ToolsType>(
|
||||
server: Arc<tokio::sync::Mutex<McpServer<ToolsType>>>,
|
||||
session_id: &str,
|
||||
accept: &Mime,
|
||||
requests: impl Iterator<Item = McpRequest> + Send + 'static,
|
||||
) -> impl IntoResponse
|
||||
async fn process_request<ToolsType, PromptsType>(
|
||||
server: Arc<tokio::sync::Mutex<McpServer<ToolsType, PromptsType>>>,
|
||||
request: McpRequest,
|
||||
) -> Option<crate::protocol::rpc::Response<Value>>
|
||||
where
|
||||
ToolsType: Tools + Send + Sync + 'static,
|
||||
PromptsType: Prompts + Send + Sync + 'static,
|
||||
{
|
||||
tracing::info!(
|
||||
session_id = session_id,
|
||||
accept = accept.essence_str(),
|
||||
"handling requests"
|
||||
);
|
||||
|
||||
match accept.essence_str() {
|
||||
"application/json" => {
|
||||
let mut resps = vec![];
|
||||
for request in requests {
|
||||
tracing::info!(session_id = session_id, request = ?request, "received request");
|
||||
let resp = server.lock().await.handle_request(request).await;
|
||||
tracing::info!(session_id = session_id, response = ?resp, "sending response");
|
||||
resps.extend(resp);
|
||||
}
|
||||
Json(resps)
|
||||
.with_content_type("application/json")
|
||||
.into_response()
|
||||
}
|
||||
"text/event-stream" => {
|
||||
let session_id = session_id.to_string();
|
||||
SSE::new(async_stream::stream! {
|
||||
for request in requests {
|
||||
tracing::info!(session_id = session_id, request = ?request, "received request");
|
||||
let resp = server.lock().await.handle_request(request).await;
|
||||
tracing::info!(session_id = session_id, response = ?resp, "sending response");
|
||||
yield Event::message(serde_json::to_string(&resp).unwrap()).event_type("message");
|
||||
}
|
||||
})
|
||||
.into_response()
|
||||
}
|
||||
_ => StatusCode::BAD_REQUEST.into_response(),
|
||||
}
|
||||
server.lock().await.handle_request(request).await
|
||||
}
|
||||
|
||||
#[handler]
|
||||
async fn post_handler<ToolsType>(
|
||||
data: Data<&Arc<State<ToolsType>>>,
|
||||
async fn get_handler<ToolsType, PromptsType>(
|
||||
data: Data<&Arc<State<ToolsType, PromptsType>>>,
|
||||
request: &Request,
|
||||
batch_request: Json<McpBatchRequest>,
|
||||
accept: Accept,
|
||||
) -> impl IntoResponse
|
||||
where
|
||||
ToolsType: Tools + Send + Sync + 'static,
|
||||
PromptsType: Prompts + Send + Sync + 'static,
|
||||
{
|
||||
let Some(accept) = accept.0.first() else {
|
||||
let session_id = session_id();
|
||||
let server = (data.0.server_factory)(request);
|
||||
|
||||
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
|
||||
|
||||
{
|
||||
let mut sessions = data.0.sessions.lock().unwrap();
|
||||
sessions.insert(
|
||||
session_id.clone(),
|
||||
Session {
|
||||
server: Arc::new(tokio::sync::Mutex::new(server)),
|
||||
sender: Some(tx),
|
||||
last_active: Instant::now(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
session_id = session_id,
|
||||
"created new standard session (SSE)"
|
||||
);
|
||||
|
||||
SSE::new(async_stream::stream! {
|
||||
let endpoint_uri = format!("?session_id={}", session_id);
|
||||
yield Event::message(endpoint_uri).event_type("endpoint");
|
||||
|
||||
while let Some(msg) = rx.recv().await {
|
||||
yield Event::message(msg).event_type("message");
|
||||
}
|
||||
})
|
||||
.into_response()
|
||||
}
|
||||
|
||||
#[handler]
|
||||
async fn post_handler<ToolsType, PromptsType>(
|
||||
data: Data<&Arc<State<ToolsType, PromptsType>>>,
|
||||
request: &Request,
|
||||
batch_request: Json<McpBatchRequest>,
|
||||
accept: Accept,
|
||||
query: Query<HashMap<String, String>>,
|
||||
) -> impl IntoResponse
|
||||
where
|
||||
ToolsType: Tools + Send + Sync + 'static,
|
||||
PromptsType: Prompts + Send + Sync + 'static,
|
||||
{
|
||||
let session_id_param = request
|
||||
.headers()
|
||||
.get("Mcp-Session-Id")
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.map(String::from)
|
||||
.or_else(|| query.get("session_id").cloned());
|
||||
|
||||
if session_id_param.is_none() {
|
||||
let Some(_accept) = accept.0.first() else {
|
||||
return StatusCode::BAD_REQUEST.into_response();
|
||||
};
|
||||
|
||||
if batch_request.len() == 1
|
||||
&& batch_request.requests()[0].is_initialize()
|
||||
&& !request.headers().contains_key("Mcp-Session-Id")
|
||||
{
|
||||
if batch_request.len() == 1 && batch_request.requests()[0].is_initialize() {
|
||||
let session_id = session_id();
|
||||
let mut server = (data.0.server_factory)(request);
|
||||
let initialize_request = batch_request.0.into_iter().next().unwrap();
|
||||
@@ -112,54 +131,127 @@ where
|
||||
session_id.clone(),
|
||||
Session {
|
||||
server: Arc::new(tokio::sync::Mutex::new(server)),
|
||||
sender: None,
|
||||
last_active: Instant::now(),
|
||||
},
|
||||
);
|
||||
|
||||
tracing::info!(session_id = session_id, "created new session");
|
||||
tracing::info!(session_id = session_id, "created new legacy session");
|
||||
return Json(resp)
|
||||
.with_header("Mcp-Session-Id", session_id)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let Some(session_id) = request
|
||||
.headers()
|
||||
.get("Mcp-Session-Id")
|
||||
.and_then(|value| value.to_str().ok())
|
||||
else {
|
||||
return StatusCode::BAD_REQUEST.into_response();
|
||||
};
|
||||
}
|
||||
|
||||
let server = {
|
||||
let session_id = session_id_param.unwrap();
|
||||
|
||||
let (server, sender) = {
|
||||
let mut sessions = data.0.sessions.lock().unwrap();
|
||||
let Some(session) = sessions.get_mut(session_id) else {
|
||||
let Some(session) = sessions.get_mut(&session_id) else {
|
||||
return StatusCode::NOT_FOUND.into_response();
|
||||
};
|
||||
session.last_active = Instant::now();
|
||||
session.server.clone()
|
||||
(session.server.clone(), session.sender.clone())
|
||||
};
|
||||
|
||||
handle_request(server, session_id, accept, batch_request.0.into_iter())
|
||||
.await
|
||||
if let Some(tx) = sender {
|
||||
for request in batch_request.0 {
|
||||
tracing::info!(session_id = session_id, request = ?request, "received request (std)");
|
||||
let resp = process_request(server.clone(), request).await;
|
||||
if let Some(resp) = resp {
|
||||
tracing::info!(session_id = session_id, response = ?resp, "pushing to SSE");
|
||||
if tx.send(serde_json::to_string(&resp).unwrap()).is_err() {
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
}
|
||||
}
|
||||
return StatusCode::ACCEPTED.into_response();
|
||||
}
|
||||
|
||||
let all_notifications = batch_request.requests().iter().all(|request| {
|
||||
matches!(
|
||||
request.body,
|
||||
Requests::Initialized | Requests::Cancelled { .. }
|
||||
)
|
||||
});
|
||||
|
||||
let requests = batch_request.0.into_iter();
|
||||
|
||||
let accept = accept
|
||||
.0
|
||||
.first()
|
||||
.map(|value| value.essence_str())
|
||||
.unwrap_or("application/json");
|
||||
|
||||
match accept {
|
||||
"text/event-stream" => {
|
||||
if all_notifications {
|
||||
return StatusCode::ACCEPTED.into_response();
|
||||
}
|
||||
let session_id = session_id.clone();
|
||||
SSE::new(async_stream::stream! {
|
||||
for request in requests {
|
||||
tracing::info!(session_id = session_id, request = ?request, "received request");
|
||||
let resp = process_request(server.clone(), request).await;
|
||||
if let Some(resp) = resp {
|
||||
tracing::info!(session_id = session_id, response = ?resp, "sending response");
|
||||
yield Event::message(serde_json::to_string(&resp).unwrap()).event_type("message");
|
||||
}
|
||||
}
|
||||
})
|
||||
.into_response()
|
||||
}
|
||||
_ => {
|
||||
let mut resps = vec![];
|
||||
for request in requests {
|
||||
tracing::info!(session_id = session_id, request = ?request, "received request");
|
||||
let resp = process_request(server.clone(), request).await;
|
||||
if let Some(resp) = resp {
|
||||
tracing::info!(session_id = session_id, response = ?resp, "sending response");
|
||||
resps.push(resp);
|
||||
}
|
||||
}
|
||||
if resps.is_empty() {
|
||||
return StatusCode::ACCEPTED.into_response();
|
||||
}
|
||||
Json(resps)
|
||||
.with_content_type("application/json")
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[handler]
|
||||
async fn delete_handler<ToolsType>(
|
||||
data: Data<&Arc<State<ToolsType>>>,
|
||||
headers: &HeaderMap,
|
||||
async fn delete_handler<ToolsType, PromptsType>(
|
||||
data: Data<&Arc<State<ToolsType, PromptsType>>>,
|
||||
req: &Request,
|
||||
query: Query<HashMap<String, String>>,
|
||||
) -> impl IntoResponse
|
||||
where
|
||||
ToolsType: Tools + Send + Sync + 'static,
|
||||
PromptsType: Prompts + Send + Sync + 'static,
|
||||
{
|
||||
let Some(session_id) = headers
|
||||
let session_id = req
|
||||
.headers()
|
||||
.get("Mcp-Session-Id")
|
||||
.and_then(|value| value.to_str().ok())
|
||||
else {
|
||||
.map(String::from)
|
||||
.or_else(|| query.get("session_id").cloned());
|
||||
|
||||
let Some(session_id) = session_id else {
|
||||
return StatusCode::BAD_REQUEST;
|
||||
};
|
||||
|
||||
if data.sessions.lock().unwrap().remove(session_id).is_none() {
|
||||
if data
|
||||
.0
|
||||
.sessions
|
||||
.lock()
|
||||
.unwrap()
|
||||
.remove(&session_id)
|
||||
.is_none()
|
||||
{
|
||||
return StatusCode::NOT_FOUND;
|
||||
}
|
||||
|
||||
@@ -168,10 +260,11 @@ where
|
||||
}
|
||||
|
||||
/// A streamable http endpoint that can be used to handle MCP requests.
|
||||
pub fn endpoint<F, ToolsType>(server_factory: F) -> impl IntoEndpoint
|
||||
pub fn endpoint<F, ToolsType, PromptsType>(server_factory: F) -> impl IntoEndpoint
|
||||
where
|
||||
F: Fn(&Request) -> McpServer<ToolsType> + Send + Sync + 'static,
|
||||
F: Fn(&Request) -> McpServer<ToolsType, PromptsType> + Send + Sync + 'static,
|
||||
ToolsType: Tools + Send + Sync + 'static,
|
||||
PromptsType: Prompts + Send + Sync + 'static,
|
||||
{
|
||||
let state = Arc::new(State {
|
||||
server_factory: Box::new(server_factory),
|
||||
@@ -190,8 +283,9 @@ where
|
||||
}
|
||||
});
|
||||
|
||||
post(post_handler::<ToolsType>::default())
|
||||
.delete(delete_handler::<ToolsType>::default())
|
||||
post(post_handler::<ToolsType, PromptsType>::default())
|
||||
.get(get_handler::<ToolsType, PromptsType>::default())
|
||||
.delete(delete_handler::<ToolsType, PromptsType>::default())
|
||||
.data(state)
|
||||
}
|
||||
|
||||
|
||||
@@ -117,7 +117,15 @@ where
|
||||
T: Serialize + JsonSchema,
|
||||
{
|
||||
fn output_schema() -> Option<Schema> {
|
||||
Some(schemars::SchemaGenerator::default().into_root_schema_for::<T>())
|
||||
let schema = schemars::SchemaGenerator::default().into_root_schema_for::<T>();
|
||||
if let Ok(value) = serde_json::to_value(&schema) {
|
||||
if value.get("type") == Some(&serde_json::Value::String("array".to_string())) {
|
||||
panic!(
|
||||
"Tool return type must be an object, but found array. Please wrap the return value in a struct."
|
||||
);
|
||||
}
|
||||
}
|
||||
Some(schema)
|
||||
}
|
||||
|
||||
fn into_tool_response(self) -> ToolsCallResponse {
|
||||
@@ -137,7 +145,15 @@ where
|
||||
E: Display,
|
||||
{
|
||||
fn output_schema() -> Option<Schema> {
|
||||
Some(schemars::SchemaGenerator::default().into_root_schema_for::<T>())
|
||||
let schema = schemars::SchemaGenerator::default().into_root_schema_for::<T>();
|
||||
if let Ok(value) = serde_json::to_value(&schema) {
|
||||
if value.get("type") == Some(&serde_json::Value::String("array".to_string())) {
|
||||
panic!(
|
||||
"Tool return type must be an object, but found array. Please wrap the return value in a struct."
|
||||
);
|
||||
}
|
||||
}
|
||||
Some(schema)
|
||||
}
|
||||
|
||||
fn into_tool_response(self) -> ToolsCallResponse {
|
||||
|
||||
315
poem-mcpserver/tests/prompts.rs
Normal file
315
poem-mcpserver/tests/prompts.rs
Normal file
@@ -0,0 +1,315 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use poem_mcpserver::{
|
||||
McpServer, Prompts,
|
||||
content::Text,
|
||||
prompts::PromptMessages,
|
||||
protocol::{
|
||||
JSON_RPC_VERSION,
|
||||
prompts::{PromptsGetRequest, PromptsListRequest},
|
||||
rpc::{Request, RequestId, Requests},
|
||||
},
|
||||
};
|
||||
|
||||
struct TestPrompts {
|
||||
system_name: String,
|
||||
}
|
||||
|
||||
impl TestPrompts {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
system_name: "TestSystem".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[Prompts]
|
||||
impl TestPrompts {
|
||||
/// A simple greeting prompt.
|
||||
async fn greet(
|
||||
&self,
|
||||
/// The name to greet
|
||||
#[mcp(required)]
|
||||
name: Option<String>,
|
||||
) -> String {
|
||||
format!("Hello, {}! Welcome to {}.", name.unwrap(), self.system_name)
|
||||
}
|
||||
|
||||
/// A code review prompt with optional language parameter.
|
||||
async fn code_review(
|
||||
&self,
|
||||
/// The code to review
|
||||
#[mcp(required)]
|
||||
code: Option<String>,
|
||||
/// The programming language
|
||||
language: Option<String>,
|
||||
) -> PromptMessages {
|
||||
let lang = language.unwrap_or_else(|| "unknown".to_string());
|
||||
PromptMessages::new()
|
||||
.user(Text(format!(
|
||||
"Please review the following {} code:\n\n```{}\n{}\n```",
|
||||
lang,
|
||||
lang,
|
||||
code.unwrap()
|
||||
)))
|
||||
.assistant(Text(
|
||||
"I'll review this code for you. Let me analyze it...".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
/// A simple prompt without required arguments.
|
||||
async fn help(&self) -> String {
|
||||
"How can I help you today?".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn prompts_list() {
|
||||
let prompts = TestPrompts::new();
|
||||
let mut server = McpServer::new().prompts(prompts);
|
||||
|
||||
let resp = server
|
||||
.handle_request(Request {
|
||||
jsonrpc: JSON_RPC_VERSION.to_string(),
|
||||
id: Some(RequestId::Int(1)),
|
||||
body: Requests::PromptsList {
|
||||
params: PromptsListRequest { cursor: None },
|
||||
},
|
||||
})
|
||||
.await;
|
||||
|
||||
assert_eq!(
|
||||
serde_json::to_value(&resp).unwrap(),
|
||||
serde_json::json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"result": {
|
||||
"prompts": [
|
||||
{
|
||||
"name": "greet",
|
||||
"description": "A simple greeting prompt.",
|
||||
"arguments": [
|
||||
{
|
||||
"name": "name",
|
||||
"description": "The name to greet",
|
||||
"required": true
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "code_review",
|
||||
"description": "A code review prompt with optional language parameter.",
|
||||
"arguments": [
|
||||
{
|
||||
"name": "code",
|
||||
"description": "The code to review",
|
||||
"required": true
|
||||
},
|
||||
{
|
||||
"name": "language",
|
||||
"description": "The programming language",
|
||||
"required": false
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "help",
|
||||
"description": "A simple prompt without required arguments.",
|
||||
"arguments": []
|
||||
}
|
||||
]
|
||||
}
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn prompts_get_simple() {
|
||||
let prompts = TestPrompts::new();
|
||||
let mut server = McpServer::new().prompts(prompts);
|
||||
|
||||
let mut arguments = HashMap::new();
|
||||
arguments.insert("name".to_string(), "Alice".to_string());
|
||||
|
||||
let resp = server
|
||||
.handle_request(Request {
|
||||
jsonrpc: JSON_RPC_VERSION.to_string(),
|
||||
id: Some(RequestId::Int(2)),
|
||||
body: Requests::PromptsGet {
|
||||
params: PromptsGetRequest {
|
||||
name: "greet".to_string(),
|
||||
arguments,
|
||||
},
|
||||
},
|
||||
})
|
||||
.await;
|
||||
|
||||
assert_eq!(
|
||||
serde_json::to_value(&resp).unwrap(),
|
||||
serde_json::json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 2,
|
||||
"result": {
|
||||
"description": "",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": {
|
||||
"type": "text",
|
||||
"text": "Hello, Alice! Welcome to TestSystem."
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn prompts_get_with_multiple_messages() {
|
||||
let prompts = TestPrompts::new();
|
||||
let mut server = McpServer::new().prompts(prompts);
|
||||
|
||||
let mut arguments = HashMap::new();
|
||||
arguments.insert("code".to_string(), "fn main() {}".to_string());
|
||||
arguments.insert("language".to_string(), "rust".to_string());
|
||||
|
||||
let resp = server
|
||||
.handle_request(Request {
|
||||
jsonrpc: JSON_RPC_VERSION.to_string(),
|
||||
id: Some(RequestId::Int(3)),
|
||||
body: Requests::PromptsGet {
|
||||
params: PromptsGetRequest {
|
||||
name: "code_review".to_string(),
|
||||
arguments,
|
||||
},
|
||||
},
|
||||
})
|
||||
.await;
|
||||
|
||||
assert_eq!(
|
||||
serde_json::to_value(&resp).unwrap(),
|
||||
serde_json::json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 3,
|
||||
"result": {
|
||||
"description": "",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": {
|
||||
"type": "text",
|
||||
"text": "Please review the following rust code:\n\n```rust\nfn main() {}\n```"
|
||||
}
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": {
|
||||
"type": "text",
|
||||
"text": "I'll review this code for you. Let me analyze it..."
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn prompts_get_missing_required_argument() {
|
||||
let prompts = TestPrompts::new();
|
||||
let mut server = McpServer::new().prompts(prompts);
|
||||
|
||||
let resp = server
|
||||
.handle_request(Request {
|
||||
jsonrpc: JSON_RPC_VERSION.to_string(),
|
||||
id: Some(RequestId::Int(4)),
|
||||
body: Requests::PromptsGet {
|
||||
params: PromptsGetRequest {
|
||||
name: "greet".to_string(),
|
||||
arguments: HashMap::new(),
|
||||
},
|
||||
},
|
||||
})
|
||||
.await;
|
||||
|
||||
let resp_value = serde_json::to_value(&resp).unwrap();
|
||||
assert_eq!(resp_value["jsonrpc"], "2.0");
|
||||
assert_eq!(resp_value["id"], 4);
|
||||
assert!(resp_value["error"]["code"].as_i64().is_some());
|
||||
assert!(
|
||||
resp_value["error"]["message"]
|
||||
.as_str()
|
||||
.unwrap()
|
||||
.contains("missing required argument: name")
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn prompts_get_unknown_prompt() {
|
||||
let prompts = TestPrompts::new();
|
||||
let mut server = McpServer::new().prompts(prompts);
|
||||
|
||||
let resp = server
|
||||
.handle_request(Request {
|
||||
jsonrpc: JSON_RPC_VERSION.to_string(),
|
||||
id: Some(RequestId::Int(5)),
|
||||
body: Requests::PromptsGet {
|
||||
params: PromptsGetRequest {
|
||||
name: "unknown_prompt".to_string(),
|
||||
arguments: HashMap::new(),
|
||||
},
|
||||
},
|
||||
})
|
||||
.await;
|
||||
|
||||
let resp_value = serde_json::to_value(&resp).unwrap();
|
||||
assert_eq!(resp_value["jsonrpc"], "2.0");
|
||||
assert_eq!(resp_value["id"], 5);
|
||||
assert!(resp_value["error"]["code"].as_i64().is_some());
|
||||
assert!(
|
||||
resp_value["error"]["message"]
|
||||
.as_str()
|
||||
.unwrap()
|
||||
.contains("prompt not found")
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn prompts_get_no_arguments_needed() {
|
||||
let prompts = TestPrompts::new();
|
||||
let mut server = McpServer::new().prompts(prompts);
|
||||
|
||||
let resp = server
|
||||
.handle_request(Request {
|
||||
jsonrpc: JSON_RPC_VERSION.to_string(),
|
||||
id: Some(RequestId::Int(6)),
|
||||
body: Requests::PromptsGet {
|
||||
params: PromptsGetRequest {
|
||||
name: "help".to_string(),
|
||||
arguments: HashMap::new(),
|
||||
},
|
||||
},
|
||||
})
|
||||
.await;
|
||||
|
||||
assert_eq!(
|
||||
serde_json::to_value(&resp).unwrap(),
|
||||
serde_json::json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 6,
|
||||
"result": {
|
||||
"description": "",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": {
|
||||
"type": "text",
|
||||
"text": "How can I help you today?"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})
|
||||
);
|
||||
}
|
||||
@@ -278,3 +278,74 @@ async fn disable_tools() {
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[derive(JsonSchema, Serialize)]
|
||||
struct StringList {
|
||||
items: Vec<String>,
|
||||
}
|
||||
|
||||
struct CollectionTools;
|
||||
|
||||
#[Tools]
|
||||
impl CollectionTools {
|
||||
async fn get_list(&self) -> StructuredContent<StringList> {
|
||||
StructuredContent(StringList {
|
||||
items: vec!["a".to_string(), "b".to_string()],
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn collection_schema() {
|
||||
let tools = CollectionTools;
|
||||
let mut server = McpServer::new().tools(tools);
|
||||
|
||||
let resp = server
|
||||
.handle_request(Request {
|
||||
jsonrpc: JSON_RPC_VERSION.to_string(),
|
||||
id: Some(RequestId::Int(1)),
|
||||
body: Requests::ToolsList {
|
||||
params: ToolsListRequest { cursor: None },
|
||||
},
|
||||
})
|
||||
.await;
|
||||
|
||||
let resp_json = serde_json::to_value(&resp).unwrap();
|
||||
let tools = resp_json["result"]["tools"].as_array().unwrap();
|
||||
let tool = &tools[0];
|
||||
|
||||
assert_eq!(tool["name"], "get_list");
|
||||
|
||||
let output_schema = &tool["outputSchema"];
|
||||
assert_eq!(output_schema["type"], "object");
|
||||
assert_eq!(output_schema["title"], "StringList");
|
||||
assert!(output_schema["properties"]["items"]["type"] == "array");
|
||||
}
|
||||
|
||||
struct ArrayTools;
|
||||
|
||||
#[Tools]
|
||||
impl ArrayTools {
|
||||
async fn array_ret(&self) -> StructuredContent<Vec<String>> {
|
||||
StructuredContent(vec![])
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[should_panic(
|
||||
expected = "Tool return type must be an object, but found array. Please wrap the return value in a struct."
|
||||
)]
|
||||
async fn test_array_panic() {
|
||||
let tools = ArrayTools;
|
||||
let mut server = McpServer::new().tools(tools);
|
||||
|
||||
let _ = server
|
||||
.handle_request(Request {
|
||||
jsonrpc: JSON_RPC_VERSION.to_string(),
|
||||
id: Some(RequestId::Int(1)),
|
||||
body: Requests::ToolsList {
|
||||
params: ToolsListRequest { cursor: None },
|
||||
},
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user