You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
138 lines
3.0 KiB
138 lines
3.0 KiB
package http
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"net/http"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/go-kratos/kratos/v2/internal/host"
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
type testKey struct{}
|
|
|
|
type testData struct {
|
|
Path string `json:"path"`
|
|
}
|
|
|
|
func TestServer(t *testing.T) {
|
|
fn := func(w http.ResponseWriter, r *http.Request) {
|
|
data := &testData{Path: r.RequestURI}
|
|
json.NewEncoder(w).Encode(data)
|
|
|
|
if r.Context().Value(testKey{}) != "test" {
|
|
w.WriteHeader(500)
|
|
}
|
|
}
|
|
ctx := context.Background()
|
|
ctx = context.WithValue(ctx, testKey{}, "test")
|
|
srv := NewServer()
|
|
srv.HandleFunc("/index", fn)
|
|
|
|
if e, err := srv.Endpoint(); err != nil || e == nil {
|
|
t.Fatal(e, err)
|
|
}
|
|
|
|
go func() {
|
|
if err := srv.Start(ctx); err != nil {
|
|
panic(err)
|
|
}
|
|
}()
|
|
time.Sleep(time.Second)
|
|
testClient(t, srv)
|
|
srv.Stop(ctx)
|
|
}
|
|
|
|
func testClient(t *testing.T, srv *Server) {
|
|
tests := []struct {
|
|
method string
|
|
path string
|
|
}{
|
|
{"GET", "/index"},
|
|
{"PUT", "/index"},
|
|
{"POST", "/index"},
|
|
{"PATCH", "/index"},
|
|
{"DELETE", "/index"},
|
|
}
|
|
port, ok := host.Port(srv.lis)
|
|
if !ok {
|
|
t.Fatalf("extract port error: %v", srv.lis)
|
|
}
|
|
client, err := NewClient(context.Background(), WithEndpoint(fmt.Sprintf("127.0.0.1:%d", port)))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
for _, test := range tests {
|
|
var res testData
|
|
url := fmt.Sprintf("http://127.0.0.1:%d%s", port, test.path)
|
|
req, err := http.NewRequest(test.method, url, nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if resp.StatusCode != 200 {
|
|
t.Fatalf("http status got %d", resp.StatusCode)
|
|
}
|
|
content, err := ioutil.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Fatalf("read resp error %v", err)
|
|
}
|
|
err = json.Unmarshal(content, &res)
|
|
if err != nil {
|
|
t.Fatalf("unmarshal resp error %v", err)
|
|
}
|
|
if res.Path != test.path {
|
|
t.Errorf("expected %s got %s", test.path, res.Path)
|
|
}
|
|
}
|
|
for _, test := range tests {
|
|
var res testData
|
|
err := client.Invoke(context.Background(), test.method, test.path, nil, &res)
|
|
if err != nil {
|
|
t.Fatalf("invoke error %v", err)
|
|
}
|
|
if res.Path != test.path {
|
|
t.Errorf("expected %s got %s", test.path, res.Path)
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
func BenchmarkServer(b *testing.B) {
|
|
fn := func(w http.ResponseWriter, r *http.Request) {
|
|
data := &testData{Path: r.RequestURI}
|
|
json.NewEncoder(w).Encode(data)
|
|
if r.Context().Value(testKey{}) != "test" {
|
|
w.WriteHeader(500)
|
|
}
|
|
}
|
|
ctx := context.Background()
|
|
ctx = context.WithValue(ctx, testKey{}, "test")
|
|
srv := NewServer()
|
|
srv.HandleFunc("/index", fn)
|
|
go func() {
|
|
if err := srv.Start(ctx); err != nil {
|
|
panic(err)
|
|
}
|
|
}()
|
|
time.Sleep(time.Second)
|
|
port, ok := host.Port(srv.lis)
|
|
assert.True(b, ok)
|
|
client, err := NewClient(context.Background(), WithEndpoint(fmt.Sprintf("127.0.0.1:%d", port)))
|
|
assert.NoError(b, err)
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
var res testData
|
|
err := client.Invoke(context.Background(), "POST", "/index", nil, &res)
|
|
assert.NoError(b, err)
|
|
}
|
|
srv.Stop(ctx)
|
|
}
|
|
|