diff --git a/hub/route/server.go b/hub/route/server.go index d530684a..7c3a2334 100644 --- a/hub/route/server.go +++ b/hub/route/server.go @@ -1,6 +1,7 @@ package route import ( + "bytes" "encoding/json" "net/http" "strings" @@ -12,6 +13,7 @@ import ( "github.com/go-chi/chi" "github.com/go-chi/cors" "github.com/go-chi/render" + "github.com/gorilla/websocket" ) var ( @@ -19,6 +21,8 @@ var ( serverAddr = "" uiPath = "" + + upgrader = websocket.Upgrader{} ) type Traffic struct { @@ -47,15 +51,12 @@ func Start(addr string, secret string) { MaxAge: 300, }) - root := chi.NewRouter().With(jsonContentType) - root.Get("/traffic", traffic) - root.Get("/logs", getLogs) - r.Get("/", hello) r.Group(func(r chi.Router) { r.Use(cors.Handler, authentication) - r.Mount("/", root) + r.Get("/logs", getLogs) + r.Get("/traffic", traffic) r.Mount("/configs", configRouter()) r.Mount("/proxies", proxyRouter()) r.Mount("/rules", ruleRouter()) @@ -78,14 +79,6 @@ func Start(addr string, secret string) { } } -func jsonContentType(next http.Handler) http.Handler { - fn := func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - next.ServeHTTP(w, r) - } - return http.HandlerFunc(fn) -} - func authentication(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { header := r.Header.Get("Authorization") @@ -113,19 +106,44 @@ func hello(w http.ResponseWriter, r *http.Request) { } func traffic(w http.ResponseWriter, r *http.Request) { - render.Status(r, http.StatusOK) + var wsConn *websocket.Conn + if websocket.IsWebSocketUpgrade(r) { + var err error + wsConn, err = upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + } + + if wsConn == nil { + w.Header().Set("Content-Type", "application/json") + render.Status(r, http.StatusOK) + } tick := time.NewTicker(time.Second) t := T.Instance().Traffic() + buf := &bytes.Buffer{} + var err error for range tick.C { + buf.Reset() up, down := t.Now() - if err := json.NewEncoder(w).Encode(Traffic{ + if err := json.NewEncoder(buf).Encode(Traffic{ Up: up, Down: down, }); err != nil { break } - w.(http.Flusher).Flush() + + if wsConn == nil { + _, err = w.Write(buf.Bytes()) + w.(http.Flusher).Flush() + } else { + err = wsConn.WriteMessage(websocket.TextMessage, buf.Bytes()) + } + + if err != nil { + break + } } } @@ -147,20 +165,47 @@ func getLogs(w http.ResponseWriter, r *http.Request) { return } + var wsConn *websocket.Conn + if websocket.IsWebSocketUpgrade(r) { + var err error + wsConn, err = upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + } + + if wsConn == nil { + w.Header().Set("Content-Type", "application/json") + render.Status(r, http.StatusOK) + } + sub := log.Subscribe() - render.Status(r, http.StatusOK) + defer log.UnSubscribe(sub) + buf := &bytes.Buffer{} + var err error for elm := range sub { + buf.Reset() log := elm.(*log.Event) if log.LogLevel < level { continue } - if err := json.NewEncoder(w).Encode(Log{ + if err := json.NewEncoder(buf).Encode(Log{ Type: log.Type(), Payload: log.Payload, }); err != nil { break } - w.(http.Flusher).Flush() + + if wsConn == nil { + _, err = w.Write(buf.Bytes()) + w.(http.Flusher).Flush() + } else { + err = wsConn.WriteMessage(websocket.TextMessage, buf.Bytes()) + } + + if err != nil { + break + } } } diff --git a/log/log.go b/log/log.go index 251929de..7190cc89 100644 --- a/log/log.go +++ b/log/log.go @@ -62,6 +62,11 @@ func Subscribe() observable.Subscription { return sub } +func UnSubscribe(sub observable.Subscription) { + source.UnSubscribe(sub) + return +} + func Level() LogLevel { return level }