Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions codex-rs/tui/src/local_chatgpt_auth.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#![cfg(test)]

use std::path::Path;

use codex_app_server_protocol::AuthMode;
Expand Down
170 changes: 137 additions & 33 deletions codex-rs/tui/src/onboarding/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ use codex_app_server_protocol::CancelLoginAccountParams;
use codex_app_server_protocol::ClientRequest;
use codex_app_server_protocol::LoginAccountParams;
use codex_app_server_protocol::LoginAccountResponse;
use codex_login::AuthCredentialsStoreMode;
use codex_login::DeviceCode;
use codex_login::read_openai_api_key_from_env;
use crossterm::event::KeyCode;
use crossterm::event::KeyEvent;
Expand All @@ -33,6 +31,7 @@ use ratatui::widgets::WidgetRef;
use ratatui::widgets::Wrap;

use codex_protocol::config_types::ForcedLoginMethod;
use std::cell::Cell;
use std::sync::Arc;
use std::sync::RwLock;
use uuid::Uuid;
Expand Down Expand Up @@ -77,8 +76,6 @@ pub(crate) fn mark_url_hyperlink(buf: &mut Buffer, area: Rect, url: &str) {
}
}
}
use std::path::PathBuf;
use tokio::sync::Notify;

use super::onboarding_screen::StepState;

Expand Down Expand Up @@ -108,6 +105,20 @@ fn onboarding_request_id() -> codex_app_server_protocol::RequestId {
codex_app_server_protocol::RequestId::String(Uuid::new_v4().to_string())
}

pub(super) async fn cancel_login_attempt(
request_handle: &AppServerRequestHandle,
login_id: String,
) {
let _ = request_handle
.request_typed::<codex_app_server_protocol::CancelLoginAccountResponse>(
ClientRequest::CancelLoginAccount {
request_id: onboarding_request_id(),
params: CancelLoginAccountParams { login_id },
},
)
.await;
}

#[derive(Clone, Default)]
pub(crate) struct ApiKeyInputState {
value: String,
Expand All @@ -123,8 +134,49 @@ pub(crate) struct ContinueInBrowserState {

#[derive(Clone)]
pub(crate) struct ContinueWithDeviceCodeState {
device_code: Option<DeviceCode>,
cancel: Option<Arc<Notify>>,
request_id: String,
login_id: Option<String>,
verification_url: Option<String>,
user_code: Option<String>,
}

impl ContinueWithDeviceCodeState {
pub(crate) fn pending(request_id: String) -> Self {
Self {
request_id,
login_id: None,
verification_url: None,
user_code: None,
}
}

pub(crate) fn ready(
request_id: String,
login_id: String,
verification_url: String,
user_code: String,
) -> Self {
Self {
request_id,
login_id: Some(login_id),
verification_url: Some(verification_url),
user_code: Some(user_code),
}
}

pub(crate) fn login_id(&self) -> Option<&str> {
self.login_id.as_deref()
}

pub(crate) fn is_showing_copyable_auth(&self) -> bool {
self.verification_url
.as_deref()
.is_some_and(|url| !url.is_empty())
&& self
.user_code
.as_deref()
.is_some_and(|user_code| !user_code.is_empty())
}
}

impl KeyboardHandler for AuthModeWidget {
Expand Down Expand Up @@ -181,36 +233,41 @@ pub(crate) struct AuthModeWidget {
pub highlighted_mode: SignInOption,
pub error: Arc<RwLock<Option<String>>>,
pub sign_in_state: Arc<RwLock<SignInState>>,
pub codex_home: PathBuf,
pub cli_auth_credentials_store_mode: AuthCredentialsStoreMode,
pub login_status: LoginStatus,
pub app_server_request_handle: AppServerRequestHandle,
pub forced_chatgpt_workspace_id: Option<String>,
pub forced_login_method: Option<ForcedLoginMethod>,
pub animations_enabled: bool,
pub animations_suppressed: Cell<bool>,
}

impl AuthModeWidget {
pub(crate) fn set_animations_suppressed(&self, suppressed: bool) {
self.animations_suppressed.set(suppressed);
}

pub(crate) fn should_suppress_animations(&self) -> bool {
matches!(
&*self.sign_in_state.read().unwrap(),
SignInState::ChatGptContinueInBrowser(_) | SignInState::ChatGptDeviceCode(_)
)
}

pub(crate) fn cancel_active_attempt(&self) {
let mut sign_in_state = self.sign_in_state.write().unwrap();
match &*sign_in_state {
SignInState::ChatGptContinueInBrowser(state) => {
let request_handle = self.app_server_request_handle.clone();
let login_id = state.login_id.clone();
tokio::spawn(async move {
let _ = request_handle
.request_typed::<codex_app_server_protocol::CancelLoginAccountResponse>(
ClientRequest::CancelLoginAccount {
request_id: onboarding_request_id(),
params: CancelLoginAccountParams { login_id },
},
)
.await;
cancel_login_attempt(&request_handle, login_id).await;
});
}
SignInState::ChatGptDeviceCode(state) => {
if let Some(cancel) = &state.cancel {
cancel.notify_one();
if let Some(login_id) = state.login_id().map(str::to_owned) {
let request_handle = self.app_server_request_handle.clone();
tokio::spawn(async move {
cancel_login_attempt(&request_handle, login_id).await;
});
}
}
_ => return,
Expand Down Expand Up @@ -415,7 +472,7 @@ impl AuthModeWidget {

fn render_continue_in_browser(&self, area: Rect, buf: &mut Buffer) {
let mut spans = vec![" ".into()];
if self.animations_enabled {
if self.animations_enabled && !self.animations_suppressed.get() {
// Schedule a follow-up frame to keep the shimmer animation going.
self.request_frame
.schedule_frame_in(std::time::Duration::from_millis(100));
Expand Down Expand Up @@ -814,6 +871,9 @@ impl AuthModeWidget {
let is_matching_login = matches!(
&*guard,
SignInState::ChatGptContinueInBrowser(state) if state.login_id == login_id
) || matches!(
&*guard,
SignInState::ChatGptDeviceCode(state) if state.login_id() == Some(login_id.as_str())
);
drop(guard);
if !is_matching_login {
Expand Down Expand Up @@ -901,6 +961,7 @@ mod tests {
use codex_arg0::Arg0DispatchPaths;
use codex_cloud_requirements::cloud_requirements_loader_for_storage;
use codex_core::config::ConfigBuilder;
use codex_login::AuthCredentialsStoreMode;

use codex_protocol::protocol::SessionSource;
use pretty_assertions::assert_eq;
Expand Down Expand Up @@ -943,13 +1004,11 @@ mod tests {
highlighted_mode: SignInOption::ChatGpt,
error: Arc::new(RwLock::new(None)),
sign_in_state: Arc::new(RwLock::new(SignInState::PickMode)),
codex_home: codex_home_path.clone(),
cli_auth_credentials_store_mode: AuthCredentialsStoreMode::File,
login_status: LoginStatus::NotAuthenticated,
app_server_request_handle: AppServerRequestHandle::InProcess(client.request_handle()),
forced_chatgpt_workspace_id: None,
forced_login_method: Some(ForcedLoginMethod::Chatgpt),
animations_enabled: true,
animations_suppressed: std::cell::Cell::new(false),
};
(widget, codex_home)
}
Expand Down Expand Up @@ -1023,13 +1082,14 @@ mod tests {
#[tokio::test]
async fn cancel_active_attempt_notifies_device_code_login() {
let (widget, _tmp) = widget_forced_chatgpt().await;
let cancel = Arc::new(Notify::new());
*widget.error.write().unwrap() = Some("still logging in".to_string());
*widget.sign_in_state.write().unwrap() =
SignInState::ChatGptDeviceCode(ContinueWithDeviceCodeState {
device_code: None,
cancel: Some(cancel.clone()),
});
SignInState::ChatGptDeviceCode(ContinueWithDeviceCodeState::ready(
"request-1".to_string(),
"login-1".to_string(),
"https://chatgpt.com/device".to_string(),
"ABCD-EFGH".to_string(),
));

widget.cancel_active_attempt();

Expand All @@ -1038,11 +1098,6 @@ mod tests {
&*widget.sign_in_state.read().unwrap(),
SignInState::PickMode
));
assert!(
tokio::time::timeout(std::time::Duration::from_millis(50), cancel.notified())
.await
.is_ok()
);
}

/// Collects all buffer cell symbols that contain the OSC 8 open sequence
Expand Down Expand Up @@ -1085,6 +1140,55 @@ mod tests {
assert_eq!(found, url, "OSC 8 hyperlink should cover the full URL");
}

#[test]
fn auth_widget_suppresses_animations_when_device_code_is_visible() {
let runtime = tokio::runtime::Runtime::new().unwrap();
let (widget, _tmp) = runtime.block_on(widget_forced_chatgpt());
*widget.sign_in_state.write().unwrap() =
SignInState::ChatGptDeviceCode(ContinueWithDeviceCodeState::ready(
"request-1".to_string(),
"login-1".to_string(),
"https://chatgpt.com/device".to_string(),
"ABCD-EFGH".to_string(),
));

assert_eq!(widget.should_suppress_animations(), true);
}

#[test]
fn auth_widget_suppresses_animations_while_requesting_device_code() {
let runtime = tokio::runtime::Runtime::new().unwrap();
let (widget, _tmp) = runtime.block_on(widget_forced_chatgpt());
*widget.sign_in_state.write().unwrap() = SignInState::ChatGptDeviceCode(
ContinueWithDeviceCodeState::pending("request-1".to_string()),
);

assert_eq!(widget.should_suppress_animations(), true);
}

#[tokio::test]
async fn device_code_login_completion_advances_to_success_message() {
let (mut widget, _tmp) = widget_forced_chatgpt().await;
*widget.sign_in_state.write().unwrap() =
SignInState::ChatGptDeviceCode(ContinueWithDeviceCodeState::ready(
"request-1".to_string(),
"login-1".to_string(),
"https://chatgpt.com/device".to_string(),
"ABCD-EFGH".to_string(),
));

widget.on_account_login_completed(AccountLoginCompletedNotification {
login_id: Some("login-1".to_string()),
success: true,
error: None,
});

assert!(matches!(
&*widget.sign_in_state.read().unwrap(),
SignInState::ChatGptSuccessMessage
));
}

#[test]
fn mark_url_hyperlink_wraps_cyan_underlined_cells() {
let url = "https://example.com";
Expand Down
Loading
Loading