rust-mcp-stack/rust-mcp-schema

Should we enforce the validation for `method` field in `xxxRequest`?

Closed this issue · 4 comments

In this repository, the ClientRequest enum is marked as #[serde(untagged)], so serde will deserialize this enum by trying each variant in order:

#[derive(:: serde :: Deserialize, :: serde :: Serialize, Clone, Debug)]
#[serde(untagged)]
pub enum ClientRequest {
    InitializeRequest(InitializeRequest),
    PingRequest(PingRequest),
    ListResourcesRequest(ListResourcesRequest),
    ListResourceTemplatesRequest(ListResourceTemplatesRequest),
    ReadResourceRequest(ReadResourceRequest),
    SubscribeRequest(SubscribeRequest),
    UnsubscribeRequest(UnsubscribeRequest),
    ListPromptsRequest(ListPromptsRequest),
    GetPromptRequest(GetPromptRequest),
    ListToolsRequest(ListToolsRequest),
    CallToolRequest(CallToolRequest),
    SetLevelRequest(SetLevelRequest),
    CompleteRequest(CompleteRequest),
}

According to the doc, untagged enum will try be deserialized orderly until the first succeed.

In jsonschema, the field method in each xxxRequest are marked as constant value, but this constraint is not applied during the deserializing process, which may lead to ambiguity because illegal method name can be set into the field by deserializing from json.

So I think we should adding a custom deserializer to the method field:

#[derive(:: serde :: Deserialize, :: serde :: Serialize, Clone, Debug)]
pub struct ListResourcesRequest {
    #[serde(deserialize_with = "deserialize_list_resources_method")]
    pub method: ::std::string::String,
    #[serde(default, skip_serializing_if = "::std::option::Option::is_none")]
    pub params: ::std::option::Option<ListResourcesRequestParams>,
}

/// other code...

fn deserialize_list_resources_method<'de, D>(deserializer: D) -> std::result::Result<String, D::Error>
where
    D: Deserializer<'de>,
{
    deserialize_const_method(deserializer, mcp_methods::LIST_RESOURCES)
}

/// other code...

fn deserialize_const_method<'de, D>(
    deserializer: D,
    expected: &'static str,
) -> std::result::Result<String, D::Error>
where
    D: Deserializer<'de>,
{
    let value = String::deserialize(deserializer)?;
    if value == expected {
        Ok(value)
    } else {
        Err(serde::de::Error::custom(format!(
            "Expected method '{}', but got '{}'",
            expected, value
        )))
    }
}

And the code below will be ok:

fn handle_request(request: &mut Request) {
    let call_req: ClientRequest = serde_json::from_reader(request.as_reader())?;
    let res: ServerResult = match call_req {
        ClientRequest::InitializeRequest(InitializeRequest { method, params }) => {
              /// deal init request
        }
        _ => {
            bail!("unsupported request")
        }
    };
}

If this approach is ok, I would be happy to implement it.

Hey @zhongyi51 , thanks a lot for reporting this and the thoughtful suggestion.

Overall, your justification is reasonable, but there are a few minor details that were either overlooked or not fully considered. So let review it together:

In this repository, the ClientRequest enum is marked as #[serde(untagged)], so serde will deserialize this enum by trying each variant in order:

#[derive(:: serde :: Deserialize, :: serde :: Serialize, Clone, Debug)]
#[serde(untagged)]
pub enum ClientRequest {
InitializeRequest(InitializeRequest),
PingRequest(PingRequest),
ListResourcesRequest(ListResourcesRequest),
ListResourceTemplatesRequest(ListResourceTemplatesRequest),
ReadResourceRequest(ReadResourceRequest),
SubscribeRequest(SubscribeRequest),
UnsubscribeRequest(UnsubscribeRequest),
ListPromptsRequest(ListPromptsRequest),
GetPromptRequest(GetPromptRequest),
ListToolsRequest(ListToolsRequest),
CallToolRequest(CallToolRequest),
SetLevelRequest(SetLevelRequest),
CompleteRequest(CompleteRequest),
}

According to the doc, untagged enum will try be deserialized orderly until the first succeed.

1- The code that you quoted ☝ is not exactly what we have in the repo.
2- Deserialization of that enum does not happen in order in this case, because there already is a custom Deserializer for ClientRequest that works based on the method names :

https://github.com/rust-mcp-stack/rust-mcp-schema/blob/main/src/generated_schema/2024_11_05/mcp_schema.rs#L5785-L5853

So currently, at ClientRequest level , if we try deserializing a json payload with a illegal method name, it would fail with Error("unknown variant 'method'")

so the following will fail:

let call_req = ClientRequest = serde_json::from_str(r#"{"jsonrpc":"2.0","id":0,"method":"INVALID-METHOD","params":{}"#).unwrap();

That being said at request struct level we may encounter the issue you mentioned.

for instance, this 👇 would fail ( because ding not a valid method )

 let call_req: ClientRequest = serde_json::from_str(r#"{"jsonrpc":"2.0","id":17,"method":"ding"}"#).unwrap();

but this 👇 would NOT fail, and creates a PingRequest with wrong method name (ding):

    let call_req: PingRequest = serde_json::from_str(r#"{"jsonrpc":"2.0","id":17,"method":"ding"}"#).unwrap();

this is an edge case, may not happen much, but your solution prevents that, so not a bad idea to pursue it 👍

Because the schema code is auto-generated, I would need you to implement it in a way I can adopt and update the schema-generator based on that.

No need to implement it for all the structs, could you implement only one example variant from each of the following enums :

  • ClientRequest
  • ClientNotification
  • ServerRequest
  • ServerNotification

then I will use your code to update SchemaGen and it will produce implementation for all and every other variants in those enums

Thanks for your reply @hashemix ! Below is the code generation pattern for fields defined with a const value in the JSON schema.

General Code Generation Pattern

When a JSON schema defines a field with a const value, for example:

{
  "type": "string",
  "const": "notifications/cancelled"
}

The code generator can produces following Rust code with extra validation:

// A struct generated from the schema, e.g., `ParamType`.
#[derive(::serde::Deserialize, ::serde::Serialize, Clone, Debug)]
pub struct ParamType {
    // This field requires custom deserialization for validation.
    #[serde(deserialize_with = "deserialize_ParamType_some_field")]
    pub some_field: String,
}

// Custom deserialization function, following the `deserialize_#StructName_#FieldName` format.
#[allow(non_snake_case)]
fn deserialize_ParamType_some_field<'de, D>(
    deserializer: D,
) -> std::result::Result<String, D::Error>
where
    D: serde::de::Deserializer<'de>,
{
    let value = String::deserialize(deserializer)?;
    // The expected constant value.
    let expected = "notifications/cancelled".to_owned();

    // Validate the deserialized value.
    if value == expected {
        Ok(value)
    } else {
        // The error message with format 
        // "Expected field `#FieldName` in struct `#StructName` as const value '{}', but got '{}'"
        Err(serde::de::Error::custom(format!(
            "Expected field `some_field` in struct `ParamType` as const value '{}', but got '{}'",
            expected, value
        )))
    }
}

Concrete Examples

Example: CallToolRequest

#[derive(::serde::Deserialize, ::serde::Serialize, Clone, Debug)]
pub struct CallToolRequest {
    #[serde(deserialize_with = "deserialize_CallToolRequest_method")]
    method: ::std::string::String,
    pub params: CallToolRequestParams,
}

#[allow(non_snake_case)]
fn deserialize_CallToolRequest_method<'de, D>(
    deserializer: D,
) -> std::result::Result<String, D::Error>
where
    D: serde::de::Deserializer<'de>,
{
    let value = String::deserialize(deserializer)?;
    let expected = "tools/call".to_owned();
    if value == expected {
        Ok(value)
    } else {
        Err(serde::de::Error::custom(format!(
            "Expected field `method` in struct `CallToolRequest` as const value '{}', but got '{}'",
            expected, value
        )))
    }
}

Example: CancelledNotification

#[derive(::serde::Deserialize, ::serde::Serialize, Clone, Debug)]
pub struct CancelledNotification {
    #[serde(deserialize_with = "deserialize_CancelledNotification_method")]
    method: ::std::string::String,
    pub params: CancelledNotificationParams,
}

#[allow(non_snake_case)]
fn deserialize_CancelledNotification_method<'de, D>(
    deserializer: D,
) -> std::result::Result<String, D::Error>
where
    D: serde::de::Deserializer<'de>,
{
    let value = String::deserialize(deserializer)?;
    let expected = "notifications/cancelled".to_owned();
    if value == expected {
        Ok(value)
    } else {
        Err(serde::de::Error::custom(format!(
            "Expected field `method` in struct `CancelledNotification` as const value '{}', but got '{}'",
            expected, value
        )))
    }
}

Example: CreateMessageRequest

#[derive(::serde::Deserialize, ::serde::Serialize, Clone, Debug)]
pub struct CreateMessageRequest {
    #[serde(deserialize_with = "deserialize_CreateMessageRequest_method")]
    method: ::std::string::String,
    pub params: CreateMessageRequestParams,
}

#[allow(non_snake_case)]
fn deserialize_CreateMessageRequest_method<'de, D>(
    deserializer: D,
) -> std::result::Result<String, D::Error>
where
    D: serde::de::Deserializer<'de>,
{
    let value = String::deserialize(deserializer)?;
    let expected = "sampling/createMessage".to_owned();
    if value == expected {
        Ok(value)
    } else {
        Err(serde::de::Error::custom(format!(
            "Expected field `method` in struct `CreateMessageRequest` as const value '{}', but got '{}'",
            expected, value
        )))
    }
}

I hope this detailed breakdown is helpful.

Hey @zhongyi51 , thanks so much for the code snippet and the clear, detailed explanation. Really appreciate the effort you put into it, this is great!

To be fair and to recognize your contribution properly, I’d love for this change to come from you. If you're open to it, could you create a PR with your idea?

You only need to implement that custom deserializer for one of the enums ,like ServerRequest , and create a PR.

I’d be happy to collaborate and continue building on it from there using the schema generator , so the final result reflects your contribution.

Let me know what you think!