Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,6 @@ bin

.history/*
coverage.html
.gobuildcache
.gocache
.gopath
1 change: 1 addition & 0 deletions internal/common/config/mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ type (
Args []string `json:"args,omitempty" yaml:"args,omitempty"` // for stdio
Env map[string]string `json:"env,omitempty" yaml:"env,omitempty"` // for stdio
URL string `json:"url,omitempty" yaml:"url,omitempty"` // for sse and streamable-http
Headers map[string]string `json:"headers,omitempty" yaml:"headers,omitempty"` // for sse and streamable-http
Policy cnst.MCPStartupPolicy `json:"policy" yaml:"policy"` // onStart or onDemand
Preinstalled bool `json:"preinstalled" yaml:"preinstalled"` // whether to install this MCP server when mcp-gateway starts
}
Expand Down
2 changes: 2 additions & 0 deletions internal/common/dto/mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ type MCPServerConfig struct {
Args []string `json:"args,omitempty"` // for stdio
Env map[string]string `json:"env,omitempty"` // for stdio
URL string `json:"url,omitempty"` // for sse and streamable-http
Headers map[string]string `json:"headers,omitempty"` // for sse and streamable-http
Policy string `json:"policy"` // onStart or onDemand
Preinstalled bool `json:"preinstalled"` // whether to install this MCP server when mcp-gateway starts
}
Expand Down Expand Up @@ -290,6 +291,7 @@ func FromMCPServerConfigs(cfgs []config.MCPServerConfig) []MCPServerConfig {
Args: cfg.Args,
Env: cfg.Env,
URL: cfg.URL,
Headers: cfg.Headers,
Policy: string(cfg.Policy),
Preinstalled: cfg.Preinstalled,
}
Expand Down
5 changes: 4 additions & 1 deletion internal/common/dto/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ func TestFromArgConfigsAndItemsConfig(t *testing.T) {

func TestFromMCPServerConfigsAndAuthPrompt(t *testing.T) {
ms := []config.MCPServerConfig{{
Type: "sse", Name: "n", Command: "cmd", Args: []string{"a"}, Env: map[string]string{"k": "v"}, URL: "u", Policy: "onStart", Preinstalled: true,
Type: "sse", Name: "n", Command: "cmd", Args: []string{"a"}, Env: map[string]string{"k": "v"}, URL: "u",
Headers: map[string]string{"Authorization": "Bearer token", "X-Req": "v"},
Policy: "onStart", Preinstalled: true,
}}
out := FromMCPServerConfigs(ms)
if assert.Len(t, out, 1) {
Expand All @@ -57,6 +59,7 @@ func TestFromMCPServerConfigsAndAuthPrompt(t *testing.T) {
assert.Equal(t, "cmd", out[0].Command)
assert.Equal(t, []string{"a"}, out[0].Args)
assert.Equal(t, "u", out[0].URL)
assert.Equal(t, map[string]string{"Authorization": "Bearer token", "X-Req": "v"}, out[0].Headers)
assert.Equal(t, "onStart", out[0].Policy)
assert.True(t, out[0].Preinstalled)
}
Expand Down
26 changes: 26 additions & 0 deletions internal/core/mcpproxy/headers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package mcpproxy

import (
"fmt"

"github.com/amoylab/unla/internal/template"
)

func renderHeaders(headers map[string]string, tmplCtx *template.Context) (map[string]string, error) {
if len(headers) == 0 {
return nil, nil
}
if tmplCtx == nil {
tmplCtx = template.NewContext()
}

rendered := make(map[string]string, len(headers))
for k, v := range headers {
out, err := template.RenderTemplate(v, tmplCtx)
if err != nil {
return nil, fmt.Errorf("failed to render header template: %w", err)
}
rendered[k] = out
}
return rendered, nil
}
46 changes: 46 additions & 0 deletions internal/core/mcpproxy/headers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package mcpproxy

import (
"testing"

"github.com/amoylab/unla/internal/template"
"github.com/stretchr/testify/assert"
)

func TestRenderHeaders_Empty(t *testing.T) {
headers, err := renderHeaders(nil, nil)
assert.NoError(t, err)
assert.Nil(t, headers)

headers, err = renderHeaders(map[string]string{}, nil)
assert.NoError(t, err)
assert.Nil(t, headers)
}

func TestRenderHeaders_WithTemplateContext(t *testing.T) {
tmplCtx := template.NewContext()
tmplCtx.Env = func(key string) string {
if key == "MCP_AUTH_TOKEN" {
return "token"
}
return ""
}
tmplCtx.Request.Headers["X-Req"] = "req"

headers, err := renderHeaders(map[string]string{
"Authorization": "Bearer {{ env \"MCP_AUTH_TOKEN\" }}",
"X-Req": "{{ index .Request.Headers \"X-Req\" }}",
}, tmplCtx)

assert.NoError(t, err)
assert.Equal(t, "Bearer token", headers["Authorization"])
assert.Equal(t, "req", headers["X-Req"])
}

func TestRenderHeaders_InvalidTemplate(t *testing.T) {
headers, err := renderHeaders(map[string]string{
"X-Bad": "{{ .Request.Headers",
}, template.NewContext())
assert.Error(t, err)
assert.Nil(t, headers)
}
43 changes: 21 additions & 22 deletions internal/core/mcpproxy/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,17 @@ func (t *SSETransport) Start(ctx context.Context, tmplCtx *template.Context) err
return nil
}

renderedHeaders, err := renderHeaders(t.cfg.Headers, tmplCtx)
if err != nil {
return err
}
var opts []transport.ClientOption
if len(renderedHeaders) > 0 {
opts = append(opts, transport.WithHeaders(renderedHeaders))
}

// Create SSE transport
sseTransport, err := transport.NewSSE(t.cfg.URL)
sseTransport, err := transport.NewSSE(t.cfg.URL, opts...)
if err != nil {
return fmt.Errorf("failed to create SSE transport: %w", err)
}
Expand Down Expand Up @@ -157,31 +166,21 @@ func (t *SSETransport) CallTool(ctx context.Context, params mcp.CallToolParams,
ctx = scope.Ctx
defer scope.End()
if !t.IsRunning() {
if err := t.Start(ctx, nil); err != nil {
return nil, err
// Convert arguments to map[string]any
var args map[string]any
if err := json.Unmarshal(params.Arguments, &args); err != nil {
return nil, fmt.Errorf("invalid tool arguments: %w", err)
}
}

// Convert arguments to map[string]any
var args map[string]any
if err := json.Unmarshal(params.Arguments, &args); err != nil {
return nil, fmt.Errorf("invalid tool arguments: %w", err)
}

// Prepare template context for environment variables
tmplCtx, err := template.AssembleTemplateContext(req, args, nil)
if err != nil {
return nil, fmt.Errorf("failed to prepare template context: %w", err)
}

// Process environment variables with templates
renderedClientEnv := make(map[string]string)
for k, v := range t.cfg.Env {
rendered, err := template.RenderTemplate(v, tmplCtx)
// Prepare template context for header templates
tmplCtx, err := template.AssembleTemplateContext(req, args, nil)
if err != nil {
return nil, fmt.Errorf("failed to render env template: %w", err)
return nil, fmt.Errorf("failed to prepare template context: %w", err)
}

if err := t.Start(ctx, tmplCtx); err != nil {
return nil, err
}
renderedClientEnv[k] = rendered
}
Comment on lines 168 to 184
Copy link

Copilot AI Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The refactored CallTool method now creates template context only when the transport is not running. However, if the transport is already running (e.g., with PolicyOnStart), any dynamic header templates that depend on the current tool arguments won't be re-evaluated. This means the first request's template values will be reused for all subsequent requests. Consider whether this is the intended behavior or if headers should be re-rendered per request.

Copilot uses AI. Check for mistakes.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot open a new pull request to apply changes based on this feedback


// Prepare tool call request parameters
Expand Down
11 changes: 10 additions & 1 deletion internal/core/mcpproxy/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,17 @@ func (t *StreamableTransport) Start(ctx context.Context, tmplCtx *template.Conte
return nil
}

renderedHeaders, err := renderHeaders(t.cfg.Headers, tmplCtx)
if err != nil {
return err
}
var opts []transport.StreamableHTTPCOption
if len(renderedHeaders) > 0 {
opts = append(opts, transport.WithHTTPHeaders(renderedHeaders))
}

// Create streamable transport
streamableTransport, err := transport.NewStreamableHTTP(t.cfg.URL)
streamableTransport, err := transport.NewStreamableHTTP(t.cfg.URL, opts...)
if err != nil {
return fmt.Errorf("failed to create Streamable HTTP transport: %w", err)
}
Expand Down
10 changes: 9 additions & 1 deletion internal/core/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package state
import (
"context"
"fmt"
"reflect"
"time"

"github.com/amoylab/unla/internal/common/cnst"
Expand Down Expand Up @@ -58,6 +59,13 @@ func NewState() *State {
}
}

func headersEqual(a, b map[string]string) bool {
if len(a) == 0 && len(b) == 0 {
return true
}
return reflect.DeepEqual(a, b)
}

// BuildStateFromConfig creates a new State from the given configuration
func BuildStateFromConfig(ctx context.Context, cfgs []*config.MCPConfig, oldState *State, logger *zap.Logger) (*State, error) {
// Create new state
Expand Down Expand Up @@ -168,7 +176,7 @@ func BuildStateFromConfig(ctx context.Context, cfgs []*config.MCPConfig, oldStat
break
}
}
if argsMatch {
if argsMatch && headersEqual(oldConfig.Headers, mcpServer.Headers) {
// Reuse existing transport
transport = oldRuntime.transport
}
Expand Down
58 changes: 58 additions & 0 deletions internal/core/state/state_headers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package state

import (
"context"
"testing"

"github.com/amoylab/unla/internal/common/cnst"
"github.com/amoylab/unla/internal/common/config"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
)

func TestBuildStateFromConfig_ReusesTransportWhenHeadersUnchanged(t *testing.T) {
ctx := context.Background()
logger := zap.NewNop()

cfg := buildMCPConfigWithHeaders(map[string]string{"Authorization": "Bearer a"})
oldState, err := BuildStateFromConfig(ctx, []*config.MCPConfig{cfg}, nil, logger)
assert.NoError(t, err)

newState, err := BuildStateFromConfig(ctx, []*config.MCPConfig{cfg}, oldState, logger)
assert.NoError(t, err)

assert.Same(t, oldState.GetTransport("/m"), newState.GetTransport("/m"))
}

func TestBuildStateFromConfig_RebuildsTransportWhenHeadersChange(t *testing.T) {
ctx := context.Background()
logger := zap.NewNop()

oldCfg := buildMCPConfigWithHeaders(map[string]string{"Authorization": "Bearer a"})
oldState, err := BuildStateFromConfig(ctx, []*config.MCPConfig{oldCfg}, nil, logger)
assert.NoError(t, err)

newCfg := buildMCPConfigWithHeaders(map[string]string{"Authorization": "Bearer b"})
newState, err := BuildStateFromConfig(ctx, []*config.MCPConfig{newCfg}, oldState, logger)
assert.NoError(t, err)

assert.NotSame(t, oldState.GetTransport("/m"), newState.GetTransport("/m"))
}

func buildMCPConfigWithHeaders(headers map[string]string) *config.MCPConfig {
return &config.MCPConfig{
Name: "c1",
Tenant: "t1",
Routers: []config.RouterConfig{{
Server: "ms1",
Prefix: "/m",
}},
McpServers: []config.MCPServerConfig{{
Type: cnst.BackendProtoSSE.String(),
Name: "ms1",
URL: "http://127.0.0.1:9/",
Policy: cnst.PolicyOnDemand,
Headers: headers,
}},
}
}
Loading
Loading