Files
cmr-mini/backend/internal/service/ops_auth_service.go

396 lines
12 KiB
Go

package service
import (
"context"
"net/http"
"strings"
"time"
"cmr-backend/internal/apperr"
"cmr-backend/internal/platform/jwtx"
"cmr-backend/internal/platform/security"
"cmr-backend/internal/store/postgres"
)
type OpsAuthSettings struct {
AppEnv string
RefreshTTL time.Duration
SMSCodeTTL time.Duration
SMSCodeCooldown time.Duration
SMSProvider string
DevSMSCode string
}
type OpsAuthService struct {
cfg OpsAuthSettings
store *postgres.Store
jwtManager *jwtx.Manager
}
type OpsSendSMSCodeInput struct {
CountryCode string `json:"countryCode"`
Mobile string `json:"mobile"`
DeviceKey string `json:"deviceKey"`
Scene string `json:"scene"`
}
type OpsRegisterInput struct {
CountryCode string `json:"countryCode"`
Mobile string `json:"mobile"`
Code string `json:"code"`
DeviceKey string `json:"deviceKey"`
DisplayName string `json:"displayName"`
}
type OpsLoginSMSInput struct {
CountryCode string `json:"countryCode"`
Mobile string `json:"mobile"`
Code string `json:"code"`
DeviceKey string `json:"deviceKey"`
}
type OpsRefreshTokenInput struct {
RefreshToken string `json:"refreshToken"`
DeviceKey string `json:"deviceKey"`
}
type OpsLogoutInput struct {
RefreshToken string `json:"refreshToken"`
}
type OpsAuthUser struct {
ID string `json:"id"`
PublicID string `json:"publicId"`
DisplayName string `json:"displayName"`
Status string `json:"status"`
RoleCode string `json:"roleCode"`
}
type OpsAuthResult struct {
User OpsAuthUser `json:"user"`
Tokens AuthTokens `json:"tokens"`
NewUser bool `json:"newUser"`
DevLoginBypass bool `json:"devLoginBypass,omitempty"`
}
func NewOpsAuthService(cfg OpsAuthSettings, store *postgres.Store, jwtManager *jwtx.Manager) *OpsAuthService {
return &OpsAuthService{cfg: cfg, store: store, jwtManager: jwtManager}
}
func (s *OpsAuthService) SendSMSCode(ctx context.Context, input OpsSendSMSCodeInput) (*SendSMSCodeResult, error) {
input.CountryCode = normalizeCountryCode(input.CountryCode)
input.Mobile = normalizeMobile(input.Mobile)
input.Scene = normalizeOpsScene(input.Scene)
if input.Mobile == "" || strings.TrimSpace(input.DeviceKey) == "" {
return nil, apperr.New(http.StatusBadRequest, "invalid_params", "mobile and deviceKey are required")
}
latest, err := s.store.GetLatestSMSCodeMeta(ctx, input.CountryCode, input.Mobile, "ops", input.Scene)
if err != nil {
return nil, err
}
now := time.Now().UTC()
if latest != nil && latest.CooldownUntil.After(now) {
return nil, apperr.New(http.StatusTooManyRequests, "sms_cooldown", "sms code sent too frequently")
}
code := s.cfg.DevSMSCode
if code == "" {
code, err = security.GenerateNumericCode(6)
if err != nil {
return nil, err
}
}
expiresAt := now.Add(s.cfg.SMSCodeTTL)
cooldownUntil := now.Add(s.cfg.SMSCodeCooldown)
if err := s.store.CreateSMSCode(ctx, postgres.CreateSMSCodeParams{
Scene: input.Scene,
CountryCode: input.CountryCode,
Mobile: input.Mobile,
ClientType: "ops",
DeviceKey: input.DeviceKey,
CodeHash: security.HashText(code),
ProviderName: s.cfg.SMSProvider,
ProviderDebug: map[string]any{"mode": s.cfg.SMSProvider, "channel": "ops_console"},
ExpiresAt: expiresAt,
CooldownUntil: cooldownUntil,
}); err != nil {
return nil, err
}
result := &SendSMSCodeResult{
TTLSeconds: int64(s.cfg.SMSCodeTTL.Seconds()),
CooldownSeconds: int64(s.cfg.SMSCodeCooldown.Seconds()),
}
if strings.EqualFold(s.cfg.SMSProvider, "console") || strings.EqualFold(s.cfg.AppEnv, "development") {
result.DevCode = &code
}
return result, nil
}
func (s *OpsAuthService) Register(ctx context.Context, input OpsRegisterInput) (*OpsAuthResult, error) {
input.CountryCode = normalizeCountryCode(input.CountryCode)
input.Mobile = normalizeMobile(input.Mobile)
input.Code = strings.TrimSpace(input.Code)
input.DeviceKey = strings.TrimSpace(input.DeviceKey)
input.DisplayName = strings.TrimSpace(input.DisplayName)
if input.Mobile == "" || input.Code == "" || input.DeviceKey == "" || input.DisplayName == "" {
return nil, apperr.New(http.StatusBadRequest, "invalid_params", "mobile, code, deviceKey and displayName are required")
}
codeRecord, err := s.store.GetLatestValidSMSCode(ctx, input.CountryCode, input.Mobile, "ops", "ops_register")
if err != nil {
return nil, err
}
if codeRecord == nil || codeRecord.CodeHash != security.HashText(input.Code) {
return nil, apperr.New(http.StatusUnauthorized, "invalid_sms_code", "invalid sms code")
}
tx, err := s.store.Begin(ctx)
if err != nil {
return nil, err
}
defer tx.Rollback(ctx)
consumed, err := s.store.ConsumeSMSCode(ctx, tx, codeRecord.ID)
if err != nil {
return nil, err
}
if !consumed {
return nil, apperr.New(http.StatusUnauthorized, "invalid_sms_code", "sms code already used")
}
existing, err := s.store.GetOpsUserByMobile(ctx, tx, input.CountryCode, input.Mobile)
if err != nil {
return nil, err
}
if existing != nil {
return nil, apperr.New(http.StatusConflict, "ops_user_exists", "ops user already exists")
}
publicID, err := security.GeneratePublicID("ops")
if err != nil {
return nil, err
}
user, err := s.store.CreateOpsUser(ctx, tx, postgres.CreateOpsUserParams{
PublicID: publicID,
CountryCode: input.CountryCode,
Mobile: input.Mobile,
DisplayName: input.DisplayName,
Status: "active",
})
if err != nil {
return nil, err
}
roleCode := "operator"
count, err := s.store.CountOpsUsers(ctx)
if err == nil && count == 0 {
roleCode = "owner"
}
role, err := s.store.GetOpsRoleByCode(ctx, tx, roleCode)
if err != nil {
return nil, err
}
if role == nil {
return nil, apperr.New(http.StatusInternalServerError, "ops_role_missing", "default ops role is missing")
}
if err := s.store.AssignOpsRole(ctx, tx, user.ID, role.ID); err != nil {
return nil, err
}
if err := s.store.TouchOpsUserLogin(ctx, tx, user.ID); err != nil {
return nil, err
}
result, _, err := s.issueAuthResult(ctx, tx, *user, input.DeviceKey, true)
if err != nil {
return nil, err
}
if err := tx.Commit(ctx); err != nil {
return nil, err
}
return result, nil
}
func (s *OpsAuthService) LoginSMS(ctx context.Context, input OpsLoginSMSInput) (*OpsAuthResult, error) {
input.CountryCode = normalizeCountryCode(input.CountryCode)
input.Mobile = normalizeMobile(input.Mobile)
input.Code = strings.TrimSpace(input.Code)
input.DeviceKey = strings.TrimSpace(input.DeviceKey)
if input.Mobile == "" || input.Code == "" || input.DeviceKey == "" {
return nil, apperr.New(http.StatusBadRequest, "invalid_params", "mobile, code and deviceKey are required")
}
codeRecord, err := s.store.GetLatestValidSMSCode(ctx, input.CountryCode, input.Mobile, "ops", "ops_login")
if err != nil {
return nil, err
}
if codeRecord == nil || codeRecord.CodeHash != security.HashText(input.Code) {
return nil, apperr.New(http.StatusUnauthorized, "invalid_sms_code", "invalid sms code")
}
tx, err := s.store.Begin(ctx)
if err != nil {
return nil, err
}
defer tx.Rollback(ctx)
consumed, err := s.store.ConsumeSMSCode(ctx, tx, codeRecord.ID)
if err != nil {
return nil, err
}
if !consumed {
return nil, apperr.New(http.StatusUnauthorized, "invalid_sms_code", "sms code already used")
}
user, err := s.store.GetOpsUserByMobile(ctx, tx, input.CountryCode, input.Mobile)
if err != nil {
return nil, err
}
if user == nil {
return nil, apperr.New(http.StatusNotFound, "ops_user_not_found", "ops user not found")
}
if user.Status != "active" {
return nil, apperr.New(http.StatusForbidden, "ops_user_inactive", "ops user is not active")
}
if err := s.store.TouchOpsUserLogin(ctx, tx, user.ID); err != nil {
return nil, err
}
result, _, err := s.issueAuthResult(ctx, tx, *user, input.DeviceKey, false)
if err != nil {
return nil, err
}
if err := tx.Commit(ctx); err != nil {
return nil, err
}
return result, nil
}
func (s *OpsAuthService) Refresh(ctx context.Context, input OpsRefreshTokenInput) (*OpsAuthResult, error) {
input.RefreshToken = strings.TrimSpace(input.RefreshToken)
if input.RefreshToken == "" {
return nil, apperr.New(http.StatusBadRequest, "invalid_params", "refreshToken is required")
}
tx, err := s.store.Begin(ctx)
if err != nil {
return nil, err
}
defer tx.Rollback(ctx)
record, err := s.store.GetOpsRefreshTokenForUpdate(ctx, tx, security.HashText(input.RefreshToken))
if err != nil {
return nil, err
}
if record == nil || record.IsRevoked || record.ExpiresAt.Before(time.Now().UTC()) {
return nil, apperr.New(http.StatusUnauthorized, "invalid_refresh_token", "refresh token is invalid or expired")
}
if input.DeviceKey != "" && record.DeviceKey != nil && input.DeviceKey != *record.DeviceKey {
return nil, apperr.New(http.StatusUnauthorized, "invalid_refresh_token", "refresh token device mismatch")
}
user, err := s.store.GetOpsUserByID(ctx, tx, record.OpsUserID)
if err != nil {
return nil, err
}
if user == nil || user.Status != "active" {
return nil, apperr.New(http.StatusUnauthorized, "invalid_refresh_token", "refresh token user not found")
}
result, newTokenID, err := s.issueAuthResult(ctx, tx, *user, nullableStringValue(record.DeviceKey), false)
if err != nil {
return nil, err
}
if err := s.store.RotateOpsRefreshToken(ctx, tx, record.ID, newTokenID); err != nil {
return nil, err
}
if err := tx.Commit(ctx); err != nil {
return nil, err
}
return result, nil
}
func (s *OpsAuthService) Logout(ctx context.Context, input OpsLogoutInput) error {
if strings.TrimSpace(input.RefreshToken) == "" {
return nil
}
return s.store.RevokeOpsRefreshToken(ctx, security.HashText(strings.TrimSpace(input.RefreshToken)))
}
func (s *OpsAuthService) GetMe(ctx context.Context, opsUserID string) (*OpsAuthUser, error) {
user, err := s.store.GetOpsUserByID(ctx, s.store.Pool(), opsUserID)
if err != nil {
return nil, err
}
if user == nil {
return nil, apperr.New(http.StatusNotFound, "ops_user_not_found", "ops user not found")
}
role, err := s.store.GetPrimaryOpsRole(ctx, s.store.Pool(), user.ID)
if err != nil {
return nil, err
}
result := buildOpsAuthUser(*user, role)
return &result, nil
}
func (s *OpsAuthService) issueAuthResult(ctx context.Context, tx postgres.Tx, user postgres.OpsUser, deviceKey string, newUser bool) (*OpsAuthResult, string, error) {
role, err := s.store.GetPrimaryOpsRole(ctx, tx, user.ID)
if err != nil {
return nil, "", err
}
roleCode := ""
if role != nil {
roleCode = role.RoleCode
}
accessToken, accessExpiresAt, err := s.jwtManager.IssueActorAccessToken(user.ID, user.PublicID, "ops", roleCode)
if err != nil {
return nil, "", err
}
refreshToken, err := security.GenerateToken(32)
if err != nil {
return nil, "", err
}
refreshTokenHash := security.HashText(refreshToken)
refreshExpiresAt := time.Now().UTC().Add(s.cfg.RefreshTTL)
refreshID, err := s.store.CreateOpsRefreshToken(ctx, tx, postgres.CreateOpsRefreshTokenParams{
OpsUserID: user.ID,
DeviceKey: deviceKey,
TokenHash: refreshTokenHash,
ExpiresAt: refreshExpiresAt,
})
if err != nil {
return nil, "", err
}
result := &OpsAuthResult{
User: buildOpsAuthUser(user, role),
Tokens: AuthTokens{
AccessToken: accessToken,
AccessTokenExpiresAt: accessExpiresAt.Format(time.RFC3339),
RefreshToken: refreshToken,
RefreshTokenExpiresAt: refreshExpiresAt.Format(time.RFC3339),
},
NewUser: newUser,
}
return result, refreshID, nil
}
func buildOpsAuthUser(user postgres.OpsUser, role *postgres.OpsRole) OpsAuthUser {
roleCode := ""
if role != nil {
roleCode = role.RoleCode
}
return OpsAuthUser{
ID: user.ID,
PublicID: user.PublicID,
DisplayName: user.DisplayName,
Status: user.Status,
RoleCode: roleCode,
}
}
func normalizeOpsScene(value string) string {
switch strings.TrimSpace(value) {
case "ops_register":
return "ops_register"
default:
return "ops_login"
}
}