use chrono::NaiveDate; use rusqlite::{Connection, Result as SqlResult, params}; use std::path::Path; use crate::models::*; const CURRENT_SCHEMA_VERSION: i32 = 1; pub struct Database { pub conn: Connection, } impl Database { pub fn open(path: &Path) -> SqlResult { if let Some(parent) = path.parent() { std::fs::create_dir_all(parent).ok(); } let conn = Connection::open(path)?; let db = Database { conn }; db.init()?; Ok(db) } pub fn open_in_memory() -> SqlResult { let conn = Connection::open_in_memory()?; let db = Database { conn }; db.init()?; Ok(db) } fn init(&self) -> SqlResult<()> { self.conn.execute_batch("PRAGMA journal_mode=WAL;")?; self.conn.execute_batch("PRAGMA foreign_keys=ON;")?; self.create_tables()?; self.migrate()?; Ok(()) } fn create_tables(&self) -> SqlResult<()> { self.conn.execute_batch( "CREATE TABLE IF NOT EXISTS schema_version ( version INTEGER NOT NULL ); CREATE TABLE IF NOT EXISTS categories ( id INTEGER PRIMARY KEY, name TEXT NOT NULL, icon TEXT, color TEXT, type TEXT NOT NULL, is_default INTEGER DEFAULT 0, sort_order INTEGER DEFAULT 0 ); CREATE TABLE IF NOT EXISTS transactions ( id INTEGER PRIMARY KEY, amount REAL NOT NULL, type TEXT NOT NULL, category_id INTEGER NOT NULL REFERENCES categories(id), currency TEXT NOT NULL DEFAULT 'USD', exchange_rate REAL DEFAULT 1.0, note TEXT, date TEXT NOT NULL, created_at TEXT NOT NULL, recurring_id INTEGER REFERENCES recurring_transactions(id) ); CREATE TABLE IF NOT EXISTS budgets ( id INTEGER PRIMARY KEY, category_id INTEGER NOT NULL REFERENCES categories(id), amount REAL NOT NULL, month TEXT NOT NULL, UNIQUE(category_id, month) ); CREATE TABLE IF NOT EXISTS recurring_transactions ( id INTEGER PRIMARY KEY, amount REAL NOT NULL, type TEXT NOT NULL, category_id INTEGER NOT NULL REFERENCES categories(id), currency TEXT NOT NULL DEFAULT 'USD', note TEXT, frequency TEXT NOT NULL, start_date TEXT NOT NULL, end_date TEXT, last_generated TEXT, active INTEGER DEFAULT 1 ); CREATE TABLE IF NOT EXISTS exchange_rates ( base TEXT NOT NULL, target TEXT NOT NULL, rate REAL NOT NULL, fetched_at TEXT NOT NULL, PRIMARY KEY (base, target) ); CREATE TABLE IF NOT EXISTS settings ( key TEXT PRIMARY KEY, value TEXT NOT NULL ); CREATE TABLE IF NOT EXISTS budget_notifications ( category_id INTEGER NOT NULL, month TEXT NOT NULL, threshold INTEGER NOT NULL, notified_at TEXT NOT NULL, PRIMARY KEY (category_id, month, threshold) );" )?; Ok(()) } fn migrate(&self) -> SqlResult<()> { let version = self.get_schema_version()?; if version == 0 { self.seed_default_categories()?; self.set_schema_version(CURRENT_SCHEMA_VERSION)?; } // Future migrations go here: // if version < 2 { ... self.set_schema_version(2)?; } Ok(()) } fn get_schema_version(&self) -> SqlResult { let count: i64 = self.conn.query_row( "SELECT COUNT(*) FROM schema_version", [], |row| row.get(0), )?; if count == 0 { return Ok(0); } self.conn.query_row( "SELECT version FROM schema_version LIMIT 1", [], |row| row.get(0), ) } fn set_schema_version(&self, version: i32) -> SqlResult<()> { self.conn.execute("DELETE FROM schema_version", [])?; self.conn.execute( "INSERT INTO schema_version (version) VALUES (?1)", params![version], )?; Ok(()) } pub fn schema_version(&self) -> SqlResult { self.get_schema_version() } // -- Transaction CRUD -- pub fn insert_transaction(&self, txn: &NewTransaction) -> SqlResult { let now = chrono::Utc::now().to_rfc3339(); self.conn.execute( "INSERT INTO transactions (amount, type, category_id, currency, exchange_rate, note, date, created_at, recurring_id) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)", params![ txn.amount, txn.transaction_type.as_str(), txn.category_id, txn.currency, txn.exchange_rate, txn.note, txn.date.format("%Y-%m-%d").to_string(), now, txn.recurring_id, ], )?; Ok(self.conn.last_insert_rowid()) } pub fn get_transaction(&self, id: i64) -> SqlResult { self.conn.query_row( "SELECT id, amount, type, category_id, currency, exchange_rate, note, date, created_at, recurring_id FROM transactions WHERE id = ?1", params![id], |row| Self::row_to_transaction(row), ) } pub fn update_transaction(&self, txn: &Transaction) -> SqlResult<()> { self.conn.execute( "UPDATE transactions SET amount=?1, type=?2, category_id=?3, currency=?4, exchange_rate=?5, note=?6, date=?7 WHERE id=?8", params![ txn.amount, txn.transaction_type.as_str(), txn.category_id, txn.currency, txn.exchange_rate, txn.note, txn.date.format("%Y-%m-%d").to_string(), txn.id, ], )?; Ok(()) } pub fn delete_transaction(&self, id: i64) -> SqlResult<()> { self.conn.execute("DELETE FROM transactions WHERE id = ?1", params![id])?; Ok(()) } pub fn list_transactions_by_month(&self, year: i32, month: u32) -> SqlResult> { let prefix = format!("{:04}-{:02}", year, month); let mut stmt = self.conn.prepare( "SELECT id, amount, type, category_id, currency, exchange_rate, note, date, created_at, recurring_id FROM transactions WHERE date LIKE ?1 ORDER BY date DESC, id DESC" )?; let rows = stmt.query_map(params![format!("{}%", prefix)], |row| { Self::row_to_transaction(row) })?; rows.collect() } pub fn list_recent_transactions(&self, limit: usize) -> SqlResult> { let mut stmt = self.conn.prepare( "SELECT id, amount, type, category_id, currency, exchange_rate, note, date, created_at, recurring_id FROM transactions ORDER BY date DESC, id DESC LIMIT ?1" )?; let rows = stmt.query_map(params![limit as i64], |row| { Self::row_to_transaction(row) })?; rows.collect() } fn row_to_transaction(row: &rusqlite::Row) -> SqlResult { let type_str: String = row.get(2)?; let date_str: String = row.get(7)?; Ok(Transaction { id: row.get(0)?, amount: row.get(1)?, transaction_type: TransactionType::from_str(&type_str).unwrap_or(TransactionType::Expense), category_id: row.get(3)?, currency: row.get(4)?, exchange_rate: row.get(5)?, note: row.get(6)?, date: NaiveDate::parse_from_str(&date_str, "%Y-%m-%d").unwrap_or_default(), created_at: row.get(8)?, recurring_id: row.get(9)?, }) } // -- Category CRUD -- pub fn list_categories(&self, txn_type: Option) -> SqlResult> { match txn_type { Some(t) => { let mut stmt = self.conn.prepare( "SELECT id, name, icon, color, type, is_default, sort_order FROM categories WHERE type = ?1 ORDER BY sort_order" )?; let rows = stmt.query_map(params![t.as_str()], |row| Self::row_to_category(row))?; rows.collect() } None => { let mut stmt = self.conn.prepare( "SELECT id, name, icon, color, type, is_default, sort_order FROM categories ORDER BY type, sort_order" )?; let rows = stmt.query_map([], |row| Self::row_to_category(row))?; rows.collect() } } } pub fn get_category(&self, id: i64) -> SqlResult { self.conn.query_row( "SELECT id, name, icon, color, type, is_default, sort_order FROM categories WHERE id = ?1", params![id], |row| Self::row_to_category(row), ) } pub fn insert_category(&self, cat: &NewCategory) -> SqlResult { self.conn.execute( "INSERT INTO categories (name, icon, color, type, is_default, sort_order) VALUES (?1, ?2, ?3, ?4, 0, ?5)", params![ cat.name, cat.icon, cat.color, cat.transaction_type.as_str(), cat.sort_order, ], )?; Ok(self.conn.last_insert_rowid()) } pub fn update_category(&self, cat: &Category) -> SqlResult<()> { self.conn.execute( "UPDATE categories SET name=?1, icon=?2, color=?3, sort_order=?4 WHERE id=?5", params![cat.name, cat.icon, cat.color, cat.sort_order, cat.id], )?; Ok(()) } pub fn delete_category(&self, id: i64) -> SqlResult<()> { self.conn.execute("DELETE FROM categories WHERE id = ?1", params![id])?; Ok(()) } fn row_to_category(row: &rusqlite::Row) -> SqlResult { let type_str: String = row.get(4)?; Ok(Category { id: row.get(0)?, name: row.get(1)?, icon: row.get(2)?, color: row.get(3)?, transaction_type: TransactionType::from_str(&type_str).unwrap_or(TransactionType::Expense), is_default: row.get(5)?, sort_order: row.get(6)?, }) } // -- Aggregation queries -- pub fn get_monthly_totals_by_category( &self, year: i32, month: u32, txn_type: TransactionType, ) -> SqlResult> { let prefix = format!("{:04}-{:02}", year, month); let mut stmt = self.conn.prepare( "SELECT c.id, c.name, c.icon, c.color, c.type, c.is_default, c.sort_order, SUM(t.amount * t.exchange_rate) as total FROM transactions t JOIN categories c ON t.category_id = c.id WHERE t.date LIKE ?1 AND t.type = ?2 GROUP BY c.id ORDER BY total DESC" )?; let rows = stmt.query_map(params![format!("{}%", prefix), txn_type.as_str()], |row| { let cat = Self::row_to_category(row)?; let total: f64 = row.get(7)?; Ok((cat, total)) })?; rows.collect() } pub fn get_monthly_total( &self, year: i32, month: u32, txn_type: TransactionType, ) -> SqlResult { let prefix = format!("{:04}-{:02}", year, month); self.conn.query_row( "SELECT COALESCE(SUM(amount * exchange_rate), 0.0) FROM transactions WHERE date LIKE ?1 AND type = ?2", params![format!("{}%", prefix), txn_type.as_str()], |row| row.get(0), ) } pub fn get_daily_totals( &self, year: i32, month: u32, ) -> SqlResult> { let prefix = format!("{:04}-{:02}", year, month); let mut stmt = self.conn.prepare( "SELECT date, COALESCE(SUM(CASE WHEN type='income' THEN amount * exchange_rate ELSE 0 END), 0.0), COALESCE(SUM(CASE WHEN type='expense' THEN amount * exchange_rate ELSE 0 END), 0.0) FROM transactions WHERE date LIKE ?1 GROUP BY date ORDER BY date DESC" )?; let rows = stmt.query_map(params![format!("{}%", prefix)], |row| { let date_str: String = row.get(0)?; let date = NaiveDate::parse_from_str(&date_str, "%Y-%m-%d").unwrap_or_default(); let income: f64 = row.get(1)?; let expense: f64 = row.get(2)?; Ok((date, income, expense)) })?; rows.collect() } // -- Exchange Rates -- pub fn get_cached_rate(&self, base: &str, target: &str) -> SqlResult> { match self.conn.query_row( "SELECT base, target, rate, fetched_at FROM exchange_rates WHERE base = ?1 AND target = ?2", params![base.to_lowercase(), target.to_lowercase()], |row| { Ok(ExchangeRate { base: row.get(0)?, target: row.get(1)?, rate: row.get(2)?, fetched_at: row.get(3)?, }) }, ) { Ok(rate) => Ok(Some(rate)), Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), Err(e) => Err(e), } } pub fn cache_rates(&self, base: &str, rates: &std::collections::HashMap) -> SqlResult<()> { let now = chrono::Utc::now().to_rfc3339(); let base_lower = base.to_lowercase(); let mut stmt = self.conn.prepare( "INSERT OR REPLACE INTO exchange_rates (base, target, rate, fetched_at) VALUES (?1, ?2, ?3, ?4)" )?; for (target, rate) in rates { stmt.execute(params![base_lower, target.to_lowercase(), rate, now])?; } Ok(()) } // -- Settings -- pub fn get_setting(&self, key: &str) -> SqlResult> { match self.conn.query_row( "SELECT value FROM settings WHERE key = ?1", params![key], |row| row.get(0), ) { Ok(val) => Ok(Some(val)), Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), Err(e) => Err(e), } } pub fn set_setting(&self, key: &str, value: &str) -> SqlResult<()> { self.conn.execute( "INSERT OR REPLACE INTO settings (key, value) VALUES (?1, ?2)", params![key, value], )?; Ok(()) } fn seed_default_categories(&self) -> SqlResult<()> { let expense_categories = [ ("Food & Dining", "\u{1f354}", "#e74c3c"), ("Groceries", "\u{1f6d2}", "#e67e22"), ("Transport", "\u{1f68c}", "#3498db"), ("Housing/Rent", "\u{1f3e0}", "#9b59b6"), ("Utilities", "\u{1f4a1}", "#f39c12"), ("Entertainment", "\u{1f3ac}", "#1abc9c"), ("Shopping", "\u{1f6cd}", "#e91e63"), ("Health", "\u{2695}", "#2ecc71"), ("Education", "\u{1f393}", "#00bcd4"), ("Subscriptions", "\u{1f4f1}", "#ff5722"), ("Personal Care", "\u{2728}", "#795548"), ("Gifts", "\u{1f381}", "#ff9800"), ("Travel", "\u{2708}", "#607d8b"), ("Other", "\u{1f4b8}", "#95a5a6"), ]; let income_categories = [ ("Salary", "\u{1f4b0}", "#27ae60"), ("Freelance", "\u{1f4bb}", "#2980b9"), ("Investment", "\u{1f4c8}", "#8e44ad"), ("Gift", "\u{1f381}", "#f1c40f"), ("Refund", "\u{1f504}", "#16a085"), ("Other", "\u{1f4b5}", "#7f8c8d"), ]; let mut stmt = self.conn.prepare( "INSERT INTO categories (name, icon, color, type, is_default, sort_order) VALUES (?1, ?2, ?3, ?4, 1, ?5)" )?; for (i, (name, icon, color)) in expense_categories.iter().enumerate() { stmt.execute(params![name, icon, color, "expense", i as i32])?; } for (i, (name, icon, color)) in income_categories.iter().enumerate() { stmt.execute(params![name, icon, color, "income", i as i32])?; } Ok(()) } } #[cfg(test)] mod tests { use super::*; use chrono::NaiveDate; #[test] fn test_init_creates_tables() { let db = Database::open_in_memory().unwrap(); let count: i64 = db.conn.query_row( "SELECT COUNT(*) FROM categories", [], |row| row.get(0), ).unwrap(); assert_eq!(count, 20); // 14 expense + 6 income } #[test] fn test_default_expense_categories() { let db = Database::open_in_memory().unwrap(); let count: i64 = db.conn.query_row( "SELECT COUNT(*) FROM categories WHERE type = 'expense'", [], |row| row.get(0), ).unwrap(); assert_eq!(count, 14); } #[test] fn test_default_income_categories() { let db = Database::open_in_memory().unwrap(); let count: i64 = db.conn.query_row( "SELECT COUNT(*) FROM categories WHERE type = 'income'", [], |row| row.get(0), ).unwrap(); assert_eq!(count, 6); } #[test] fn test_schema_version_set() { let db = Database::open_in_memory().unwrap(); assert_eq!(db.schema_version().unwrap(), CURRENT_SCHEMA_VERSION); } #[test] fn test_all_tables_exist() { let db = Database::open_in_memory().unwrap(); let tables = [ "categories", "transactions", "budgets", "recurring_transactions", "exchange_rates", "settings", "budget_notifications", "schema_version", ]; for table in tables { let exists: bool = db.conn.query_row( "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='table' AND name=?1", params![table], |row| row.get(0), ).unwrap(); assert!(exists, "Table '{}' should exist", table); } } #[test] fn test_categories_have_icons_and_colors() { let db = Database::open_in_memory().unwrap(); let missing: i64 = db.conn.query_row( "SELECT COUNT(*) FROM categories WHERE icon IS NULL OR color IS NULL", [], |row| row.get(0), ).unwrap(); assert_eq!(missing, 0, "All default categories should have icon and color"); } fn make_expense(db: &Database) -> i64 { let cats = db.list_categories(Some(TransactionType::Expense)).unwrap(); let cat = &cats[0]; let txn = NewTransaction { amount: 12.50, transaction_type: TransactionType::Expense, category_id: cat.id, currency: "USD".to_string(), exchange_rate: 1.0, note: Some("Lunch".to_string()), date: NaiveDate::from_ymd_opt(2026, 3, 1).unwrap(), recurring_id: None, }; db.insert_transaction(&txn).unwrap() } fn make_income(db: &Database) -> i64 { let cats = db.list_categories(Some(TransactionType::Income)).unwrap(); let cat = &cats[0]; let txn = NewTransaction { amount: 3000.0, transaction_type: TransactionType::Income, category_id: cat.id, currency: "USD".to_string(), exchange_rate: 1.0, note: Some("Salary".to_string()), date: NaiveDate::from_ymd_opt(2026, 3, 1).unwrap(), recurring_id: None, }; db.insert_transaction(&txn).unwrap() } #[test] fn test_insert_and_get_transaction() { let db = Database::open_in_memory().unwrap(); let id = make_expense(&db); let txn = db.get_transaction(id).unwrap(); assert_eq!(txn.amount, 12.50); assert_eq!(txn.transaction_type, TransactionType::Expense); assert_eq!(txn.note.as_deref(), Some("Lunch")); assert_eq!(txn.date, NaiveDate::from_ymd_opt(2026, 3, 1).unwrap()); } #[test] fn test_update_transaction() { let db = Database::open_in_memory().unwrap(); let id = make_expense(&db); let mut txn = db.get_transaction(id).unwrap(); txn.amount = 15.75; txn.note = Some("Dinner".to_string()); db.update_transaction(&txn).unwrap(); let updated = db.get_transaction(id).unwrap(); assert_eq!(updated.amount, 15.75); assert_eq!(updated.note.as_deref(), Some("Dinner")); } #[test] fn test_delete_transaction() { let db = Database::open_in_memory().unwrap(); let id = make_expense(&db); db.delete_transaction(id).unwrap(); assert!(db.get_transaction(id).is_err()); } #[test] fn test_list_transactions_by_month() { let db = Database::open_in_memory().unwrap(); make_expense(&db); make_income(&db); let march = db.list_transactions_by_month(2026, 3).unwrap(); assert_eq!(march.len(), 2); let feb = db.list_transactions_by_month(2026, 2).unwrap(); assert_eq!(feb.len(), 0); } #[test] fn test_list_recent_transactions() { let db = Database::open_in_memory().unwrap(); make_expense(&db); make_income(&db); let recent = db.list_recent_transactions(1).unwrap(); assert_eq!(recent.len(), 1); let all = db.list_recent_transactions(10).unwrap(); assert_eq!(all.len(), 2); } #[test] fn test_list_categories_by_type() { let db = Database::open_in_memory().unwrap(); let expense = db.list_categories(Some(TransactionType::Expense)).unwrap(); assert_eq!(expense.len(), 14); let income = db.list_categories(Some(TransactionType::Income)).unwrap(); assert_eq!(income.len(), 6); let all = db.list_categories(None).unwrap(); assert_eq!(all.len(), 20); } #[test] fn test_insert_custom_category() { let db = Database::open_in_memory().unwrap(); let cat = NewCategory { name: "Pets".to_string(), icon: Some("\u{1f436}".to_string()), color: Some("#a0522d".to_string()), transaction_type: TransactionType::Expense, sort_order: 99, }; let id = db.insert_category(&cat).unwrap(); let fetched = db.get_category(id).unwrap(); assert_eq!(fetched.name, "Pets"); assert!(!fetched.is_default); } #[test] fn test_monthly_totals_by_category() { let db = Database::open_in_memory().unwrap(); make_expense(&db); make_expense(&db); let totals = db.get_monthly_totals_by_category(2026, 3, TransactionType::Expense).unwrap(); assert_eq!(totals.len(), 1); assert_eq!(totals[0].1, 25.0); // 12.50 * 2 } #[test] fn test_monthly_total() { let db = Database::open_in_memory().unwrap(); make_expense(&db); make_income(&db); let expenses = db.get_monthly_total(2026, 3, TransactionType::Expense).unwrap(); assert_eq!(expenses, 12.50); let income = db.get_monthly_total(2026, 3, TransactionType::Income).unwrap(); assert_eq!(income, 3000.0); } #[test] fn test_daily_totals() { let db = Database::open_in_memory().unwrap(); make_expense(&db); make_income(&db); let daily = db.get_daily_totals(2026, 3).unwrap(); assert_eq!(daily.len(), 1); let (date, inc, exp) = &daily[0]; assert_eq!(*date, NaiveDate::from_ymd_opt(2026, 3, 1).unwrap()); assert_eq!(*inc, 3000.0); assert_eq!(*exp, 12.50); } #[test] fn test_settings_crud() { let db = Database::open_in_memory().unwrap(); assert_eq!(db.get_setting("base_currency").unwrap(), None); db.set_setting("base_currency", "EUR").unwrap(); assert_eq!(db.get_setting("base_currency").unwrap(), Some("EUR".to_string())); db.set_setting("base_currency", "GBP").unwrap(); assert_eq!(db.get_setting("base_currency").unwrap(), Some("GBP".to_string())); } #[test] fn test_idempotent_init() { let db = Database::open_in_memory().unwrap(); // Re-init should not duplicate categories db.init().unwrap(); let count: i64 = db.conn.query_row( "SELECT COUNT(*) FROM categories", [], |row| row.get(0), ).unwrap(); assert_eq!(count, 20); } }