396 lines
12 KiB
Go
396 lines
12 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"cmr-backend/internal/apperr"
|
|
"cmr-backend/internal/platform/jwtx"
|
|
"cmr-backend/internal/platform/security"
|
|
"cmr-backend/internal/store/postgres"
|
|
)
|
|
|
|
type OpsAuthSettings struct {
|
|
AppEnv string
|
|
RefreshTTL time.Duration
|
|
SMSCodeTTL time.Duration
|
|
SMSCodeCooldown time.Duration
|
|
SMSProvider string
|
|
DevSMSCode string
|
|
}
|
|
|
|
type OpsAuthService struct {
|
|
cfg OpsAuthSettings
|
|
store *postgres.Store
|
|
jwtManager *jwtx.Manager
|
|
}
|
|
|
|
type OpsSendSMSCodeInput struct {
|
|
CountryCode string `json:"countryCode"`
|
|
Mobile string `json:"mobile"`
|
|
DeviceKey string `json:"deviceKey"`
|
|
Scene string `json:"scene"`
|
|
}
|
|
|
|
type OpsRegisterInput struct {
|
|
CountryCode string `json:"countryCode"`
|
|
Mobile string `json:"mobile"`
|
|
Code string `json:"code"`
|
|
DeviceKey string `json:"deviceKey"`
|
|
DisplayName string `json:"displayName"`
|
|
}
|
|
|
|
type OpsLoginSMSInput struct {
|
|
CountryCode string `json:"countryCode"`
|
|
Mobile string `json:"mobile"`
|
|
Code string `json:"code"`
|
|
DeviceKey string `json:"deviceKey"`
|
|
}
|
|
|
|
type OpsRefreshTokenInput struct {
|
|
RefreshToken string `json:"refreshToken"`
|
|
DeviceKey string `json:"deviceKey"`
|
|
}
|
|
|
|
type OpsLogoutInput struct {
|
|
RefreshToken string `json:"refreshToken"`
|
|
}
|
|
|
|
type OpsAuthUser struct {
|
|
ID string `json:"id"`
|
|
PublicID string `json:"publicId"`
|
|
DisplayName string `json:"displayName"`
|
|
Status string `json:"status"`
|
|
RoleCode string `json:"roleCode"`
|
|
}
|
|
|
|
type OpsAuthResult struct {
|
|
User OpsAuthUser `json:"user"`
|
|
Tokens AuthTokens `json:"tokens"`
|
|
NewUser bool `json:"newUser"`
|
|
DevLoginBypass bool `json:"devLoginBypass,omitempty"`
|
|
}
|
|
|
|
func NewOpsAuthService(cfg OpsAuthSettings, store *postgres.Store, jwtManager *jwtx.Manager) *OpsAuthService {
|
|
return &OpsAuthService{cfg: cfg, store: store, jwtManager: jwtManager}
|
|
}
|
|
|
|
func (s *OpsAuthService) SendSMSCode(ctx context.Context, input OpsSendSMSCodeInput) (*SendSMSCodeResult, error) {
|
|
input.CountryCode = normalizeCountryCode(input.CountryCode)
|
|
input.Mobile = normalizeMobile(input.Mobile)
|
|
input.Scene = normalizeOpsScene(input.Scene)
|
|
if input.Mobile == "" || strings.TrimSpace(input.DeviceKey) == "" {
|
|
return nil, apperr.New(http.StatusBadRequest, "invalid_params", "mobile and deviceKey are required")
|
|
}
|
|
|
|
latest, err := s.store.GetLatestSMSCodeMeta(ctx, input.CountryCode, input.Mobile, "ops", input.Scene)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
now := time.Now().UTC()
|
|
if latest != nil && latest.CooldownUntil.After(now) {
|
|
return nil, apperr.New(http.StatusTooManyRequests, "sms_cooldown", "sms code sent too frequently")
|
|
}
|
|
|
|
code := s.cfg.DevSMSCode
|
|
if code == "" {
|
|
code, err = security.GenerateNumericCode(6)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
expiresAt := now.Add(s.cfg.SMSCodeTTL)
|
|
cooldownUntil := now.Add(s.cfg.SMSCodeCooldown)
|
|
if err := s.store.CreateSMSCode(ctx, postgres.CreateSMSCodeParams{
|
|
Scene: input.Scene,
|
|
CountryCode: input.CountryCode,
|
|
Mobile: input.Mobile,
|
|
ClientType: "ops",
|
|
DeviceKey: input.DeviceKey,
|
|
CodeHash: security.HashText(code),
|
|
ProviderName: s.cfg.SMSProvider,
|
|
ProviderDebug: map[string]any{"mode": s.cfg.SMSProvider, "channel": "ops_console"},
|
|
ExpiresAt: expiresAt,
|
|
CooldownUntil: cooldownUntil,
|
|
}); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
result := &SendSMSCodeResult{
|
|
TTLSeconds: int64(s.cfg.SMSCodeTTL.Seconds()),
|
|
CooldownSeconds: int64(s.cfg.SMSCodeCooldown.Seconds()),
|
|
}
|
|
if strings.EqualFold(s.cfg.SMSProvider, "console") || strings.EqualFold(s.cfg.AppEnv, "development") {
|
|
result.DevCode = &code
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func (s *OpsAuthService) Register(ctx context.Context, input OpsRegisterInput) (*OpsAuthResult, error) {
|
|
input.CountryCode = normalizeCountryCode(input.CountryCode)
|
|
input.Mobile = normalizeMobile(input.Mobile)
|
|
input.Code = strings.TrimSpace(input.Code)
|
|
input.DeviceKey = strings.TrimSpace(input.DeviceKey)
|
|
input.DisplayName = strings.TrimSpace(input.DisplayName)
|
|
if input.Mobile == "" || input.Code == "" || input.DeviceKey == "" || input.DisplayName == "" {
|
|
return nil, apperr.New(http.StatusBadRequest, "invalid_params", "mobile, code, deviceKey and displayName are required")
|
|
}
|
|
|
|
codeRecord, err := s.store.GetLatestValidSMSCode(ctx, input.CountryCode, input.Mobile, "ops", "ops_register")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if codeRecord == nil || codeRecord.CodeHash != security.HashText(input.Code) {
|
|
return nil, apperr.New(http.StatusUnauthorized, "invalid_sms_code", "invalid sms code")
|
|
}
|
|
|
|
tx, err := s.store.Begin(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer tx.Rollback(ctx)
|
|
|
|
consumed, err := s.store.ConsumeSMSCode(ctx, tx, codeRecord.ID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if !consumed {
|
|
return nil, apperr.New(http.StatusUnauthorized, "invalid_sms_code", "sms code already used")
|
|
}
|
|
|
|
existing, err := s.store.GetOpsUserByMobile(ctx, tx, input.CountryCode, input.Mobile)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if existing != nil {
|
|
return nil, apperr.New(http.StatusConflict, "ops_user_exists", "ops user already exists")
|
|
}
|
|
|
|
publicID, err := security.GeneratePublicID("ops")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
user, err := s.store.CreateOpsUser(ctx, tx, postgres.CreateOpsUserParams{
|
|
PublicID: publicID,
|
|
CountryCode: input.CountryCode,
|
|
Mobile: input.Mobile,
|
|
DisplayName: input.DisplayName,
|
|
Status: "active",
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
roleCode := "operator"
|
|
count, err := s.store.CountOpsUsers(ctx)
|
|
if err == nil && count == 0 {
|
|
roleCode = "owner"
|
|
}
|
|
role, err := s.store.GetOpsRoleByCode(ctx, tx, roleCode)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if role == nil {
|
|
return nil, apperr.New(http.StatusInternalServerError, "ops_role_missing", "default ops role is missing")
|
|
}
|
|
if err := s.store.AssignOpsRole(ctx, tx, user.ID, role.ID); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := s.store.TouchOpsUserLogin(ctx, tx, user.ID); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
result, _, err := s.issueAuthResult(ctx, tx, *user, input.DeviceKey, true)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if err := tx.Commit(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func (s *OpsAuthService) LoginSMS(ctx context.Context, input OpsLoginSMSInput) (*OpsAuthResult, error) {
|
|
input.CountryCode = normalizeCountryCode(input.CountryCode)
|
|
input.Mobile = normalizeMobile(input.Mobile)
|
|
input.Code = strings.TrimSpace(input.Code)
|
|
input.DeviceKey = strings.TrimSpace(input.DeviceKey)
|
|
if input.Mobile == "" || input.Code == "" || input.DeviceKey == "" {
|
|
return nil, apperr.New(http.StatusBadRequest, "invalid_params", "mobile, code and deviceKey are required")
|
|
}
|
|
|
|
codeRecord, err := s.store.GetLatestValidSMSCode(ctx, input.CountryCode, input.Mobile, "ops", "ops_login")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if codeRecord == nil || codeRecord.CodeHash != security.HashText(input.Code) {
|
|
return nil, apperr.New(http.StatusUnauthorized, "invalid_sms_code", "invalid sms code")
|
|
}
|
|
|
|
tx, err := s.store.Begin(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer tx.Rollback(ctx)
|
|
|
|
consumed, err := s.store.ConsumeSMSCode(ctx, tx, codeRecord.ID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if !consumed {
|
|
return nil, apperr.New(http.StatusUnauthorized, "invalid_sms_code", "sms code already used")
|
|
}
|
|
|
|
user, err := s.store.GetOpsUserByMobile(ctx, tx, input.CountryCode, input.Mobile)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if user == nil {
|
|
return nil, apperr.New(http.StatusNotFound, "ops_user_not_found", "ops user not found")
|
|
}
|
|
if user.Status != "active" {
|
|
return nil, apperr.New(http.StatusForbidden, "ops_user_inactive", "ops user is not active")
|
|
}
|
|
if err := s.store.TouchOpsUserLogin(ctx, tx, user.ID); err != nil {
|
|
return nil, err
|
|
}
|
|
result, _, err := s.issueAuthResult(ctx, tx, *user, input.DeviceKey, false)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if err := tx.Commit(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func (s *OpsAuthService) Refresh(ctx context.Context, input OpsRefreshTokenInput) (*OpsAuthResult, error) {
|
|
input.RefreshToken = strings.TrimSpace(input.RefreshToken)
|
|
if input.RefreshToken == "" {
|
|
return nil, apperr.New(http.StatusBadRequest, "invalid_params", "refreshToken is required")
|
|
}
|
|
tx, err := s.store.Begin(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer tx.Rollback(ctx)
|
|
|
|
record, err := s.store.GetOpsRefreshTokenForUpdate(ctx, tx, security.HashText(input.RefreshToken))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if record == nil || record.IsRevoked || record.ExpiresAt.Before(time.Now().UTC()) {
|
|
return nil, apperr.New(http.StatusUnauthorized, "invalid_refresh_token", "refresh token is invalid or expired")
|
|
}
|
|
if input.DeviceKey != "" && record.DeviceKey != nil && input.DeviceKey != *record.DeviceKey {
|
|
return nil, apperr.New(http.StatusUnauthorized, "invalid_refresh_token", "refresh token device mismatch")
|
|
}
|
|
user, err := s.store.GetOpsUserByID(ctx, tx, record.OpsUserID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if user == nil || user.Status != "active" {
|
|
return nil, apperr.New(http.StatusUnauthorized, "invalid_refresh_token", "refresh token user not found")
|
|
}
|
|
result, newTokenID, err := s.issueAuthResult(ctx, tx, *user, nullableStringValue(record.DeviceKey), false)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if err := s.store.RotateOpsRefreshToken(ctx, tx, record.ID, newTokenID); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := tx.Commit(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func (s *OpsAuthService) Logout(ctx context.Context, input OpsLogoutInput) error {
|
|
if strings.TrimSpace(input.RefreshToken) == "" {
|
|
return nil
|
|
}
|
|
return s.store.RevokeOpsRefreshToken(ctx, security.HashText(strings.TrimSpace(input.RefreshToken)))
|
|
}
|
|
|
|
func (s *OpsAuthService) GetMe(ctx context.Context, opsUserID string) (*OpsAuthUser, error) {
|
|
user, err := s.store.GetOpsUserByID(ctx, s.store.Pool(), opsUserID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if user == nil {
|
|
return nil, apperr.New(http.StatusNotFound, "ops_user_not_found", "ops user not found")
|
|
}
|
|
role, err := s.store.GetPrimaryOpsRole(ctx, s.store.Pool(), user.ID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
result := buildOpsAuthUser(*user, role)
|
|
return &result, nil
|
|
}
|
|
|
|
func (s *OpsAuthService) issueAuthResult(ctx context.Context, tx postgres.Tx, user postgres.OpsUser, deviceKey string, newUser bool) (*OpsAuthResult, string, error) {
|
|
role, err := s.store.GetPrimaryOpsRole(ctx, tx, user.ID)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
roleCode := ""
|
|
if role != nil {
|
|
roleCode = role.RoleCode
|
|
}
|
|
accessToken, accessExpiresAt, err := s.jwtManager.IssueActorAccessToken(user.ID, user.PublicID, "ops", roleCode)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
refreshToken, err := security.GenerateToken(32)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
refreshTokenHash := security.HashText(refreshToken)
|
|
refreshExpiresAt := time.Now().UTC().Add(s.cfg.RefreshTTL)
|
|
refreshID, err := s.store.CreateOpsRefreshToken(ctx, tx, postgres.CreateOpsRefreshTokenParams{
|
|
OpsUserID: user.ID,
|
|
DeviceKey: deviceKey,
|
|
TokenHash: refreshTokenHash,
|
|
ExpiresAt: refreshExpiresAt,
|
|
})
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
result := &OpsAuthResult{
|
|
User: buildOpsAuthUser(user, role),
|
|
Tokens: AuthTokens{
|
|
AccessToken: accessToken,
|
|
AccessTokenExpiresAt: accessExpiresAt.Format(time.RFC3339),
|
|
RefreshToken: refreshToken,
|
|
RefreshTokenExpiresAt: refreshExpiresAt.Format(time.RFC3339),
|
|
},
|
|
NewUser: newUser,
|
|
}
|
|
return result, refreshID, nil
|
|
}
|
|
|
|
func buildOpsAuthUser(user postgres.OpsUser, role *postgres.OpsRole) OpsAuthUser {
|
|
roleCode := ""
|
|
if role != nil {
|
|
roleCode = role.RoleCode
|
|
}
|
|
return OpsAuthUser{
|
|
ID: user.ID,
|
|
PublicID: user.PublicID,
|
|
DisplayName: user.DisplayName,
|
|
Status: user.Status,
|
|
RoleCode: roleCode,
|
|
}
|
|
}
|
|
|
|
func normalizeOpsScene(value string) string {
|
|
switch strings.TrimSpace(value) {
|
|
case "ops_register":
|
|
return "ops_register"
|
|
default:
|
|
return "ops_login"
|
|
}
|
|
}
|