diff --git a/.gitignore b/.gitignore index 7f6b888..98273d6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ build*/ include/strata/version.hpp +examples/ complogs.txt +include/strata/db_config.hpp diff --git a/CHANGELOG.md b/CHANGELOG.md deleted file mode 100644 index 3e89d3b..0000000 --- a/CHANGELOG.md +++ /dev/null @@ -1,7 +0,0 @@ -- ```FEAT```: - - Added support for environmental variables with ability to be set in the program itself, instead of setting them in the ```config.json``` file. - (see examples/ for usage) - - Added support for updates and deletes, see ```README.md``` examples section for a demo. - -- ```FIXES```: - - Fixed the ```'text'``` datatype where we were appending the size '0' when it is not needed. Removed that from the ```text``` datatype in ```CharField``` class. diff --git a/CMakeLists.txt b/CMakeLists.txt index 660a742..92a73cc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,17 +4,16 @@ project(StrataORM VERSION 0.2.0 LANGUAGES CXX) #set CXX standard set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) - set (CMAKE_CXX_FLAGS_DEBUG "-std=c++20 -Wall -Wextra -Wpedantic -g") - set (CMAKE_CXX_FLAGS_RELEASE "-std=c++20 -O3") set(SOURCES src/datatypes.cpp src/models.cpp + src/utils.cpp ) #Build options -option(DB_ENGINE "Target database engine(only 'PSQL' supported now)" "PSQL") +option(DB_ENGINE "Target database engine." "PSQL") #Database Engine Macro setup if(DB_ENGINE STREQUAL "PSQL") @@ -22,6 +21,11 @@ if(DB_ENGINE STREQUAL "PSQL") file(GLOB DB_SRC src/psql/*.cpp) list(APPEND SOURCES ${DB_SRC}) set(DB_HEADERS "${CMAKE_CURRENT_SOURCE_DIR}/include/strata/psql") +elseif(DB_ENGINE STREQUAL "MARIADB") + set(DB_MACRO "#define MARIADB") + file(GLOB DB_SRC src/mariadb/*.cpp) + list(APPEND SOURCES ${DB_SRC}) + set(DB_HEADERS "${CMAKE_CURRENT_SOURCE_DIR}/include/strata/mariadb") else() message(FATAL_ERROR "Unsupported database engine chosen: ${DB_ENGINE}") endif() @@ -40,6 +44,9 @@ set_target_properties(strata_shared PROPERTIES OUTPUT_NAME "strata") add_library(strata_static STATIC ${SOURCES}) set_target_properties(strata_static PROPERTIES OUTPUT_NAME "strata") +# target_link_libraries(strata_shared PUBLIC ${DEP_LIBS}) +# target_link_libraries(strata_static PUBLIC ${DEP_LIBS}) + #generate version.hpp from version.hpp.in in include/strata/ configure_file( ${CMAKE_CURRENT_SOURCE_DIR}/include/strata/version.hpp.in @@ -66,8 +73,8 @@ target_include_directories(strata_static #Install both targets install(TARGETS strata_shared strata_static EXPORT strata-targets - ARCHIVE DESTINATION lib - LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib/strata + LIBRARY DESTINATION lib/strata ) #Install top-level headers @@ -75,6 +82,7 @@ install(DIRECTORY include/strata DESTINATION include FILES_MATCHING PATTERN "*.hpp" PATTERN "psql" EXCLUDE + PATTERN "mariadb" EXCLUDE ) #Install database-specific headers diff --git a/README.md b/README.md index 4da8a68..6dd4ede 100644 --- a/README.md +++ b/README.md @@ -3,12 +3,12 @@ An ORM built on C++20 that is inspired by Django's ORM, built to support multipl [![C++20](https://img.shields.io/badge/C%2B%2B-20-blue)](https://en.cppreference.com/w/cpp/20.html) [![CMake](https://img.shields.io/badge/Build-CMake%203.16%2B-brightblue)](https://cmake.org) -[![PostgreSQL](https://img.shields.io/badge/Database-PostgreSQL-blueviolet)](https://www.postgresql.org) -[![libpqxx](https://img.shields.io/badge/Dependency-libpqxx--dev-pink)](https://github.com/jtv/libpqxx) [![License: GPL v3](https://img.shields.io/badge/License-GPLv3-darkred)](https://www.gnu.org/licenses/gpl-3.0.en.html) -Strata provides an easy-to-use, intuitive API to interact with databases. As of now, it only -supports PostgreSQL but plans to add support for more database engines are in mind. +Strata provides an easy-to-use, intuitive API to interact with databases. +Supports the following databases: + - PostgreSQL + - MariaDB Documentation can be found at the [WIKI](https://github.com/bitflaw/strataorm/wiki). @@ -23,100 +23,55 @@ Documentation can be found at the [WIKI](https://github.com/bitflaw/strataorm/wi - [X] Clean abstraction over raw, basic SQL datatypes using classes. - [X] Support for performing fetches, filters(limited) and joins. - [X] Environmental variables set in-program for db connections. +- [X] Support for user-defined datatypes. +- [X] Support for more database engines ie. postgres and Mariadb - [ ] Support for nullable values. -- [ ] Support for user-defined datatypes. -- [ ] Support for more database engines eg MySQL, SQLite, MSSQL etc. ## Dependencies -Since this library supports only PostgreSQL now, dependencies are: - A C++ compiler supporting [```-std=C++20```](https://en.cppreference.com/w/cpp/compiler_support/20). - [CMake](https://cmake.org/download/) (at least version 3.16). -- [PostgreSQL](https://www.postgresql.org/download/) library. +**If using PostgreSQL** - [libpqxx](https://github.com/jtv/libpqxx) -> Official C++ client library for postgres. - - ->[!Warning] -> This library depends on a feature from libpqxx that is only available on the latest development version(not released yet). -> Options are to build from source or wait for the upcoming [```8.0```](https://github.com/jtv/libpqxx/pull/914) release. - +- [PostgreSQL](https://postgresql.org/download) -> PostgreSQL database. +**If using MariaDB** +- [mdbcxx](https://github.com/bitflaw/mdbcxx) -> A MariaDB C++ Connector. +- [MariaDB's C connector](https://github.com/mariadb-corporation/mariadb-connector-c) -> C connector for MariaDB which is a dependency for `mdbcxx` +- [MariaDB](https://mariadb.com/downloads) -> MariaDB database ## Build & Installation - -### Step 1: Clone the repository ```bash -$ git clone git@github.com:bitflaw/strataorm.git -$ cd strataorm -``` -### Step 2: Build the library -Since we are using CMake, I recommend building in a dedicated build directory: -```bash -$ mkdir build -$ cmake -B ${BUILD_DIR} -S . -DDB_ENGINE=PSQL -$ cmake --build ${BUILD_DIR} -``` - -- ```-DDB_ENGINE=PSQL``` to specify the database you want to use the ORM with. - This flag only takes ```PSQL``` for now since only postgres is supported for now. -- Note that both static and dynamic libraries will be built for both use cases, avoiding rebuilding just to -use a desired one. - - -### Step 3: Install to System -To install to the default location specified by CMake, run: -```bash -$ cmake --install ${BUILD_DIR} -``` -To install to a specified location, do: -```bash -$ cmake --install ${BUILD_DIR} --prefix ${DESTINATION} -``` -> [!NOTE] -> You might need sudo/admin privileges to run this command. - - -If it's a CMake project, add this in your CMakeLists.txt file immediately after cloning the repo: -``` -add_subdirectory(strata) -target_link_libraries(my_project PRIVATE strata) +git clone git@github.com:bitflaw/strataorm.git +cd strataorm +cmake -B ${BUILD_DIR} -S . -DDB_ENGINE={PSQL or MARIADB} +cmake --build ${BUILD_DIR} +cmake --install ${BUILD_DIR} #installs to /usr/local/ ``` -## Examples -Examples can be found under the ```examples``` directory in the source tree. - **Model usage example** ```cpp -#include #include #include class users : public Model{ public: users(){ - col_map["username"] = std::make_shared(CharField("varchar", 24, true, true)); - col_map["email"] = std::make_shared(CharField("varchar", 50, true, true)); - col_map["pin"] = std::make_shared(IntegerField("integer", false, true)); + col_map["username"] = db::Field::CharField("varchar", 24, true, false); + col_map["email"] = db::Field::CharField("varchar", 51, false, true); + col_map["pin"] = db::Field::IntegerField("tinyint", false, true); } };REGISTER_MODEL(users); -class message : public Model{ -public: - message(){ - col_map["sender"] = std::make_shared(ForeignKey("sender", "users", "users_id", std::nullopt, "CASCADE", "CASCADE")); - col_map["receiver"] = std::make_shared(ForeignKey("receiver", "users", "users_id", std::nullopt, "CASCADE", "CASCADE")); - col_map["content"] = std::make_shared(CharField("varchar", 256, true)); - } -};REGISTER_MODEL(message); - int main(){ - Utils::dbenvars vars = { - //insert your db credentials here - {"DBUSER", ""}, - {"DBPASS", ""}, - {"DBNAME", ""}, - {"DBHOST", ""}, - {"DBPORT", ""} - }; - Utils::set_dbenvars(vars); + //can remove if u don't plan to apply the changes to the actual db. + //this is only relevant when there is a db in play. + // Utils::dbenvars vars = { + // {"DBUSER", ""}, + // {"DBPASS", ""}, + // {"DBNAME", ""}, + // {"DBHOST", ""}, + // {"DBPORT", ""} + // }; + // Utils::set_dbenvars(vars); Model model {}; nlohmann::json mrm {}; @@ -124,8 +79,7 @@ int main(){ std::string sql_filename {"migrations.sql"}; model.make_migrations(mrm, frm, sql_filename); - - db_adapter::opt_result_t result = db_adapter::execute_sql(sql_filename); + // db::opt_result_t result = db::execute_sql(sql_filename); return 0; } ``` @@ -148,22 +102,14 @@ int main(){ Utils::set_dbenvars(vars); users user {}; - message m {}; using params = std::vector; params user_rows = user.parse_json_rows(); - params message_rows = m.parse_json_rows(); - pqxx::connection cxn = db_adapter::prepare_insert(); + pqxx::connection cxn = db::prepare_insert(); for(pqxx::params& user_row : user_rows){ - db_adapter::exec_insert(cxn, user_row); + db::exec_insert(cxn, user_row); } - - cxn = db_adapter::prepare_insert(); - for(pqxx::params& message_r : message_rows){ - db_adapter::exec_insert(cxn, message_r); - } - return 0; } ``` @@ -185,16 +131,16 @@ int main(){ users user {}; - db_adapter::query::fetch_all(user, "*"); - //db_adapter::query::get(user, "username", "berna"); + db::query::fetch_all(user, "*"); + //db::query::get(user, "username", "berna"); filters filters = { {"email", OP::CONTAINS, "gmail"}, {"username", OP::STARTSWITH, "b"} }; - db_adapter::query::filter(user, "or", filters); + db::query::filter(user, "or", filters); - std::vector my_users = db_adapter::query::to_instances(user); + std::vector my_users = db::query::to_instances(user); return 0; } @@ -216,7 +162,7 @@ int main(){ Utils::set_dbenvars(vars); users user {}; - db_adapter::query::JoinBuilder JB {user}; + db::query::JoinBuilder JB {user}; pqxx::result result = JB.select("username, email") .inner_join("message") .on("and", "users.users_id = message.sender") @@ -244,7 +190,7 @@ int main(){ users user {}; - db_adapter::Update user_update {}; + db::Update user_update {}; Utils::filters filters = { {"username", OP::EQ, "janedoe"} }; @@ -254,7 +200,7 @@ int main(){ .where("and", filters) .commit(); - db_adapter::query::get(user, "username", "'janny'"); + db::query::get(user, "username", "'janny'"); return 0; } ``` @@ -274,13 +220,11 @@ int main(){ }; Utils::set_dbenvars(vars); - users user {}; - Utils::filters filters = { {"users_id", OP::EQ, 3} }; - db_adapter::delete_row("and", filters); + db::delete_row("and", filters); return 0; } @@ -289,5 +233,3 @@ int main(){ > [!NOTE] > Tests have not been implemented yet but will be soon. -## Contributing -All contributions are welcome. Please open an issue or submit a pull request for contributions to the library. diff --git a/examples/deletes/CMakeLists.txt b/examples/deletes/CMakeLists.txt deleted file mode 100644 index 30f06dc..0000000 --- a/examples/deletes/CMakeLists.txt +++ /dev/null @@ -1,24 +0,0 @@ -cmake_minimum_required(VERSION 3.16) -project(delete CXX) - -set(CMAKE_CXX_STANDARD 20) -set(CMAKE_CXX_STANDARD_REQUIRED ON) - -add_executable(delete - main.cpp -) - -# Add include directories (e.g., your headers) -target_include_directories(delete - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/../include -) - -target_link_libraries(delete - PRIVATE - pq - pqxx - strata -) - -target_compile_options(delete PRIVATE -Wall -Wextra -Wpedantic) diff --git a/examples/deletes/main.cpp b/examples/deletes/main.cpp deleted file mode 100644 index 02fba80..0000000 --- a/examples/deletes/main.cpp +++ /dev/null @@ -1,48 +0,0 @@ -#include "../include/models.hpp" -#include - -int main(){ - Utils::dbenvars vars = { - {"DBUSER", ""}, - {"DBPASS", ""}, - {"DBNAME", ""}, - {"DBHOST", ""}, - {"DBPORT", ""} - }; - Utils::set_dbenvars(vars); - - - users user {}; - - db_adapter::query::fetch_all(user, "*"); - int records_size = user.records.size(); - std::vector my_users = db_adapter::to_instances(user); - std::cout<< "Before deleting a user with id 3:\n"; - for (int i = 0; i < records_size; ++i) { - std::cout<< my_users[i].id<<": " - << my_users[i].username<<", " - << my_users[i].email<<", " - << my_users[i].pin<< "\n"; - } - std::cout<("and", filters); - - user.records.clear(); - db_adapter::query::fetch_all(user, "*"); - records_size = user.records.size(); - my_users = db_adapter::to_instances(user); - std::cout<< "After deleting user with id 3:\n"; - for (int i = 0; i < records_size; ++i) { - std::cout<< my_users[i].id<<": " - << my_users[i].username<<", " - << my_users[i].email<<", " - << my_users[i].pin<< "\n"; - } - std::cout< -#include -#include -#include -#include -#include - -class users{ -public: - std::string table_name = "users"; - int id; - int pin; - std::string email; - std::string username; - std::vector records; - std::string col_str = "pin,email,username"; - int col_map_size = 3; - - users() = default; - template - users(tuple_T tup){ - std::tie(id,pin,email,username) = tup; - } - - auto get_attr() const{ - return std::make_tuple(id,pin,email,username); - } - - std::vector parse_json_rows(){ - std::vector rows {}; - std::ifstream json_row_file("../../insert/users.json"); - if(!json_row_file.is_open()) std::runtime_error("Couldn't open file attempting to parse rows"); - - nlohmann::json json_data = nlohmann::json::parse(json_row_file); - - for(auto& json_row: json_data){ - rows.push_back(pqxx::params{json_row["pin"].get(), - json_row["email"].get(), - json_row["username"].get() - }); - } - return rows; - } -}; - -class message{ -public: - std::string table_name = "message"; - int id; - std::string content; - int receiver; - int sender; - std::vector records; - std::string col_str = "content,receiver,sender"; - int col_map_size = 3; - - message() = default; - template - message(tuple_T tup){ - std::tie(id,content,receiver,sender) = tup; - } - - auto get_attr() const{ - return std::make_tuple(id,content,receiver,sender); - } - - std::vector parse_json_rows(){ - std::vector rows {}; - std::ifstream json_row_file("../../insert/messages.json"); - if(!json_row_file.is_open()) std::runtime_error("Couldn't open file attempting to parse rows"); - - nlohmann::json json_data = nlohmann::json::parse(json_row_file); - - for(auto& json_row: json_data){ - rows.push_back(pqxx::params{json_row["content"].get(), - json_row["receiver"].get(), - json_row["sender"].get() - }); - } - return rows; - } -}; - diff --git a/examples/insert/CMakeLists.txt b/examples/insert/CMakeLists.txt deleted file mode 100644 index ec868f6..0000000 --- a/examples/insert/CMakeLists.txt +++ /dev/null @@ -1,24 +0,0 @@ -cmake_minimum_required(VERSION 3.16) -project(inserts CXX) - -set(CMAKE_CXX_STANDARD 20) -set(CMAKE_CXX_STANDARD_REQUIRED ON) - -add_executable(inserts - insert.cpp -) - -# Add include directories (e.g., your headers) -target_include_directories(inserts - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/../include -) - -target_link_libraries(inserts - PRIVATE - pq - pqxx - strata -) - -target_compile_options(inserts PRIVATE -Wall -Wextra -Wpedantic) diff --git a/examples/insert/insert.cpp b/examples/insert/insert.cpp deleted file mode 100644 index 91ee9b3..0000000 --- a/examples/insert/insert.cpp +++ /dev/null @@ -1,33 +0,0 @@ -#include "../include/models.hpp" -#include - -int main(){ - Utils::dbenvars vars = { - {"DBUSER", ""}, - {"DBPASS", ""}, - {"DBNAME", ""}, - {"DBHOST", ""}, - {"DBPORT", ""} - }; - Utils::set_dbenvars(vars); - - - users user {}; - message m {}; - - using params = std::vector; - params user_rows = user.parse_json_rows(); - params message_rows = m.parse_json_rows(); - - pqxx::connection cxn = db_adapter::prepare_insert(); - for(pqxx::params& user_row : user_rows){ - db_adapter::exec_insert(cxn, user_row); - } - - cxn = db_adapter::prepare_insert(); - for(pqxx::params& message_r : message_rows){ - db_adapter::exec_insert(cxn, message_r); - } - - return 0; -} diff --git a/examples/insert/messages.json b/examples/insert/messages.json deleted file mode 100644 index b8d4baa..0000000 --- a/examples/insert/messages.json +++ /dev/null @@ -1,22 +0,0 @@ -[ - { - "sender": 3, - "receiver": 1, - "content": "This is a test message :}" - }, - { - "sender": 4, - "receiver": 1, - "content": "This is another test message :)" - }, - { - "sender": 1, - "receiver": 5, - "content": "This is yet another test message =)" - }, - { - "sender": 5, - "receiver": 3, - "content": "Again, this is just a test message!" - } -] diff --git a/examples/insert/users.json b/examples/insert/users.json deleted file mode 100644 index 0632bc3..0000000 --- a/examples/insert/users.json +++ /dev/null @@ -1,27 +0,0 @@ -[ - { - "username": "barbie", - "email": "barbaradoe@hotmail.me", - "pin": 7890 - }, - { - "username": "berna", - "email": "bernarddoe@gmail.com", - "pin": 3456 - }, - { - "username": "peterdoe", - "email": "peterdoe@yahoo.com", - "pin": 9012 - }, - { - "username": "janedoe", - "email": "janedoe@gmail.com", - "pin": 5678 - }, - { - "username": "jamesdoe", - "email": "jamesdoe@gmail.com", - "pin": 1234 - } -] diff --git a/examples/models/CMakeLists.txt b/examples/models/CMakeLists.txt deleted file mode 100644 index 4170e56..0000000 --- a/examples/models/CMakeLists.txt +++ /dev/null @@ -1,18 +0,0 @@ -cmake_minimum_required(VERSION 3.16) -project(models CXX) - -set(CMAKE_CXX_STANDARD 20) -set(CMAKE_CXX_STANDARD_REQUIRED ON) - -add_executable(models - models.cpp -) - -target_link_libraries(models - PRIVATE - pq - pqxx - strata -) - -target_compile_options(models PRIVATE -Wall -Wextra -Wpedantic) diff --git a/examples/models/models.cpp b/examples/models/models.cpp deleted file mode 100755 index f10ceb4..0000000 --- a/examples/models/models.cpp +++ /dev/null @@ -1,46 +0,0 @@ -#include -#include -#include - -class users : public Model{ -public: - users(){ - col_map["username"] = std::make_shared(CharField("varchar", 24, true, true)); - col_map["email"] = std::make_shared(CharField("varchar", 50, true, true)); - col_map["pin"] = std::make_shared(IntegerField("integer", false, true)); - } -};REGISTER_MODEL(users); - -class message : public Model{ -public: - message(){ - col_map["sender"] = std::make_shared(ForeignKey("sender", "users", "users_id", std::nullopt, "CASCADE", "CASCADE")); - col_map["receiver"] = std::make_shared(ForeignKey("receiver", "users", "users_id", std::nullopt, "CASCADE", "CASCADE")); - col_map["content"] = std::make_shared(CharField("varchar", 256, true)); - } -};REGISTER_MODEL(message); - -int main(){ - //can remove if u don't plan to apply the changes to the actual db. - //this is only relevant when there is a db in play. - Utils::dbenvars vars = { - {"DBUSER", ""}, - {"DBPASS", ""}, - {"DBNAME", ""}, - {"DBHOST", ""}, - {"DBPORT", ""} - }; - Utils::set_dbenvars(vars); - - - Model model {}; - nlohmann::json mrm {}; - nlohmann::json frm {}; - std::string sql_filename {"migrations.sql"}; - - model.make_migrations(mrm, frm, sql_filename); - - db_adapter::opt_result_t result = db_adapter::execute_sql(sql_filename); - - return 0; -} diff --git a/examples/queries/CMakeLists.txt b/examples/queries/CMakeLists.txt deleted file mode 100644 index 2b08837..0000000 --- a/examples/queries/CMakeLists.txt +++ /dev/null @@ -1,28 +0,0 @@ -cmake_minimum_required(VERSION 3.16) -project(query CXX) - -set(CMAKE_CXX_STANDARD 20) -set(CMAKE_CXX_STANDARD_REQUIRED ON) - -add_executable(query - queries.cpp -) - -# Add include directories (e.g., your headers) -target_include_directories(query - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/../include -) - -target_link_directories(query - PRIVATE -) - -target_link_libraries(query - PRIVATE - pq - pqxx - strata -) - -target_compile_options(query PRIVATE -Wall -Wextra -Wpedantic) diff --git a/examples/queries/queries.cpp b/examples/queries/queries.cpp deleted file mode 100644 index 3db9a0b..0000000 --- a/examples/queries/queries.cpp +++ /dev/null @@ -1,50 +0,0 @@ -#include -#include -#include "../include/models.hpp" - -int main(){ - Utils::dbenvars vars = { - {"DBUSER", "root"}, - {"DBPASS", "root"}, - {"DBNAME", "testdb"}, - {"DBHOST", "localhost"}, - {"DBPORT", "5432"} - }; - Utils::set_dbenvars(vars); - - users user {}; - //message m {}; - -// ***************************************// - //uncomment as needed -// **************************************// - - db_adapter::query::fetch_all(user, "*"); - //db_adapter::query::get(user, "username", "berna"); - /*filters filters = { - {"email", OP::CONTAINS, "gmail"}, - {"username", OP::STARTSWITH, "b"} - }; - - db_adapter::query::filter(user, "or", filters);*/ - - int records_size = user.records.size(); - std::vector my_users = db_adapter::to_instances(user); - - for (int i = 0; i < records_size; ++i) { - std::cout<< my_users[i].id<<": " - << my_users[i].username<<", " - << my_users[i].email<<", " - << my_users[i].pin<< "\n"; - } - std::cout< -#include "../include/models.hpp" - -int main(){ - Utils::dbenvars vars = { - {"DBUSER", ""}, - {"DBPASS", ""}, - {"DBNAME", ""}, - {"DBHOST", ""}, - {"DBPORT", ""} - }; - Utils::set_dbenvars(vars); - - users user {}; - - db_adapter::Update user_update {}; - Utils::filters filters = { - {"username", OP::EQ, "janedoe"} - }; - - user_update.update_column("username", "email") - .set_to("janny", "jannysimpleton@gmail.com") - .where("and", filters) - .commit(); - - db_adapter::query::get(user, "username", "'janny'"); - return 0; -} diff --git a/include/strata/custom_array.hpp b/include/strata/custom_array.hpp index 75ede51..b5ba43c 100644 --- a/include/strata/custom_array.hpp +++ b/include/strata/custom_array.hpp @@ -7,7 +7,8 @@ namespace Utils { template -struct CustomArray{ +struct CustomArray +{ std::array wrapped_array {}; std::size_t index = 0; @@ -15,23 +16,21 @@ struct CustomArray{ template requires (sizeof...(Args) <= N) && (all_convertible_to_T || all_same_as_T) - constexpr CustomArray(Args&&... args): wrapped_array {static_cast(args)...}, index(sizeof...(args)){} + constexpr CustomArray(Args&&... args) + : wrapped_array {static_cast(args)...}, index(sizeof...(args)) + {} - constexpr void push_back(T value){ + constexpr void push_back(T value) + { if(index >= N) throw std::length_error("Custom Array is up to capacity!"); wrapped_array[index++] = value; } constexpr auto begin(){ return wrapped_array.begin(); } - constexpr auto begin() const { return wrapped_array.begin(); } constexpr auto end(){ return wrapped_array.begin() + index; } - constexpr auto end() const { return wrapped_array.begin() + index; } - constexpr T& operator[](std::size_t i){ - if (i >= index) throw "CustomArray: Index out of bounds!"; - return wrapped_array[i]; - } - constexpr const T& operator[](std::size_t i) const { + constexpr T& operator[](std::size_t i) + { if (i >= index) throw "CustomArray: Index out of bounds!"; return wrapped_array[i]; } diff --git a/include/strata/datatypes.hpp b/include/strata/datatypes.hpp deleted file mode 100644 index 8fb5d52..0000000 --- a/include/strata/datatypes.hpp +++ /dev/null @@ -1,109 +0,0 @@ -#pragma once -#include -#include -#include "json.hpp" - -class FieldAttr{ -public: - std::string ctype, datatype, sql_segment; - bool primary_key, not_null, unique; - - FieldAttr(std::string ct = "null",std::string dt = "null", bool nn = false, bool uq = false, bool pk = false) - : ctype(ct), datatype(dt), primary_key(pk), not_null(nn), unique(uq) - {} - - ~FieldAttr() = default; -}; - -class IntegerField: public FieldAttr{ -public: - std::string check_condition; - int check_constraint; - - IntegerField() = default; - IntegerField(std::string datatype, bool pk = false, bool not_null = false, bool unique = false, - int check_constr = 0, std::string check_cond = ""); - - ~IntegerField() = default; -}; -void to_json(nlohmann::json& j, const IntegerField& field); -void from_json(const nlohmann::json& j, IntegerField& field); - -class DecimalField: public FieldAttr{ -public: - int max_length, decimal_places; - - DecimalField() = default; - DecimalField(std::string datatype, int max_length, int decimal_places, bool pk = false); - - ~DecimalField() = default; -}; -void to_json(nlohmann::json& j, const DecimalField& field); -void from_json(const nlohmann::json& j, DecimalField& field); - -class CharField: public FieldAttr{ -public: - int length; - - CharField() = default; - CharField(std::string datatype, int length = 0, bool not_null = false, bool unique = false, bool pk = false); - - ~CharField() = default; -}; -void to_json(nlohmann::json& j, const CharField& field); -void from_json(const nlohmann::json& j, CharField& field); - -class BoolField:public FieldAttr{ -public: - bool enable_default, default_value; - - BoolField(bool not_null = false, bool enable_default = false, bool default_value = false); - - ~BoolField() = default; -}; -void to_json(nlohmann::json& j, const BoolField& field); -void from_json(const nlohmann::json& j, BoolField& field); - -class BinaryField: public FieldAttr{ -public: - BinaryField() = default; - BinaryField(bool not_null, bool unique = false, bool pk = false); - - ~BinaryField() = default; -}; -void to_json(nlohmann::json& j, const BinaryField& field); -void from_json(const nlohmann::json& j, BinaryField& field); - -class DateTimeField:public FieldAttr{ -public: - bool enable_default; - std::string default_val; - - DateTimeField() = default; - DateTimeField(std::string datatype, bool enable_default = false, std::string default_val = "", bool pk = false); - - ~DateTimeField() = default; -}; -void to_json(nlohmann::json& j, const DateTimeField& field); -void from_json(const nlohmann::json& j, DateTimeField& field); - -class ForeignKey : public FieldAttr{ -public: - std::string col_name, sql_type, model_name, ref_col_name, on_delete, on_update; - - ForeignKey() = default; - ForeignKey(std::string cn, std::string mn, std::string rcn, std::optional pk_col_obj=std::nullopt, - std::string on_del="CASCADE", std::string on_upd="CASCADE"); - - ~ForeignKey() = default; -}; -void to_json(nlohmann::json& j, const ForeignKey& field); -void from_json(const nlohmann::json& j, ForeignKey& field); - -using DataTypeVariant = std::variant,std::shared_ptr, std::shared_ptr, - std::shared_ptr, std::shared_ptr, std::shared_ptr, - std::shared_ptr - >; - -void variant_to_json(nlohmann::json& j, const DataTypeVariant& variant); -void variant_from_json(const nlohmann::json& j, DataTypeVariant& variant); diff --git a/include/strata/db_adapters.hpp b/include/strata/db_adapters.hpp index 6ff198e..42d3028 100644 --- a/include/strata/db_adapters.hpp +++ b/include/strata/db_adapters.hpp @@ -1,12 +1,14 @@ #pragma once #include "db_config.hpp" +#include "variant_helpers.hpp" #ifdef PSQL #include "psql/alterers.hpp" #include "psql/connectors.hpp" #include "psql/converters.hpp" -#include "psql/creators.hpp" +#include "psql/create_table.hpp" +#include "psql/create_constraints.hpp" #include "psql/deleters.hpp" #include "psql/executor.hpp" #include "psql/fetcher.hpp" @@ -14,9 +16,25 @@ #include "psql/queriers.hpp" #include "psql/row_deleter.hpp" #include "psql/updater.hpp" -#include "psql/sql_generators.hpp" #include "psql/create_model_header.hpp" -namespace db_adapter = psql; +namespace db = psql; + +#elif defined(MARIADB) + +#include "mariadb/alterers.hpp" +#include "mariadb/connectors.hpp" +#include "mariadb/converters.hpp" +#include "mariadb/create_table.hpp" +#include "mariadb/create_constraints.hpp" +#include "mariadb/deleters.hpp" +#include "mariadb/executor.hpp" +#include "mariadb/fetcher.hpp" +#include "mariadb/inserters.hpp" +#include "mariadb/queriers.hpp" +#include "mariadb/row_deleter.hpp" +#include "mariadb/updater.hpp" +#include "mariadb/create_model_header.hpp" +namespace db = mariadb; #else #error "No valid db_engine specified" diff --git a/include/strata/db_config.hpp b/include/strata/db_config.hpp deleted file mode 100644 index c5de2b1..0000000 --- a/include/strata/db_config.hpp +++ /dev/null @@ -1,2 +0,0 @@ -#pragma once -#define PSQL diff --git a/include/strata/field_base.hpp b/include/strata/field_base.hpp new file mode 100644 index 0000000..a5b1f1f --- /dev/null +++ b/include/strata/field_base.hpp @@ -0,0 +1,20 @@ +#include "json.hpp" +#include + +class FieldAttr { +public: + std::string ctype, datatype, sql_segment; + bool primary_key, not_null, unique; + + FieldAttr(std::string ctype = "null", std::string datatype = "null", bool not_null = false, bool unique = false, bool pk = false) + : ctype(ctype), datatype(datatype), primary_key(pk), not_null(not_null), unique(unique) + {} + + virtual void gen_sql() = 0; + virtual void from_json(const nlohmann::json&) = 0; + virtual void to_json(nlohmann::json&) const = 0; + virtual void track(std::string new_model_name, const std::string col_name, FieldAttr& old_col_obj, + std::ofstream& Migrations, const nlohmann::json& frm = {}) const = 0; + + virtual ~FieldAttr() = default; +}; diff --git a/include/strata/mariadb/alterers.hpp b/include/strata/mariadb/alterers.hpp new file mode 100644 index 0000000..2c10744 --- /dev/null +++ b/include/strata/mariadb/alterers.hpp @@ -0,0 +1,29 @@ +#include "../db_config.hpp" + +#ifdef MARIADB +#include +#include + +namespace mariadb +{ + +void alter_rename_table(const std::string& old_model_name, const std::string& new_model_name, std::ofstream& Migrations); + +void alter_add_column(const std::string& model_name, const std::string& column_name, + const std::string& column_sql_attributes, std::ofstream& Migrations); + +void alter_rename_column(const std::string& model_name, const std::string& old_column_name, + const std::string& new_column_name, std::ofstream& Migrations); + +void alter_column_type(const std::string& model_name, const std::string& column_name, + const std::string& sql_segment, std::ofstream& Migrations); + +void alter_column_defaultval(const std::string& model_name, const std::string& column_name, + const bool set_default, const std::string& defaultval, std::ofstream& Migrations); + +void alter_column_nullable(const std::string& model_name, const std::string& column_name, const bool nullable, std::ofstream& Migrations); + +} // INFO: namespace mariadb +namespace db = mariadb; + +#endif diff --git a/include/strata/mariadb/connectors.hpp b/include/strata/mariadb/connectors.hpp new file mode 100644 index 0000000..30e3cb8 --- /dev/null +++ b/include/strata/mariadb/connectors.hpp @@ -0,0 +1,35 @@ +#pragma once +#include "../db_config.hpp" + +#ifdef MARIADB + +#include +#include +#include "../utils.hpp" + +namespace mariadb +{ +inline mcxx::Connection connect () +{ + try + { + Utils::db_params params = Utils::parse_dbenvars(); + mcxx::Properties props {}; + props.db_name = params.db_name; + props.host = params.host; + props.passwd = params.passwd; + props.port = params.port; + props.user = params.user; + mcxx::Connection cxn {props}; + return cxn; + } catch (std::exception& e) + { + throw std::runtime_error (e.what()); + } catch (...) { + throw std::runtime_error("Failed to connect to the database!"); + } +} + +} +namespace db = mariadb; +#endif diff --git a/include/strata/mariadb/converters.hpp b/include/strata/mariadb/converters.hpp new file mode 100644 index 0000000..40ef864 --- /dev/null +++ b/include/strata/mariadb/converters.hpp @@ -0,0 +1,50 @@ +#pragma once +#include "../db_config.hpp" + +#ifdef MARIADB + +#include +namespace mariadb +{ + +template +Model_T to_instance(mcxx::Row& row) +{ + using tuple_T = decltype(std::declval().get_attr()); + return Model_T(row.template as_tuple()); +} + +template +std::vector to_instances(Model_T& obj) +{ + using tuple_T = decltype(obj.get_attr()); + std::vector instances {}; + instances.reserve(obj.records.size()); + + for(const mcxx::Row& row : obj.records) + instances.push_back(Model_T(row.template as_tuple())); + return instances; +} + +template +decltype(std::declval().get_attr()) to_tuple(mcxx::Row& row) +{ + using tuple_T = decltype(std::declval().get_attr()); + return row.template as_tuple(); +} + +template +std::vector().get_attr())> to_tuples(Model_T& obj) +{ + using tuple_T = decltype(obj.get_attr()); + std::vector values {}; + values.reserve(obj.records.size()); + + for(const mcxx::Row& row : obj.records) + values.push_back(row.template as_tuple()); + return values; +} + +} +namespace db = mariadb; +#endif diff --git a/include/strata/mariadb/create_constraints.hpp b/include/strata/mariadb/create_constraints.hpp new file mode 100644 index 0000000..383a0c1 --- /dev/null +++ b/include/strata/mariadb/create_constraints.hpp @@ -0,0 +1,44 @@ +#pragma once +#include "../db_config.hpp" + +#ifdef MARIADB + +#include +#include + +namespace mariadb +{ + +inline void create_pk_constraint(const std::string& model_name, const std::vector& pk_cols, std::ofstream& Migrations) +{ + std::string pk_seg = "CONSTRAINT pk_" + model_name + " PRIMARY KEY (" + model_name + "_id)"; + if (!pk_cols.empty()) { + pk_seg.replace(pk_seg.length() - 1, 1, ","); + for(const auto& col : pk_cols) { + pk_seg += col + ","; + } + pk_seg.replace(pk_seg.length() - 1, 1, ")"); + } + Migrations< +FieldAttr& as_ref(T& obj) +{ + if constexpr (std::is_same_v>) { + return *obj; + } else { + return obj; + } +} + +template +void create_table(const std::string& model_name, Col_Map& field_map, std::ofstream& Migrations) +{ + std::vector primary_key_cols; + std::vector unique_constraint_cols; + + Migrations<< "CREATE TABLE IF NOT EXISTS " + model_name + " (\n " + model_name + "_id SERIAL,\n "; + + for(auto& [col, dtv_obj] : field_map){ + std::visit([&](auto& visited_col){ + using col_T = decltype(visited_col); +//TODO: create some sort of stash map for fk fields, then create them at the end of table. + if constexpr(std::is_same_v){ + Migrations << " "; + create_column(col, visited_col.sql_type, Migrations); + Migrations << ",\n "; + create_fk_constraint(visited_col.sql_segment, col, Migrations); + Migrations << ",\n "; + return; + } + + FieldAttr& col_obj = as_ref(visited_col); + if(col_obj.primary_key){ + primary_key_cols.push_back(col); + } + + if(col_obj.unique){ + unique_constraint_cols.push_back(col); + } + + create_column(col, col_obj.sql_segment, Migrations); + Migrations << ",\n "; + }, dtv_obj); + } + + for(std::string& col: unique_constraint_cols){ + Migrations << " "; + create_uq_constraint(col, Migrations); + } + + Migrations << " "; + create_pk_constraint(model_name, primary_key_cols, Migrations); + Migrations<< "\n);\n\n"; +} + +} + +namespace db = mariadb; +#endif diff --git a/include/strata/mariadb/datatypes.hpp b/include/strata/mariadb/datatypes.hpp new file mode 100644 index 0000000..e7a9155 --- /dev/null +++ b/include/strata/mariadb/datatypes.hpp @@ -0,0 +1,537 @@ +#pragma once +#include "../db_config.hpp" + +#ifdef MARIADB + +#include +#include +#include +#include +#include +#include "../json.hpp" +#include "../field_base.hpp" +#include "../utils.hpp" +#include "alterers.hpp" +#include "create_constraints.hpp" +#include "deleters.hpp" + +namespace mariadb::Field { + +class IntegerField: public FieldAttr{ +public: + std::string check_condition; + int check_constraint; + + IntegerField() = default; + IntegerField(std::string datatype, bool pk=false, bool not_null=false, bool unique=false, + int check_constr=0, std::string check_cond="") + :FieldAttr("int", datatype, not_null, unique, pk), check_condition(check_cond), check_constraint(check_constr) + { + gen_sql(); + } + + void gen_sql() override + { + datatype = Utils::str_to_upper(datatype); + if(datatype != "TINYINT" && + datatype != "SMALLINT" && + datatype != "INT" && + datatype != "MEDIUMINT" && + datatype != "BIGINT") + { + throw std::runtime_error(std::format("Datatype '{}' is not supported by MariaDB.", datatype)); + } + sql_segment += datatype; + if (not_null) sql_segment += " NOT NULL"; + } + + void to_json (nlohmann::json& j) const override + { + j = nlohmann::json{ + {"datatype", datatype}, + {"not_null", not_null}, + {"unique", unique}, + {"primary_key", primary_key}, + {"check_constraint", check_constraint}, + {"check_condition", check_condition} + }; + } + + void from_json (const nlohmann::json& j) override + { + datatype = j.at("datatype").get(); + not_null = j.at("not_null").get(); + unique = j.at("unique").get(); + primary_key = j.at("primary_key").get(); + check_constraint = j.at("check_constraint").get(); + check_condition = j.at("check_condition").get(); + gen_sql(); + } + + void track (std::string new_model_name, const std::string col_name, FieldAttr& old_col_obj, + std::ofstream& Migrations, const nlohmann::json& frm = {}) const override + { + try { + IntegerField int_obj = dynamic_cast(old_col_obj); + if(int_obj.datatype != datatype) + { + db::alter_column_type(new_model_name, col_name, datatype, Migrations); + } + //if((int_obj.check_condition != check_condition) && check_condition != ""){ + // string check = "CHECK(" + col_name + check_condition + std::to_string(check_constraint) + ")"; + // Migrations << "ALTER TABLE " + new_it->first + " ALTER COLUMN " + alterations + ";\n"; + //} + } catch (std::bad_cast& e) { + throw std::runtime_error( + std::format("[ERROR: in IntegerField::track()] {}", e.what()) + ); + } + } + + ~IntegerField () = default; +}; + +class DecimalField: public FieldAttr{ +public: + int max_length, decimal_places; + + DecimalField() = default; + DecimalField(std::string datatype, int max_length, int decimal_places, bool pk=false) + :FieldAttr("float", datatype, false, false, pk), max_length(max_length), decimal_places(decimal_places) + { + gen_sql(); + } + + void gen_sql() override + { + datatype = Utils::str_to_upper(datatype); + if(datatype != "DECIMAL" && + datatype != "FLOAT" && + datatype != "DOUBLE") + { + throw std::runtime_error(std::format("Datatype '{}' is not supported by MariaDB.", datatype)); + } + if(datatype == "FLOAT" || datatype == "DOUBLE") + { + sql_segment = datatype; + return; + } + + if(max_length > 0 || decimal_places > 0) + sql_segment = datatype + "(" + std::to_string(max_length) + "," + std::to_string(decimal_places) + ")"; + else + throw std::runtime_error(std::format("Max length and/or decimal places cannot be 0 for datatype '{}'", datatype)); + } + + void to_json(nlohmann::json& j) const override + { + j = nlohmann::json{ + {"datatype", datatype}, + {"primary_key", primary_key}, + {"max_length", max_length}, + {"dec_places", decimal_places} + }; + } + void from_json(const nlohmann::json& j) override + { + datatype = j.at("datatype").get(); + primary_key = j.at("primary_key").get(); + max_length = j.at("max_length").get(); + decimal_places = j.at("dec_places").get(); + gen_sql(); + } + void track(std::string new_model_name, const std::string col_name, FieldAttr& old_col_obj, + std::ofstream& Migrations, const nlohmann::json& frm = {}) const override + { + try { + DecimalField init_obj = dynamic_cast(old_col_obj); + std::string alterations {}; + if(init_obj.datatype != datatype || + init_obj.max_length != max_length || + init_obj.decimal_places != decimal_places) + { + alterations = datatype + " (" + std::to_string(max_length) + "," + std::to_string(decimal_places) + ")"; + db::alter_column_type(new_model_name, col_name, alterations, Migrations); + } + } catch (std::bad_cast& e) { + throw std::runtime_error( + std::format("[ERROR: in DecimalField::track()] {}", e.what()) + ); + } + } + + ~DecimalField() = default; +}; + +class CharField: public FieldAttr{ +public: + int length; + + CharField() = default; + CharField(std::string datatype, int length=0, bool not_null=false, bool unique=false, bool pk=false) + :FieldAttr("std::string", datatype, not_null, unique, pk), length(length) + { + gen_sql(); + } + + void gen_sql () override + { + datatype = Utils::str_to_upper(datatype); + if(datatype != "VARCHAR" && + datatype != "CHAR" && + datatype != "TINYTEXT" && + datatype != "TEXT" && + datatype != "MEDIUMTEXT" && + datatype != "LONGTEXT") + { + throw std::runtime_error(std::format("Datatype '{}' is not supported by MariaDB.", datatype)); + } + sql_segment += datatype; + if (length == 0 && (datatype != "TINYTEXT" && datatype != "TEXT" && datatype != "MEDIUMTEXT" && datatype != "LONGTEXT")) + throw std::runtime_error(std::format("Length attribute is required for datatype '{}'", datatype)); + if (datatype != "TINYTEXT" && datatype != "TEXT" && datatype != "MEDIUMTEXT" && datatype != "LONGTEXT") + sql_segment += "(" + std::to_string(length) + ")"; + if (not_null) sql_segment += " NOT NULL"; + } + + void to_json(nlohmann::json& j) const override + { + j = nlohmann::json{ + {"datatype", datatype}, + {"not_null", not_null}, + {"unique", unique}, + {"primary_key", primary_key}, + {"length", length} + }; + } + + void from_json(const nlohmann::json& j) override + { + datatype = j.at("datatype").get(); + not_null = j.at("not_null").get(); + unique = j.at("unique").get(); + primary_key = j.at("primary_key").get(); + length = j.at("length").get(); + gen_sql(); + } + + void track (std::string new_model_name, const std::string col_name, FieldAttr& old_col_obj, + std::ofstream& Migrations, const nlohmann::json& frm = {}) const override + { + try + { + CharField init_obj = dynamic_cast(old_col_obj); + std::string alterations {}; + if((init_obj.datatype != datatype) || (init_obj.length != length)) + { + alterations = "VARCHAR(" + std::to_string(length) + ")"; + db::alter_column_type(new_model_name, col_name, alterations, Migrations); + } + } catch (std::bad_cast& e) + { + throw std::runtime_error( + std::format("[ERROR: in CharField::track()] {}", e.what()) + ); + } + } + + ~CharField() = default; +}; + +class BoolField : public FieldAttr{ +public: + bool enable_default, default_value; + + BoolField(bool not_null=false, bool enable_default=false, bool default_value=false) + :FieldAttr("bool", "BOOLEAN", not_null, false, false), enable_default(enable_default), default_value(default_value) + { + gen_sql(); + } + + void gen_sql() override + { + sql_segment += datatype; + if (not_null) sql_segment += " NOT NULL"; + if (enable_default) + { + sql_segment += (default_value ? " DEFAULT TRUE": " DEFAULT FALSE"); + } + } + + void to_json(nlohmann::json& j) const override + { + j = nlohmann::json{ + {"not_null", not_null}, + {"enable_def", enable_default}, + {"default", default_value} + }; + } + + void from_json(const nlohmann::json& j) override + { + datatype = "BOOLEAN"; + not_null = j.at("not_null").get(); + enable_default = j.at("enable_def").get(); + default_value = j.at("default").get(); + gen_sql(); + } + + void track (std::string new_model_name, const std::string col_name, FieldAttr& old_col_obj, + std::ofstream& Migrations, const nlohmann::json& frm = {}) const override + { + try + { + std::string alterations {}; + BoolField& init_obj = dynamic_cast(old_col_obj); + if(init_obj.enable_default != enable_default) + { + if(enable_default){ + db::alter_column_defaultval(new_model_name, col_name, true, std::to_string(default_value), Migrations); + }else{ + alterations = col_name + " DROP DEFAULT"; + db::alter_column_defaultval(new_model_name, col_name, false, "false", Migrations); + } + } + } catch (std::bad_cast& e) + { + throw std::runtime_error( + std::format("[ERROR: in BoolField::track()] {}", e.what()) + ); + } + } + ~BoolField() = default; +}; + +class BinaryField: public FieldAttr{ +public: + std::size_t max_n {}; + BinaryField() = default; + BinaryField(std::string datatype, std::size_t size=0, bool not_null = false, bool unique=false, bool pk=false) + :FieldAttr("int", datatype, not_null, unique, pk), max_n(size) + { + gen_sql(); + } + + void gen_sql() override + { + if (datatype != "BINARY" && datatype != "VARBINARY") + { + throw std::runtime_error(std::format("Datatype '{}' doesn't exist in MariaDB::BinaryField", datatype)); + } + sql_segment = datatype; + if (datatype != "BINARY") + { + if (max_n != 0) { + sql_segment += "(" + std::to_string(max_n) + ")"; + }else { + throw std::runtime_error(std::format("Size must be provided for datatype '{}'!", datatype)); + } + } + if(not_null) sql_segment += " NOT NULL"; + } + + void to_json(nlohmann::json& j) const override + { + j = nlohmann::json{ + {"datatype",datatype}, + {"max", max_n}, + {"not_null", not_null}, + {"unique", unique}, + {"primary_key", primary_key}, + }; + } + + void from_json(const nlohmann::json& j) override + { + datatype = j.at("datatype").get(); + max_n = j.at("max").get(); + not_null = j.at("not_null").get(); + unique = j.at("unique").get(); + primary_key = j.at("primary_key").get(); + gen_sql(); + } + + void track (std::string new_model_name, const std::string col_name, FieldAttr& old_col_obj, + std::ofstream& Migrations, const nlohmann::json& frm = {}) const override + { + BinaryField& bin_obj = dynamic_cast(old_col_obj); + if (bin_obj.datatype != datatype) + { + if (datatype == "VARBINARY") { + db::alter_column_type(new_model_name, col_name, "VARBINARY("+std::to_string(max_n)+")", Migrations); + } else { + db::alter_column_type(new_model_name, col_name, "BINARY", Migrations); + } + } + if (bin_obj.max_n != max_n && datatype == "VARBINARY") { + db::alter_column_type(new_model_name, col_name, "VARBINARY("+std::to_string(max_n)+")", Migrations); + } + } + + ~BinaryField() = default; +}; + +class DateTimeField:public FieldAttr{ +public: + bool enable_default; + std::string default_val; + + DateTimeField() = default; + DateTimeField(std::string datatype, bool enable_default=false, std::string default_val="", bool pk=false) + :FieldAttr("std::string",datatype, false, false, pk), enable_default(enable_default), default_val(default_val) + { + gen_sql(); + } + + void gen_sql() override + { + datatype = Utils::str_to_upper(datatype); + if(datatype != "DATE" && datatype != "TIME" && datatype != "DATETIME" && + datatype != "TIMESTAMP" && datatype != "YEAR") + { + throw std::runtime_error(std::format("Datatype '{}' not supported in MariaDB::DateTimeField.", datatype)); + } + + sql_segment = datatype; + if(enable_default && !default_val.empty()){ + default_val = Utils::str_to_upper(default_val); + sql_segment += " DEFAULT " + default_val; + } + } + + void to_json(nlohmann::json& j) const override + { + j = nlohmann::json{ + {"datatype", datatype}, + {"primary_key", primary_key}, + {"default_value", default_val}, + {"enable_def", enable_default} + }; + } + + void from_json(const nlohmann::json& j) override + { + datatype = j.at("datatype").get(); + primary_key = j.at("primary_key").get(); + enable_default = j.at("enable_def").get(); + default_val = j.at("default_value").get(); + gen_sql(); + } + + void track (std::string new_model_name, const std::string col_name, FieldAttr& old_col_obj, + std::ofstream& Migrations, const nlohmann::json& frm = {}) const override + { + try + { + DateTimeField init_obj = dynamic_cast(old_col_obj); + if(init_obj.datatype != datatype){ + db::alter_column_type(new_model_name, col_name, datatype, Migrations); + } + if((init_obj.enable_default != enable_default) && enable_default){ + db::alter_column_defaultval(new_model_name, col_name, true, default_val, Migrations); + }else{ + db::alter_column_defaultval(new_model_name, col_name, false, default_val, Migrations); + } + } catch (std::bad_cast& e) + { + throw std::runtime_error( + std::format("[ERROR: in DateTimeField::track()] {}", e.what()) + ); + } + } + + ~DateTimeField() = default; +}; + +class ForeignKey : public FieldAttr{ +public: + std::string col_name, sql_type, model_name, ref_col_name, on_delete, on_update; + + ForeignKey() = default; + ForeignKey(std::string cn, std::string mn, std::string rcn, std::string ctype="int", + std::string rsql="INTEGER NOT NULL", std::string on_del="RESTRICT", std::string on_upd="RESTRICT") + :FieldAttr(ctype, "FOREIGN KEY", false, false, false), + col_name(cn), model_name(mn), ref_col_name(rcn), on_delete(on_del), on_update(on_upd) + { + sql_type = rsql; + gen_sql(); + } + + void gen_sql() override + { + sql_segment ="FOREIGN KEY(" + col_name + ") REFERENCES " + model_name + " (" + ref_col_name + ")"; + on_delete = Utils::str_to_upper(on_delete); + sql_segment += " ON DELETE " + on_delete; + on_update = Utils::str_to_upper(on_update); + sql_segment += " ON UPDATE " + on_update; + } + + void to_json(nlohmann::json& j) const override + { + j = nlohmann::json{ + {"column_name", col_name}, + {"model_name", model_name}, + {"referenced_column_name", ref_col_name}, + {"ctype", ctype}, + {"sql_type", sql_type}, + {"on_delete", on_delete}, + {"on_update", on_update}, + }; + } + + void from_json(const nlohmann::json& j) override + { + datatype = "FOREIGN KEY"; + col_name = j.at("column_name").get(); + model_name = j.at("model_name").get(); + ref_col_name = j.at("referenced_column_name").get(); + ctype = j.at("ctype").get(); + sql_type = j.at("sql_type").get(); + on_delete = j.at("on_delete").get(); + on_update = j.at("on_update").get(); + gen_sql(); + } + + void track (std::string new_model_name, const std::string col_name, FieldAttr& old_col_obj, + std::ofstream& Migrations, const nlohmann::json& frm) const override + { + try + { + std::string constraint_name = "fk_"; + for(auto& [model_name, col_renames]: frm.items()){ + if(model_name == new_model_name){ + for(auto& [old_cn, new_cn] : col_renames.items()){ + if(col_name == new_cn.get()){ + constraint_name = constraint_name + "_" + old_cn; + }else{ + constraint_name = constraint_name + "_" + col_name; + } + } + }else{ + constraint_name = constraint_name + "_" + col_name; + } + } + db::drop_constraint(new_model_name, constraint_name, Migrations); + Migrations<<"ALTER TABLE " + new_model_name + " "; + db::create_fk_constraint(sql_segment, col_name, Migrations); + } catch (std::bad_cast& e) + { + throw std::runtime_error( + std::format("[ERROR: in ForeignKey::track()] {}", e.what()) + ); + } + } + ~ForeignKey() = default; +}; + +}// INFO: namespace mariadb::Field +namespace mariadb { + +using DataTypeVariant = std::variant + >; +} //INFO: namespace mariadb + +namespace db = mariadb; +#endif diff --git a/include/strata/mariadb/deleters.hpp b/include/strata/mariadb/deleters.hpp new file mode 100644 index 0000000..6ff84d9 --- /dev/null +++ b/include/strata/mariadb/deleters.hpp @@ -0,0 +1,18 @@ +#include "../db_config.hpp" + +#ifdef MARIADB +#include +#include + +namespace mariadb { + +void drop_table(const std::string& model_name, std::ofstream& Migrations); + +void drop_column(const std::string& model_name, const std::string& column_name, std::ofstream& Migrations); + +void drop_constraint(const std::string& model_name, const std::string& constraint_name, std::ofstream& Migrations); + +}// INFO: namespace mariadb + +namespace db = mariadb; +#endif diff --git a/include/strata/mariadb/executor.hpp b/include/strata/mariadb/executor.hpp new file mode 100644 index 0000000..6c54076 --- /dev/null +++ b/include/strata/mariadb/executor.hpp @@ -0,0 +1,40 @@ +#pragma once +#include "../db_config.hpp" + +#ifdef MARIADB + +#include +#include "connectors.hpp" +#include + +namespace mariadb +{ +using opt_result_t = std::optional; +inline opt_result_t execute_sql(std::string& sql_file_or_str, bool is_file_name = true) +{ + std::ostringstream raw_sql {}; + + if(is_file_name){ + std::ifstream sql_file(sql_file_or_str); + if(!sql_file.is_open()) + throw std::runtime_error("Couldn't open the sql file to which the path is provided."); + raw_sql << sql_file.rdbuf(); + }else{ + raw_sql << sql_file_or_str; + } + + try{ + mcxx::Connection cxn = connect(); + mcxx::Transaction txn {cxn}; + return txn.exec(raw_sql.str()); + }catch (std::exception& e) + { + throw std::runtime_error(e.what()); + }catch (...) + { + throw std::runtime_error("SQL Execution failed!"); + } +} +} +namespace db = mariadb; +#endif diff --git a/include/strata/mariadb/fetcher.hpp b/include/strata/mariadb/fetcher.hpp new file mode 100644 index 0000000..7650302 --- /dev/null +++ b/include/strata/mariadb/fetcher.hpp @@ -0,0 +1,38 @@ +#pragma once +#include "../db_config.hpp" + +#ifdef MARIADB + +#include "connectors.hpp" +#include + +namespace mariadb +{ + +template +void dbfetch(Model_T& obj, std::string& sql_string, bool getfn_called = false) +{ + try + { + mcxx::Connection cxn = connect(); + mcxx::Transaction txn {cxn}; + + if(getfn_called) + { + mcxx::Row row = txn.exec1(sql_string); + obj.records.push_back(row); + return; + } + std::optional result = txn.exec(sql_string); + if (!result.has_value()) return; + for (mcxx::Row& row : result.value()) obj.records.push_back(row); + + } catch (const std::exception& e) + { + throw std::runtime_error(e.what()); + } +} + +}//namespace mariadb +namespace db = mariadb; +#endif diff --git a/include/strata/mariadb/inserters.hpp b/include/strata/mariadb/inserters.hpp new file mode 100644 index 0000000..02d5893 --- /dev/null +++ b/include/strata/mariadb/inserters.hpp @@ -0,0 +1,41 @@ +#pragma once +#include "../db_config.hpp" +#include +#include + +#ifdef MARIADB + +#include "connectors.hpp" +#include + +namespace mariadb +{ +template +mcxx::prepped_stmt prepare_insert() +{ + Model_T obj {}; + mcxx::Connection cxn = connect(); + std::string insert_statement = "insert into "+ obj.table_name + " (" + obj.col_str +") values("; + + for(int i=0; i +#include +#include + +namespace mariadb::query { + +template +void fetch_all(Model_T& obj, std::string columns){ + std::string sql_string {"select " + columns + " from " + obj.table_name + ";"}; + dbfetch(obj, sql_string); +} + +template +void get(Model_T& obj, Args... args){ + static_assert(sizeof...(args) > 0 || sizeof...(args)%2 == 0, + "[ERROR:'db_adapter::query::get()'] Args are provided in key-value pairs." + ); + + std::string sql_kwargs {}; + constexpr int N = sizeof...(args); + Utils::CustomArray, N/2> kwargs {}; + Utils::CustomArray parsed_args {Utils::to_str(args)...}; + + for(int i = 0; i < N; i+=2){ + sql_kwargs += parsed_args[i] + "=" + parsed_args[i+1] + " and "; + kwargs.push_back(std::make_pair(parsed_args[i], parsed_args[i+1])); + } + + sql_kwargs.replace(sql_kwargs.size()-5, 5, ";"); + + if(obj.records.empty()){ + std::string sql_str = "select * from " + obj.table_name + " where " + sql_kwargs; + dbfetch(obj, sql_str, true); + }else{ + std::vector filtered_rows {}; + for(const mcxx::Row& row : obj.records) + { + bool accept_row = true; + for(const auto& kwarg : kwargs){ + if(row[kwarg.first].template as() != kwarg.second){ + accept_row = false; + break; + } + } + if(accept_row){ + filtered_rows.push_back(row); + continue; + } + if(filtered_rows.size() == 1) break; + } + obj.records = filtered_rows; + } +} + +inline bool matches_conditions(mcxx::Field& field, OP op, Utils::Value_T v) +{ + bool accept = false; + std::any value = Utils::filter_val(v); + int int_cast = 0; + double double_cast = 0; + std::string str_cast {}; + + if(value.type() == typeid(int)) int_cast = std::any_cast(value); + else if(value.type() == typeid(double)) double_cast = std::any_cast(value); + else if(value.type() == typeid(std::string)) str_cast = std::any_cast(value); + else throw std::invalid_argument("[ERROR: 'filter().match_conditions()'] => Unsupported type passed to filters"); + + try{ + switch (op) { + case EQ: + if(int_cast) accept = (field.as() == int_cast); + else if(double_cast) accept = (field.as() == double_cast); + else if(!str_cast.empty()) accept = (field.as().find(str_cast) != std::string::npos); + else throw std::runtime_error("Unsupported type for OP::EQ"); + break; + case GT: + if(int_cast) accept = (field.as() > int_cast); + else if(double_cast) accept = (field.as() > double_cast); + else throw std::runtime_error("Unsupported type for OP::GT)"); + break; + case GTE: + if(int_cast) accept = (field.as() >= int_cast); + else if(double_cast) accept = (field.as() >= double_cast); + else throw std::runtime_error("Unsupported type for OP::GTE)"); + break; + case LT: + if(int_cast) accept = (field.as() < int_cast); + else if(double_cast) accept = (field.as() < double_cast); + else throw std::runtime_error("Unsupported type for OPERAND 'OP::LT'"); + break; + case LTE: + if(int_cast) accept = (field.as() <= int_cast); + else if(double_cast) accept = (field.as() <= double_cast); + else throw std::runtime_error("Unsupported type for OPERAND 'OP::LTE'"); + break; + case LIKE: + case ILIKE: + throw std::runtime_error("LIKE/ILIKE not implemented yet for filtering from obj.records."); + break; + case STARTSWITH: + case CONTAINS: + if(!str_cast.empty()) accept = (field.as().find(str_cast) != std::string::npos); + else throw std::runtime_error("Unsupported type for OPERAND(OP::STARTSWITH || OP::CONTAINS)"); + break; + case ENDSWITH: + if(!str_cast.empty()){ + std::string field_str = field.as(); + accept = (field.as().find(str_cast, field_str.size() - str_cast.size()) != std::string::npos); + } else throw std::runtime_error("Unsupported type for OP::ENDSWITH"); + break; + default: + throw std::runtime_error("Unknown operator!"); + } + }catch(const std::exception& e){ + throw std::runtime_error(e.what()); + } + return accept; +} + +template +void filter(Model_T& obj, std::string logical_op, Utils::filters& filters) +{ + if(obj.records.empty()) + { + std::string sql_str = "select * from " + obj.table_name + " where " + build_filter_args(logical_op, filters) + ";"; + dbfetch(obj, sql_str); + }else{ + std::vector filtered_rows {}; + if (logical_op == "and") + { + for(mcxx::Row& row : obj.records) + { + bool accept_row = true; + for(Utils::Condition& filter: filters) + { + if(!matches_conditions(row[filter.column], filter.op, filter.value)) + { + accept_row = false; + break; + } + } + if(accept_row) + { + filtered_rows.push_back(row); + continue; + } + } + }else if(logical_op == "or"){ + for(mcxx::Row& row : obj.records) + { + bool accept_row = false; + for(Utils::Condition& filter: filters) + { + if(matches_conditions(row[filter.column], filter.op, filter.value)){ + accept_row = true; + break; + } + } + if(accept_row) + { + filtered_rows.push_back(row); + continue; + } + } + }else { + throw std::runtime_error("Unknown logical operator for filter fn"); + } + + if(filtered_rows.size() <= obj.records.size()) obj.records = filtered_rows; + else throw std::runtime_error("Filtered rows are more than the actual initial rows!"); + } +} + +class JoinBuilder +{ + std::string query_str {}, table_name {}; + bool join_pending = true; +public: + template + JoinBuilder(T& model): table_name(model.table_name) {} + + template + requires all_same_as_T || all_convertible_to_T + JoinBuilder& select(Args&&... columns){ + query_str = "select " + ((to_str(columns) + ",") + ...); + query_str.pop_back(); + query_str += " from " + table_name; + return *this; + } + + JoinBuilder& inner_join(std::string join_table){ + query_str += " inner join " + join_table; + if(!join_pending) + throw std::runtime_error("You have not implemented .on() yet for the previous join!"); + join_pending = false; + return *this; + } + JoinBuilder& outer_join(std::string join_table){ + query_str += " outer join " + join_table; + if(!join_pending) + throw std::runtime_error("You have not implemented .on() yet for the previous join!"); + join_pending = false; + return *this; + } + JoinBuilder& full_join(std::string join_table){ + query_str += " full join " + join_table; + if(!join_pending) + throw std::runtime_error("You have not implemented .on() yet for the previous join!"); + join_pending = false; + return *this; + } + JoinBuilder& left_join(std::string join_table){ + query_str += " left join " + join_table; + if(!join_pending) + throw std::runtime_error("You have not implemented .on() yet for the previous join!"); + join_pending = false; + return *this; + } + JoinBuilder& right_join(std::string join_table){ + query_str += " right join " + join_table; + if(!join_pending) + throw std::runtime_error("You have not implemented .on() yet for the previous join!"); + join_pending = false; + return *this; + } + + template + requires all_same_as_T || all_convertible_to_T + JoinBuilder& on(std::string logical_op, Args&&... conditions){ + if(join_pending) throw std::runtime_error("Join pending"); + if(logical_op != "and" && logical_op != "or") + throw std::runtime_error(std::format("Unknown logical operator: {}", logical_op)); + query_str += " on " + ((to_str(conditions) + " " + logical_op + " ") + ...); + query_str.resize(query_str.size() - (logical_op.size() + 2)); + join_pending = true; + return *this; + } + + mcxx::Result execute() + { + try{ + mcxx::Connection cxn = connect(); + mcxx::Transaction txn {cxn}; + std::optional join_results = txn.exec(query_str + ";"); + txn.commit(); + return join_results.value_or({}); + }catch(const std::exception& e){ + throw std::runtime_error(e.what()); + } + } + + std::string str(){ return query_str + ";"; } +}; + +}//INFO: namespace mariadb::query +namespace db = mariadb; +#endif diff --git a/include/strata/mariadb/row_deleter.hpp b/include/strata/mariadb/row_deleter.hpp new file mode 100644 index 0000000..50678eb --- /dev/null +++ b/include/strata/mariadb/row_deleter.hpp @@ -0,0 +1,19 @@ +#pragma once +#include "../db_config.hpp" + +#ifdef MARIADB +#include "executor.hpp" + +namespace mariadb { + +template +void delete_row(std::string logical_op, Utils::filters& filters) +{ + Model_T obj {}; + std::string sql_str {"delete from " + obj.table_name + " where " + Utils::build_filter_args(logical_op, filters) + ";"}; + opt_result_t result = execute_sql(sql_str, false); +} + +} //INFO: namespace mariadb +namespace db = mariadb; +#endif diff --git a/include/strata/mariadb/updater.hpp b/include/strata/mariadb/updater.hpp new file mode 100644 index 0000000..5506207 --- /dev/null +++ b/include/strata/mariadb/updater.hpp @@ -0,0 +1,70 @@ +#pragma once +#include "../db_config.hpp" + +#ifdef MARIADB + +#include "../concepts.hpp" +#include "connectors.hpp" +#include +#include + + +namespace mariadb +{ + +template +struct Update +{ + Model_T obj {}; + std::string query {"update "+ obj.table_name + " set "}; + std::size_t arg_count {0}; + mcxx::params param_list {}; + bool update_col_called = false; + + template + requires (sizeof...(Args) > 0) && (all_convertible_to_T || all_same_as_T) + Update& update_column(Args&&... args){ + ((query += std::string{std::forward(args)} + "=?,"), ...); + query.pop_back(); + arg_count += sizeof...(args); + update_col_called = true; + return *this; + } + + template + requires (sizeof...(Args) > 0) && (all_convertible_to_T || all_same_as_T) + Update& set_to(Args... args){ + if(!update_col_called) throw std::logic_error(".column() must be called first to set the column to be updated!"); + (param_list.append(Utils::to_str(args)), ...); + update_col_called = false; + return *this; + } + + Update& where (std::string logical_op, Utils::filters& filters){ + query.append(" where " + Utils::build_filter_args(logical_op, filters)); + return *this; + } + + void commit(){ + query.append(";"); + if(param_list.size() <= 0 || param_list.size() != arg_count) + throw std::length_error( + "Unable to commit transaction, number of parameter values doesn't match number of columns provided!" + ); + try + { + mcxx::Connection cxn = connect(); + cxn.prepare("update_stmt", query); + mcxx::prepped_stmt prepped {cxn.prepped("update_stmt")}; + mcxx::Transaction txn {cxn}; + txn.exec0(prepped, param_list); + }catch(const std::exception& e) + { + throw std::runtime_error(e.what()); + } + } +}; + +} +namespace db = mariadb; +#endif diff --git a/include/strata/models.hpp b/include/strata/models.hpp index df3213f..dd25d4e 100644 --- a/include/strata/models.hpp +++ b/include/strata/models.hpp @@ -1,12 +1,22 @@ #pragma once #include #include -#include "./datatypes.hpp" +#include "./db_config.hpp" -using fields = std::unordered_map; +#ifdef PSQL +#include "psql/datatypes.hpp" +#elif defined(MARIADB) +#include "mariadb/datatypes.hpp" +#else +#error "No database adapter specified" +#endif + + +using fields = std::unordered_map; using ms_map = std::unordered_map; -class Model{ +class Model +{ public: fields col_map; ms_map init_ms; @@ -20,19 +30,23 @@ class Model{ ~Model() = default; }; -class ModelFactory{ +class ModelFactory +{ public: using Creator = std::function()>; - static std::unordered_map& registry(){ + static std::unordered_map& registry() + { static std::unordered_map registry_map; return registry_map; } - static void register_model(const std::string& model_name, Creator creator){ + static void register_model(const std::string& model_name, Creator creator) + { registry()[model_name] = std::move(creator); } - static std::unique_ptr create_model_instance(const std::string& model_name){ + static std::unique_ptr create_model_instance(const std::string& model_name) + { auto it = registry().find(model_name); if(it != registry().end()){ return it->second(); diff --git a/include/strata/psql/alterers.hpp b/include/strata/psql/alterers.hpp index 5b50d18..bb66ae7 100644 --- a/include/strata/psql/alterers.hpp +++ b/include/strata/psql/alterers.hpp @@ -1,9 +1,10 @@ #pragma once #include "../db_config.hpp" + +#ifdef PSQL #include #include -#ifdef PSQL namespace psql { void alter_rename_table(const std::string& old_model_name, const std::string& new_model_name, std::ofstream& Migrations); @@ -22,5 +23,5 @@ void alter_column_defaultval(const std::string& model_name, const std::string& c void alter_column_nullable(const std::string& model_name, const std::string& column_name, const bool nullable, std::ofstream& Migrations); } -namespace db_adapter = psql; +namespace db = psql; #endif diff --git a/include/strata/psql/connectors.hpp b/include/strata/psql/connectors.hpp index 4e3e727..96a70c6 100644 --- a/include/strata/psql/connectors.hpp +++ b/include/strata/psql/connectors.hpp @@ -1,11 +1,12 @@ #pragma once #include "../db_config.hpp" -#include "utils.hpp" -#include #ifdef PSQL -#include -#include + +#include "../utils.hpp" +#include +#include + namespace psql { inline pqxx::connection connect(){ @@ -24,5 +25,5 @@ inline pqxx::connection connect(){ } } -namespace db_adapter = psql; +namespace db = psql; #endif diff --git a/include/strata/psql/converters.hpp b/include/strata/psql/converters.hpp index d22f42c..32c4e0f 100644 --- a/include/strata/psql/converters.hpp +++ b/include/strata/psql/converters.hpp @@ -2,6 +2,7 @@ #include "../db_config.hpp" #ifdef PSQL + #include namespace psql { @@ -42,5 +43,5 @@ std::vector().get_attr())> to_tuples(Model_T& obj } } -namespace db_adapter = psql; +namespace db = psql; #endif diff --git a/include/strata/psql/create_constraints.hpp b/include/strata/psql/create_constraints.hpp new file mode 100644 index 0000000..6356ba1 --- /dev/null +++ b/include/strata/psql/create_constraints.hpp @@ -0,0 +1,42 @@ +#pragma once +#include "../db_config.hpp" + +#ifdef PSQL + +#include +#include +#include + +namespace psql { + +inline void create_pk_constraint(const std::string& model_name, const std::vector& pk_cols, std::ofstream& Migrations){ + std::string pk_seg = "CONSTRAINT pk_" + model_name + " PRIMARY KEY (" + model_name + "_id)"; + if (!pk_cols.empty()) { + pk_seg.replace(pk_seg.length() - 1, 1, ","); + for(const auto& col : pk_cols) { + pk_seg += col + ","; + } + pk_seg.replace(pk_seg.length() - 1, 1, ")"); + } + Migrations< +FieldAttr& as_ref(T& obj) { + if constexpr (std::is_same_v>) { + return *obj; + } else { + return obj; + } +} + +template +void create_table(const std::string& model_name, Col_Map& field_map, std::ofstream& Migrations){ + std::vector primary_key_cols; + std::vector unique_constraint_cols; + + Migrations<< "CREATE TABLE IF NOT EXISTS " + model_name + " (\n " + model_name + "_id SERIAL NOT NULL,\n "; + + for(auto& [col, dtv_obj] : field_map){ + std::visit([&](auto& visited_col){ + using col_T = decltype(visited_col); + if constexpr(std::is_same_v){ + Migrations << " "; + create_column(col, visited_col.sql_type, Migrations); + Migrations << ",\n "; + create_fk_constraint(visited_col.sql_segment, col, Migrations); + Migrations << ",\n "; + return; + } + + FieldAttr& col_obj = as_ref(visited_col); + if(col_obj.primary_key){ + primary_key_cols.push_back(col); + } + + if(col_obj.unique){ + unique_constraint_cols.push_back(col); + } + + create_column(col, col_obj.sql_segment, Migrations); + Migrations << ",\n "; + }, dtv_obj); + } + + for(std::string& col: unique_constraint_cols){ + Migrations << " "; + create_uq_constraint(col, Migrations); + } + + Migrations << " "; + create_pk_constraint(model_name, primary_key_cols, Migrations); + Migrations<< "\n);\n\n"; +} + +} + +namespace db = psql; +#endif diff --git a/include/strata/psql/creators.hpp b/include/strata/psql/creators.hpp deleted file mode 100644 index 683e242..0000000 --- a/include/strata/psql/creators.hpp +++ /dev/null @@ -1,83 +0,0 @@ -#pragma once -#include -#include -#include -#include -#include "../datatypes.hpp" -#include "../db_config.hpp" - -#ifdef PSQL - -namespace psql { - -inline void create_pk_constraint(const std::string& model_name, const std::vector& pk_cols, std::ofstream& Migrations){ - std::string pk_seg = "CONSTRAINT pk_" + model_name + " PRIMARY KEY (" + model_name + "_id)"; - if (!pk_cols.empty()) { - pk_seg.replace(pk_seg.length() - 1, 1, ","); - for(const auto& col : pk_cols) { - pk_seg += col + ","; - } - pk_seg.replace(pk_seg.length() - 1, 1, ")"); - } - Migrations< -void create_table(const std::string& model_name, Col_Map& field_map, std::ofstream& Migrations){ - std::vector primary_key_cols; - std::vector unique_constraint_cols; - - Migrations<< "CREATE TABLE IF NOT EXISTS " + model_name + " (\n " + model_name + "_id SERIAL NOT NULL,\n "; - - for(auto& [col, dtv_obj] : field_map){ - std::visit([&](auto& col_obj){ - if constexpr(std::is_same_v, std::shared_ptr>){ - Migrations << " "; - create_column(col, col_obj->sql_type, Migrations); - Migrations << ",\n "; - create_fk_constraint(col_obj->sql_segment, col, Migrations); - Migrations << ",\n "; - return; - } - if(col_obj->primary_key){ - primary_key_cols.push_back(col); - } - - if(col_obj->unique){ - unique_constraint_cols.push_back(col); - } - - create_column(col, col_obj->sql_segment, Migrations); - Migrations << ",\n "; - }, dtv_obj); - } - - for(std::string& col: unique_constraint_cols){ - Migrations << " "; - create_uq_constraint(col, Migrations); - } - - Migrations << " "; - create_pk_constraint(model_name, primary_key_cols, Migrations); - Migrations<< "\n);\n\n"; -} - -} - -namespace db_adapter = psql; -#endif diff --git a/include/strata/psql/datatypes.hpp b/include/strata/psql/datatypes.hpp new file mode 100644 index 0000000..6624ce4 --- /dev/null +++ b/include/strata/psql/datatypes.hpp @@ -0,0 +1,497 @@ +#pragma once +#include "../db_config.hpp" + +#ifdef PSQL + +#include +#include +#include +#include +#include +#include "../json.hpp" +#include "../field_base.hpp" +#include "../utils.hpp" +#include "alterers.hpp" +#include "create_constraints.hpp" +#include "deleters.hpp" + +namespace psql::Field { + +class IntegerField: public FieldAttr{ +public: + std::string check_condition; + int check_constraint; + + IntegerField() = default; + IntegerField(std::string datatype, bool pk=false, bool not_null=false, bool unique=false, + int check_constr=0, std::string check_cond="") + :FieldAttr("int", datatype, not_null, unique, pk), check_condition(check_cond), check_constraint(check_constr) + { + gen_sql(); + } + + void gen_sql() override + { + datatype = Utils::str_to_upper(datatype); + if(datatype != "INTEGER" && datatype != "SMALLINT" && datatype != "BIGINT"){ + throw std::runtime_error(std::format("Datatype '{}' is not supported by postgreSQL.", datatype)); + } + sql_segment += datatype; + if (not_null) sql_segment += " NOT NULL"; + + } + + void to_json (nlohmann::json& j) const override + { + j = nlohmann::json{ + {"datatype", datatype}, + {"not_null", not_null}, + {"unique", unique}, + {"primary_key", primary_key}, + {"check_constraint", check_constraint}, + {"check_condition", check_condition} + }; + } + + void from_json (const nlohmann::json& j) override + { + datatype = j.at("datatype").get(); + not_null = j.at("not_null").get(); + unique = j.at("unique").get(); + primary_key = j.at("primary_key").get(); + check_constraint = j.at("check_constraint").get(); + check_condition = j.at("check_condition").get(); + gen_sql(); + } + + void track (std::string new_model_name, const std::string col_name, FieldAttr& old_col_obj, + std::ofstream& Migrations, const nlohmann::json& frm = {}) const override + { + try { + IntegerField int_obj = dynamic_cast(old_col_obj); + if(int_obj.datatype != datatype) + { + db::alter_column_type(new_model_name, col_name, datatype, Migrations); + } + //if((int_obj.check_condition != check_condition) && check_condition != ""){ + // string check = "CHECK(" + col_name + check_condition + std::to_string(check_constraint) + ")"; + // Migrations << "ALTER TABLE " + new_it->first + " ALTER COLUMN " + alterations + ";\n"; + //} + } catch (std::bad_cast& e) { + throw std::runtime_error(e.what()); + } + } + + ~IntegerField () = default; +}; + +class DecimalField: public FieldAttr{ +public: + int max_length, decimal_places; + + DecimalField() = default; + DecimalField(std::string datatype, int max_length, int decimal_places, bool pk=false) + :FieldAttr("float", datatype, false, false, pk), max_length(max_length), decimal_places(decimal_places) + { + gen_sql(); + } + + void gen_sql() override + { + datatype = Utils::str_to_upper(datatype); + if(datatype != "DECIMAL" && datatype != "REAL" && + datatype != "DOUBLE PRECISION" && datatype != "NUMERIC") + { + throw std::runtime_error(std::format("Datatype '{}' is not supported by postgreSQL. Provide a valid datatype", datatype)); + } + if(datatype == "REAL" || datatype == "DOUBLE PRECISION") + { + sql_segment = datatype; + return; + } + + if(max_length > 0 || decimal_places > 0) + sql_segment = datatype + "(" + std::to_string(max_length) + "," + std::to_string(decimal_places) + ")"; + else + throw std::runtime_error(std::format("Max length and/or decimal places cannot be 0 for datatype '{}'", datatype)); + } + + void to_json(nlohmann::json& j) const override + { + j = nlohmann::json{ + {"datatype", datatype}, + {"primary_key", primary_key}, + {"max_length", max_length}, + {"dec_places", decimal_places} + }; + } + void from_json(const nlohmann::json& j) override + { + datatype = j.at("datatype").get(); + primary_key = j.at("primary_key").get(); + max_length = j.at("max_length").get(); + decimal_places = j.at("dec_places").get(); + gen_sql(); + } + void track(std::string new_model_name, const std::string col_name, FieldAttr& old_col_obj, + std::ofstream& Migrations, const nlohmann::json& frm = {}) const override + { + try { + DecimalField init_obj = dynamic_cast(old_col_obj); + std::string alterations {}; + if(init_obj.datatype != datatype || + init_obj.max_length != max_length || + init_obj.decimal_places != decimal_places) + { + alterations = datatype + " (" + std::to_string(max_length) + "," + std::to_string(decimal_places) + ")"; + db::alter_column_type(new_model_name, col_name, alterations, Migrations); + } + } catch (std::bad_cast& e) { + throw std::runtime_error( + std::format("[ERROR: in DecimalField::track()] {}", e.what()) + ); + } + } + + ~DecimalField() = default; +}; + +class CharField: public FieldAttr{ +public: + int length; + + CharField() = default; + CharField(std::string datatype, int length=0, bool not_null=false, bool unique=false, bool pk=false) + :FieldAttr("std::string", datatype, not_null, unique, pk), length(length) + { + gen_sql(); + } + + void gen_sql () override + { + datatype = Utils::str_to_upper(datatype); + if(datatype != "VARCHAR" && datatype != "CHAR" && datatype != "TEXT") + throw std::runtime_error(std::format("Datatype '{}' is not supported by postgreSQL.", datatype)); + sql_segment += datatype; + if (length == 0 && datatype != "TEXT") + throw std::runtime_error(std::format("Length attribute is required for datatype '{}'", datatype)); + if (datatype != "TEXT") + sql_segment += "(" + std::to_string(length) + ")"; + if (not_null) sql_segment += " NOT NULL"; + } + + void to_json(nlohmann::json& j) const override + { + j = nlohmann::json{ + {"datatype", datatype}, + {"not_null", not_null}, + {"unique", unique}, + {"primary_key", primary_key}, + {"length", length} + }; + } + + void from_json(const nlohmann::json& j) override + { + datatype = j.at("datatype").get(); + not_null = j.at("not_null").get(); + unique = j.at("unique").get(); + primary_key = j.at("primary_key").get(); + length = j.at("length").get(); + gen_sql(); + } + + void track (std::string new_model_name, const std::string col_name, FieldAttr& old_col_obj, + std::ofstream& Migrations, const nlohmann::json& frm = {}) const override + { + try + { + CharField init_obj = dynamic_cast(old_col_obj); + std::string alterations {}; + if((init_obj.datatype != datatype) || (init_obj.length != length)) + { + alterations = "VARCHAR( " + std::to_string(length) + " )"; + db::alter_column_type(new_model_name, col_name, alterations, Migrations); + } + } catch (std::bad_cast& e) + { + throw std::runtime_error( + std::format("[ERROR: in CharField::track()] {}", e.what()) + ); + } + } + + ~CharField() = default; +}; + +class BoolField : public FieldAttr{ +public: + bool enable_default, default_value; + + BoolField(bool not_null=false, bool enable_default=false, bool default_value=false) + :FieldAttr("bool", "BOOLEAN", not_null, false, false), enable_default(enable_default), default_value(default_value) + { + gen_sql(); + } + + void gen_sql() override + { + sql_segment += datatype; + if (not_null) sql_segment += " NOT NULL"; + if (enable_default) + { + if(default_value) sql_segment += " DEFAULT TRUE"; + else sql_segment += " DEFAULT FALSE"; + } + } + + void to_json(nlohmann::json& j) const override + { + j = nlohmann::json{ + {"not_null", not_null}, + {"enable_def", enable_default}, + {"default", default_value} + }; + } + + void from_json(const nlohmann::json& j) override + { + datatype = "BOOLEAN"; + not_null = j.at("not_null").get(); + enable_default = j.at("enable_def").get(); + default_value = j.at("default").get(); + gen_sql(); + } + + void track (std::string new_model_name, const std::string col_name, FieldAttr& old_col_obj, + std::ofstream& Migrations, const nlohmann::json& frm = {}) const override + { + try + { + std::string alterations {}; + BoolField init_obj = dynamic_cast(old_col_obj); + if(init_obj.enable_default != enable_default){ + if(enable_default){ + db::alter_column_defaultval(new_model_name, col_name, true, std::to_string(default_value), Migrations); + }else{ + alterations = col_name + " DROP DEFAULT"; + db::alter_column_defaultval(new_model_name, col_name, false, "false", Migrations); + } + } + } catch (std::bad_cast& e) + { + throw std::runtime_error( + std::format("[ERROR: in BoolField::track()] {}", e.what()) + ); + } + } + ~BoolField() = default; +}; + +class BinaryField: public FieldAttr{ +public: + BinaryField() = default; + BinaryField(bool not_null, bool unique=false, bool pk=false) + :FieldAttr("int", "BYTEA", not_null, unique, pk) + { + gen_sql(); + } + + void gen_sql() override + { + sql_segment = datatype; + if(not_null) sql_segment += " NOT NULL"; + } + + void to_json(nlohmann::json& j) const override + { + j = nlohmann::json{ + {"not_null", not_null}, + {"unique", unique}, + {"primary_key", primary_key}, + }; + } + + void from_json(const nlohmann::json& j) override + { + datatype = "BYTEA"; + not_null = j.at("not_null").get(); + unique = j.at("unique").get(); + primary_key = j.at("primary_key").get(); + gen_sql(); + } + + void track (std::string new_model_name, const std::string col_name, FieldAttr& old_col_obj, + std::ofstream& Migrations, const nlohmann::json& frm = {}) const override + { return; } + + ~BinaryField() = default; +}; + +class DateTimeField:public FieldAttr{ +public: + bool enable_default; + std::string default_val; + + DateTimeField() = default; + DateTimeField(std::string datatype, bool enable_default=false, std::string default_val="", bool pk=false) + :FieldAttr("std::string",datatype, false, false, pk), enable_default(enable_default), default_val(default_val) + { + gen_sql(); + } + + void gen_sql() override + { + datatype = Utils::str_to_upper(datatype); + if(datatype != "DATE" && datatype != "TIME" && datatype != "TIMESTAMP_WTZ" && + datatype != "TIMESTAMP" && datatype != "TIME_WTZ" && datatype != "INTERVAL"){ + throw std::runtime_error(std::format("Datatype '{}' not supported in postgreSQL.", datatype)); + } + + std::string::size_type n = datatype.find('_'); + if(n != std::string::npos){ + datatype.replace(n+1, n+3, "WITH TIME ZONE"); + } + + sql_segment = datatype; + if(enable_default && !default_val.empty()){ + default_val = Utils::str_to_upper(default_val); + sql_segment += " DEFAULT " + default_val; + } + } + void to_json(nlohmann::json& j) const override + { + j = nlohmann::json{ + {"datatype", datatype}, + {"primary_key", primary_key}, + {"default_value", default_val}, + {"enable_def", enable_default} + }; + } + + void from_json(const nlohmann::json& j) override + { + datatype = j.at("datatype").get(); + primary_key = j.at("primary_key").get(); + enable_default = j.at("enable_def").get(); + default_val = j.at("default_value").get(); + gen_sql(); + } + + void track (std::string new_model_name, const std::string col_name, FieldAttr& old_col_obj, + std::ofstream& Migrations, const nlohmann::json& frm = {}) const override + { + try + { + DateTimeField init_obj = dynamic_cast(old_col_obj); + if(init_obj.datatype != datatype){ + db::alter_column_type(new_model_name, col_name, datatype, Migrations); + } + if((init_obj.enable_default != enable_default) && enable_default){ + db::alter_column_defaultval(new_model_name, col_name, true, default_val, Migrations); + }else{ + db::alter_column_defaultval(new_model_name, col_name, false, default_val, Migrations); + } + } catch (std::bad_cast& e) + { + throw std::runtime_error( + std::format("[ERROR: in DateTimeField::track()] {}", e.what()) + ); + } + } + + ~DateTimeField() = default; +}; + +class ForeignKey : public FieldAttr{ +public: + std::string col_name, sql_type, model_name, ref_col_name, on_delete, on_update; + + ForeignKey() = default; + ForeignKey(std::string cn, std::string mn, std::string rcn, std::string ctype="int", + std::string rsql="INTEGER NOT NULL", std::string on_del="CASCADE", std::string on_upd="CASCADE") + :FieldAttr(ctype, "FOREIGN KEY", false, false, false), + col_name(cn), model_name(mn), ref_col_name(rcn), on_delete(on_del), on_update(on_upd) + { + sql_type = rsql; + gen_sql(); + } + + void gen_sql() override + { + sql_segment ="FOREIGN KEY(" + col_name + ") REFERENCES " + model_name + " (" + ref_col_name + ")"; + on_delete = Utils::str_to_upper(on_delete); + sql_segment += " ON DELETE " + on_delete; + on_update = Utils::str_to_upper(on_update); + sql_segment += " ON UPDATE " + on_update; + } + + void to_json(nlohmann::json& j) const override + { + j = nlohmann::json{ + {"column_name", col_name}, + {"model_name", model_name}, + {"referenced_column_name", ref_col_name}, + {"ctype", ctype}, + {"sql_type", sql_type}, + {"on_delete", on_delete}, + {"on_update", on_update}, + }; + } + + void from_json(const nlohmann::json& j) override + { + datatype = "FOREIGN KEY"; + col_name = j.at("column_name").get(); + model_name = j.at("model_name").get(); + ref_col_name = j.at("referenced_column_name").get(); + ctype = j.at("ctype").get(); + sql_type = j.at("sql_type").get(); + on_delete = j.at("on_delete").get(); + on_update = j.at("on_update").get(); + gen_sql(); + } + + void track (std::string new_model_name, const std::string col_name, FieldAttr& old_col_obj, + std::ofstream& Migrations, const nlohmann::json& frm) const override + { + try + { + std::string constraint_name = "fk_"; + for(auto& [model_name, col_renames]: frm.items()){ + if(model_name == new_model_name){ + for(auto& [old_cn, new_cn] : col_renames.items()){ + if(col_name == new_cn.get()){ + constraint_name = constraint_name + "_" + old_cn; + }else{ + constraint_name = constraint_name + "_" + col_name; + } + } + }else{ + constraint_name = constraint_name + "_" + col_name; + } + } + db::drop_constraint(new_model_name, constraint_name, Migrations); + Migrations<<"ALTER TABLE " + new_model_name + " ADD "; + db::create_fk_constraint(sql_segment, col_name, Migrations); + } catch (std::bad_cast& e) + { + throw std::runtime_error( + std::format("[ERROR: in ForeignKey::track()] {}", e.what()) + ); + } + } + ~ForeignKey() = default; +}; + +}// INFO: namespace psql::Field +namespace psql { + +using DataTypeVariant = std::variant + >; +} //INFO: namespace psql + +namespace db = psql; +#endif diff --git a/include/strata/psql/deleters.hpp b/include/strata/psql/deleters.hpp index 87560fe..0d91831 100644 --- a/include/strata/psql/deleters.hpp +++ b/include/strata/psql/deleters.hpp @@ -1,9 +1,11 @@ #pragma once -#include -#include #include "../db_config.hpp" #ifdef PSQL + +#include +#include + namespace psql { void drop_table(const std::string& model_name, std::ofstream& Migrations); @@ -14,5 +16,5 @@ void drop_constraint(const std::string& model_name, const std::string& constrain } -namespace db_adapter = psql; +namespace db = psql; #endif diff --git a/include/strata/psql/executor.hpp b/include/strata/psql/executor.hpp index 19ca435..0c437c3 100644 --- a/include/strata/psql/executor.hpp +++ b/include/strata/psql/executor.hpp @@ -1,16 +1,18 @@ #pragma once #include "../db_config.hpp" -#include "connectors.hpp" -#include #ifdef PSQL + +#include "connectors.hpp" +#include #include #include namespace psql { using opt_result_t = std::optional; -inline opt_result_t execute_sql(std::string& sql_file_or_str, bool is_file_name = true){ +inline opt_result_t execute_sql(std::string& sql_file_or_str, bool is_file_name = true) +{ std::ostringstream raw_sql {}; if(is_file_name){ @@ -36,5 +38,5 @@ inline opt_result_t execute_sql(std::string& sql_file_or_str, bool is_file_name } } -namespace db_adapter = psql; +namespace db = psql; #endif diff --git a/include/strata/psql/fetcher.hpp b/include/strata/psql/fetcher.hpp index df9f774..62b3b82 100644 --- a/include/strata/psql/fetcher.hpp +++ b/include/strata/psql/fetcher.hpp @@ -1,15 +1,16 @@ #pragma once #include "../db_config.hpp" -#include "connectors.hpp" #ifdef PSQL +#include "connectors.hpp" #include #include namespace psql { template -void dbfetch(Model_T& obj, std::string& sql_string, bool getfn_called = false){ +void dbfetch(Model_T& obj, std::string& sql_string, bool getfn_called = false) +{ pqxx::connection cxn= connect(); try{ pqxx::work txn(cxn); @@ -30,5 +31,5 @@ void dbfetch(Model_T& obj, std::string& sql_string, bool getfn_called = false){ } } -namespace db_adapter = psql; +namespace db = psql; #endif diff --git a/include/strata/psql/inserters.hpp b/include/strata/psql/inserters.hpp index 7f9d051..12dfbdc 100644 --- a/include/strata/psql/inserters.hpp +++ b/include/strata/psql/inserters.hpp @@ -1,14 +1,15 @@ #pragma once #include "../db_config.hpp" -#include "connectors.hpp" #ifdef PSQL +#include "connectors.hpp" #include namespace psql { template -pqxx::connection prepare_insert(){ +pqxx::connection prepare_insert() +{ Model_T obj {}; pqxx::placeholders row_vals {}; pqxx::connection cxn = connect(); @@ -24,7 +25,8 @@ pqxx::connection prepare_insert(){ return cxn; } -inline void exec_insert(pqxx::connection& cxn, pqxx::params& row){ +inline void exec_insert(pqxx::connection& cxn, pqxx::params& row) +{ try{ pqxx::work txn(cxn); pqxx::result result = txn.exec(pqxx::prepped{"insert_stmt"}, row).no_rows(); @@ -35,5 +37,5 @@ inline void exec_insert(pqxx::connection& cxn, pqxx::params& row){ } } -namespace db_adapter = psql; +namespace db = psql; #endif diff --git a/include/strata/psql/queriers.hpp b/include/strata/psql/queriers.hpp index a1df8b8..5915c14 100644 --- a/include/strata/psql/queriers.hpp +++ b/include/strata/psql/queriers.hpp @@ -1,13 +1,14 @@ #pragma once #include "../db_config.hpp" -#include "../custom_array.hpp" -#include "connectors.hpp" #ifdef PSQL +#include "../custom_array.hpp" +#include "connectors.hpp" #include #include -namespace psql::query{ +namespace psql::query +{ template void fetch_all(Model_T& obj, std::string columns){ @@ -17,7 +18,7 @@ void fetch_all(Model_T& obj, std::string columns){ template void get(Model_T& obj, Args... args){ - static_assert(sizeof...(args) > 0 || sizeof...(args)%2 == 0, "[ERROR:'db_adapter::query::get()'] => Args are provided in key-value pairs."); + static_assert(sizeof...(args) > 0 || sizeof...(args)%2 == 0, "Args are provided in key-value pairs."); std::string sql_kwargs {}; constexpr int N = sizeof...(args); @@ -51,11 +52,12 @@ void get(Model_T& obj, Args... args){ if(filtered_rows.size() == 1) break; } if(filtered_rows.size() <= obj.records.size()) obj.records = filtered_rows; - else throw std::runtime_error("[ERROR: get()] => filtered rows are more than the actual initial rows!"); + else throw std::runtime_error("Filtered rows are more than the actual initial rows!"); } } -inline bool matches_conditions(pqxx::field&& field, OP op, Utils::Value_T v){ +inline bool matches_conditions(pqxx::field&& field, OP op, Utils::Value_T v) +{ bool accept = false; std::any value = Utils::filter_val(v); int int_cast = 0; @@ -246,5 +248,5 @@ class JoinBuilder{ }; }//namespace query -namespace db_adapter = psql; +namespace db = psql; #endif diff --git a/include/strata/psql/row_deleter.hpp b/include/strata/psql/row_deleter.hpp index ffdba47..91e6979 100644 --- a/include/strata/psql/row_deleter.hpp +++ b/include/strata/psql/row_deleter.hpp @@ -1,8 +1,8 @@ #pragma once #include "../db_config.hpp" -#include "executor.hpp" #ifdef PSQL +#include "executor.hpp" namespace psql{ template @@ -14,5 +14,5 @@ void delete_row(std::string logical_op, Utils::filters& filters){ } -namespace x = psql; +namespace db = psql; #endif diff --git a/include/strata/psql/sql_generators.hpp b/include/strata/psql/sql_generators.hpp deleted file mode 100644 index d4f751a..0000000 --- a/include/strata/psql/sql_generators.hpp +++ /dev/null @@ -1,24 +0,0 @@ -#pragma once -#include "../db_config.hpp" -#include "../datatypes.hpp" - -#ifdef PSQL -namespace psql{ -void generate_int_sql(IntegerField& int_obj); - -void generate_char_sql(CharField& char_obj); - -void generate_decimal_sql(DecimalField& dec_obj); - -void generate_bool_sql(BoolField& bool_obj); - -void generate_bin_sql(BinaryField& bin_obj); - -void generate_datetime_sql(DateTimeField& dt_obj); - -void generate_foreignkey_sql(ForeignKey& fk_obj); - -} - -namespace db_adapter = psql; -#endif diff --git a/include/strata/psql/updater.hpp b/include/strata/psql/updater.hpp index 4753943..6a0f5a3 100644 --- a/include/strata/psql/updater.hpp +++ b/include/strata/psql/updater.hpp @@ -1,9 +1,9 @@ #pragma once #include "../db_config.hpp" -#include "../concepts.hpp" -#include "connectors.hpp" #ifdef PSQL +#include "../concepts.hpp" +#include "connectors.hpp" #include namespace psql { @@ -58,5 +58,5 @@ struct Update { }; } -namespace db_adapter = psql; +namespace db = psql; #endif diff --git a/include/strata/psql/utils.hpp b/include/strata/utils.hpp similarity index 90% rename from include/strata/psql/utils.hpp rename to include/strata/utils.hpp index 049ce09..951b3bb 100644 --- a/include/strata/psql/utils.hpp +++ b/include/strata/utils.hpp @@ -1,11 +1,11 @@ #pragma once -#include #include #include #include #include -enum OP{ +enum OP +{ EQ=1, GT, LT, @@ -18,19 +18,22 @@ enum OP{ CONTAINS }; -namespace Utils{ +namespace Utils +{ std::string str_to_upper(std::string& str); using Value_T = std::variant; -inline std::string to_sql_literal(Value_T& value){ +inline std::string to_sql_literal(Value_T& value) +{ return std::visit([](auto& v)-> std::string{ if constexpr(std::is_same_v, std::string>) return "'" + v + "'"; else return std::to_string(v); }, value); } -inline std::string op_to_str(OP op, Value_T v){ +inline std::string op_to_str(OP op, Value_T v) +{ std::string str, value; switch (op) { case EQ: @@ -76,13 +79,15 @@ inline std::string op_to_str(OP op, Value_T v){ return str; } -inline std::any filter_val(Value_T& val){ +inline std::any filter_val(Value_T& val) +{ return std::visit([](auto& v)->std::any{ return std::any{v}; }, val); } -struct Condition { +struct Condition +{ std::string column; OP op; Value_T value; @@ -92,7 +97,8 @@ struct Condition { }; using filters = std::vector; -inline std::string build_filter_args(std::string logical_op, filters& filters){ +inline std::string build_filter_args(std::string logical_op, filters& filters) +{ int op_size = logical_op.size(); std::string where_str {}; for(Condition& filter: filters){ @@ -106,7 +112,8 @@ inline std::string build_filter_args(std::string logical_op, filters& filters){ using dbenvars = std::vector>; void set_dbenvars(dbenvars&); -typedef struct{ +typedef struct +{ std::string db_name; std::string user; std::string passwd; @@ -116,7 +123,8 @@ typedef struct{ db_params parse_dbenvars(); template -std::string to_str(T& arg){ +std::string to_str(T& arg) +{ std::ostringstream ss; ss< - -IntegerField::IntegerField(std::string datatype, bool pk, bool not_null, bool unique, int check_constr, std::string check_cond) -:FieldAttr("int", datatype, not_null, unique, pk), check_constraint(check_constr), check_condition(check_cond) -{ - db_adapter::generate_int_sql(*this); -} - -void to_json(nlohmann::json& j, const IntegerField& field){ - j = nlohmann::json{ - {"datatype", field.datatype}, - {"not_null", field.not_null}, - {"unique", field.unique}, - {"primary_key", field.primary_key}, - {"check_constraint", field.check_constraint}, - {"check_condition", field.check_condition} - }; -} - -void from_json(const nlohmann::json& j, IntegerField& field){ - field.datatype = j.at("datatype").get(); - field.not_null = j.at("not_null").get(); - field.unique = j.at("unique").get(); - field.primary_key = j.at("primary_key").get(); - field.check_constraint = j.at("check_constraint").get(); - field.check_condition = j.at("check_condition").get(); - db_adapter::generate_int_sql(field); -} - -DecimalField::DecimalField(std::string datatype, int max_length, int decimal_places, bool pk) - :FieldAttr("float", datatype, false, false, pk), max_length(max_length), decimal_places(decimal_places) -{ - db_adapter::generate_decimal_sql(*this); -} - -void to_json(nlohmann::json& j, const DecimalField& field){ - j = nlohmann::json{ - {"datatype", field.datatype}, - {"primary_key", field.primary_key}, - {"max_length", field.max_length}, - {"dec_places", field.decimal_places} - }; -} - -void from_json(const nlohmann::json& j, DecimalField& field){ - field.datatype = j.at("datatype").get(); - field.primary_key = j.at("primary_key").get(); - field.max_length = j.at("max_length").get(); - field.decimal_places = j.at("dec_places").get(); - db_adapter::generate_decimal_sql(field); -} - -CharField::CharField(std::string datatype, int length, bool not_null, bool unique, bool pk) - :FieldAttr("std::string", datatype, not_null, unique, pk), length(length) -{ - db_adapter::generate_char_sql(*this); -} - -void to_json(nlohmann::json& j, const CharField& field){ - j = nlohmann::json{ - {"datatype", field.datatype}, - {"not_null", field.not_null}, - {"unique", field.unique}, - {"primary_key", field.primary_key}, - {"length", field.length} - }; -} - -void from_json(const nlohmann::json& j, CharField& field){ - field.datatype = j.at("datatype").get(); - field.not_null = j.at("not_null").get(); - field.unique = j.at("unique").get(); - field.primary_key = j.at("primary_key").get(); - field.length = j.at("length").get(); - db_adapter::generate_char_sql(field); -} - -BoolField::BoolField(bool not_null, bool enable_default, bool default_value) -:FieldAttr("bool", "BOOLEAN", not_null, false, false), enable_default(enable_default), default_value(default_value) -{ - db_adapter::generate_bool_sql(*this); -} - -void to_json(nlohmann::json& j, const BoolField& field){ - j = nlohmann::json{ - {"not_null", field.not_null}, - {"enable_def", field.enable_default}, - {"default", field.default_value} - }; -} - -void from_json(const nlohmann::json& j, BoolField& field){ - field.datatype = "BOOLEAN"; - field.not_null = j.at("not_null").get(); - field.enable_default = j.at("enable_def").get(); - field.default_value = j.at("default").get(); - db_adapter::generate_bool_sql(field); -} - -BinaryField::BinaryField(bool not_null, bool unique, bool pk) -:FieldAttr("int", "BYTEA", not_null, unique, pk) -{ - db_adapter::generate_bin_sql(*this); -} - -void to_json(nlohmann::json& j, const BinaryField& field){ - j = nlohmann::json{ - {"not_null", field.not_null}, - {"unique", field.unique}, - {"primary_key", field.primary_key}, - }; -} - -void from_json(const nlohmann::json& j, BinaryField& field){ - field.datatype = "BYTEA"; - field.not_null = j.at("not_null").get(); - field.unique = j.at("unique").get(); - field.primary_key = j.at("primary_key").get(); - db_adapter::generate_bin_sql(field); -} - -DateTimeField::DateTimeField(std::string datatype, bool enable_default, std::string default_val, bool pk) -:FieldAttr("std::string",datatype, false, false, pk), enable_default(enable_default), default_val(default_val) -{ - db_adapter::generate_datetime_sql(*this); -} - -void to_json(nlohmann::json& j, const DateTimeField& field){ - j = nlohmann::json{ - {"datatype", field.datatype}, - {"primary_key", field.primary_key}, - {"default_value", field.default_val}, - {"enable_def", field.enable_default} - }; -} - -void from_json(const nlohmann::json& j, DateTimeField& field){ - field.datatype = j.at("datatype").get(); - field.primary_key = j.at("primary_key").get(); - field.enable_default = j.at("enable_def").get(); - field.default_val = j.at("default_value").get(); - db_adapter::generate_datetime_sql(field); -} - -ForeignKey::ForeignKey(std::string cn, std::string mn, std::string rcn, std::optional pk_col_obj, std::string on_del, std::string on_upd) -:FieldAttr("null", "FOREIGN KEY", false, false, false), -col_name(cn), model_name(mn), ref_col_name(rcn), on_delete(on_del), on_update(on_upd) -{ - - if(pk_col_obj.has_value()){ - ctype = pk_col_obj->ctype; - sql_type = pk_col_obj->sql_segment; - }else{ - ctype="int"; - sql_type = "INTEGER NOT NULL"; - } - - db_adapter::generate_foreignkey_sql(*this); -} - -void to_json(nlohmann::json& j, const ForeignKey& field){ - j = nlohmann::json{ - {"column_name", field.col_name}, - {"model_name", field.model_name}, - {"referenced_column_name", field.ref_col_name}, - {"on_delete", field.on_delete}, - {"on_update", field.on_update}, - }; -} - -void from_json(const nlohmann::json& j, ForeignKey& field){ - field.datatype = "FOREIGN KEY"; - field.col_name = j.at("column_name").get(); - field.model_name = j.at("model_name").get(); - field.ref_col_name = j.at("referenced_column_name").get(); - field.on_delete = j.at("on_delete").get(); - field.on_update = j.at("on_update").get(); - db_adapter::generate_foreignkey_sql(field); -} +#ifdef PSQL +#include "../include/strata/psql/datatypes.hpp" +#elif defined(MARIADB) +#include "../include/strata/mariadb/datatypes.hpp" +#else +#error "No database adapter specified" +#endif template -bool try_set_variant(const nlohmann::json& j, DataTypeVariant& variant) { +bool try_set_variant(const nlohmann::json& j, db::DataTypeVariant& variant) { try { - T value_sptr = std::make_shared(); - from_json(j, *value_sptr); - variant = value_sptr; + T value {}; + if constexpr (std::is_same_v>) value->from_json(j); + else value.from_json(j); + variant = value; return true; - }catch (const std::exception& e) { + }catch (...) { return false; } } template -bool try_deserialize(const nlohmann::json& j, DataTypeVariant& variant, std::variant*) { +bool try_deserialize(const nlohmann::json& j, db::DataTypeVariant& variant, std::variant*) { return ((try_set_variant(j, variant)) || ...); } -void variant_to_json(nlohmann::json& j, const DataTypeVariant& variant){ - std::visit([&j](auto& arg) mutable { - to_json(j, *arg); +void variant_to_json(nlohmann::json& j, const db::DataTypeVariant& variant){ + std::visit([&j](auto& obj) mutable { + using T = std::decay_t; + if constexpr (std::is_same_v>) obj->to_json(j); + else obj.to_json(j); }, variant); } -void variant_from_json(const nlohmann::json& j, DataTypeVariant& variant) { - if (!try_deserialize(j, variant, static_cast(nullptr))) { +void variant_from_json(const nlohmann::json& j, db::DataTypeVariant& variant) { + if (!try_deserialize(j, variant, static_cast(nullptr))) { throw std::invalid_argument("Error occured while parsing JSON back to objects."); } } diff --git a/src/mariadb/alterers.cpp b/src/mariadb/alterers.cpp new file mode 100644 index 0000000..42f60b7 --- /dev/null +++ b/src/mariadb/alterers.cpp @@ -0,0 +1,48 @@ +#include "../../include/strata/mariadb/alterers.hpp" +#include + +namespace mariadb { + +void alter_rename_table(const std::string& old_model_name, const std::string& new_model_name, std::ofstream& Migrations){ + Migrations<< "ALTER TABLE " + old_model_name + " RENAME TO " + new_model_name + ";\n"; +} + +void alter_add_column(const std::string& model_name, const std::string& column_name, + const std::string& column_sql_attributes, std::ofstream& Migrations){ + Migrations<< "ALTER TABLE " + model_name + " ADD COLUMN IF NOT EXISTS " + column_name + " " + column_sql_attributes + ";\n"; +} + +void alter_rename_column(const std::string& model_name, const std::string& old_column_name, + const std::string& new_column_name, std::ofstream& Migrations){ + Migrations << "ALTER TABLE " + model_name + " RENAME COLUMN IF EXISTS " + old_column_name + " TO " + new_column_name + ";\n"; +} + +void alter_column_type(const std::string& model_name, const std::string& column_name, + const std::string& sql_segment, std::ofstream& Migrations){ + Migrations<< "ALTER TABLE " + model_name + " MODIFY " + column_name + " " + sql_segment + ";\n"; +} + +void alter_column_defaultval(const std::string& model_name, const std::string& column_name, + const bool set_default, const std::string& defaultval, std::ofstream& Migrations){ + if(set_default){ + Migrations << "ALTER TABLE " + model_name + " ALTER COLUMN " + column_name + " SET DEFAULT " + defaultval + ";\n"; + }else{ + Migrations<< "ALTER TABLE " + model_name + " ALTER COLUMN " + column_name + " DROP DEFAULT;\n"; + } +} + +//WARNING: Can't implement yet because mariadb's version of altering column nullability requires full column redefinition... +void alter_column_nullable(const std::string& model_name, const std::string& column_name, const bool nullable, std::ofstream& Migrations){ + if(nullable){ + Migrations<< "ALTER TABLE " + model_name + " ALTER COLUMN " + column_name + " DROP NOT NULL;\n"; + }else{ + std::string default_value; + std::cout<<"Provide a default value for the column '" + column_name +"' to be set to non-nullable: " << std::endl; + std::cin>> default_value; + Migrations<< std::boolalpha + << "UPDATE " + model_name + " SET " + column_name + " = '" + default_value + "' WHERE " + column_name + << " IS NULL;\n ALTER TABLE " + model_name + " ALTER COLUMN " + column_name + " SET NOT NULL;\n"; + } +} + +} // INFO: namespace mariadb diff --git a/src/mariadb/create_model_header.cpp b/src/mariadb/create_model_header.cpp new file mode 100644 index 0000000..de92c32 --- /dev/null +++ b/src/mariadb/create_model_header.cpp @@ -0,0 +1,45 @@ +#include "../../include/strata/mariadb/create_model_header.hpp" +#include +#include +#include + +namespace mariadb +{ + +void create_models_hpp(const ms_map& migrations) +{ + std::ofstream models_hpp("models.hpp"); + std::string cols_str {}; + + if(!migrations.empty()) + models_hpp<<"#include \n\n"; + + for(const auto& [model_name, col_map] : migrations) + { + models_hpp<<"class " + model_name + "\n{\npublic:\n" + <<" std::string table_name = \"" + model_name + "\";\n int id {};\n"; + for(const auto& [col_name, dtv_obj] : col_map){ + cols_str += col_name + ","; + std::visit([&](auto& col_obj){ + models_hpp<< " "; + using T = std::decay_t; + if constexpr (std::is_same_v>) models_hpp<< col_obj->ctype; + else models_hpp<< col_obj.ctype; + models_hpp<< " " + col_name + ";\n"; + }, dtv_obj); + } + cols_str.pop_back(); + models_hpp<< " std::vector records {};\n" + << " std::string col_str = \"" + cols_str + "\";\n" + << " int col_map_size = " + std::to_string(col_map.size()) + ";\n\n" + << " " + model_name + "() = default;\n" + << " template \n" + << " " + model_name + "(tuple_T tup){\n" + << " std::tie(id," + cols_str + ") = tup;\n }\n\n" + << " auto get_attr() const{\n" + << " return std::make_tuple(id," + cols_str + ");\n }\n};\n\n"; + cols_str.clear(); + } +} + +}// INFO: namespace mariadb diff --git a/src/mariadb/deleters.cpp b/src/mariadb/deleters.cpp new file mode 100644 index 0000000..d89a989 --- /dev/null +++ b/src/mariadb/deleters.cpp @@ -0,0 +1,20 @@ +#include "../../include/strata/mariadb/deleters.hpp" + +namespace mariadb { + +void drop_table(const std::string& model_name, std::ofstream& Migrations) +{ + Migrations << "DROP TABLE IF EXISTS " + model_name + ";\n"; +} + +void drop_column(const std::string& model_name, const std::string& column_name, std::ofstream& Migrations) +{ + Migrations << "ALTER TABLE " + model_name + " DROP COLUMN IF EXISTS " + column_name + ";\n"; +} + +void drop_constraint(const std::string& model_name, const std::string& constraint_name, std::ofstream& Migrations) +{ + Migrations<< "ALTER TABLE " + model_name + " DROP CONSTRAINT IF EXISTS " + constraint_name + ";\n"; +} + +}// INFO: namespace mariadb diff --git a/src/models.cpp b/src/models.cpp index 3c082aa..5cf7e9c 100644 --- a/src/models.cpp +++ b/src/models.cpp @@ -1,8 +1,6 @@ #include "../include/strata/models.hpp" #include "../include/strata/db_adapters.hpp" #include -#include -#include template struct overloaded : Ts... { using Ts::operator()...; }; @@ -10,7 +8,8 @@ struct overloaded : Ts... { using Ts::operator()...; }; template overloaded(Ts...) -> overloaded; -nlohmann::json jsonify(const ms_map& schema){ +nlohmann::json jsonify(const ms_map& schema) +{ nlohmann::json j; nlohmann::json j_col; @@ -26,30 +25,33 @@ nlohmann::json jsonify(const ms_map& schema){ return j; } -ms_map parse_to_obj(nlohmann::json& j){ - ms_map parsed; - std::unordered_map fields; - DataTypeVariant variant; +ms_map parse_to_obj(nlohmann::json& j) +{ + ms_map parsed {}; + fields col_fields {}; + db::DataTypeVariant variant {}; for(const auto& [model, j_field_map] : j.items()){ for(const auto& [col, json_dtv] : j_field_map.items()){ variant_from_json(json_dtv, variant); - fields[col] = variant; + col_fields[col] = variant; } - parsed[model] = fields; - fields.clear(); + parsed[model] = col_fields; + col_fields.clear(); } return parsed; } -void save_schema_ms(const ms_map& schema){ +void save_schema_ms(const ms_map& schema) +{ std::ofstream schema_ms_file("schema.json"); if(!schema_ms_file.is_open()) throw std::runtime_error("[ERROR: from 'save_schema_ms()'] => Could not write schema into file."); schema_ms_file << jsonify(schema).dump(2); } -ms_map load_schema_ms(){ +ms_map load_schema_ms() +{ std::ifstream schema_ms_file("schema.json"); if(!schema_ms_file.is_open()) throw std::runtime_error("[ERROR: from 'load_schema_ms()'] => Could not load schema from file."); nlohmann::json j; @@ -57,7 +59,8 @@ ms_map load_schema_ms(){ return parse_to_obj(j); } -void Model::make_migrations(const nlohmann::json& mrm, const nlohmann::json& frm, std::string sql_filename){ +void Model::make_migrations(const nlohmann::json& mrm, const nlohmann::json& frm, std::string sql_filename) +{ for(const auto& pair : ModelFactory::registry()){ new_ms[pair.first] = ModelFactory::create_model_instance(pair.first)->col_map; } @@ -66,15 +69,16 @@ void Model::make_migrations(const nlohmann::json& mrm, const nlohmann::json& frm } save_schema_ms(new_ms); track_changes(mrm, frm, sql_filename); - db_adapter::create_models_hpp(new_ms); + db::create_models_hpp(new_ms); } -void rename(const nlohmann::json& mrm, const nlohmann::json& frm, ms_map& init_ms, std::ofstream& Migrations){ +void rename(const nlohmann::json& mrm, const nlohmann::json& frm, ms_map& init_ms, std::ofstream& Migrations) +{ for(const auto& [old_mn, new_mn] : mrm.items()){ if(init_ms.find(old_mn) != init_ms.end()){ init_ms[new_mn.get()] = init_ms[old_mn]; - db_adapter::alter_rename_table(old_mn, new_mn.get(), Migrations); + db::alter_rename_table(old_mn, new_mn.get(), Migrations); init_ms.erase(old_mn); }else{ throw std::runtime_error(R"([ERROR: from 'rename()' inside model renames]=> @@ -88,7 +92,7 @@ void rename(const nlohmann::json& mrm, const nlohmann::json& frm, ms_map& init_m for(const auto& [old_cn, new_cn] : col_renames.items()){ if(init_ms[new_mn].find(old_cn) != init_ms[new_mn].end()){ init_ms[new_mn][new_cn.get()] = init_ms[new_mn][old_cn]; - db_adapter::alter_rename_column(new_mn, old_cn, new_cn.get(), Migrations); + db::alter_rename_column(new_mn, old_cn, new_cn.get(), Migrations); init_ms[new_mn].erase(old_cn); }else{ throw std::runtime_error(R"([ERROR: from 'rename()' in column renames]=> @@ -104,7 +108,8 @@ void rename(const nlohmann::json& mrm, const nlohmann::json& frm, ms_map& init_m } } -void create_or_drop_tables(ms_map& init_ms, ms_map& new_ms, std::ofstream& Migrations){ +void create_or_drop_tables(ms_map& init_ms, ms_map& new_ms, std::ofstream& Migrations) +{ char choice = 'n'; for(auto it = init_ms.begin(); it != init_ms.end();){ @@ -114,7 +119,7 @@ void create_or_drop_tables(ms_map& init_ms, ms_map& new_ms, std::ofstream& Migra std::cin >>choice; if(choice == 'y' || choice == 'Y'){ std::cout<<"The model "< -std::string type_name() { - const char* mangled = typeid(T).name(); - int status = 0; - std::unique_ptr demangled( - abi::__cxa_demangle(mangled, nullptr, nullptr, &status), - std::free - ); - return (status == 0) ? demangled.get() : mangled; -} - -void handle_types(ms_map::iterator& new_it, const std::string col, DataTypeVariant& dtv_obj, - const nlohmann::json& frm, DataTypeVariant& init_dtv, std::ofstream& Migrations) +template +void handle_types(std::string new_model_name, const std::string col, FieldAttr& new_col_obj, + const nlohmann::json& frm, FieldAttr& old_col_obj, std::ofstream& Migrations) { - std::string alterations; - - auto visitor = overloaded{ - [&](std::shared_ptr& col_obj){ - std::visit(overloaded{ - [&](std::shared_ptr& init_field){ - if(init_field->datatype != col_obj->datatype){ - db_adapter::alter_column_type(new_it->first, col, col_obj->datatype, Migrations); - } - if((init_field->enable_default != col_obj->enable_default) && col_obj->enable_default){ - db_adapter::alter_column_defaultval(new_it->first, col, true, col_obj->default_val, Migrations); - }else{ - db_adapter::alter_column_defaultval(new_it->first, col, false, col_obj->default_val, Migrations); - } - return; - }, - [&](auto& init_field){ - throw std::runtime_error(std::format( - "[ERROR: in 'handle_types.DateTimeField()'] => Conversions from {} to DateTimeField are not compatible.", - type_name() - )); - //convert_to_DateTimeField(col_obj, init_field); - } - }, init_dtv); - }, - [&](std::shared_ptr& col_obj){ - std::visit(overloaded{ - [&](std::shared_ptr& init_field){ - if(init_field->datatype != col_obj->datatype){ - db_adapter::alter_column_type(new_it->first, col, col_obj->datatype, Migrations); - } - /*if((init_field.check_condition != col_obj.check_condition) && col_obj.check_condition != "default"){ - string check = "CHECK(" + col + col_obj.check_condition + std::to_string(col_obj.check_constraint) + ")"; - Migrations << "ALTER TABLE " + new_it->first + " ALTER COLUMN " + alterations + ";\n"; - }*/ - return; - }, - [&](auto& init_field){ - throw std::runtime_error(std::format( - "[ERROR: in 'handle_types.IntegerField()'] => Conversions from {} to IntegerField are not compatible.", - type_name() - )); - //convert_to_IntegerField(col_obj, init_field); - return; - } - }, init_dtv); - }, - [&](std::shared_ptr& col_obj){ - std::string constraint_name = "fk_"; - for(auto& [model_name, col_renames]: frm.items()){ - if(model_name == new_it->first){ - for(auto& [old_cn, new_cn] : col_renames.items()){ - if(col == new_cn.get()){ - constraint_name = constraint_name + "_" + old_cn; - }else{ - constraint_name = constraint_name + "_" + col; - } - } - }else{ - constraint_name = constraint_name + "_" + col; - } - } - db_adapter::drop_constraint(new_it->first, constraint_name, Migrations); - Migrations<<"ALTER TABLE " + new_it->first + " ADD "; - db_adapter::create_fk_constraint(col_obj->sql_segment, col, Migrations); - return; - }, - [&](std::shared_ptr& col_obj){ - std::visit(overloaded{ - [&](std::shared_ptr& init_field){ - if(init_field->datatype != col_obj->datatype || - init_field->max_length != col_obj->max_length || - init_field->decimal_places != col_obj->decimal_places){ - - alterations = col_obj->datatype + " (" + std::to_string(col_obj->max_length) + "," + std::to_string(col_obj->decimal_places) + ")"; - db_adapter::alter_column_type(new_it->first, col, alterations, Migrations); - } - return; - }, - [&](auto& init_field){ - throw std::runtime_error(std::format( - "[ERROR: in 'handle_types.DecimalField()'] => Conversions from {} to DecimalField are not compatible.)", - type_name() - )); - //convert_to_DecField(col_obj, init_field, col, model_name); - } - }, init_dtv); - }, - [&](std::shared_ptr& col_obj){ - std::visit(overloaded{ - [&](std::shared_ptr& init_field){ - if((init_field->datatype != col_obj->datatype) || (init_field->length != col_obj->length)){ - alterations = "VARCHAR( " + std::to_string(col_obj->length) + " )"; - db_adapter::alter_column_type(new_it->first, col, alterations, Migrations); - } - return; - }, - [&](auto& init_field){ - throw std::runtime_error(std::format( - "[ERROR: in 'handle_types.CharField()'] => Conversions from {} to CharField are not compatible.", - type_name() - )); - //convert_to_CharField(col_obj, init_field); - } - }, init_dtv); - }, - [&](std::shared_ptr& col_obj){ - std::visit(overloaded{ - [&](std::shared_ptr& init_field){ - return; - }, - [&](auto& init_field){ - throw std::runtime_error(std::format( - "([ERROR: in 'handle_types.BinaryField()'] => Conversions from {} to BinaryField are not compatible.", - type_name() - )); - //convert_to_BinaryField(col_obj, init_field) - return; - } - }, init_dtv); - }, - [&](std::shared_ptr& col_obj){ - std::visit(overloaded{ - [&](std::shared_ptr& init_field){ - if(init_field->enable_default != col_obj->enable_default){ - if(col_obj->enable_default){ - db_adapter::alter_column_defaultval(new_it->first, col, true, std::to_string(col_obj->default_value), Migrations); - }else{ - alterations = col + " DROP DEFAULT"; - db_adapter::alter_column_defaultval(new_it->first, col, false, "false", Migrations); - } - } - return; - }, - [&](auto& init_field){ - throw std::runtime_error(std::format( - "([ERROR: in 'handle_types.BoolField()'] => Conversions from {} to BoolField are not compatible.", - type_name() - )); - //convert_to_BoolField(col_obj itself, and the init_field for conversion compatibility checks); - return; - } - }, init_dtv); - } - }; - std::visit(visitor, dtv_obj); + if constexpr(!std::is_same_v) + { throw std::logic_error("Conversions are not yet supported!"); } + + if constexpr(std::is_same_v) + { + new_col_obj.track(new_model_name, col, old_col_obj, Migrations, frm); + return; + } + + new_col_obj.track(new_model_name, col, old_col_obj, Migrations); } -std::string find_uq_constraint(const nlohmann::json& frm, const std::string& new_model_name, const std::string& new_col){ +std::string find_uq_constraint(const nlohmann::json& frm, const std::string& new_model_name, const std::string& new_col) +{ std::string constraint_name; auto outer_it = frm.find(new_model_name); if(outer_it != frm.end()){ @@ -312,14 +172,29 @@ std::string find_uq_constraint(const nlohmann::json& frm, const std::string& new return constraint_name; } -void Model::track_changes(const nlohmann::json& mrm, const nlohmann::json& frm, std::string sql_filename){ +void check_for_column_drops(ms_map& new_ms, ms_map& init_ms, std::ofstream& Migrations) +{ + for(const auto& [new_model_name, new_col_map]:new_ms){ + const auto init_it = init_ms.find(new_model_name); + if(init_it == init_ms.end()){ + throw std::runtime_error(R"([ERROR: in 'track_changes()'] => + Error in check for init iterator with new model name against the initial migrations.)"); + } + for(auto& [old_col, dtv_obj] : init_it->second){ + if(new_col_map.find(old_col) == new_col_map.end()){ + db::drop_column(init_it->first, old_col, Migrations); + } + } + } +} +void Model::track_changes(const nlohmann::json& mrm, const nlohmann::json& frm, std::string sql_filename) +{ std::ofstream Migrations (sql_filename); - - if(init_ms.empty()){ - for(auto& [model_name, field_map] : new_ms){ - db_adapter::create_table(model_name, field_map, Migrations); - } + if(init_ms.empty()) + { + for(auto& [model_name, field_map] : new_ms) + db::create_table(model_name, field_map, Migrations); return; } @@ -329,44 +204,48 @@ void Model::track_changes(const nlohmann::json& mrm, const nlohmann::json& frm, std::vector pk_cols, uq_cols; std::string alterations, pk, fk; - for(auto& [init_model_name, init_col_map] : init_ms){ + for(auto& [init_model_name, init_col_map] : init_ms) + { auto new_it = new_ms.find(init_model_name); - if(new_it == new_ms.end()){ - throw std::runtime_error("[ERROR: in 'track_changes()'] => Error in check for new iterator with initial model name against new_ms."); - } - for(auto& [new_col, dtv_obj] : new_it->second){ - std::visit([&](auto& col_obj){ - if(init_col_map.find(new_col) == init_col_map.end()){ - db_adapter::alter_add_column(new_it->first, new_col, col_obj->sql_segment, Migrations); + if(new_it == new_ms.end()) + throw std::runtime_error("Error in check for new iterator with initial model name against new_ms."); + for(auto& [new_col, dtv_obj] : new_it->second) + { + std::visit([&](auto& new_raw_obj){ + using new_col_T = decltype(new_raw_obj); + FieldAttr& new_field = db::as_ref(new_raw_obj); + + if(init_col_map.find(new_col) == init_col_map.end()) + { + db::alter_add_column(new_it->first, new_col, new_field.sql_segment, Migrations); return; } - std::visit([&](auto& init_field){ - if(init_field->sql_segment != col_obj->sql_segment){ - handle_types(new_it, new_col, dtv_obj, frm, init_col_map[new_col], Migrations); - - if(col_obj->primary_key){ - pk_cols.push_back(new_col); - } - - if(init_field->not_null != col_obj->not_null){ - if(col_obj->not_null){ - db_adapter::alter_column_nullable(new_it->first, new_col, false, Migrations); - }else{ - db_adapter::alter_column_nullable(new_it->first, new_col, true, Migrations); - } - } - - if((init_field->unique != col_obj->unique) && col_obj->unique){ - if(frm.empty()){ + std::visit([&](auto& raw_obj){ + using old_col_T = decltype(raw_obj); + FieldAttr& init_field = db::as_ref(raw_obj); + + if(init_field.sql_segment != new_field.sql_segment) + { + handle_types(new_it->first, new_col, new_field, frm, init_field, Migrations); + if(new_field.primary_key) pk_cols.push_back(new_col); + + if(init_field.not_null != new_field.not_null) + db::alter_column_nullable(new_it->first, new_col, !new_field.not_null, Migrations); + + if((init_field.unique != new_field.unique) && new_field.unique) + { + if(frm.empty()) + { uq_cols.push_back(new_col); }else{ uq_cols.push_back(find_uq_constraint(frm, new_it->first, new_col)); } - }else if((init_field->unique != col_obj->unique) && !col_obj->unique){ - if(frm.empty()){ - db_adapter::drop_constraint(new_it->first, "uq_"+new_col , Migrations); + }else if((init_field.unique != new_field.unique) && !new_field.unique){ + if(frm.empty()) + { + db::drop_constraint(new_it->first, "uq_"+new_col , Migrations); }else{ - db_adapter::drop_constraint(new_it->first, "uq_"+find_uq_constraint(frm,new_it->first,new_col), Migrations); + db::drop_constraint(new_it->first, "uq_"+find_uq_constraint(frm,new_it->first,new_col), Migrations); } }else { return; @@ -375,47 +254,41 @@ void Model::track_changes(const nlohmann::json& mrm, const nlohmann::json& frm, }, init_col_map[std::string(new_col)]); }, dtv_obj); - for(const std::string& uq_col : uq_cols){ + for(const std::string& uq_col : uq_cols) + { Migrations<<"ALTER TABLE " + new_it->first + " ADD "; - db_adapter::create_uq_constraint(uq_col, Migrations); + db::create_uq_constraint(uq_col, Migrations); } uq_cols.clear(); - if(!mrm.empty() && !pk_cols.empty()){ + if(!mrm.empty() && !pk_cols.empty()) + { std::string pk_constraint = "pk_"; - for(auto& [old_mn, new_mn] : mrm.items()){ - if(new_mn.get() == new_it->first){ + for(auto& [old_mn, new_mn] : mrm.items()) + { + if(new_mn.get() == new_it->first) + { pk_constraint += old_mn; }else{ pk_constraint += new_it->first; } } - db_adapter::drop_constraint(new_it->first, pk_constraint, Migrations); + db::drop_constraint(new_it->first, pk_constraint, Migrations); }else if(mrm.empty() && !pk_cols.empty()){ - db_adapter::drop_constraint(new_it->first, "pk_" + new_it->first , Migrations); + db::drop_constraint(new_it->first, "pk_" + new_it->first , Migrations); }else{ continue; } - if(!pk_cols.empty()){ + if(!pk_cols.empty()) + { Migrations<<"ALTER TABLE " + new_it->first + " ADD "; - db_adapter::create_pk_constraint(new_it->first, pk_cols, Migrations); + db::create_pk_constraint(new_it->first, pk_cols, Migrations); Migrations<<";\n"; } pk_cols.clear(); } } - for(const auto& [new_model_name, new_col_map]:new_ms){ - const auto init_it = init_ms.find(new_model_name); - if(init_it == init_ms.end()){ - throw std::runtime_error(R"([ERROR: in 'track_changes()'] => - Error in check for init iterator with new model name against the initial migrations.)"); - } - for(auto& [old_col, dtv_obj] : init_it->second){ - if(new_col_map.find(old_col) == new_col_map.end()){ - db_adapter::drop_column(init_it->first, old_col, Migrations); - } - } - } + check_for_column_drops(new_ms, init_ms, Migrations); } diff --git a/src/psql/alterers.cpp b/src/psql/alterers.cpp index 0db35dc..4b16791 100644 --- a/src/psql/alterers.cpp +++ b/src/psql/alterers.cpp @@ -1,7 +1,8 @@ #include "../../include/strata/psql/alterers.hpp" #include +#include -namespace psql{ +namespace psql { void alter_rename_table(const std::string& old_model_name, const std::string& new_model_name, std::ofstream& Migrations){ Migrations<< "ALTER TABLE " + old_model_name + " RENAME TO " + new_model_name + ";\n"; @@ -44,4 +45,4 @@ void alter_column_nullable(const std::string& model_name, const std::string& col } } -} +} // INFO: namespace psql diff --git a/src/psql/create_model_header.cpp b/src/psql/create_model_header.cpp index 89f7ea8..f19b72d 100644 --- a/src/psql/create_model_header.cpp +++ b/src/psql/create_model_header.cpp @@ -1,7 +1,9 @@ #include "../../include/strata/psql/create_model_header.hpp" #include +#include +#include -namespace psql{ +namespace psql { void create_models_hpp(const ms_map& migrations){ std::ofstream models_hpp("models.hpp"); @@ -16,7 +18,11 @@ void create_models_hpp(const ms_map& migrations){ for(const auto& [col_name, dtv_obj] : col_map){ cols_str += col_name + ","; std::visit([&](auto& col_obj){ - models_hpp<< " " + col_obj->ctype + " " + col_name + ";\n"; + models_hpp<< " "; + using T = std::decay_t; + if constexpr (std::is_same_v>) models_hpp<< col_obj->ctype; + else models_hpp<< col_obj.ctype; + models_hpp<< " " + col_name + ";\n"; }, dtv_obj); } cols_str.pop_back(); @@ -33,4 +39,4 @@ void create_models_hpp(const ms_map& migrations){ } } -} +}// INFO: namespace psql diff --git a/src/psql/deleters.cpp b/src/psql/deleters.cpp index 4991e18..51e6570 100644 --- a/src/psql/deleters.cpp +++ b/src/psql/deleters.cpp @@ -14,4 +14,4 @@ void drop_constraint(const std::string& model_name, const std::string& constrain Migrations<< "ALTER TABLE " + model_name + " DROP CONSTRAINT " + constraint_name + ";\n"; } -} +}// INFO: namespace psql diff --git a/src/psql/sql_generators.cpp b/src/psql/sql_generators.cpp deleted file mode 100644 index 3529ae4..0000000 --- a/src/psql/sql_generators.cpp +++ /dev/null @@ -1,91 +0,0 @@ -#include "../../include/strata/psql/sql_generators.hpp" -#include "../../include/strata/psql/utils.hpp" -#include - -namespace psql { -void generate_int_sql(IntegerField& int_obj){ - int_obj.datatype = Utils::str_to_upper(int_obj.datatype); - if(int_obj.datatype != "INTEGER" && int_obj.datatype != "SMALLINT" && int_obj.datatype != "BIGINT"){ - throw std::runtime_error(std::format("Datatype '{}' is not supported by postgreSQL. Provide a valid datatype", int_obj.datatype)); - } - int_obj.sql_segment += int_obj.datatype; - if (int_obj.not_null) int_obj.sql_segment = int_obj.datatype + " NOT NULL"; -} - -void generate_char_sql(CharField& char_obj){ - char_obj.datatype = Utils::str_to_upper(char_obj.datatype); - if(char_obj.datatype != "VARCHAR" && char_obj.datatype != "CHAR" && char_obj.datatype != "TEXT"){ - throw std::runtime_error(std::format("Datatype '{}' is not supported by postgreSQL. Provide a valid datatype", char_obj.datatype)); - } - char_obj.sql_segment += char_obj.datatype; - if (char_obj.length == 0 && char_obj.datatype != "TEXT"){ - throw std::runtime_error(std::format("Length attribute is required for datatype '{}'", char_obj.datatype)); - } - if (char_obj.datatype != "TEXT") { - char_obj.sql_segment += "(" + std::to_string(char_obj.length) + ")"; - } - if (char_obj.not_null) char_obj.sql_segment += " NOT NULL"; -} - -void generate_decimal_sql(DecimalField& dec_obj){ - dec_obj.datatype = Utils::str_to_upper(dec_obj.datatype); - - if(dec_obj.datatype != "DECIMAL" && dec_obj.datatype != "REAL" && - dec_obj.datatype != "DOUBLE PRECISION" && dec_obj.datatype != "NUMERIC"){ - throw std::runtime_error(std::format("Datatype '{}' is not supported by postgreSQL. Provide a valid datatype", dec_obj.datatype)); - } - - if(dec_obj.datatype == "REAL" || dec_obj.datatype == "DOUBLE PRECISION"){ - dec_obj.sql_segment = dec_obj.datatype; - return; - } - - if(dec_obj.max_length > 0 || dec_obj.decimal_places > 0) - dec_obj.sql_segment = dec_obj.datatype + "(" + std::to_string(dec_obj.max_length) + "," + std::to_string(dec_obj.decimal_places) + ")"; - else - throw std::runtime_error(std::format("Max length and/or decimal places cannot be 0 for datatype '{}'", dec_obj.datatype)); -} - -void generate_bool_sql(BoolField& bool_obj){ - bool_obj.sql_segment += bool_obj.datatype; - if (bool_obj.not_null)bool_obj.sql_segment += " NOT NULL"; - if (bool_obj.enable_default){ - if(bool_obj.default_value)bool_obj.sql_segment += " DEFAULT TRUE"; - else bool_obj.sql_segment += " DEFAULT FALSE"; - } -} - -void generate_bin_sql(BinaryField& bin_obj){ - bin_obj.sql_segment = bin_obj.datatype; - if(bin_obj.not_null) bin_obj.sql_segment += " NOT NULL"; -} - -void generate_datetime_sql(DateTimeField& dt_obj){ - dt_obj.datatype = Utils::str_to_upper(dt_obj.datatype); - if(dt_obj.datatype != "DATE" && dt_obj.datatype != "TIME" && dt_obj.datatype != "TIMESTAMP_WTZ" && - dt_obj.datatype != "TIMESTAMP" && dt_obj.datatype != "TIME_WTZ" && dt_obj.datatype != "INTERVAL"){ - throw std::runtime_error(std::format("Datatype '{}' not supported in postgreSQL. Provide a valid datatype", dt_obj.datatype)); - return; - } - - std::string::size_type n = dt_obj.datatype.find('_'); - if(n != std::string::npos){ - dt_obj.datatype.replace(n+1, n+3, "WITH TIME ZONE"); - } - - dt_obj.sql_segment = dt_obj.datatype; - if(dt_obj.enable_default && !dt_obj.default_val.empty()){ - dt_obj.default_val = Utils::str_to_upper(dt_obj.default_val); - dt_obj.sql_segment += " DEFAULT " + dt_obj.default_val; - } -} - -void generate_foreignkey_sql(ForeignKey& fk_obj){ - fk_obj.sql_segment ="FOREIGN KEY(" + fk_obj.col_name + ") REFERENCES " + fk_obj.model_name + " (" + fk_obj.ref_col_name + ")"; - fk_obj.on_delete = Utils::str_to_upper(fk_obj.on_delete); - fk_obj.sql_segment += " ON DELETE " + fk_obj.on_delete; - fk_obj.on_update = Utils::str_to_upper(fk_obj.on_update); - fk_obj.sql_segment += " ON UPDATE " + fk_obj.on_update; -} - -} diff --git a/src/psql/utils.cpp b/src/utils.cpp similarity index 79% rename from src/psql/utils.cpp rename to src/utils.cpp index acffd44..c1829d6 100644 --- a/src/psql/utils.cpp +++ b/src/utils.cpp @@ -1,4 +1,4 @@ -#include "../../include/strata/psql/utils.hpp" +#include "../include/strata/utils.hpp" #include #include #include @@ -34,8 +34,11 @@ db_params parse_dbenvars(){ }; return params; } catch (std::exception& e) { - throw std::runtime_error(std::format("[ERROR: In 'parse_db_envars()'] => {}", e.what())); + throw std::runtime_error(std::format( + "{} \n Check if you have set environmental variables using Utils::set_dbenvars()!", + e.what() + )); } } -} +} //INFO: Utils namespace