Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions mcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ type ClientOptions struct {
// Setting CreateMessageHandler to a non-nil value causes the client to
// advertise the sampling capability.
CreateMessageHandler func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error)
// SamplingSupportsTools indicates that the client's CreateMessageHandler
// supports tool use. If true and CreateMessageHandler is set, the
// sampling.tools capability is advertised.
SamplingSupportsTools bool
// SamplingSupportsContext indicates that the client supports
// includeContext values other than "none".
SamplingSupportsContext bool
// ElicitationHandler handles incoming requests for elicitation/create.
//
// Setting ElicitationHandler to a non-nil value causes the client to
Expand Down Expand Up @@ -131,6 +138,12 @@ func (c *Client) capabilities() *ClientCapabilities {
caps.Roots.ListChanged = true
if c.opts.CreateMessageHandler != nil {
caps.Sampling = &SamplingCapabilities{}
if c.opts.SamplingSupportsTools {
caps.Sampling.Tools = &SamplingToolsCapabilities{}
}
if c.opts.SamplingSupportsContext {
caps.Sampling.Context = &SamplingContextCapabilities{}
}
}
if c.opts.ElicitationHandler != nil {
caps.Elicitation = &ElicitationCapabilities{}
Expand Down
137 changes: 133 additions & 4 deletions mcp/content.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ import (
)

// A Content is a [TextContent], [ImageContent], [AudioContent],
// [ResourceLink], or [EmbeddedResource].
// [ResourceLink], [EmbeddedResource], [ToolUseContent], or [ToolResultContent].
//
// Note: [ToolUseContent] and [ToolResultContent] are only valid in sampling
// message contexts (CreateMessageParams/CreateMessageResult).
type Content interface {
MarshalJSON() ([]byte, error)
fromWire(*wireContent)
Expand Down Expand Up @@ -183,6 +186,104 @@ func (c *EmbeddedResource) fromWire(wire *wireContent) {
c.Annotations = wire.Annotations
}

// ToolUseContent represents a request from the assistant to invoke a tool.
// This content type is only valid in sampling messages.
type ToolUseContent struct {
// ID is a unique identifier for this tool use, used to match with ToolResultContent.
ID string
// Name is the name of the tool to invoke.
Name string
// Input contains the tool arguments as a JSON object.
Input map[string]any
Meta Meta
}

func (c *ToolUseContent) MarshalJSON() ([]byte, error) {
input := c.Input
if input == nil {
input = map[string]any{}
}
wire := struct {
Type string `json:"type"`
ID string `json:"id"`
Name string `json:"name"`
Input map[string]any `json:"input"`
Meta Meta `json:"_meta,omitempty"`
}{
Type: "tool_use",
ID: c.ID,
Name: c.Name,
Input: input,
Meta: c.Meta,
}
return json.Marshal(wire)
}

func (c *ToolUseContent) fromWire(wire *wireContent) {
c.ID = wire.ID
c.Name = wire.Name
c.Input = wire.Input
c.Meta = wire.Meta
}

// ToolResultContent represents the result of a tool invocation.
// This content type is only valid in sampling messages with role "user".
type ToolResultContent struct {
// ToolUseID references the ID from the corresponding ToolUseContent.
ToolUseID string
// Content holds the unstructured result of the tool call.
Content []Content
// StructuredContent holds an optional structured result as a JSON object.
StructuredContent any
// IsError indicates whether the tool call ended in an error.
IsError bool
Meta Meta
}

func (c *ToolResultContent) MarshalJSON() ([]byte, error) {
// Marshal nested content
var contentWire []*wireContent
for _, content := range c.Content {
data, err := content.MarshalJSON()
if err != nil {
return nil, err
}
var w wireContent
if err := json.Unmarshal(data, &w); err != nil {
return nil, err
}
contentWire = append(contentWire, &w)
}
if contentWire == nil {
contentWire = []*wireContent{} // avoid JSON null
}

wire := struct {
Type string `json:"type"`
ToolUseID string `json:"toolUseId"`
Content []*wireContent `json:"content"`
StructuredContent any `json:"structuredContent,omitempty"`
IsError bool `json:"isError,omitempty"`
Meta Meta `json:"_meta,omitempty"`
}{
Type: "tool_result",
ToolUseID: c.ToolUseID,
Content: contentWire,
StructuredContent: c.StructuredContent,
IsError: c.IsError,
Meta: c.Meta,
}
return json.Marshal(wire)
}

func (c *ToolResultContent) fromWire(wire *wireContent) {
c.ToolUseID = wire.ToolUseID
c.StructuredContent = wire.StructuredContent
c.IsError = wire.IsError
c.Meta = wire.Meta
// Content is handled separately in contentFromWire due to nested content
}

// ResourceContents contains the contents of a specific resource or
// sub-resource.
type ResourceContents struct {
Expand Down Expand Up @@ -224,10 +325,9 @@ func (r *ResourceContents) MarshalJSON() ([]byte, error) {

// wireContent is the wire format for content.
// It represents the protocol types TextContent, ImageContent, AudioContent,
// ResourceLink, and EmbeddedResource.
// ResourceLink, EmbeddedResource, ToolUseContent, and ToolResultContent.
// The Type field distinguishes them. In the protocol, each type has a constant
// value for the field.
// At most one of Text, Data, Resource, and URI is non-zero.
type wireContent struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
Expand All @@ -242,6 +342,14 @@ type wireContent struct {
Meta Meta `json:"_meta,omitempty"`
Annotations *Annotations `json:"annotations,omitempty"`
Icons []Icon `json:"icons,omitempty"`
// Fields for ToolUseContent (type: "tool_use")
ID string `json:"id,omitempty"`
Input map[string]any `json:"input,omitempty"`
// Fields for ToolResultContent (type: "tool_result")
ToolUseID string `json:"toolUseId,omitempty"`
ToolResultContent []*wireContent `json:"content,omitempty"` // nested content for tool_result
StructuredContent any `json:"structuredContent,omitempty"`
IsError bool `json:"isError,omitempty"`
}

func contentsFromWire(wires []*wireContent, allow map[string]bool) ([]Content, error) {
Expand Down Expand Up @@ -284,6 +392,27 @@ func contentFromWire(wire *wireContent, allow map[string]bool) (Content, error)
v := new(EmbeddedResource)
v.fromWire(wire)
return v, nil
case "tool_use":
v := new(ToolUseContent)
v.fromWire(wire)
return v, nil
case "tool_result":
v := new(ToolResultContent)
v.fromWire(wire)
// Handle nested content - tool_result content can contain text, image, audio,
// resource_link, and resource (same as CallToolResult.content)
if wire.ToolResultContent != nil {
toolResultContentAllow := map[string]bool{
"text": true, "image": true, "audio": true,
"resource_link": true, "resource": true,
}
nestedContent, err := contentsFromWire(wire.ToolResultContent, toolResultContentAllow)
if err != nil {
return nil, fmt.Errorf("tool_result nested content: %w", err)
}
v.Content = nestedContent
}
return v, nil
}
return nil, fmt.Errorf("internal error: unrecognized content type %s", wire.Type)
return nil, fmt.Errorf("unrecognized content type %q", wire.Type)
}
48 changes: 45 additions & 3 deletions mcp/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,11 @@ type CreateMessageParams struct {
Meta `json:"_meta,omitempty"`
// A request to include context from one or more MCP servers (including the
// caller), to be attached to the prompt. The client may ignore this request.
//
// The default behavior is Default is "none". Values "thisServer" and
// "allServers" are soft-deprecated. Servers SHOULD only use these values if
// the client declares ClientCapabilities.sampling.context. These values may
// be removed in future spec releases.
IncludeContext string `json:"includeContext,omitempty"`
// The maximum number of tokens to sample, as requested by the server. The
// client may choose to sample fewer tokens than requested.
Expand All @@ -307,6 +312,12 @@ type CreateMessageParams struct {
// may modify or omit this prompt.
SystemPrompt string `json:"systemPrompt,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
// Tools is an optional list of tools available for the model to use.
// Requires the client's sampling.tools capability.
Tools []*Tool `json:"tools,omitempty"`
// ToolChoice controls how the model should use tools.
// Requires the client's sampling.tools capability.
ToolChoice *ToolChoice `json:"toolChoice,omitempty"`
}

func (x *CreateMessageParams) isParams() {}
Expand All @@ -326,6 +337,12 @@ type CreateMessageResult struct {
Model string `json:"model"`
Role Role `json:"role"`
// The reason why sampling stopped, if known.
//
// Standard values:
// - "endTurn": natural end of the assistant's turn
// - "stopSequence": a stop sequence was encountered
// - "maxTokens": reached the maximyum token limit
// - "toolUse": the model wants to use one or more tools
StopReason string `json:"stopReason,omitempty"`
}

Expand All @@ -339,8 +356,9 @@ func (r *CreateMessageResult) UnmarshalJSON(data []byte) error {
if err := json.Unmarshal(data, &wire); err != nil {
return err
}
// Allow text, image, audio, and tool_use in results
var err error
if wire.result.Content, err = contentFromWire(wire.Content, map[string]bool{"text": true, "image": true, "audio": true}); err != nil {
if wire.result.Content, err = contentFromWire(wire.Content, map[string]bool{"text": true, "image": true, "audio": true, "tool_use": true}); err != nil {
return err
}
*r = CreateMessageResult(wire.result)
Expand Down Expand Up @@ -876,7 +894,27 @@ func (x *RootsListChangedParams) GetProgressToken() any { return getProgressTok
func (x *RootsListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) }

// SamplingCapabilities describes the capabilities for sampling.
type SamplingCapabilities struct{}
type SamplingCapabilities struct {
// Context indicates the client supports includeContext values other than "none".
Context *SamplingContextCapabilities `json:"context,omitempty"`
// Tools indicates the client supports tools and toolChoice in sampling requests.
Tools *SamplingToolsCapabilities `json:"tools,omitempty"`
}

// SamplingContextCapabilities indicates the client supports context inclusion.
type SamplingContextCapabilities struct{}

// SamplingToolsCapabilities indicates the client supports tool use in sampling.
type SamplingToolsCapabilities struct{}

// ToolChoice controls how the model uses tools during sampling.
type ToolChoice struct {
// Mode controls tool invocation behavior:
// - "auto": Model decides whether to use tools (default)
// - "required": Model must use at least one tool
// - "none": Model must not use any tools
Mode string `json:"mode,omitempty"`
}

// ElicitationCapabilities describes the capabilities for elicitation.
//
Expand All @@ -895,6 +933,9 @@ type URLElicitationCapabilities struct {
}

// Describes a message issued to or received from an LLM API.
//
// For assistant messages, Content may be text, image, audio, or tool_use.
// For user messages, Content may be text, image, audio, or tool_result.
type SamplingMessage struct {
Content Content `json:"content"`
Role Role `json:"role"`
Expand All @@ -911,8 +952,9 @@ func (m *SamplingMessage) UnmarshalJSON(data []byte) error {
if err := json.Unmarshal(data, &wire); err != nil {
return err
}
// Allow text, image, audio, tool_use, and tool_result in sampling messages
var err error
if wire.msg.Content, err = contentFromWire(wire.Content, map[string]bool{"text": true, "image": true, "audio": true}); err != nil {
if wire.msg.Content, err = contentFromWire(wire.Content, map[string]bool{"text": true, "image": true, "audio": true, "tool_use": true, "tool_result": true}); err != nil {
return err
}
*m = SamplingMessage(wire.msg)
Expand Down
Loading