Add backend foundation and config-driven workbench
This commit is contained in:
310
backend/internal/store/postgres/auth_store.go
Normal file
310
backend/internal/store/postgres/auth_store.go
Normal file
@@ -0,0 +1,310 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"cmr-backend/internal/apperr"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
)
|
||||
|
||||
type SMSCodeMeta struct {
|
||||
ID string
|
||||
CodeHash string
|
||||
ExpiresAt time.Time
|
||||
CooldownUntil time.Time
|
||||
}
|
||||
|
||||
type CreateSMSCodeParams struct {
|
||||
Scene string
|
||||
CountryCode string
|
||||
Mobile string
|
||||
ClientType string
|
||||
DeviceKey string
|
||||
CodeHash string
|
||||
ProviderName string
|
||||
ProviderDebug map[string]any
|
||||
ExpiresAt time.Time
|
||||
CooldownUntil time.Time
|
||||
}
|
||||
|
||||
type CreateMobileIdentityParams struct {
|
||||
UserID string
|
||||
IdentityType string
|
||||
Provider string
|
||||
ProviderSubj string
|
||||
CountryCode string
|
||||
Mobile string
|
||||
}
|
||||
|
||||
type CreateIdentityParams struct {
|
||||
UserID string
|
||||
IdentityType string
|
||||
Provider string
|
||||
ProviderSubj string
|
||||
CountryCode *string
|
||||
Mobile *string
|
||||
ProfileJSON string
|
||||
}
|
||||
|
||||
type CreateRefreshTokenParams struct {
|
||||
UserID string
|
||||
ClientType string
|
||||
DeviceKey string
|
||||
TokenHash string
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
type RefreshTokenRecord struct {
|
||||
ID string
|
||||
UserID string
|
||||
ClientType string
|
||||
DeviceKey *string
|
||||
ExpiresAt time.Time
|
||||
IsRevoked bool
|
||||
}
|
||||
|
||||
func (s *Store) GetLatestSMSCodeMeta(ctx context.Context, countryCode, mobile, clientType, scene string) (*SMSCodeMeta, error) {
|
||||
row := s.pool.QueryRow(ctx, `
|
||||
SELECT id, code_hash, expires_at, cooldown_until
|
||||
FROM auth_sms_codes
|
||||
WHERE country_code = $1 AND mobile = $2 AND client_type = $3 AND scene = $4
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1
|
||||
`, countryCode, mobile, clientType, scene)
|
||||
|
||||
var record SMSCodeMeta
|
||||
err := row.Scan(&record.ID, &record.CodeHash, &record.ExpiresAt, &record.CooldownUntil)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query latest sms code meta: %w", err)
|
||||
}
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
func (s *Store) CreateSMSCode(ctx context.Context, params CreateSMSCodeParams) error {
|
||||
payload, err := json.Marshal(map[string]any{
|
||||
"provider": params.ProviderName,
|
||||
"debug": params.ProviderDebug,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = s.pool.Exec(ctx, `
|
||||
INSERT INTO auth_sms_codes (
|
||||
scene, country_code, mobile, client_type, device_key, code_hash,
|
||||
provider_payload_jsonb, expires_at, cooldown_until
|
||||
)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7::jsonb, $8, $9)
|
||||
`, params.Scene, params.CountryCode, params.Mobile, params.ClientType, params.DeviceKey, params.CodeHash, string(payload), params.ExpiresAt, params.CooldownUntil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert sms code: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) GetLatestValidSMSCode(ctx context.Context, countryCode, mobile, clientType, scene string) (*SMSCodeMeta, error) {
|
||||
row := s.pool.QueryRow(ctx, `
|
||||
SELECT id, code_hash, expires_at, cooldown_until
|
||||
FROM auth_sms_codes
|
||||
WHERE country_code = $1
|
||||
AND mobile = $2
|
||||
AND client_type = $3
|
||||
AND scene = $4
|
||||
AND consumed_at IS NULL
|
||||
AND expires_at > NOW()
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1
|
||||
`, countryCode, mobile, clientType, scene)
|
||||
|
||||
var record SMSCodeMeta
|
||||
err := row.Scan(&record.ID, &record.CodeHash, &record.ExpiresAt, &record.CooldownUntil)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query latest valid sms code: %w", err)
|
||||
}
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
func (s *Store) ConsumeSMSCode(ctx context.Context, tx Tx, id string) (bool, error) {
|
||||
commandTag, err := tx.Exec(ctx, `
|
||||
UPDATE auth_sms_codes
|
||||
SET consumed_at = NOW()
|
||||
WHERE id = $1 AND consumed_at IS NULL
|
||||
`, id)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("consume sms code: %w", err)
|
||||
}
|
||||
return commandTag.RowsAffected() == 1, nil
|
||||
}
|
||||
|
||||
func (s *Store) CreateMobileIdentity(ctx context.Context, tx Tx, params CreateMobileIdentityParams) error {
|
||||
countryCode := params.CountryCode
|
||||
mobile := params.Mobile
|
||||
return s.CreateIdentity(ctx, tx, CreateIdentityParams{
|
||||
UserID: params.UserID,
|
||||
IdentityType: params.IdentityType,
|
||||
Provider: params.Provider,
|
||||
ProviderSubj: params.ProviderSubj,
|
||||
CountryCode: &countryCode,
|
||||
Mobile: &mobile,
|
||||
ProfileJSON: "{}",
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Store) CreateIdentity(ctx context.Context, tx Tx, params CreateIdentityParams) error {
|
||||
_, err := tx.Exec(ctx, `
|
||||
INSERT INTO login_identities (
|
||||
user_id, identity_type, provider, provider_subject, country_code, mobile, status, profile_jsonb
|
||||
)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, 'active', $7::jsonb)
|
||||
ON CONFLICT (provider, provider_subject) DO NOTHING
|
||||
`, params.UserID, params.IdentityType, params.Provider, params.ProviderSubj, params.CountryCode, params.Mobile, zeroJSON(params.ProfileJSON))
|
||||
if err != nil {
|
||||
return fmt.Errorf("create identity: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) FindUserByProviderSubject(ctx context.Context, tx Tx, provider, providerSubject string) (*User, error) {
|
||||
row := tx.QueryRow(ctx, `
|
||||
SELECT u.id, u.user_public_id, u.status, u.nickname, u.avatar_url
|
||||
FROM users u
|
||||
JOIN login_identities li ON li.user_id = u.id
|
||||
WHERE li.provider = $1
|
||||
AND li.provider_subject = $2
|
||||
AND li.status = 'active'
|
||||
LIMIT 1
|
||||
`, provider, providerSubject)
|
||||
return scanUser(row)
|
||||
}
|
||||
|
||||
func (s *Store) CreateRefreshToken(ctx context.Context, tx Tx, params CreateRefreshTokenParams) (string, error) {
|
||||
row := tx.QueryRow(ctx, `
|
||||
INSERT INTO auth_refresh_tokens (user_id, client_type, device_key, token_hash, expires_at)
|
||||
VALUES ($1, $2, NULLIF($3, ''), $4, $5)
|
||||
RETURNING id
|
||||
`, params.UserID, params.ClientType, params.DeviceKey, params.TokenHash, params.ExpiresAt)
|
||||
|
||||
var id string
|
||||
if err := row.Scan(&id); err != nil {
|
||||
return "", fmt.Errorf("create refresh token: %w", err)
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (s *Store) GetRefreshTokenForUpdate(ctx context.Context, tx Tx, tokenHash string) (*RefreshTokenRecord, error) {
|
||||
row := tx.QueryRow(ctx, `
|
||||
SELECT id, user_id, client_type, device_key, expires_at, revoked_at IS NOT NULL
|
||||
FROM auth_refresh_tokens
|
||||
WHERE token_hash = $1
|
||||
FOR UPDATE
|
||||
`, tokenHash)
|
||||
|
||||
var record RefreshTokenRecord
|
||||
err := row.Scan(&record.ID, &record.UserID, &record.ClientType, &record.DeviceKey, &record.ExpiresAt, &record.IsRevoked)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query refresh token for update: %w", err)
|
||||
}
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
func (s *Store) RotateRefreshToken(ctx context.Context, tx Tx, oldTokenID, newTokenID string) error {
|
||||
_, err := tx.Exec(ctx, `
|
||||
UPDATE auth_refresh_tokens
|
||||
SET revoked_at = NOW(), replaced_by_token_id = $2
|
||||
WHERE id = $1
|
||||
`, oldTokenID, newTokenID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("rotate refresh token: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) RevokeRefreshToken(ctx context.Context, tokenHash string) error {
|
||||
commandTag, err := s.pool.Exec(ctx, `
|
||||
UPDATE auth_refresh_tokens
|
||||
SET revoked_at = COALESCE(revoked_at, NOW())
|
||||
WHERE token_hash = $1
|
||||
`, tokenHash)
|
||||
if err != nil {
|
||||
return fmt.Errorf("revoke refresh token: %w", err)
|
||||
}
|
||||
if commandTag.RowsAffected() == 0 {
|
||||
return apperr.New(http.StatusNotFound, "refresh_token_not_found", "refresh token not found")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) RevokeRefreshTokensByUserID(ctx context.Context, tx Tx, userID string) error {
|
||||
_, err := tx.Exec(ctx, `
|
||||
UPDATE auth_refresh_tokens
|
||||
SET revoked_at = COALESCE(revoked_at, NOW())
|
||||
WHERE user_id = $1
|
||||
`, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("revoke refresh tokens by user id: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) TransferNonMobileIdentities(ctx context.Context, tx Tx, sourceUserID, targetUserID string) error {
|
||||
if sourceUserID == targetUserID {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := tx.Exec(ctx, `
|
||||
INSERT INTO login_identities (
|
||||
user_id, identity_type, provider, provider_subject, country_code, mobile, status, profile_jsonb, created_at, updated_at
|
||||
)
|
||||
SELECT
|
||||
$2,
|
||||
li.identity_type,
|
||||
li.provider,
|
||||
li.provider_subject,
|
||||
li.country_code,
|
||||
li.mobile,
|
||||
li.status,
|
||||
li.profile_jsonb,
|
||||
li.created_at,
|
||||
li.updated_at
|
||||
FROM login_identities li
|
||||
WHERE li.user_id = $1
|
||||
AND li.provider <> 'mobile'
|
||||
ON CONFLICT (provider, provider_subject) DO NOTHING
|
||||
`, sourceUserID, targetUserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("copy non-mobile identities: %w", err)
|
||||
}
|
||||
|
||||
_, err = tx.Exec(ctx, `
|
||||
DELETE FROM login_identities
|
||||
WHERE user_id = $1
|
||||
AND provider <> 'mobile'
|
||||
`, sourceUserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete source non-mobile identities: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func zeroJSON(value string) string {
|
||||
if value == "" {
|
||||
return "{}"
|
||||
}
|
||||
return value
|
||||
}
|
||||
Reference in New Issue
Block a user