diff options
-rw-r--r-- | .rr.yaml | 8 | ||||
-rw-r--r-- | service/static/service_test.go | 101 |
2 files changed, 90 insertions, 19 deletions
@@ -162,6 +162,14 @@ static: # list of extensions for forbid for serving. forbid: [".php", ".htaccess"] + # Automatically add headers to every request. + request: + "Example-Request-Header": "Value" + + # Automatically add headers to every response. + response: + "X-Powered-By": "RoadRunner" + # health service configuration health: # http host to serve health requests. diff --git a/service/static/service_test.go b/service/static/service_test.go index 1a137cbc..842662c9 100644 --- a/service/static/service_test.go +++ b/service/static/service_test.go @@ -37,25 +37,6 @@ func (cfg *testCfg) Unmarshal(out interface{}) error { return j.Unmarshal([]byte(cfg.target), out) } -func get(url string) (string, *http.Response, error) { - r, err := http.Get(url) - if err != nil { - return "", nil, err - } - - b, err := ioutil.ReadAll(r.Body) - if err != nil { - return "", nil, err - } - - err = r.Body.Close() - if err != nil { - return "", nil, err - } - - return string(b), r, err -} - func Test_Files(t *testing.T) { logger, _ := test.NewNullLogger() logger.SetLevel(logrus.DebugLevel) @@ -442,6 +423,88 @@ func Test_Files_NotForbid(t *testing.T) { c.Stop() } +func TestStatic_Headers(t *testing.T) { + logger, _ := test.NewNullLogger() + logger.SetLevel(logrus.DebugLevel) + + c := service.NewContainer(logger) + c.Register(rrhttp.ID, &rrhttp.Service{}) + c.Register(ID, &Service{}) + + assert.NoError(t, c.Init(&testCfg{ + static: `{"enable":true, "dir":"../../tests", "forbid":[], "request":{"input": "custom-header"}, "response":{"output": "output-header"}}`, + httpCfg: `{ + "enable": true, + "address": ":8037", + "maxRequestSize": 1024, + "uploads": { + "dir": ` + tmpDir() + `, + "forbid": [] + }, + "workers":{ + "command": "php ../../tests/http/client.php pid pipes", + "relay": "pipes", + "pool": { + "numWorkers": 1, + "allocateTimeout": 10000000, + "destroyTimeout": 10000000 + } + } + }`})) + + go func() { + err := c.Serve() + if err != nil { + t.Errorf("serve error: %v", err) + } + }() + + time.Sleep(time.Millisecond * 500) + + req, err := http.NewRequest("GET", "http://localhost:8037/client.php", nil) + if err != nil { + t.Fatal(err) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + + if resp.Header.Get("Output") != "output-header" { + t.Fatal("can't find output header in response") + } + + + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, all("../../tests/client.php"), string(b)) + assert.Equal(t, all("../../tests/client.php"), string(b)) + c.Stop() +} + +func get(url string) (string, *http.Response, error) { + r, err := http.Get(url) + if err != nil { + return "", nil, err + } + + b, err := ioutil.ReadAll(r.Body) + if err != nil { + return "", nil, err + } + + err = r.Body.Close() + if err != nil { + return "", nil, err + } + + return string(b), r, err +} + func tmpDir() string { p := os.TempDir() j := json.ConfigCompatibleWithStandardLibrary |