diff --git a/goodnews/api.py b/goodnews/api.py index 1544396..7338ae9 100644 --- a/goodnews/api.py +++ b/goodnews/api.py @@ -432,6 +432,8 @@ def create_app() -> FastAPI: try: tokens = oauth_google.exchange_code(code, _google_redirect_uri(), verifier) info = oauth_google.verify_id_token(tokens["id_token"]) + if not info.get("picture") and tokens.get("access_token"): + info["picture"] = oauth_google.fetch_userinfo(tokens["access_token"]).get("picture") except Exception: return fail with get_conn() as conn: diff --git a/goodnews/auth.py b/goodnews/auth.py index d3032e8..278f238 100644 --- a/goodnews/auth.py +++ b/goodnews/auth.py @@ -69,27 +69,29 @@ def find_or_create_user( (provider, provider_subject), ).fetchone() if existing: - return existing["user_id"] + user_id = existing["user_id"] + else: + user = conn.execute("SELECT id FROM users WHERE email = ?", (email,)).fetchone() + if user: + user_id = user["id"] + else: + user_id = conn.execute( + "INSERT INTO users (email, display_name, avatar_url) VALUES (?, ?, ?)", + (email, display_name, avatar_url), + ).lastrowid + conn.execute( + "INSERT OR IGNORE INTO identities (user_id, provider, provider_subject) VALUES (?, ?, ?)", + (user_id, provider, provider_subject), + ) - user = conn.execute("SELECT id FROM users WHERE email = ?", (email,)).fetchone() - if user: - user_id = user["id"] - # Fill display name if missing; refresh avatar whenever the provider gives one. + # Always refresh provider-supplied profile bits (even for a returning identity): + # fill the name if missing, and keep the avatar current when the provider sends one. + if display_name or avatar_url: conn.execute( "UPDATE users SET display_name = COALESCE(display_name, ?), " "avatar_url = COALESCE(?, avatar_url), updated_at = CURRENT_TIMESTAMP WHERE id = ?", (display_name, avatar_url, user_id), ) - else: - user_id = conn.execute( - "INSERT INTO users (email, display_name, avatar_url) VALUES (?, ?, ?)", - (email, display_name, avatar_url), - ).lastrowid - - conn.execute( - "INSERT OR IGNORE INTO identities (user_id, provider, provider_subject) VALUES (?, ?, ?)", - (user_id, provider, provider_subject), - ) return user_id diff --git a/goodnews/oauth_google.py b/goodnews/oauth_google.py index a0fdc24..9fa1d81 100644 --- a/goodnews/oauth_google.py +++ b/goodnews/oauth_google.py @@ -19,6 +19,7 @@ from datetime import datetime, timezone AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth" TOKEN_URL = "https://oauth2.googleapis.com/token" +USERINFO_URL = "https://openidconnect.googleapis.com/v1/userinfo" _VALID_ISS = {"accounts.google.com", "https://accounts.google.com"} @@ -75,6 +76,19 @@ def exchange_code(code: str, redirect_uri: str, code_verifier: str) -> dict: return json.loads(response.read()) +def fetch_userinfo(access_token: str) -> dict: + """Call the OIDC userinfo endpoint — a reliable source for the profile picture + if the ID token happens to omit it. Best-effort; returns {} on any error.""" + request = urllib.request.Request( + USERINFO_URL, headers={"Authorization": f"Bearer {access_token}", "Accept": "application/json"} + ) + try: + with urllib.request.urlopen(request, timeout=10) as response: + return json.loads(response.read()) + except Exception: + return {} + + def _decode_jwt_payload(token: str) -> dict: parts = token.split(".") if len(parts) != 3: diff --git a/tests/test_auth.py b/tests/test_auth.py index d1ca21e..6bcf9b3 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -33,6 +33,15 @@ def test_find_or_create_links_by_email_and_dedupes_identity(): assert auth.get_user(c, uid)["display_name"] == "A" +def test_returning_identity_refreshes_avatar(): + c = _db() + uid = auth.find_or_create_user(c, "a@b.com", "google", "gsub", display_name="A", avatar_url="http://pic/1") + assert auth.get_user(c, uid)["avatar_url"] == "http://pic/1" + # a repeat sign-in with the SAME identity must still refresh the picture + assert auth.find_or_create_user(c, "a@b.com", "google", "gsub", avatar_url="http://pic/2") == uid + assert auth.get_user(c, uid)["avatar_url"] == "http://pic/2" + + def test_magic_link_token_single_use(): c = _db() raw = auth.create_login_token(c, "a@b.com")