Files
cmr-mini/backend/internal/store/postgres/session_store.go

321 lines
7.3 KiB
Go

package postgres
import (
"context"
"errors"
"fmt"
"time"
"github.com/jackc/pgx/v5"
)
type Session struct {
ID string
SessionPublicID string
UserID string
EventID string
EventReleaseID string
ReleasePublicID *string
ConfigLabel *string
ManifestURL *string
ManifestChecksum *string
DeviceKey string
ClientType string
AssignmentMode *string
VariantID *string
VariantName *string
RouteCode *string
Status string
SessionTokenHash string
SessionTokenExpiresAt time.Time
LaunchedAt time.Time
StartedAt *time.Time
EndedAt *time.Time
EventPublicID *string
EventDisplayName *string
}
type FinishSessionParams struct {
SessionID string
Status string
}
func (s *Store) GetSessionByPublicID(ctx context.Context, sessionPublicID string) (*Session, error) {
row := s.pool.QueryRow(ctx, `
SELECT
gs.id,
gs.session_public_id,
gs.user_id,
gs.event_id,
gs.event_release_id,
er.release_public_id,
er.config_label,
er.manifest_url,
er.manifest_checksum_sha256,
gs.device_key,
gs.client_type,
gs.assignment_mode,
gs.variant_id,
gs.variant_name,
gs.route_code,
gs.status,
gs.session_token_hash,
gs.session_token_expires_at,
gs.launched_at,
gs.started_at,
gs.ended_at,
e.event_public_id,
e.display_name
FROM game_sessions gs
JOIN events e ON e.id = gs.event_id
JOIN event_releases er ON er.id = gs.event_release_id
WHERE gs.session_public_id = $1
LIMIT 1
`, sessionPublicID)
return scanSession(row)
}
func (s *Store) GetSessionByPublicIDForUpdate(ctx context.Context, tx Tx, sessionPublicID string) (*Session, error) {
row := tx.QueryRow(ctx, `
SELECT
gs.id,
gs.session_public_id,
gs.user_id,
gs.event_id,
gs.event_release_id,
er.release_public_id,
er.config_label,
er.manifest_url,
er.manifest_checksum_sha256,
gs.device_key,
gs.client_type,
gs.assignment_mode,
gs.variant_id,
gs.variant_name,
gs.route_code,
gs.status,
gs.session_token_hash,
gs.session_token_expires_at,
gs.launched_at,
gs.started_at,
gs.ended_at,
e.event_public_id,
e.display_name
FROM game_sessions gs
JOIN events e ON e.id = gs.event_id
JOIN event_releases er ON er.id = gs.event_release_id
WHERE gs.session_public_id = $1
FOR UPDATE
`, sessionPublicID)
return scanSession(row)
}
func (s *Store) ListSessionsByUserID(ctx context.Context, userID string, limit int) ([]Session, error) {
if limit <= 0 || limit > 100 {
limit = 20
}
rows, err := s.pool.Query(ctx, `
SELECT
gs.id,
gs.session_public_id,
gs.user_id,
gs.event_id,
gs.event_release_id,
er.release_public_id,
er.config_label,
er.manifest_url,
er.manifest_checksum_sha256,
gs.device_key,
gs.client_type,
gs.assignment_mode,
gs.variant_id,
gs.variant_name,
gs.route_code,
gs.status,
gs.session_token_hash,
gs.session_token_expires_at,
gs.launched_at,
gs.started_at,
gs.ended_at,
e.event_public_id,
e.display_name
FROM game_sessions gs
JOIN events e ON e.id = gs.event_id
JOIN event_releases er ON er.id = gs.event_release_id
WHERE gs.user_id = $1
ORDER BY gs.created_at DESC
LIMIT $2
`, userID, limit)
if err != nil {
return nil, fmt.Errorf("list sessions by user id: %w", err)
}
defer rows.Close()
var sessions []Session
for rows.Next() {
session, err := scanSessionFromRows(rows)
if err != nil {
return nil, err
}
sessions = append(sessions, *session)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate sessions by user id: %w", err)
}
return sessions, nil
}
func (s *Store) ListSessionsByUserAndEvent(ctx context.Context, userID, eventID string, limit int) ([]Session, error) {
if limit <= 0 || limit > 100 {
limit = 20
}
rows, err := s.pool.Query(ctx, `
SELECT
gs.id,
gs.session_public_id,
gs.user_id,
gs.event_id,
gs.event_release_id,
er.release_public_id,
er.config_label,
er.manifest_url,
er.manifest_checksum_sha256,
gs.device_key,
gs.client_type,
gs.assignment_mode,
gs.variant_id,
gs.variant_name,
gs.route_code,
gs.status,
gs.session_token_hash,
gs.session_token_expires_at,
gs.launched_at,
gs.started_at,
gs.ended_at,
e.event_public_id,
e.display_name
FROM game_sessions gs
JOIN events e ON e.id = gs.event_id
JOIN event_releases er ON er.id = gs.event_release_id
WHERE gs.user_id = $1
AND gs.event_id = $2
ORDER BY gs.created_at DESC
LIMIT $3
`, userID, eventID, limit)
if err != nil {
return nil, fmt.Errorf("list sessions by user and event: %w", err)
}
defer rows.Close()
var sessions []Session
for rows.Next() {
session, err := scanSessionFromRows(rows)
if err != nil {
return nil, err
}
sessions = append(sessions, *session)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate sessions by user and event: %w", err)
}
return sessions, nil
}
func (s *Store) StartSession(ctx context.Context, tx Tx, sessionID string) error {
_, err := tx.Exec(ctx, `
UPDATE game_sessions
SET status = CASE WHEN status = 'launched' THEN 'running' ELSE status END,
started_at = COALESCE(started_at, NOW())
WHERE id = $1
`, sessionID)
if err != nil {
return fmt.Errorf("start session: %w", err)
}
return nil
}
func (s *Store) FinishSession(ctx context.Context, tx Tx, params FinishSessionParams) error {
_, err := tx.Exec(ctx, `
UPDATE game_sessions
SET status = $2,
started_at = COALESCE(started_at, NOW()),
ended_at = COALESCE(ended_at, NOW())
WHERE id = $1
`, params.SessionID, params.Status)
if err != nil {
return fmt.Errorf("finish session: %w", err)
}
return nil
}
func scanSession(row pgx.Row) (*Session, error) {
var session Session
err := row.Scan(
&session.ID,
&session.SessionPublicID,
&session.UserID,
&session.EventID,
&session.EventReleaseID,
&session.ReleasePublicID,
&session.ConfigLabel,
&session.ManifestURL,
&session.ManifestChecksum,
&session.DeviceKey,
&session.ClientType,
&session.AssignmentMode,
&session.VariantID,
&session.VariantName,
&session.RouteCode,
&session.Status,
&session.SessionTokenHash,
&session.SessionTokenExpiresAt,
&session.LaunchedAt,
&session.StartedAt,
&session.EndedAt,
&session.EventPublicID,
&session.EventDisplayName,
)
if errors.Is(err, pgx.ErrNoRows) {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("scan session: %w", err)
}
return &session, nil
}
func scanSessionFromRows(rows pgx.Rows) (*Session, error) {
var session Session
err := rows.Scan(
&session.ID,
&session.SessionPublicID,
&session.UserID,
&session.EventID,
&session.EventReleaseID,
&session.ReleasePublicID,
&session.ConfigLabel,
&session.ManifestURL,
&session.ManifestChecksum,
&session.DeviceKey,
&session.ClientType,
&session.AssignmentMode,
&session.VariantID,
&session.VariantName,
&session.RouteCode,
&session.Status,
&session.SessionTokenHash,
&session.SessionTokenExpiresAt,
&session.LaunchedAt,
&session.StartedAt,
&session.EndedAt,
&session.EventPublicID,
&session.EventDisplayName,
)
if err != nil {
return nil, fmt.Errorf("scan session row: %w", err)
}
return &session, nil
}