#ifndef MAMBA_API_CONFIGURATION_IMPL_HPP #define MAMBA_API_CONFIGURATION_IMPL_HPP #include #include #include #include #include "mamba/core/common_types.hpp" #include "mamba/core/context.hpp" #include "mamba/core/mamba_fs.hpp" namespace mamba { namespace detail { // Because CLI11 supports std::optional for options but not for flags... /************** * cli_config * **************/ template struct cli_config { using storage_type = std::optional; storage_type m_storage; cli_config() = default; cli_config(const T& value) : m_storage(value) { } storage_type& storage() { return m_storage; } bool has_value() const { return m_storage.has_value(); } const T& value() const { return m_storage.value(); } void reset() { m_storage.reset(); } }; /********************** * Source declaration * **********************/ template struct Source { static std::vector default_value(const T&) { return std::vector({ "default" }); }; static void merge( const std::map& values, const std::vector& sources, T& value, std::vector& source ); static T deserialize(const std::string& value); static bool is_sequence(); }; template struct Source> { static std::vector default_value(const std::vector& init) { return std::vector(init.size(), "default"); }; static void merge( const std::map>& values, const std::vector& sources, std::vector& value, std::vector& source ); static std::vector deserialize(const std::string& value); static bool is_sequence(); }; /************************* * Source implementation * *************************/ template void Source::merge( const std::map& values, const std::vector& sources, T& value, std::vector& source ) { source = sources; value = values.at(sources.front()); } template T Source::deserialize(const std::string& value) { if (value.empty()) { return YAML::Node("").as(); } else { return YAML::Load(value).as(); } } template bool Source::is_sequence() { return false; } template void Source>::merge( const std::map>& values, const std::vector& sources, std::vector& value, std::vector& source ) { value.clear(); source.clear(); for (auto& s : sources) { auto& vec = values.at(s); for (auto& v : vec) { auto find_v = std::find(value.begin(), value.end(), v); if (find_v == value.end()) { value.push_back(v); source.push_back(s); } } } } template std::vector Source>::deserialize(const std::string& value) { return YAML::Load("[" + value + "]").as>(); } template bool Source>::is_sequence() { return true; } } } /**************** * YAML parsers * ****************/ namespace YAML { template struct convert> { static Node encode(const T& rhs) { return Node(rhs.value()); } static bool decode(const Node& node, std::optional& rhs) { if (!node.IsScalar()) { return false; } rhs = std::optional(node.as()); return true; } }; template <> struct convert { static Node encode(const mamba::VerificationLevel& rhs) { if (rhs == mamba::VerificationLevel::kDisabled) { return Node("disabled"); } else if (rhs == mamba::VerificationLevel::kWarn) { return Node("warn"); } else if (rhs == mamba::VerificationLevel::kEnabled) { return Node("enabled"); } else { return Node(); } } static bool decode(const Node& node, mamba::VerificationLevel& rhs) { if (!node.IsScalar()) { return false; } auto str = node.as(); if (str == "enabled") { rhs = mamba::VerificationLevel::kEnabled; } else if (str == "warn") { rhs = mamba::VerificationLevel::kWarn; } else if (str == "disabled") { rhs = mamba::VerificationLevel::kDisabled; } else { throw std::runtime_error( "Invalid 'VerificationLevel', should be in {'enabled', 'warn', 'disabled'}" ); } return true; } }; template <> struct convert { static Node encode(const mamba::ChannelPriority& rhs) { if (rhs == mamba::ChannelPriority::kStrict) { return Node("strict"); } else if (rhs == mamba::ChannelPriority::kFlexible) { return Node("flexible"); } else if (rhs == mamba::ChannelPriority::kDisabled) { return Node("disabled"); } else { return Node(); } } static bool decode(const Node& node, mamba::ChannelPriority& rhs) { if (!node.IsScalar()) { return false; } auto str = node.as(); if (str == "strict") { rhs = mamba::ChannelPriority::kStrict; } else if ((str == "flexible") || (str == "true")) { rhs = mamba::ChannelPriority::kFlexible; } else if (str == "disabled") { rhs = mamba::ChannelPriority::kDisabled; } else { return false; } return true; } }; template <> struct convert { static Node encode(const fs::u8path& rhs) { return Node(rhs.string()); } static bool decode(const Node& node, fs::u8path& rhs) { if (!node.IsScalar()) { return false; } rhs = fs::u8path(node.as()); return true; } }; template <> struct convert { private: inline static const std::array log_level_names = { "trace", "debug", "info", "warning", "error", "critical", "off" }; public: static Node encode(const mamba::log_level& rhs) { return Node(log_level_names[static_cast(rhs)]); } static bool decode(const Node& node, mamba::log_level& rhs) { auto name = node.as(); auto it = std::find(log_level_names.begin(), log_level_names.end(), name); if (it != log_level_names.end()) { rhs = static_cast(it - log_level_names.begin()); return true; } LOG_ERROR << "Invalid log level, should be in {'critical', 'error', 'warning', 'info', 'debug', 'trace', 'off'} but is '" << name << "'"; return false; } }; } #endif