Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 42 additions & 17 deletions core/translate/insert.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use crate::{
error::{SQLITE_CONSTRAINT_NOTNULL, SQLITE_CONSTRAINT_PRIMARYKEY, SQLITE_CONSTRAINT_UNIQUE},
schema::{self, BTreeTable, ColDef, Column, Index, IndexColumn, ResolvedFkRef, Table},
schema::{
self, BTreeTable, ColDef, Column, Index, IndexColumn, ResolvedFkRef, Table,
SQLITE_SEQUENCE_TABLE_NAME,
},
sync::Arc,
translate::{
emitter::{
Expand Down Expand Up @@ -1233,16 +1236,44 @@ fn resolve_upserts(
Ok(())
}

fn get_valid_sqlite_sequence_table(
resolver: &Resolver,
database_id: usize,
) -> Result<Arc<BTreeTable>> {
let Some(seq_table) = resolver.with_schema(database_id, |s| {
s.get_btree_table(SQLITE_SEQUENCE_TABLE_NAME)
}) else {
crate::bail_corrupt_error!("missing sqlite_sequence table");
};

if !seq_table.has_rowid {
crate::bail_corrupt_error!("malformed sqlite_sequence: table must have rowid");
}

if seq_table.columns.len() != 2 {
crate::bail_corrupt_error!(
"malformed sqlite_sequence: expected 2 columns, got {}",
seq_table.columns.len()
);
}

let col0_name = seq_table.columns[0].name.as_deref();
let col1_name = seq_table.columns[1].name.as_deref();
if !matches!(col0_name, Some(name) if name.eq_ignore_ascii_case("name"))
|| !matches!(col1_name, Some(name) if name.eq_ignore_ascii_case("seq"))
{
crate::bail_corrupt_error!("malformed sqlite_sequence: expected columns (name, seq)");
}

Ok(seq_table)
}

fn init_autoincrement(
program: &mut ProgramBuilder,
ctx: &mut InsertEmitCtx,
resolver: &Resolver,
) -> Result<()> {
let seq_table = resolver
.with_schema(ctx.database_id, |s| s.get_btree_table("sqlite_sequence"))
.ok_or_else(|| {
crate::error::LimboError::InternalError("sqlite_sequence table not found".to_string())
})?;
let seq_table = get_valid_sqlite_sequence_table(resolver, ctx.database_id)?;
let seq_cursor_id = program.alloc_cursor_id(CursorType::BTreeTable(seq_table.clone()));
program.emit_insn(Insn::OpenWrite {
cursor_id: seq_cursor_id,
Expand Down Expand Up @@ -2668,11 +2699,7 @@ fn ensure_sequence_initialized(
table: &schema::BTreeTable,
database_id: usize,
) -> Result<()> {
let seq_table = resolver
.with_schema(database_id, |s| s.get_btree_table("sqlite_sequence"))
.ok_or_else(|| {
crate::error::LimboError::InternalError("sqlite_sequence table not found".to_string())
})?;
let seq_table = get_valid_sqlite_sequence_table(resolver, database_id)?;

let seq_cursor_id = program.alloc_cursor_id(CursorType::BTreeTable(seq_table.clone()));

Expand Down Expand Up @@ -2760,7 +2787,7 @@ fn ensure_sequence_initialized(
key_reg: new_rowid_reg,
record_reg,
flag: InsertFlags::new(),
table_name: "sqlite_sequence".to_string(),
table_name: SQLITE_SEQUENCE_TABLE_NAME.to_string(),
});

program.preassign_label_to_next_insn(entry_exists_label);
Expand Down Expand Up @@ -3033,9 +3060,7 @@ fn emit_update_sqlite_sequence(
extra_amount: 0,
});

let seq_table = resolver
.with_schema(database_id, |s| s.get_btree_table("sqlite_sequence"))
.unwrap();
let seq_table = get_valid_sqlite_sequence_table(resolver, database_id)?;
let affinity_str = seq_table
.columns
.iter()
Expand Down Expand Up @@ -3066,7 +3091,7 @@ fn emit_update_sqlite_sequence(
key_reg: r_seq_rowid,
record_reg,
flag: InsertFlags::new(),
table_name: "sqlite_sequence".to_string(),
table_name: SQLITE_SEQUENCE_TABLE_NAME.to_string(),
});
program.emit_insn(Insn::Goto {
target_pc: end_update_label,
Expand All @@ -3078,7 +3103,7 @@ fn emit_update_sqlite_sequence(
key_reg: r_seq_rowid,
record_reg,
flag: InsertFlags(turso_parser::ast::ResolveType::Replace.bit_value() as u8),
table_name: "sqlite_sequence".to_string(),
table_name: SQLITE_SEQUENCE_TABLE_NAME.to_string(),
});

program.preassign_label_to_next_insn(end_update_label);
Expand Down
79 changes: 79 additions & 0 deletions core/translate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ pub fn translate_inner(
mod tests {
use super::*;
use crate::io::MemoryIO;
use crate::schema::{BTreeTable, Table, SQLITE_SEQUENCE_TABLE_NAME};
use crate::Database;

/// Verify that REGEXP produces the correct error when no regexp function is registered.
Expand Down Expand Up @@ -414,4 +415,82 @@ mod tests {
"expected 'no such function: regexp', got: {err}"
);
}

#[test]
fn test_insert_autoincrement_with_malformed_sqlite_sequence_is_corrupt() {
let io = Arc::new(MemoryIO::new());
let db = Database::open_file(io, ":memory:").unwrap();
let conn = db.connect().unwrap();
conn.execute("CREATE TABLE t(id INTEGER PRIMARY KEY AUTOINCREMENT, v TEXT)")
.unwrap();

let mut schema = db.schema.lock().as_ref().clone();
let seq_root_page = schema
.get_btree_table(SQLITE_SEQUENCE_TABLE_NAME)
.expect("sqlite_sequence should exist after creating AUTOINCREMENT table")
.root_page;
let malformed_seq =
BTreeTable::from_sql("CREATE TABLE sqlite_sequence(name)", seq_root_page)
.expect("malformed sqlite_sequence SQL should parse");
schema.tables.insert(
SQLITE_SEQUENCE_TABLE_NAME.to_string(),
Arc::new(Table::BTree(Arc::new(malformed_seq))),
);

let pager = conn.pager.load().clone();
let syms = SymbolTable::new();

let mut parser = turso_parser::parser::Parser::new(b"INSERT INTO t(v) VALUES('x')");
let cmd = parser.next().unwrap().unwrap();
let stmt = match cmd {
ast::Cmd::Stmt(s) => s,
_ => panic!("expected statement"),
};

let err = translate(&schema, stmt, pager, conn, &syms, QueryMode::Normal, "")
.expect_err("translation should fail with malformed sqlite_sequence");
match err {
crate::LimboError::Corrupt(msg) => {
assert!(
msg.contains("sqlite_sequence"),
"expected sqlite_sequence corruption error, got: {msg}"
);
}
other => panic!("expected LimboError::Corrupt, got: {other}"),
}
}

#[test]
fn test_insert_autoincrement_with_missing_sqlite_sequence_is_corrupt() {
let io = Arc::new(MemoryIO::new());
let db = Database::open_file(io, ":memory:").unwrap();
let conn = db.connect().unwrap();
conn.execute("CREATE TABLE t(id INTEGER PRIMARY KEY AUTOINCREMENT, v TEXT)")
.unwrap();

let mut schema = db.schema.lock().as_ref().clone();
schema.tables.remove(SQLITE_SEQUENCE_TABLE_NAME);

let pager = conn.pager.load().clone();
let syms = SymbolTable::new();

let mut parser = turso_parser::parser::Parser::new(b"INSERT INTO t(v) VALUES('x')");
let cmd = parser.next().unwrap().unwrap();
let stmt = match cmd {
ast::Cmd::Stmt(s) => s,
_ => panic!("expected statement"),
};

let err = translate(&schema, stmt, pager, conn, &syms, QueryMode::Normal, "")
.expect_err("translation should fail with missing sqlite_sequence");
match err {
crate::LimboError::Corrupt(msg) => {
assert!(
msg.contains("missing sqlite_sequence"),
"expected missing sqlite_sequence error, got: {msg}"
);
}
other => panic!("expected LimboError::Corrupt, got: {other}"),
}
}
}
Loading