diff --git a/health/health.go b/health/health.go new file mode 100644 index 000000000..3c61b48d5 --- /dev/null +++ b/health/health.go @@ -0,0 +1,107 @@ +package health + +import ( + "context" + "encoding/json" + "net/http" +) + +type Status string + +const ( + Up Status = "UP" + Down Status = "DOWN" +) + +type Checker interface { + Check(ctx context.Context) error +} + +type CheckerFunc func(ctx context.Context) error + +func (f CheckerFunc) Check(ctx context.Context) error { + return f(ctx) +} + +type Health struct { + status Status + checkers map[string]Checker +} + +func New() *Health { + h := &Health{ + status: Down, + checkers: make(map[string]Checker), + } + return h +} + +func (h *Health) Register(name string, checker CheckerFunc) { + h.checkers[name] = checker +} + +func (h *Health) Start(_ context.Context) error { + h.status = Up + return nil +} + +func (h *Health) Stop(_ context.Context) error { + h.status = Down + return nil +} + +func (h *Health) Check(ctx context.Context) Result { + res := Result{Status: h.status, Details: make(map[string]Detail, len(h.checkers))} + for n, c := range h.checkers { + if err := c.Check(ctx); err != nil { + res.Status = Down + res.Details[n] = Detail{Status: Down, Error: err.Error()} + } else { + res.Details[n] = Detail{Status: Up} + } + } + return res +} + +func (h *Health) CheckService(ctx context.Context, svc string) Detail { + c, ok := h.checkers[svc] + if !ok { + return Detail{Status: Down, Error: "service not find"} + } + err := c.Check(ctx) + if err != nil { + return Detail{Status: Down, Error: err.Error()} + } + return Detail{Status: Up} +} + +func (h *Health) ServeHTTP(w http.ResponseWriter, r *http.Request) { + service := r.URL.Query().Get("service") + if service == "" { + res := h.Check(r.Context()) + if res.Status == Down { + w.WriteHeader(http.StatusInternalServerError) + } else { + w.WriteHeader(http.StatusOK) + } + _ = json.NewEncoder(w).Encode(res) + } else { + detail := h.CheckService(r.Context(), service) + if detail.Status == Down { + w.WriteHeader(http.StatusInternalServerError) + } else { + w.WriteHeader(http.StatusOK) + } + _ = json.NewEncoder(w).Encode(detail) + } +} + +type Result struct { + Status Status `json:"status"` + Details map[string]Detail `json:"details"` +} + +type Detail struct { + Status Status `json:"status"` + Error string `json:"error"` +} diff --git a/transport/http/server.go b/transport/http/server.go index 165bb99bc..b6bc9779f 100644 --- a/transport/http/server.go +++ b/transport/http/server.go @@ -9,6 +9,7 @@ import ( "net/url" "time" + "github.com/go-kratos/kratos/v2/health" "github.com/go-kratos/kratos/v2/internal/endpoint" "github.com/go-kratos/kratos/v2/internal/matcher" @@ -135,6 +136,17 @@ func PathPrefix(prefix string) ServerOption { } } +func HealthChecks(ho ...HealthOption) ServerOption { + return func(s *Server) { + s.ho = ho + } +} + +type HealthOption struct { + Name string + health.CheckerFunc +} + // Server is an HTTP server wrapper. type Server struct { *http.Server @@ -154,6 +166,8 @@ type Server struct { ene EncodeErrorFunc strictSlash bool router *mux.Router + health *health.Health + ho []HealthOption } // NewServer creates an HTTP server by options. @@ -170,10 +184,15 @@ func NewServer(opts ...ServerOption) *Server { ene: DefaultErrorEncoder, strictSlash: true, router: mux.NewRouter(), + health: health.New(), + ho: make([]HealthOption, 0), } for _, o := range opts { o(srv) } + for _, v := range srv.ho { + srv.health.Register(v.Name, v.CheckerFunc) + } srv.router.StrictSlash(srv.strictSlash) srv.router.NotFoundHandler = http.DefaultServeMux srv.router.MethodNotAllowedHandler = http.DefaultServeMux @@ -182,6 +201,9 @@ func NewServer(opts ...ServerOption) *Server { Handler: FilterChain(srv.filters...)(srv.router), TLSConfig: srv.tlsConf, } + + // health + srv.router.Handle("/health", srv.health).Methods("GET") return srv } @@ -301,6 +323,7 @@ func (s *Server) Start(ctx context.Context) error { return ctx } log.Infof("[HTTP] server listening on: %s", s.lis.Addr().String()) + _ = s.health.Start(ctx) var err error if s.tlsConf != nil { err = s.ServeTLS(s.lis, "", "") @@ -316,6 +339,7 @@ func (s *Server) Start(ctx context.Context) error { // Stop stop the HTTP server. func (s *Server) Stop(ctx context.Context) error { log.Info("[HTTP] server stopping") + _ = s.health.Stop(ctx) return s.Shutdown(ctx) }