Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 99 additions & 44 deletions codex-rs/core/src/rollout/recorder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,11 @@ pub enum RolloutRecorderParams {
enum RolloutCmd {
AddItems(Vec<RolloutItem>),
Persist {
ack: oneshot::Sender<()>,
ack: oneshot::Sender<std::io::Result<()>>,
},
/// Ensure all prior writes are processed; respond when flushed.
Flush {
ack: oneshot::Sender<()>,
ack: oneshot::Sender<std::io::Result<()>>,
},
Shutdown {
ack: oneshot::Sender<()>,
Expand Down Expand Up @@ -453,7 +453,7 @@ impl RolloutRecorder {
// writes. Using `tokio::fs::File` keeps everything on the async I/O
// driver instead of blocking the runtime.
tokio::task::spawn(rollout_writer(
file,
file.map(JsonlWriter::new),
deferred_log_file_info,
rx,
meta,
Expand Down Expand Up @@ -514,6 +514,7 @@ impl RolloutRecorder {
.map_err(|e| IoError::other(format!("failed to queue rollout persist: {e}")))?;
rx.await
.map_err(|e| IoError::other(format!("failed waiting for rollout persist: {e}")))
.and_then(|result| result)
}

/// Flush all queued writes and wait until they are committed by the writer task.
Expand All @@ -525,6 +526,7 @@ impl RolloutRecorder {
.map_err(|e| IoError::other(format!("failed to queue rollout flush: {e}")))?;
rx.await
.map_err(|e| IoError::other(format!("failed waiting for rollout flush: {e}")))
.and_then(|result| result)
}

pub(crate) async fn load_rollout_items(
Expand Down Expand Up @@ -647,6 +649,7 @@ fn truncate_fs_page(
page
}

#[derive(Clone)]
struct LogFileInfo {
/// Full path to the rollout file.
path: PathBuf,
Expand Down Expand Up @@ -706,7 +709,7 @@ fn open_log_file(path: &Path) -> std::io::Result<File> {

#[allow(clippy::too_many_arguments)]
async fn rollout_writer(
file: Option<tokio::fs::File>,
file: Option<JsonlWriter>,
mut deferred_log_file_info: Option<LogFileInfo>,
mut rx: mpsc::Receiver<RolloutCmd>,
mut meta: Option<SessionMeta>,
Expand All @@ -717,8 +720,8 @@ async fn rollout_writer(
default_provider: String,
generate_memories: bool,
) -> std::io::Result<()> {
let mut writer = file.map(|file| JsonlWriter { file });
let mut buffered_items = Vec::<RolloutItem>::new();
let mut writer = file;
let mut pending_items = Vec::<RolloutItem>::new();
if let Some(builder) = state_builder.as_mut() {
builder.rollout_path = rollout_path.clone();
}
Expand Down Expand Up @@ -749,35 +752,49 @@ async fn rollout_writer(
continue;
}

pending_items.extend(items);

if writer.is_none() {
buffered_items.extend(items);
continue;
if meta.is_some() {
continue;
}
match reopen_rollout_writer(&rollout_path) {
Ok(reopened_writer) => writer = Some(reopened_writer),
Err(err) => {
warn!("rollout reopen failed; keeping pending items queued: {err}");
continue;
}
}
}

write_and_reconcile_items(
writer.as_mut(),
items.as_slice(),
pending_items.as_slice(),
&rollout_path,
state_db_ctx.as_deref(),
state_builder.as_ref(),
default_provider.as_str(),
)
.await?;
.await
.map(|()| pending_items.clear())
.unwrap_or_else(|err| {
writer = None;
warn!("rollout write failed; queued items will retry after reopen: {err}");
});
}
RolloutCmd::Persist { ack } => {
if writer.is_none() {
if writer.is_none() || meta.is_some() || !pending_items.is_empty() {
let result = async {
let Some(log_file_info) = deferred_log_file_info.take() else {
return Err(IoError::other(
"deferred rollout recorder missing log file metadata",
));
};
let file = open_log_file(log_file_info.path.as_path())?;
writer = Some(JsonlWriter {
file: tokio::fs::File::from_std(file),
});

if let Some(session_meta) = meta.take() {
if writer.is_none() {
let writer_path = deferred_log_file_info
.as_ref()
.map(|log_file_info| log_file_info.path.as_path())
.unwrap_or(rollout_path.as_path());
let file = open_log_file(writer_path)?;
writer = Some(JsonlWriter::new(tokio::fs::File::from_std(file)));
}

if let Some(session_meta) = meta.clone() {
write_session_meta(
writer.as_mut(),
session_meta,
Expand All @@ -789,41 +806,47 @@ async fn rollout_writer(
generate_memories,
)
.await?;
meta = None;
}

if !buffered_items.is_empty() {
if !pending_items.is_empty() {
write_and_reconcile_items(
writer.as_mut(),
buffered_items.as_slice(),
pending_items.as_slice(),
&rollout_path,
state_db_ctx.as_deref(),
state_builder.as_ref(),
default_provider.as_str(),
)
.await?;
buffered_items.clear();
pending_items.clear();
}

deferred_log_file_info = None;
Ok(())
}
.await;

if let Err(err) = result {
let _ = ack.send(());
return Err(err);
writer = None;
warn!("rollout persist failed; keeping writer alive: {err}");
let _ = ack.send(Err(err));
continue;
}
}
let _ = ack.send(());
let _ = ack.send(Ok(()));
}
RolloutCmd::Flush { ack } => {
// Deferred fresh threads may not have an initialized file yet.
if let Some(writer) = writer.as_mut()
&& let Err(e) = writer.file.flush().await
if let Some(current_writer) = writer.as_mut()
&& let Err(e) = current_writer.file.flush().await
{
let _ = ack.send(());
return Err(e);
writer = None;
warn!("rollout flush failed; keeping writer alive: {e}");
let _ = ack.send(Err(e));
continue;
}
let _ = ack.send(());
let _ = ack.send(Ok(()));
}
RolloutCmd::Shutdown { ack } => {
let _ = ack.send(());
Expand Down Expand Up @@ -879,9 +902,7 @@ async fn write_and_reconcile_items(
default_provider: &str,
) -> std::io::Result<()> {
if let Some(writer) = writer.as_mut() {
for item in items {
writer.write_rollout_item(item).await?;
}
writer.write_rollout_items(items).await?;
}
sync_thread_state_after_write(
state_db_ctx,
Expand Down Expand Up @@ -956,7 +977,42 @@ struct RolloutLineRef<'a> {
}

impl JsonlWriter {
fn new(file: tokio::fs::File) -> Self {
Self { file }
}

async fn write_rollout_item(&mut self, rollout_item: &RolloutItem) -> std::io::Result<()> {
self.write_rollout_items(std::slice::from_ref(rollout_item))
.await
}

async fn write_rollout_items(&mut self, rollout_items: &[RolloutItem]) -> std::io::Result<()> {
let file_len_before_write = self.file.metadata().await?.len();
let mut json = String::new();
for rollout_item in rollout_items {
json.push_str(&Self::rollout_line_json(rollout_item)?);
json.push('\n');
}

let result = async {
self.file.write_all(json.as_bytes()).await?;
self.file.flush().await
}
.await;

if let Err(err) = result {
if let Err(truncate_err) = self.file.set_len(file_len_before_write).await {
return Err(IoError::other(format!(
"failed to roll back partial rollout write after {err}: {truncate_err}"
)));
}
return Err(err);
}

Ok(())
}

fn rollout_line_json(rollout_item: &RolloutItem) -> std::io::Result<String> {
let timestamp_format: &[FormatItem] = format_description!(
"[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:3]Z"
);
Expand All @@ -968,17 +1024,16 @@ impl JsonlWriter {
timestamp,
item: rollout_item,
};
self.write_line(&line).await
}
async fn write_line(&mut self, item: &impl serde::Serialize) -> std::io::Result<()> {
let mut json = serde_json::to_string(item)?;
json.push('\n');
self.file.write_all(json.as_bytes()).await?;
self.file.flush().await?;
Ok(())
serde_json::to_string(&line).map_err(IoError::other)
}
}

fn reopen_rollout_writer(rollout_path: &Path) -> std::io::Result<JsonlWriter> {
open_log_file(rollout_path)
.map(tokio::fs::File::from_std)
.map(JsonlWriter::new)
}

impl From<codex_state::ThreadsPage> for ThreadsPage {
fn from(db_page: codex_state::ThreadsPage) -> Self {
let items = db_page
Expand Down
Loading
Loading