Add realtime gateway and simulator bridge
This commit is contained in:
277
realtime-gateway/internal/gateway/client.go
Normal file
277
realtime-gateway/internal/gateway/client.go
Normal file
@@ -0,0 +1,277 @@
|
||||
package gateway
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
"github.com/coder/websocket/wsjson"
|
||||
|
||||
"realtime-gateway/internal/channel"
|
||||
"realtime-gateway/internal/config"
|
||||
"realtime-gateway/internal/model"
|
||||
"realtime-gateway/internal/plugin"
|
||||
"realtime-gateway/internal/router"
|
||||
"realtime-gateway/internal/session"
|
||||
)
|
||||
|
||||
type client struct {
|
||||
conn *websocket.Conn
|
||||
logger *slog.Logger
|
||||
cfg config.GatewayConfig
|
||||
hub *router.Hub
|
||||
channels *channel.Manager
|
||||
plugins *plugin.Bus
|
||||
session *session.Session
|
||||
auth config.AuthConfig
|
||||
|
||||
writeMu sync.Mutex
|
||||
}
|
||||
|
||||
func serveClient(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
logger *slog.Logger,
|
||||
cfg config.Config,
|
||||
hub *router.Hub,
|
||||
channels *channel.Manager,
|
||||
plugins *plugin.Bus,
|
||||
sessions *session.Manager,
|
||||
) {
|
||||
conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("websocket accept failed", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
sess := sessions.Create()
|
||||
c := &client{
|
||||
conn: conn,
|
||||
logger: logger.With("sessionId", sess.ID),
|
||||
cfg: cfg.Gateway,
|
||||
hub: hub,
|
||||
channels: channels,
|
||||
plugins: plugins,
|
||||
session: sess,
|
||||
auth: cfg.Auth,
|
||||
}
|
||||
|
||||
hub.Register(c, nil)
|
||||
defer func() {
|
||||
if sess.ChannelID != "" {
|
||||
channels.Unbind(sess.ChannelID, sess.Role)
|
||||
}
|
||||
hub.Unregister(sess.ID)
|
||||
sessions.Delete(sess.ID)
|
||||
_ = conn.Close(websocket.StatusNormalClosure, "session closed")
|
||||
}()
|
||||
|
||||
if err := c.run(r.Context()); err != nil && !errors.Is(err, context.Canceled) {
|
||||
c.logger.Warn("client closed", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) ID() string {
|
||||
return c.session.ID
|
||||
}
|
||||
|
||||
func (c *client) Send(message model.ServerMessage) error {
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), c.cfg.WriteWait())
|
||||
defer cancel()
|
||||
return wsjson.Write(ctx, c.conn, message)
|
||||
}
|
||||
|
||||
func (c *client) run(ctx context.Context) error {
|
||||
if err := c.Send(model.ServerMessage{
|
||||
Type: "welcome",
|
||||
SessionID: c.session.ID,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pingCtx, cancelPing := context.WithCancel(ctx)
|
||||
defer cancelPing()
|
||||
go c.pingLoop(pingCtx)
|
||||
|
||||
for {
|
||||
readCtx, cancel := context.WithTimeout(ctx, c.cfg.PongWait())
|
||||
var message model.ClientMessage
|
||||
err := wsjson.Read(readCtx, c.conn, &message)
|
||||
cancel()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.handleMessage(message); err != nil {
|
||||
_ = c.Send(model.ServerMessage{
|
||||
Type: "error",
|
||||
Error: err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) handleMessage(message model.ClientMessage) error {
|
||||
switch message.Type {
|
||||
case "authenticate":
|
||||
return c.handleAuthenticate(message)
|
||||
case "join_channel":
|
||||
return c.handleJoinChannel(message)
|
||||
case "subscribe":
|
||||
return c.handleSubscribe(message)
|
||||
case "publish":
|
||||
return c.handlePublish(message)
|
||||
case "snapshot":
|
||||
return c.handleSnapshot(message)
|
||||
default:
|
||||
return errors.New("unsupported message type")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) handleJoinChannel(message model.ClientMessage) error {
|
||||
if strings.TrimSpace(message.ChannelID) == "" {
|
||||
return errors.New("channelId is required")
|
||||
}
|
||||
|
||||
snapshot, err := c.channels.Join(message.ChannelID, message.Token, message.Role)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.session.ChannelID != "" {
|
||||
c.channels.Unbind(c.session.ChannelID, c.session.Role)
|
||||
}
|
||||
if err := c.channels.Bind(snapshot.ID, message.Role); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.session.Role = message.Role
|
||||
c.session.Authenticated = true
|
||||
c.session.ChannelID = snapshot.ID
|
||||
c.session.Subscriptions = nil
|
||||
c.hub.UpdateSubscriptions(c.session.ID, nil)
|
||||
|
||||
return c.Send(model.ServerMessage{
|
||||
Type: "joined_channel",
|
||||
SessionID: c.session.ID,
|
||||
State: json.RawMessage([]byte(
|
||||
`{"channelId":"` + snapshot.ID + `","deliveryMode":"` + snapshot.DeliveryMode + `"}`,
|
||||
)),
|
||||
})
|
||||
}
|
||||
|
||||
func (c *client) handleAuthenticate(message model.ClientMessage) error {
|
||||
if !authorize(c.auth, message.Role, message.Token) {
|
||||
return errors.New("authentication failed")
|
||||
}
|
||||
|
||||
c.session.Role = message.Role
|
||||
c.session.Authenticated = true
|
||||
return c.Send(model.ServerMessage{
|
||||
Type: "authenticated",
|
||||
SessionID: c.session.ID,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *client) handleSubscribe(message model.ClientMessage) error {
|
||||
if !c.session.Authenticated && !c.auth.AllowAnonymousConsumers {
|
||||
return errors.New("consumer must authenticate before subscribe")
|
||||
}
|
||||
|
||||
subscriptions := normalizeSubscriptions(c.session.ChannelID, message.Subscriptions)
|
||||
c.session.Subscriptions = subscriptions
|
||||
c.hub.UpdateSubscriptions(c.session.ID, subscriptions)
|
||||
return c.Send(model.ServerMessage{
|
||||
Type: "subscribed",
|
||||
SessionID: c.session.ID,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *client) handlePublish(message model.ClientMessage) error {
|
||||
if !c.session.Authenticated {
|
||||
return errors.New("authentication required")
|
||||
}
|
||||
if c.session.Role != model.RoleProducer && c.session.Role != model.RoleController {
|
||||
return errors.New("publish is only allowed for producer or controller")
|
||||
}
|
||||
if message.Envelope == nil {
|
||||
return errors.New("envelope is required")
|
||||
}
|
||||
|
||||
envelope := *message.Envelope
|
||||
if envelope.Source.Kind == "" {
|
||||
envelope.Source.Kind = c.session.Role
|
||||
}
|
||||
|
||||
if c.session.ChannelID != "" {
|
||||
envelope.Target.ChannelID = c.session.ChannelID
|
||||
}
|
||||
deliveryMode := channel.DeliveryModeCacheLatest
|
||||
if envelope.Target.ChannelID != "" {
|
||||
deliveryMode = c.channels.DeliveryMode(envelope.Target.ChannelID)
|
||||
}
|
||||
result := c.hub.Publish(envelope, deliveryMode)
|
||||
if !result.Dropped {
|
||||
c.plugins.Publish(envelope)
|
||||
}
|
||||
return c.Send(model.ServerMessage{
|
||||
Type: "published",
|
||||
SessionID: c.session.ID,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *client) handleSnapshot(message model.ClientMessage) error {
|
||||
if len(message.Subscriptions) == 0 || message.Subscriptions[0].DeviceID == "" {
|
||||
return errors.New("snapshot requires deviceId in first subscription")
|
||||
}
|
||||
channelID := message.Subscriptions[0].ChannelID
|
||||
if channelID == "" {
|
||||
channelID = c.session.ChannelID
|
||||
}
|
||||
state, ok := c.hub.Snapshot(channelID, message.Subscriptions[0].DeviceID)
|
||||
if !ok {
|
||||
return errors.New("snapshot not found")
|
||||
}
|
||||
return c.Send(model.ServerMessage{
|
||||
Type: "snapshot",
|
||||
SessionID: c.session.ID,
|
||||
State: json.RawMessage(state),
|
||||
})
|
||||
}
|
||||
|
||||
func (c *client) pingLoop(ctx context.Context) {
|
||||
ticker := time.NewTicker(c.cfg.PingInterval())
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
pingCtx, cancel := context.WithTimeout(ctx, c.cfg.WriteWait())
|
||||
_ = c.conn.Ping(pingCtx)
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeSubscriptions(channelID string, subscriptions []model.Subscription) []model.Subscription {
|
||||
items := make([]model.Subscription, 0, len(subscriptions))
|
||||
for _, entry := range subscriptions {
|
||||
if channelID != "" && strings.TrimSpace(entry.ChannelID) == "" {
|
||||
entry.ChannelID = channelID
|
||||
}
|
||||
items = append(items, entry)
|
||||
}
|
||||
return items
|
||||
}
|
||||
Reference in New Issue
Block a user