Skip to content

Add a TurnDiffTracker to create a unified diff for an entire turn #1770

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions codex-rs/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions codex-rs/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ serde_json = "1"
serde_bytes = "0.11"
sha1 = "0.10.6"
shlex = "1.3.0"
similar = "2"
strum_macros = "0.27.2"
thiserror = "2.0.12"
time = { version = "0.3", features = ["formatting", "local-offset", "macros"] }
Expand Down
111 changes: 93 additions & 18 deletions codex-rs/core/src/codex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,13 @@ use crate::protocol::SandboxPolicy;
use crate::protocol::SessionConfiguredEvent;
use crate::protocol::Submission;
use crate::protocol::TaskCompleteEvent;
use crate::protocol::TurnDiffEvent;
use crate::rollout::RolloutRecorder;
use crate::safety::SafetyCheck;
use crate::safety::assess_command_safety;
use crate::safety::assess_safety_for_untrusted_command;
use crate::shell;
use crate::turn_diff_tracker::TurnDiffTracker;
use crate::user_notification::UserNotification;
use crate::util::backoff;

Expand Down Expand Up @@ -362,7 +364,11 @@ impl Session {
}
}

async fn notify_exec_command_begin(&self, exec_command_context: ExecCommandContext) {
async fn on_exec_command_begin(
&self,
turn_diff_tracker: &mut TurnDiffTracker,
exec_command_context: ExecCommandContext,
) {
let ExecCommandContext {
sub_id,
call_id,
Expand All @@ -374,11 +380,15 @@ impl Session {
Some(ApplyPatchCommandContext {
user_explicitly_approved_this_action,
changes,
}) => EventMsg::PatchApplyBegin(PatchApplyBeginEvent {
call_id,
auto_approved: !user_explicitly_approved_this_action,
changes,
}),
}) => {
let _ = turn_diff_tracker.on_patch_begin(&changes);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this doesn't have to return Result, then let _ can go away, of course, but depending on what sort of Err we expect, perhaps we should at least warn!() or error!()?


EventMsg::PatchApplyBegin(PatchApplyBeginEvent {
call_id,
auto_approved: !user_explicitly_approved_this_action,
changes,
})
}
None => EventMsg::ExecCommandBegin(ExecCommandBeginEvent {
call_id,
command: command_for_display.clone(),
Expand All @@ -392,8 +402,10 @@ impl Session {
let _ = self.tx_event.send(event).await;
}

async fn notify_exec_command_end(
#[allow(clippy::too_many_arguments)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should maybe introduce a struct in a follow-up PR.

async fn on_exec_command_end(
&self,
turn_diff_tracker: &mut TurnDiffTracker,
sub_id: &str,
call_id: &str,
output: &ExecToolCallOutput,
Expand Down Expand Up @@ -433,6 +445,20 @@ impl Session {
msg,
};
let _ = self.tx_event.send(event).await;

// If this is an apply_patch, after we emit the end patch, emit a second event
// with the full turn diff if there is one.
if is_apply_patch {
let unified_diff = turn_diff_tracker.get_unified_diff();
if let Ok(Some(unified_diff)) = unified_diff {
let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff });
let event = Event {
id: sub_id.into(),
msg,
};
let _ = self.tx_event.send(event).await;
}
}
}

/// Helper that emits a BackgroundEvent with the given message. This keeps
Expand Down Expand Up @@ -1006,6 +1032,10 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
.await;

let last_agent_message: Option<String>;
// Although from the perspective of codex.rs, TurnDiffTracker has the lifecycle of a Task which contains
// many turns, from the perspective of the user, it is a single turn.
let mut turn_diff_tracker = TurnDiffTracker::new();

loop {
// Note that pending_input would be something like a message the user
// submitted through the UI while the model was running. Though the UI
Expand Down Expand Up @@ -1037,7 +1067,7 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
})
})
.collect();
match run_turn(&sess, sub_id.clone(), turn_input).await {
match run_turn(&sess, &mut turn_diff_tracker, sub_id.clone(), turn_input).await {
Ok(turn_output) => {
let mut items_to_record_in_conversation_history = Vec::<ResponseItem>::new();
let mut responses = Vec::<ResponseInputItem>::new();
Expand Down Expand Up @@ -1163,6 +1193,7 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {

async fn run_turn(
sess: &Session,
turn_diff_tracker: &mut TurnDiffTracker,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll probably want a struct TurnContext or somesuch in the near future.

sub_id: String,
input: Vec<ResponseItem>,
) -> CodexResult<Vec<ProcessedResponseItem>> {
Expand All @@ -1177,7 +1208,7 @@ async fn run_turn(

let mut retries = 0;
loop {
match try_run_turn(sess, &sub_id, &prompt).await {
match try_run_turn(sess, turn_diff_tracker, &sub_id, &prompt).await {
Ok(output) => return Ok(output),
Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted),
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
Expand Down Expand Up @@ -1223,6 +1254,7 @@ struct ProcessedResponseItem {

async fn try_run_turn(
sess: &Session,
turn_diff_tracker: &mut TurnDiffTracker,
sub_id: &str,
prompt: &Prompt,
) -> CodexResult<Vec<ProcessedResponseItem>> {
Expand Down Expand Up @@ -1310,7 +1342,8 @@ async fn try_run_turn(
match event {
ResponseEvent::Created => {}
ResponseEvent::OutputItemDone(item) => {
let response = handle_response_item(sess, sub_id, item.clone()).await?;
let response =
handle_response_item(sess, turn_diff_tracker, sub_id, item.clone()).await?;

output.push(ProcessedResponseItem { item, response });
}
Expand All @@ -1328,6 +1361,16 @@ async fn try_run_turn(
.ok();
}

let unified_diff = turn_diff_tracker.get_unified_diff();
if let Ok(Some(unified_diff)) = unified_diff {
let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff });
let event = Event {
id: sub_id.to_string(),
msg,
};
let _ = sess.tx_event.send(event).await;
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm starting to think that we should do break token_usage; to get out of the loop and then do all of this post-loop stuff below just in case there ever ends up being another way to break out.

It would also eliminate this return statement buried in here (though admittedly it would bury the break statement instead).

return Ok(output);
}
ResponseEvent::OutputTextDelta(delta) => {
Expand Down Expand Up @@ -1432,6 +1475,7 @@ async fn run_compact_task(

async fn handle_response_item(
sess: &Session,
turn_diff_tracker: &mut TurnDiffTracker,
sub_id: &str,
item: ResponseItem,
) -> CodexResult<Option<ResponseInputItem>> {
Expand Down Expand Up @@ -1469,7 +1513,17 @@ async fn handle_response_item(
..
} => {
info!("FunctionCall: {arguments}");
Some(handle_function_call(sess, sub_id.to_string(), name, arguments, call_id).await)
Some(
handle_function_call(
sess,
turn_diff_tracker,
sub_id.to_string(),
name,
arguments,
call_id,
)
.await,
)
}
ResponseItem::LocalShellCall {
id,
Expand Down Expand Up @@ -1504,6 +1558,7 @@ async fn handle_response_item(
handle_container_exec_with_params(
exec_params,
sess,
turn_diff_tracker,
sub_id.to_string(),
effective_call_id,
)
Expand All @@ -1521,6 +1576,7 @@ async fn handle_response_item(

async fn handle_function_call(
sess: &Session,
turn_diff_tracker: &mut TurnDiffTracker,
sub_id: String,
name: String,
arguments: String,
Expand All @@ -1534,7 +1590,8 @@ async fn handle_function_call(
return *output;
}
};
handle_container_exec_with_params(params, sess, sub_id, call_id).await
handle_container_exec_with_params(params, sess, turn_diff_tracker, sub_id, call_id)
.await
}
"update_plan" => handle_update_plan(sess, arguments, sub_id, call_id).await,
_ => {
Expand Down Expand Up @@ -1608,6 +1665,7 @@ fn maybe_run_with_user_profile(params: ExecParams, sess: &Session) -> ExecParams
async fn handle_container_exec_with_params(
params: ExecParams,
sess: &Session,
turn_diff_tracker: &mut TurnDiffTracker,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, what would happen if we wanted to support parallel tool calls at one point. This would be a problem, no?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you think it would break?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because only one tool call could take ownership of TurnDiffTracker.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can solve this then. Maybe a Mutex or a channel or something, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, though also, if we introduce a struct TurnContext as mentioned above, that may also force the move to Mutex. But yes, does not have to be done in this PR.

sub_id: String,
call_id: String,
) -> ResponseInputItem {
Expand Down Expand Up @@ -1755,7 +1813,7 @@ async fn handle_container_exec_with_params(
},
),
};
sess.notify_exec_command_begin(exec_command_context.clone())
sess.on_exec_command_begin(turn_diff_tracker, exec_command_context.clone())
.await;

let params = maybe_run_with_user_profile(params, sess);
Expand All @@ -1782,7 +1840,8 @@ async fn handle_container_exec_with_params(
duration,
} = &output;

sess.notify_exec_command_end(
sess.on_exec_command_end(
turn_diff_tracker,
&sub_id,
&call_id,
&output,
Expand All @@ -1806,7 +1865,15 @@ async fn handle_container_exec_with_params(
}
}
Err(CodexErr::Sandbox(error)) => {
handle_sandbox_error(params, exec_command_context, error, sandbox_type, sess).await
handle_sandbox_error(
turn_diff_tracker,
params,
exec_command_context,
error,
sandbox_type,
sess,
)
.await
}
Err(e) => {
// Handle non-sandbox errors
Expand All @@ -1822,6 +1889,7 @@ async fn handle_container_exec_with_params(
}

async fn handle_sandbox_error(
turn_diff_tracker: &mut TurnDiffTracker,
params: ExecParams,
exec_command_context: ExecCommandContext,
error: SandboxErr,
Expand Down Expand Up @@ -1878,7 +1946,8 @@ async fn handle_sandbox_error(
sess.notify_background_event(&sub_id, "retrying command without sandbox")
.await;

sess.notify_exec_command_begin(exec_command_context).await;
sess.on_exec_command_begin(turn_diff_tracker, exec_command_context)
.await;

// This is an escalated retry; the policy will not be
// examined and the sandbox has been set to `None`.
Expand All @@ -1905,8 +1974,14 @@ async fn handle_sandbox_error(
duration,
} = &retry_output;

sess.notify_exec_command_end(&sub_id, &call_id, &retry_output, is_apply_patch)
.await;
sess.on_exec_command_end(
turn_diff_tracker,
&sub_id,
&call_id,
&retry_output,
is_apply_patch,
)
.await;

let is_success = *exit_code == 0;
let content = format_exec_output(
Expand Down
1 change: 1 addition & 0 deletions codex-rs/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ pub(crate) mod safety;
pub mod seatbelt;
pub mod shell;
pub mod spawn;
pub mod turn_diff_tracker;
mod user_notification;
pub mod util;

Expand Down
7 changes: 7 additions & 0 deletions codex-rs/core/src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,8 @@ pub enum EventMsg {
/// Notification that a patch application has finished.
PatchApplyEnd(PatchApplyEndEvent),

TurnDiff(TurnDiffEvent),

/// Response to GetHistoryEntryRequest.
GetHistoryEntryResponse(GetHistoryEntryResponseEvent),

Expand Down Expand Up @@ -598,6 +600,11 @@ pub struct PatchApplyEndEvent {
pub success: bool,
}

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct TurnDiffEvent {
pub unified_diff: String,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this would be easier to work with programmatically if this were keyed by path, more like changes in PatchApplyBeginEvent. Maybe for a full add or a full delete for an individual file, we still want the unified diff, but it's nice to have added/modified/removed metadata for each path so it's easy to build a compact summary for the diff (maybe with +/- line counts)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What guarantees, if any, can we make about the paths in the unified_diff: will they all be absolute paths?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't currently need this because we parse the whole unified diff. Could we add it if/when we need it?

}

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct GetHistoryEntryResponseEvent {
pub offset: usize,
Expand Down
Loading
Loading