Files
cmr-mini/backend/internal/store/postgres/auth_store.go

311 lines
8.3 KiB
Go

package postgres
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"time"
"cmr-backend/internal/apperr"
"github.com/jackc/pgx/v5"
)
type SMSCodeMeta struct {
ID string
CodeHash string
ExpiresAt time.Time
CooldownUntil time.Time
}
type CreateSMSCodeParams struct {
Scene string
CountryCode string
Mobile string
ClientType string
DeviceKey string
CodeHash string
ProviderName string
ProviderDebug map[string]any
ExpiresAt time.Time
CooldownUntil time.Time
}
type CreateMobileIdentityParams struct {
UserID string
IdentityType string
Provider string
ProviderSubj string
CountryCode string
Mobile string
}
type CreateIdentityParams struct {
UserID string
IdentityType string
Provider string
ProviderSubj string
CountryCode *string
Mobile *string
ProfileJSON string
}
type CreateRefreshTokenParams struct {
UserID string
ClientType string
DeviceKey string
TokenHash string
ExpiresAt time.Time
}
type RefreshTokenRecord struct {
ID string
UserID string
ClientType string
DeviceKey *string
ExpiresAt time.Time
IsRevoked bool
}
func (s *Store) GetLatestSMSCodeMeta(ctx context.Context, countryCode, mobile, clientType, scene string) (*SMSCodeMeta, error) {
row := s.pool.QueryRow(ctx, `
SELECT id, code_hash, expires_at, cooldown_until
FROM auth_sms_codes
WHERE country_code = $1 AND mobile = $2 AND client_type = $3 AND scene = $4
ORDER BY created_at DESC
LIMIT 1
`, countryCode, mobile, clientType, scene)
var record SMSCodeMeta
err := row.Scan(&record.ID, &record.CodeHash, &record.ExpiresAt, &record.CooldownUntil)
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("query latest sms code meta: %w", err)
}
return &record, nil
}
func (s *Store) CreateSMSCode(ctx context.Context, params CreateSMSCodeParams) error {
payload, err := json.Marshal(map[string]any{
"provider": params.ProviderName,
"debug": params.ProviderDebug,
})
if err != nil {
return err
}
_, err = s.pool.Exec(ctx, `
INSERT INTO auth_sms_codes (
scene, country_code, mobile, client_type, device_key, code_hash,
provider_payload_jsonb, expires_at, cooldown_until
)
VALUES ($1, $2, $3, $4, $5, $6, $7::jsonb, $8, $9)
`, params.Scene, params.CountryCode, params.Mobile, params.ClientType, params.DeviceKey, params.CodeHash, string(payload), params.ExpiresAt, params.CooldownUntil)
if err != nil {
return fmt.Errorf("insert sms code: %w", err)
}
return nil
}
func (s *Store) GetLatestValidSMSCode(ctx context.Context, countryCode, mobile, clientType, scene string) (*SMSCodeMeta, error) {
row := s.pool.QueryRow(ctx, `
SELECT id, code_hash, expires_at, cooldown_until
FROM auth_sms_codes
WHERE country_code = $1
AND mobile = $2
AND client_type = $3
AND scene = $4
AND consumed_at IS NULL
AND expires_at > NOW()
ORDER BY created_at DESC
LIMIT 1
`, countryCode, mobile, clientType, scene)
var record SMSCodeMeta
err := row.Scan(&record.ID, &record.CodeHash, &record.ExpiresAt, &record.CooldownUntil)
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("query latest valid sms code: %w", err)
}
return &record, nil
}
func (s *Store) ConsumeSMSCode(ctx context.Context, tx Tx, id string) (bool, error) {
commandTag, err := tx.Exec(ctx, `
UPDATE auth_sms_codes
SET consumed_at = NOW()
WHERE id = $1 AND consumed_at IS NULL
`, id)
if err != nil {
return false, fmt.Errorf("consume sms code: %w", err)
}
return commandTag.RowsAffected() == 1, nil
}
func (s *Store) CreateMobileIdentity(ctx context.Context, tx Tx, params CreateMobileIdentityParams) error {
countryCode := params.CountryCode
mobile := params.Mobile
return s.CreateIdentity(ctx, tx, CreateIdentityParams{
UserID: params.UserID,
IdentityType: params.IdentityType,
Provider: params.Provider,
ProviderSubj: params.ProviderSubj,
CountryCode: &countryCode,
Mobile: &mobile,
ProfileJSON: "{}",
})
}
func (s *Store) CreateIdentity(ctx context.Context, tx Tx, params CreateIdentityParams) error {
_, err := tx.Exec(ctx, `
INSERT INTO login_identities (
user_id, identity_type, provider, provider_subject, country_code, mobile, status, profile_jsonb
)
VALUES ($1, $2, $3, $4, $5, $6, 'active', $7::jsonb)
ON CONFLICT (provider, provider_subject) DO NOTHING
`, params.UserID, params.IdentityType, params.Provider, params.ProviderSubj, params.CountryCode, params.Mobile, zeroJSON(params.ProfileJSON))
if err != nil {
return fmt.Errorf("create identity: %w", err)
}
return nil
}
func (s *Store) FindUserByProviderSubject(ctx context.Context, tx Tx, provider, providerSubject string) (*User, error) {
row := tx.QueryRow(ctx, `
SELECT u.id, u.user_public_id, u.status, u.nickname, u.avatar_url
FROM users u
JOIN login_identities li ON li.user_id = u.id
WHERE li.provider = $1
AND li.provider_subject = $2
AND li.status = 'active'
LIMIT 1
`, provider, providerSubject)
return scanUser(row)
}
func (s *Store) CreateRefreshToken(ctx context.Context, tx Tx, params CreateRefreshTokenParams) (string, error) {
row := tx.QueryRow(ctx, `
INSERT INTO auth_refresh_tokens (user_id, client_type, device_key, token_hash, expires_at)
VALUES ($1, $2, NULLIF($3, ''), $4, $5)
RETURNING id
`, params.UserID, params.ClientType, params.DeviceKey, params.TokenHash, params.ExpiresAt)
var id string
if err := row.Scan(&id); err != nil {
return "", fmt.Errorf("create refresh token: %w", err)
}
return id, nil
}
func (s *Store) GetRefreshTokenForUpdate(ctx context.Context, tx Tx, tokenHash string) (*RefreshTokenRecord, error) {
row := tx.QueryRow(ctx, `
SELECT id, user_id, client_type, device_key, expires_at, revoked_at IS NOT NULL
FROM auth_refresh_tokens
WHERE token_hash = $1
FOR UPDATE
`, tokenHash)
var record RefreshTokenRecord
err := row.Scan(&record.ID, &record.UserID, &record.ClientType, &record.DeviceKey, &record.ExpiresAt, &record.IsRevoked)
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("query refresh token for update: %w", err)
}
return &record, nil
}
func (s *Store) RotateRefreshToken(ctx context.Context, tx Tx, oldTokenID, newTokenID string) error {
_, err := tx.Exec(ctx, `
UPDATE auth_refresh_tokens
SET revoked_at = NOW(), replaced_by_token_id = $2
WHERE id = $1
`, oldTokenID, newTokenID)
if err != nil {
return fmt.Errorf("rotate refresh token: %w", err)
}
return nil
}
func (s *Store) RevokeRefreshToken(ctx context.Context, tokenHash string) error {
commandTag, err := s.pool.Exec(ctx, `
UPDATE auth_refresh_tokens
SET revoked_at = COALESCE(revoked_at, NOW())
WHERE token_hash = $1
`, tokenHash)
if err != nil {
return fmt.Errorf("revoke refresh token: %w", err)
}
if commandTag.RowsAffected() == 0 {
return apperr.New(http.StatusNotFound, "refresh_token_not_found", "refresh token not found")
}
return nil
}
func (s *Store) RevokeRefreshTokensByUserID(ctx context.Context, tx Tx, userID string) error {
_, err := tx.Exec(ctx, `
UPDATE auth_refresh_tokens
SET revoked_at = COALESCE(revoked_at, NOW())
WHERE user_id = $1
`, userID)
if err != nil {
return fmt.Errorf("revoke refresh tokens by user id: %w", err)
}
return nil
}
func (s *Store) TransferNonMobileIdentities(ctx context.Context, tx Tx, sourceUserID, targetUserID string) error {
if sourceUserID == targetUserID {
return nil
}
_, err := tx.Exec(ctx, `
INSERT INTO login_identities (
user_id, identity_type, provider, provider_subject, country_code, mobile, status, profile_jsonb, created_at, updated_at
)
SELECT
$2,
li.identity_type,
li.provider,
li.provider_subject,
li.country_code,
li.mobile,
li.status,
li.profile_jsonb,
li.created_at,
li.updated_at
FROM login_identities li
WHERE li.user_id = $1
AND li.provider <> 'mobile'
ON CONFLICT (provider, provider_subject) DO NOTHING
`, sourceUserID, targetUserID)
if err != nil {
return fmt.Errorf("copy non-mobile identities: %w", err)
}
_, err = tx.Exec(ctx, `
DELETE FROM login_identities
WHERE user_id = $1
AND provider <> 'mobile'
`, sourceUserID)
if err != nil {
return fmt.Errorf("delete source non-mobile identities: %w", err)
}
return nil
}
func zeroJSON(value string) string {
if value == "" {
return "{}"
}
return value
}