596 lines
16 KiB
Go
596 lines
16 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"cmr-backend/internal/apperr"
|
|
"cmr-backend/internal/platform/jwtx"
|
|
"cmr-backend/internal/platform/security"
|
|
"cmr-backend/internal/platform/wechatmini"
|
|
"cmr-backend/internal/store/postgres"
|
|
)
|
|
|
|
type AuthSettings struct {
|
|
AppEnv string
|
|
RefreshTTL time.Duration
|
|
SMSCodeTTL time.Duration
|
|
SMSCodeCooldown time.Duration
|
|
SMSProvider string
|
|
DevSMSCode string
|
|
WechatMini *wechatmini.Client
|
|
}
|
|
|
|
type AuthService struct {
|
|
cfg AuthSettings
|
|
store *postgres.Store
|
|
jwtManager *jwtx.Manager
|
|
}
|
|
|
|
type SendSMSCodeInput struct {
|
|
CountryCode string `json:"countryCode"`
|
|
Mobile string `json:"mobile"`
|
|
ClientType string `json:"clientType"`
|
|
DeviceKey string `json:"deviceKey"`
|
|
Scene string `json:"scene"`
|
|
}
|
|
|
|
type SendSMSCodeResult struct {
|
|
TTLSeconds int64 `json:"ttlSeconds"`
|
|
CooldownSeconds int64 `json:"cooldownSeconds"`
|
|
DevCode *string `json:"devCode,omitempty"`
|
|
}
|
|
|
|
type LoginSMSInput struct {
|
|
CountryCode string `json:"countryCode"`
|
|
Mobile string `json:"mobile"`
|
|
Code string `json:"code"`
|
|
ClientType string `json:"clientType"`
|
|
DeviceKey string `json:"deviceKey"`
|
|
}
|
|
|
|
type LoginWechatMiniInput struct {
|
|
Code string `json:"code"`
|
|
ClientType string `json:"clientType"`
|
|
DeviceKey string `json:"deviceKey"`
|
|
}
|
|
|
|
type BindMobileInput struct {
|
|
UserID string `json:"-"`
|
|
CountryCode string `json:"countryCode"`
|
|
Mobile string `json:"mobile"`
|
|
Code string `json:"code"`
|
|
ClientType string `json:"clientType"`
|
|
DeviceKey string `json:"deviceKey"`
|
|
}
|
|
|
|
type RefreshTokenInput struct {
|
|
RefreshToken string `json:"refreshToken"`
|
|
ClientType string `json:"clientType"`
|
|
DeviceKey string `json:"deviceKey"`
|
|
}
|
|
|
|
type LogoutInput struct {
|
|
RefreshToken string `json:"refreshToken"`
|
|
UserID string `json:"-"`
|
|
}
|
|
|
|
type AuthUser struct {
|
|
ID string `json:"id"`
|
|
PublicID string `json:"publicId"`
|
|
Status string `json:"status"`
|
|
Nickname *string `json:"nickname,omitempty"`
|
|
AvatarURL *string `json:"avatarUrl,omitempty"`
|
|
}
|
|
|
|
type AuthTokens struct {
|
|
AccessToken string `json:"accessToken"`
|
|
AccessTokenExpiresAt string `json:"accessTokenExpiresAt"`
|
|
RefreshToken string `json:"refreshToken"`
|
|
RefreshTokenExpiresAt string `json:"refreshTokenExpiresAt"`
|
|
}
|
|
|
|
type AuthResult struct {
|
|
User AuthUser `json:"user"`
|
|
Tokens AuthTokens `json:"tokens"`
|
|
NewUser bool `json:"newUser"`
|
|
}
|
|
|
|
func NewAuthService(cfg AuthSettings, store *postgres.Store, jwtManager *jwtx.Manager) *AuthService {
|
|
return &AuthService{
|
|
cfg: cfg,
|
|
store: store,
|
|
jwtManager: jwtManager,
|
|
}
|
|
}
|
|
|
|
func (s *AuthService) SendSMSCode(ctx context.Context, input SendSMSCodeInput) (*SendSMSCodeResult, error) {
|
|
input.CountryCode = normalizeCountryCode(input.CountryCode)
|
|
input.Mobile = normalizeMobile(input.Mobile)
|
|
input.Scene = normalizeScene(input.Scene)
|
|
|
|
if err := validateClientType(input.ClientType); err != nil {
|
|
return nil, err
|
|
}
|
|
if input.Mobile == "" || input.DeviceKey == "" {
|
|
return nil, apperr.New(http.StatusBadRequest, "invalid_params", "mobile and deviceKey are required")
|
|
}
|
|
|
|
latest, err := s.store.GetLatestSMSCodeMeta(ctx, input.CountryCode, input.Mobile, input.ClientType, input.Scene)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
now := time.Now().UTC()
|
|
if latest != nil && latest.CooldownUntil.After(now) {
|
|
return nil, apperr.New(http.StatusTooManyRequests, "sms_cooldown", "sms code sent too frequently")
|
|
}
|
|
|
|
code := s.cfg.DevSMSCode
|
|
if code == "" {
|
|
code, err = security.GenerateNumericCode(6)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
expiresAt := now.Add(s.cfg.SMSCodeTTL)
|
|
cooldownUntil := now.Add(s.cfg.SMSCodeCooldown)
|
|
if err := s.store.CreateSMSCode(ctx, postgres.CreateSMSCodeParams{
|
|
Scene: input.Scene,
|
|
CountryCode: input.CountryCode,
|
|
Mobile: input.Mobile,
|
|
ClientType: input.ClientType,
|
|
DeviceKey: input.DeviceKey,
|
|
CodeHash: security.HashText(code),
|
|
ProviderName: s.cfg.SMSProvider,
|
|
ProviderDebug: map[string]any{"mode": s.cfg.SMSProvider},
|
|
ExpiresAt: expiresAt,
|
|
CooldownUntil: cooldownUntil,
|
|
}); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
result := &SendSMSCodeResult{
|
|
TTLSeconds: int64(s.cfg.SMSCodeTTL.Seconds()),
|
|
CooldownSeconds: int64(s.cfg.SMSCodeCooldown.Seconds()),
|
|
}
|
|
if strings.EqualFold(s.cfg.SMSProvider, "console") || strings.EqualFold(s.cfg.AppEnv, "development") {
|
|
result.DevCode = &code
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func (s *AuthService) LoginSMS(ctx context.Context, input LoginSMSInput) (*AuthResult, error) {
|
|
input.CountryCode = normalizeCountryCode(input.CountryCode)
|
|
input.Mobile = normalizeMobile(input.Mobile)
|
|
input.Code = strings.TrimSpace(input.Code)
|
|
|
|
if err := validateClientType(input.ClientType); err != nil {
|
|
return nil, err
|
|
}
|
|
if input.Mobile == "" || input.DeviceKey == "" || input.Code == "" {
|
|
return nil, apperr.New(http.StatusBadRequest, "invalid_params", "mobile, code and deviceKey are required")
|
|
}
|
|
|
|
codeRecord, err := s.store.GetLatestValidSMSCode(ctx, input.CountryCode, input.Mobile, input.ClientType, "login")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if codeRecord == nil || codeRecord.CodeHash != security.HashText(input.Code) {
|
|
return nil, apperr.New(http.StatusUnauthorized, "invalid_sms_code", "invalid sms code")
|
|
}
|
|
|
|
tx, err := s.store.Begin(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer tx.Rollback(ctx)
|
|
|
|
consumed, err := s.store.ConsumeSMSCode(ctx, tx, codeRecord.ID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if !consumed {
|
|
return nil, apperr.New(http.StatusUnauthorized, "invalid_sms_code", "sms code already used")
|
|
}
|
|
|
|
user, err := s.store.FindUserByMobile(ctx, tx, input.CountryCode, input.Mobile)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
newUser := false
|
|
if user == nil {
|
|
userPublicID, err := security.GeneratePublicID("usr")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
user, err = s.store.CreateUser(ctx, tx, postgres.CreateUserParams{
|
|
PublicID: userPublicID,
|
|
Status: "active",
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := s.store.CreateMobileIdentity(ctx, tx, postgres.CreateMobileIdentityParams{
|
|
UserID: user.ID,
|
|
CountryCode: input.CountryCode,
|
|
Mobile: input.Mobile,
|
|
Provider: "mobile",
|
|
ProviderSubj: input.CountryCode + ":" + input.Mobile,
|
|
IdentityType: "mobile",
|
|
}); err != nil {
|
|
return nil, err
|
|
}
|
|
newUser = true
|
|
}
|
|
|
|
if err := s.store.TouchUserLogin(ctx, tx, user.ID); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
result, err := s.issueAuthResult(ctx, tx, *user, input.ClientType, input.DeviceKey, newUser)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := tx.Commit(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func (s *AuthService) Refresh(ctx context.Context, input RefreshTokenInput) (*AuthResult, error) {
|
|
input.RefreshToken = strings.TrimSpace(input.RefreshToken)
|
|
if err := validateClientType(input.ClientType); err != nil {
|
|
return nil, err
|
|
}
|
|
if input.RefreshToken == "" {
|
|
return nil, apperr.New(http.StatusBadRequest, "invalid_params", "refreshToken is required")
|
|
}
|
|
|
|
tx, err := s.store.Begin(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer tx.Rollback(ctx)
|
|
|
|
record, err := s.store.GetRefreshTokenForUpdate(ctx, tx, security.HashText(input.RefreshToken))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if record == nil || record.IsRevoked || record.ExpiresAt.Before(time.Now().UTC()) {
|
|
return nil, apperr.New(http.StatusUnauthorized, "invalid_refresh_token", "refresh token is invalid or expired")
|
|
}
|
|
|
|
if input.ClientType != "" && input.ClientType != record.ClientType {
|
|
return nil, apperr.New(http.StatusUnauthorized, "invalid_refresh_token", "refresh token client mismatch")
|
|
}
|
|
if input.DeviceKey != "" && record.DeviceKey != nil && input.DeviceKey != *record.DeviceKey {
|
|
return nil, apperr.New(http.StatusUnauthorized, "invalid_refresh_token", "refresh token device mismatch")
|
|
}
|
|
|
|
user, err := s.store.GetUserByID(ctx, tx, record.UserID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if user == nil {
|
|
return nil, apperr.New(http.StatusUnauthorized, "invalid_refresh_token", "refresh token user not found")
|
|
}
|
|
|
|
result, refreshTokenID, err := s.issueAuthResultWithRefreshID(ctx, tx, *user, record.ClientType, nullableStringValue(record.DeviceKey), false)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if err := s.store.RotateRefreshToken(ctx, tx, record.ID, refreshTokenID); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := tx.Commit(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func (s *AuthService) LoginWechatMini(ctx context.Context, input LoginWechatMiniInput) (*AuthResult, error) {
|
|
input.Code = strings.TrimSpace(input.Code)
|
|
if err := validateClientType(input.ClientType); err != nil {
|
|
return nil, err
|
|
}
|
|
if input.ClientType != "wechat" {
|
|
return nil, apperr.New(http.StatusBadRequest, "invalid_client_type", "wechat mini login requires clientType=wechat")
|
|
}
|
|
if input.Code == "" || strings.TrimSpace(input.DeviceKey) == "" {
|
|
return nil, apperr.New(http.StatusBadRequest, "invalid_params", "code and deviceKey are required")
|
|
}
|
|
if s.cfg.WechatMini == nil {
|
|
return nil, apperr.New(http.StatusNotImplemented, "wechat_not_configured", "wechat mini provider is not configured")
|
|
}
|
|
|
|
session, err := s.cfg.WechatMini.ExchangeCode(ctx, input.Code)
|
|
if err != nil {
|
|
return nil, apperr.New(http.StatusUnauthorized, "wechat_login_failed", err.Error())
|
|
}
|
|
|
|
openIDSubject := session.AppID + ":" + session.OpenID
|
|
unionIDSubject := strings.TrimSpace(session.UnionID)
|
|
|
|
tx, err := s.store.Begin(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer tx.Rollback(ctx)
|
|
|
|
user, err := s.store.FindUserByProviderSubject(ctx, tx, "wechat_mini", openIDSubject)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if user == nil && unionIDSubject != "" {
|
|
user, err = s.store.FindUserByProviderSubject(ctx, tx, "wechat_unionid", unionIDSubject)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
newUser := false
|
|
if user == nil {
|
|
userPublicID, err := security.GeneratePublicID("usr")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
user, err = s.store.CreateUser(ctx, tx, postgres.CreateUserParams{
|
|
PublicID: userPublicID,
|
|
Status: "active",
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
newUser = true
|
|
}
|
|
|
|
profileJSON, err := json.Marshal(map[string]any{
|
|
"appId": session.AppID,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := s.store.CreateIdentity(ctx, tx, postgres.CreateIdentityParams{
|
|
UserID: user.ID,
|
|
IdentityType: "wechat_mini_openid",
|
|
Provider: "wechat_mini",
|
|
ProviderSubj: openIDSubject,
|
|
ProfileJSON: string(profileJSON),
|
|
}); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if unionIDSubject != "" {
|
|
if err := s.store.CreateIdentity(ctx, tx, postgres.CreateIdentityParams{
|
|
UserID: user.ID,
|
|
IdentityType: "wechat_unionid",
|
|
Provider: "wechat_unionid",
|
|
ProviderSubj: unionIDSubject,
|
|
ProfileJSON: "{}",
|
|
}); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
if err := s.store.TouchUserLogin(ctx, tx, user.ID); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
result, err := s.issueAuthResult(ctx, tx, *user, input.ClientType, input.DeviceKey, newUser)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := tx.Commit(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func (s *AuthService) BindMobile(ctx context.Context, input BindMobileInput) (*AuthResult, error) {
|
|
input.CountryCode = normalizeCountryCode(input.CountryCode)
|
|
input.Mobile = normalizeMobile(input.Mobile)
|
|
input.Code = strings.TrimSpace(input.Code)
|
|
|
|
if err := validateClientType(input.ClientType); err != nil {
|
|
return nil, err
|
|
}
|
|
if input.UserID == "" || input.Mobile == "" || input.Code == "" || strings.TrimSpace(input.DeviceKey) == "" {
|
|
return nil, apperr.New(http.StatusBadRequest, "invalid_params", "user, mobile, code and deviceKey are required")
|
|
}
|
|
|
|
codeRecord, err := s.store.GetLatestValidSMSCode(ctx, input.CountryCode, input.Mobile, input.ClientType, "bind_mobile")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if codeRecord == nil || codeRecord.CodeHash != security.HashText(input.Code) {
|
|
return nil, apperr.New(http.StatusUnauthorized, "invalid_sms_code", "invalid sms code")
|
|
}
|
|
|
|
tx, err := s.store.Begin(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer tx.Rollback(ctx)
|
|
|
|
consumed, err := s.store.ConsumeSMSCode(ctx, tx, codeRecord.ID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if !consumed {
|
|
return nil, apperr.New(http.StatusUnauthorized, "invalid_sms_code", "sms code already used")
|
|
}
|
|
|
|
currentUser, err := s.store.GetUserByID(ctx, tx, input.UserID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if currentUser == nil {
|
|
return nil, apperr.New(http.StatusNotFound, "user_not_found", "current user not found")
|
|
}
|
|
|
|
mobileUser, err := s.store.FindUserByMobile(ctx, tx, input.CountryCode, input.Mobile)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
finalUser := currentUser
|
|
newlyBound := false
|
|
|
|
if mobileUser == nil {
|
|
if err := s.store.CreateMobileIdentity(ctx, tx, postgres.CreateMobileIdentityParams{
|
|
UserID: currentUser.ID,
|
|
CountryCode: input.CountryCode,
|
|
Mobile: input.Mobile,
|
|
Provider: "mobile",
|
|
ProviderSubj: input.CountryCode + ":" + input.Mobile,
|
|
IdentityType: "mobile",
|
|
}); err != nil {
|
|
return nil, err
|
|
}
|
|
newlyBound = true
|
|
} else if mobileUser.ID != currentUser.ID {
|
|
if err := s.store.TransferNonMobileIdentities(ctx, tx, currentUser.ID, mobileUser.ID); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := s.store.RevokeRefreshTokensByUserID(ctx, tx, currentUser.ID); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := s.store.DeactivateUser(ctx, tx, currentUser.ID); err != nil {
|
|
return nil, err
|
|
}
|
|
finalUser = mobileUser
|
|
}
|
|
|
|
if err := s.store.TouchUserLogin(ctx, tx, finalUser.ID); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
result, err := s.issueAuthResult(ctx, tx, *finalUser, input.ClientType, input.DeviceKey, newlyBound)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := tx.Commit(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func (s *AuthService) Logout(ctx context.Context, input LogoutInput) error {
|
|
if strings.TrimSpace(input.RefreshToken) == "" {
|
|
return nil
|
|
}
|
|
return s.store.RevokeRefreshToken(ctx, security.HashText(strings.TrimSpace(input.RefreshToken)))
|
|
}
|
|
|
|
func (s *AuthService) issueAuthResult(
|
|
ctx context.Context,
|
|
tx postgres.Tx,
|
|
user postgres.User,
|
|
clientType string,
|
|
deviceKey string,
|
|
newUser bool,
|
|
) (*AuthResult, error) {
|
|
result, _, err := s.issueAuthResultWithRefreshID(ctx, tx, user, clientType, deviceKey, newUser)
|
|
return result, err
|
|
}
|
|
|
|
func (s *AuthService) issueAuthResultWithRefreshID(
|
|
ctx context.Context,
|
|
tx postgres.Tx,
|
|
user postgres.User,
|
|
clientType string,
|
|
deviceKey string,
|
|
newUser bool,
|
|
) (*AuthResult, string, error) {
|
|
accessToken, accessExpiresAt, err := s.jwtManager.IssueAccessToken(user.ID, user.PublicID)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
refreshToken, err := security.GenerateToken(32)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
refreshTokenHash := security.HashText(refreshToken)
|
|
refreshExpiresAt := time.Now().UTC().Add(s.cfg.RefreshTTL)
|
|
|
|
refreshID, err := s.store.CreateRefreshToken(ctx, tx, postgres.CreateRefreshTokenParams{
|
|
UserID: user.ID,
|
|
ClientType: clientType,
|
|
DeviceKey: deviceKey,
|
|
TokenHash: refreshTokenHash,
|
|
ExpiresAt: refreshExpiresAt,
|
|
})
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
return &AuthResult{
|
|
User: AuthUser{
|
|
ID: user.ID,
|
|
PublicID: user.PublicID,
|
|
Status: user.Status,
|
|
Nickname: user.Nickname,
|
|
AvatarURL: user.AvatarURL,
|
|
},
|
|
Tokens: AuthTokens{
|
|
AccessToken: accessToken,
|
|
AccessTokenExpiresAt: accessExpiresAt.Format(time.RFC3339),
|
|
RefreshToken: refreshToken,
|
|
RefreshTokenExpiresAt: refreshExpiresAt.Format(time.RFC3339),
|
|
},
|
|
NewUser: newUser,
|
|
}, refreshID, nil
|
|
}
|
|
|
|
func validateClientType(clientType string) error {
|
|
switch clientType {
|
|
case "app", "wechat":
|
|
return nil
|
|
default:
|
|
return apperr.New(http.StatusBadRequest, "invalid_client_type", "clientType must be app or wechat")
|
|
}
|
|
}
|
|
|
|
func normalizeCountryCode(value string) string {
|
|
value = strings.TrimSpace(value)
|
|
if value == "" {
|
|
return "86"
|
|
}
|
|
return strings.TrimPrefix(value, "+")
|
|
}
|
|
|
|
func normalizeMobile(value string) string {
|
|
value = strings.TrimSpace(value)
|
|
value = strings.ReplaceAll(value, " ", "")
|
|
value = strings.ReplaceAll(value, "-", "")
|
|
return value
|
|
}
|
|
|
|
func normalizeScene(value string) string {
|
|
value = strings.TrimSpace(value)
|
|
if value == "" {
|
|
return "login"
|
|
}
|
|
return value
|
|
}
|
|
|
|
func nullableStringValue(value *string) string {
|
|
if value == nil {
|
|
return ""
|
|
}
|
|
return *value
|
|
}
|