From 3cffb6cd940761534d0703131e2b728b5c7bc1e4 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sun, 25 May 2025 14:48:05 +1200 Subject: [PATCH] Bug fixing --- README.md | 12 +++++++++- src/config.cpp | 17 ++++++++++++++ src/config.hpp | 5 ++++ src/server.cpp | 62 ++++++++++++++++++++++++++++++++++++++++++++++++++ src/server.hpp | 3 +++ 5 files changed, 98 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index f7a2337..ca5f114 100644 --- a/README.md +++ b/README.md @@ -83,10 +83,20 @@ The server can be configured by creating a JSON configuration file at `~/.config "host": "localhost", "port": 8080, "storage_path": "/path/to/storage", - "write_tokens": ["your-secret-token"] + "write_tokens": ["your-secret-token"], } ``` +Optionally, you can modify the CORS configuration. Defaults: +```json + "cors": { + "allowed_origins": ["*"], + "allowed_methods": ["GET", "PUT", "POST", "DELETE", "OPTIONS"], + "allowed_headers": ["Authorization", "Content-Type"], + "allow_credentials": false + } +``` + ## API Endpoints ### Upload a File diff --git a/src/config.cpp b/src/config.cpp index df21b82..6ddd1db 100644 --- a/src/config.cpp +++ b/src/config.cpp @@ -42,6 +42,23 @@ bool load_config(const std::string& config_path, ServerConfig& config) { config.port = j["port"].get(); } + // Parse CORS configuration + if (j.contains("cors")) { + const auto& cors = j["cors"]; + if (cors.contains("allowed_origins")) { + config.allowed_origins = cors["allowed_origins"].get>(); + } + if (cors.contains("allowed_methods")) { + config.allowed_methods = cors["allowed_methods"].get>(); + } + if (cors.contains("allowed_headers")) { + config.allowed_headers = cors["allowed_headers"].get>(); + } + if (cors.contains("allow_credentials")) { + config.allow_credentials = cors["allow_credentials"].get(); + } + } + return true; } catch (const std::exception& e) { std::cerr << "Error parsing config file: " << e.what() << std::endl; diff --git a/src/config.hpp b/src/config.hpp index ef20890..4704437 100644 --- a/src/config.hpp +++ b/src/config.hpp @@ -12,6 +12,11 @@ struct ServerConfig { std::filesystem::path object_store_path; std::string host = "0.0.0.0"; uint16_t port = 0; + // CORS configuration + std::vector allowed_origins = {"*"}; // Default to allow all origins + std::vector allowed_methods = {"GET", "PUT", "POST", "DELETE", "OPTIONS"}; + std::vector allowed_headers = {"Authorization", "Content-Type"}; + bool allow_credentials = false; }; bool load_config(const std::string& config_path, ServerConfig& config); diff --git a/src/server.cpp b/src/server.cpp index ac2042d..ff4ca7c 100644 --- a/src/server.cpp +++ b/src/server.cpp @@ -123,6 +123,16 @@ void Server::stop() { } void Server::setup_routes() { + // Add CORS preflight handler for all routes + server_.Options(".*", [this](const httplib::Request& req, httplib::Response& res) { + handle_cors_preflight(req, res); + }); + + // Add CORS headers to all responses + server_.set_post_routing_handler([this](const httplib::Request& req, httplib::Response& res) { + add_cors_headers(req, res); + }); + const std::string welcome_page = "

simple_object_storage Template Registry

"; // Welcome page server_.Get("/", [welcome_page](const httplib::Request&, httplib::Response& res) { @@ -180,6 +190,58 @@ void Server::setup_routes() { }); } +void Server::handle_cors_preflight(const httplib::Request& req, httplib::Response& res) { + add_cors_headers(req, res); + res.status = 204; // No content +} + +void Server::add_cors_headers(const httplib::Request& req, httplib::Response& res) { + // Get the origin from the request + std::string origin = req.get_header_value("Origin"); + + // If no origin header, no CORS headers needed + if (origin.empty()) { + return; + } + + // Check if origin is allowed + bool origin_allowed = false; + if (config_.allowed_origins.empty() || + std::find(config_.allowed_origins.begin(), config_.allowed_origins.end(), "*") != config_.allowed_origins.end()) { + origin_allowed = true; + } else { + origin_allowed = std::find(config_.allowed_origins.begin(), config_.allowed_origins.end(), origin) != config_.allowed_origins.end(); + } + + if (origin_allowed) { + res.set_header("Access-Control-Allow-Origin", origin); + + // Add other CORS headers + std::string methods = join(config_.allowed_methods, ", "); + res.set_header("Access-Control-Allow-Methods", methods); + + std::string headers = join(config_.allowed_headers, ", "); + res.set_header("Access-Control-Allow-Headers", headers); + + if (config_.allow_credentials) { + res.set_header("Access-Control-Allow-Credentials", "true"); + } + + // Add max age for preflight requests + res.set_header("Access-Control-Max-Age", "86400"); // 24 hours + } +} + +std::string Server::join(const std::vector& strings, const std::string& delimiter) { + if (strings.empty()) return ""; + + std::string result = strings[0]; + for (size_t i = 1; i < strings.size(); ++i) { + result += delimiter + strings[i]; + } + return result; +} + void Server::handle_get_object(const httplib::Request& req, httplib::Response& res) { const auto& key = req.matches[1].str(); std::string hash_str = key; diff --git a/src/server.hpp b/src/server.hpp index 90a2d8a..6f1990d 100644 --- a/src/server.hpp +++ b/src/server.hpp @@ -34,6 +34,9 @@ private: void handle_get_metadata(const httplib::Request& req, httplib::Response& res); void handle_delete_object(const httplib::Request& req, httplib::Response& res); void handle_exists(const httplib::Request& req, httplib::Response& res); + void handle_cors_preflight(const httplib::Request& req, httplib::Response& res); + void add_cors_headers(const httplib::Request& req, httplib::Response& res); + std::string join(const std::vector& strings, const std::string& delimiter); bool init_db();