◐ Shell
clean mode source ↗

fix: only trust x-forwarded-host from configured trusted proxies (#2…… · coder/coder@3c46473

@@ -68,7 +68,7 @@ func TestLoggerMiddleware_SingleRequest(t *testing.T) {

6868

})

69697070

// Wrap the test handler with the Logger middleware

71-

loggerMiddleware := Logger(logger)

71+

loggerMiddleware := Logger(logger, nil)

7272

wrappedHandler := loggerMiddleware(testHandler)

73737474

// Create a test HTTP request

@@ -91,7 +91,7 @@ func TestLoggerMiddleware_SingleRequest(t *testing.T) {

9191

}

92929393

// Check that the log contains the expected fields

94-

requiredFields := []string{"host", "path", "proto", "remote_addr", "start", "took", "status_code", "user_agent", "latency_ms"}

94+

requiredFields := []string{"host", "received_host", "path", "proto", "remote_addr", "start", "took", "status_code", "user_agent", "latency_ms"}

9595

for _, field := range requiredFields {

9696

_, exists := fieldsMap[field]

9797

require.True(t, exists, "field %q is missing in log fields", field)

@@ -103,6 +103,38 @@ func TestLoggerMiddleware_SingleRequest(t *testing.T) {

103103

require.Equal(t, fieldsMap["status_code"], http.StatusOK)

104104

}

105105106+

func TestLoggerMiddleware_HostFields(t *testing.T) {

107+

t.Parallel()

108+109+

sink := testutil.NewFakeSink(t)

110+

logger := sink.Logger()

111+112+

testHandler := http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) {

113+

rw.WriteHeader(http.StatusOK)

114+

})

115+116+

loggerMiddleware := Logger(logger, func(_ *http.Request) string {

117+

return "effective.test"

118+

})

119+

wrappedHandler := loggerMiddleware(testHandler)

120+121+

req := httptest.NewRequest(http.MethodGet, "http://received.test/path", nil)

122+123+

sw := &tracing.StatusWriter{ResponseWriter: httptest.NewRecorder()}

124+

wrappedHandler.ServeHTTP(sw, req)

125+126+

entries := sink.Entries()

127+

require.Len(t, entries, 1, "expected exactly one log entry")

128+129+

fieldsMap := make(map[string]any)

130+

for _, field := range entries[0].Fields {

131+

fieldsMap[field.Name] = field.Value

132+

}

133+134+

require.Equal(t, "effective.test", fieldsMap["host"])

135+

require.Equal(t, "received.test", fieldsMap["received_host"])

136+

}

137+106138

func TestLoggerMiddleware_WebSocket(t *testing.T) {

107139

t.Parallel()

108140

ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)

@@ -129,7 +161,7 @@ func TestLoggerMiddleware_WebSocket(t *testing.T) {

129161

})

130162131163

// Wrap the test handler with the Logger middleware

132-

loggerMiddleware := Logger(logger)

164+

loggerMiddleware := Logger(logger, nil)

133165

wrappedHandler := loggerMiddleware(testHandler)

134166135167

// RequestLogger expects the ResponseWriter to be *tracing.StatusWriter

@@ -186,7 +218,7 @@ func TestRequestLogger_HTTPRouteParams(t *testing.T) {

186218

})

187219188220

// Wrap the test handler with the Logger middleware

189-

loggerMiddleware := Logger(logger)

221+

loggerMiddleware := Logger(logger, nil)

190222

wrappedHandler := loggerMiddleware(testHandler)

191223192224

// Create a test HTTP request