kratos/transport/http/router_test.go

194 lines
4.6 KiB

package http
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"reflect"
"strings"
"testing"
"time"
"github.com/go-kratos/kratos/v2/internal/host"
)
const appJSONStr = "application/json"
type User struct {
Name string `json:"name"`
}
func corsFilter(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodOptions {
log.Println("cors:", r.Method, r.RequestURI)
w.Header().Set("Access-Control-Allow-Methods", r.Method)
return
}
next.ServeHTTP(w, r)
})
}
func authFilter(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Do stuff here
log.Println("auth:", r.Method, r.RequestURI)
// Call the next handler, which can be another middleware in the chain, or the final handler.
next.ServeHTTP(w, r)
})
}
func loggingFilter(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Do stuff here
log.Println("logging:", r.Method, r.RequestURI)
// Call the next handler, which can be another middleware in the chain, or the final handler.
next.ServeHTTP(w, r)
})
}
func TestRoute(t *testing.T) {
ctx := context.Background()
srv := NewServer(
Filter(corsFilter, loggingFilter),
)
route := srv.Route("/v1")
route.GET("/users/{name}", func(ctx Context) error {
u := new(User)
u.Name = ctx.Vars().Get("name")
return ctx.Result(200, u)
}, authFilter)
route.POST("/users", func(ctx Context) error {
u := new(User)
if err := ctx.Bind(u); err != nil {
return err
}
return ctx.Result(201, u)
})
route.PUT("/users", func(ctx Context) error {
u := new(User)
if err := ctx.Bind(u); err != nil {
return err
}
h := ctx.Middleware(func(ctx context.Context, in interface{}) (interface{}, error) {
return u, nil
})
return ctx.Returns(h(ctx, u))
})
if e, err := srv.Endpoint(); err != nil || e == nil {
t.Fatal(e, err)
}
go func() {
if err := srv.Start(ctx); err != nil {
panic(err)
}
}()
time.Sleep(time.Second)
testRoute(t, srv)
_ = srv.Stop(ctx)
}
func testRoute(t *testing.T, srv *Server) {
port, ok := host.Port(srv.lis)
if !ok {
t.Fatalf("extract port error: %v", srv.lis)
}
base := fmt.Sprintf("http://127.0.0.1:%d/v1", port)
// GET
resp, err := http.Get(base + "/users/foo")
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Fatalf("code: %d", resp.StatusCode)
}
if v := resp.Header.Get("Content-Type"); v != appJSONStr {
t.Fatalf("contentType: %s", v)
}
u := new(User)
if err = json.NewDecoder(resp.Body).Decode(u); err != nil {
t.Fatal(err)
}
if u.Name != "foo" {
t.Fatalf("got %s want foo", u.Name)
}
// POST
resp, err = http.Post(base+"/users", appJSONStr, strings.NewReader(`{"name":"bar"}`))
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 201 {
t.Fatalf("code: %d", resp.StatusCode)
}
if v := resp.Header.Get("Content-Type"); v != appJSONStr {
t.Fatalf("contentType: %s", v)
}
u = new(User)
if err = json.NewDecoder(resp.Body).Decode(u); err != nil {
t.Fatal(err)
}
if u.Name != "bar" {
t.Fatalf("got %s want bar", u.Name)
}
// PUT
req, _ := http.NewRequest(http.MethodPut, base+"/users", strings.NewReader(`{"name":"bar"}`))
req.Header.Set("Content-Type", appJSONStr)
resp, err = http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Fatalf("code: %d", resp.StatusCode)
}
if v := resp.Header.Get("Content-Type"); v != appJSONStr {
t.Fatalf("contentType: %s", v)
}
u = new(User)
if err = json.NewDecoder(resp.Body).Decode(u); err != nil {
t.Fatal(err)
}
if u.Name != "bar" {
t.Fatalf("got %s want bar", u.Name)
}
// OPTIONS
req, _ = http.NewRequest(http.MethodOptions, base+"/users", nil)
resp, err = http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Fatalf("code: %d", resp.StatusCode)
}
if resp.Header.Get("Access-Control-Allow-Methods") != http.MethodOptions {
t.Fatal("cors failed")
}
}
func TestRouter_Group(t *testing.T) {
r := &Router{}
rr := r.Group("a", func(http.Handler) http.Handler { return nil })
if !reflect.DeepEqual("a", rr.prefix) {
t.Errorf("expected %q, got %q", "a", rr.prefix)
}
}
func TestHandle(t *testing.T) {
r := newRouter("/", NewServer())
h := func(i Context) error {
return nil
}
r.GET("/get", h)
r.HEAD("/head", h)
r.PATCH("/patch", h)
r.DELETE("/delete", h)
r.CONNECT("/connect", h)
r.OPTIONS("/options", h)
r.TRACE("/trace", h)
}