configuration_impl.hpp 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. #ifndef MAMBA_API_CONFIGURATION_IMPL_HPP
  2. #define MAMBA_API_CONFIGURATION_IMPL_HPP
  3. #include <optional>
  4. #include <string>
  5. #include <vector>
  6. #include <yaml-cpp/yaml.h>
  7. #include "mamba/core/common_types.hpp"
  8. #include "mamba/core/context.hpp"
  9. #include "mamba/core/mamba_fs.hpp"
  10. namespace mamba
  11. {
  12. namespace detail
  13. {
  14. // Because CLI11 supports std::optional for options but not for flags...
  15. /**************
  16. * cli_config *
  17. **************/
  18. template <class T>
  19. struct cli_config
  20. {
  21. using storage_type = std::optional<T>;
  22. storage_type m_storage;
  23. cli_config() = default;
  24. cli_config(const T& value)
  25. : m_storage(value)
  26. {
  27. }
  28. storage_type& storage()
  29. {
  30. return m_storage;
  31. }
  32. bool has_value() const
  33. {
  34. return m_storage.has_value();
  35. }
  36. const T& value() const
  37. {
  38. return m_storage.value();
  39. }
  40. void reset()
  41. {
  42. m_storage.reset();
  43. }
  44. };
  45. /**********************
  46. * Source declaration *
  47. **********************/
  48. template <class T>
  49. struct Source
  50. {
  51. static std::vector<std::string> default_value(const T&)
  52. {
  53. return std::vector<std::string>({ "default" });
  54. };
  55. static void merge(
  56. const std::map<std::string, T>& values,
  57. const std::vector<std::string>& sources,
  58. T& value,
  59. std::vector<std::string>& source
  60. );
  61. static T deserialize(const std::string& value);
  62. static bool is_sequence();
  63. };
  64. template <class T>
  65. struct Source<std::vector<T>>
  66. {
  67. static std::vector<std::string> default_value(const std::vector<T>& init)
  68. {
  69. return std::vector<std::string>(init.size(), "default");
  70. };
  71. static void merge(
  72. const std::map<std::string, std::vector<T>>& values,
  73. const std::vector<std::string>& sources,
  74. std::vector<T>& value,
  75. std::vector<std::string>& source
  76. );
  77. static std::vector<T> deserialize(const std::string& value);
  78. static bool is_sequence();
  79. };
  80. /*************************
  81. * Source implementation *
  82. *************************/
  83. template <class T>
  84. void Source<T>::merge(
  85. const std::map<std::string, T>& values,
  86. const std::vector<std::string>& sources,
  87. T& value,
  88. std::vector<std::string>& source
  89. )
  90. {
  91. source = sources;
  92. value = values.at(sources.front());
  93. }
  94. template <class T>
  95. T Source<T>::deserialize(const std::string& value)
  96. {
  97. if (value.empty())
  98. {
  99. return YAML::Node("").as<T>();
  100. }
  101. else
  102. {
  103. return YAML::Load(value).as<T>();
  104. }
  105. }
  106. template <class T>
  107. bool Source<T>::is_sequence()
  108. {
  109. return false;
  110. }
  111. template <class T>
  112. void Source<std::vector<T>>::merge(
  113. const std::map<std::string, std::vector<T>>& values,
  114. const std::vector<std::string>& sources,
  115. std::vector<T>& value,
  116. std::vector<std::string>& source
  117. )
  118. {
  119. value.clear();
  120. source.clear();
  121. for (auto& s : sources)
  122. {
  123. auto& vec = values.at(s);
  124. for (auto& v : vec)
  125. {
  126. auto find_v = std::find(value.begin(), value.end(), v);
  127. if (find_v == value.end())
  128. {
  129. value.push_back(v);
  130. source.push_back(s);
  131. }
  132. }
  133. }
  134. }
  135. template <class T>
  136. std::vector<T> Source<std::vector<T>>::deserialize(const std::string& value)
  137. {
  138. return YAML::Load("[" + value + "]").as<std::vector<T>>();
  139. }
  140. template <class T>
  141. bool Source<std::vector<T>>::is_sequence()
  142. {
  143. return true;
  144. }
  145. }
  146. }
  147. /****************
  148. * YAML parsers *
  149. ****************/
  150. namespace YAML
  151. {
  152. template <class T>
  153. struct convert<std::optional<T>>
  154. {
  155. static Node encode(const T& rhs)
  156. {
  157. return Node(rhs.value());
  158. }
  159. static bool decode(const Node& node, std::optional<T>& rhs)
  160. {
  161. if (!node.IsScalar())
  162. {
  163. return false;
  164. }
  165. rhs = std::optional<T>(node.as<T>());
  166. return true;
  167. }
  168. };
  169. template <>
  170. struct convert<mamba::VerificationLevel>
  171. {
  172. static Node encode(const mamba::VerificationLevel& rhs)
  173. {
  174. if (rhs == mamba::VerificationLevel::kDisabled)
  175. {
  176. return Node("disabled");
  177. }
  178. else if (rhs == mamba::VerificationLevel::kWarn)
  179. {
  180. return Node("warn");
  181. }
  182. else if (rhs == mamba::VerificationLevel::kEnabled)
  183. {
  184. return Node("enabled");
  185. }
  186. else
  187. {
  188. return Node();
  189. }
  190. }
  191. static bool decode(const Node& node, mamba::VerificationLevel& rhs)
  192. {
  193. if (!node.IsScalar())
  194. {
  195. return false;
  196. }
  197. auto str = node.as<std::string>();
  198. if (str == "enabled")
  199. {
  200. rhs = mamba::VerificationLevel::kEnabled;
  201. }
  202. else if (str == "warn")
  203. {
  204. rhs = mamba::VerificationLevel::kWarn;
  205. }
  206. else if (str == "disabled")
  207. {
  208. rhs = mamba::VerificationLevel::kDisabled;
  209. }
  210. else
  211. {
  212. throw std::runtime_error(
  213. "Invalid 'VerificationLevel', should be in {'enabled', 'warn', 'disabled'}"
  214. );
  215. }
  216. return true;
  217. }
  218. };
  219. template <>
  220. struct convert<mamba::ChannelPriority>
  221. {
  222. static Node encode(const mamba::ChannelPriority& rhs)
  223. {
  224. if (rhs == mamba::ChannelPriority::kStrict)
  225. {
  226. return Node("strict");
  227. }
  228. else if (rhs == mamba::ChannelPriority::kFlexible)
  229. {
  230. return Node("flexible");
  231. }
  232. else if (rhs == mamba::ChannelPriority::kDisabled)
  233. {
  234. return Node("disabled");
  235. }
  236. else
  237. {
  238. return Node();
  239. }
  240. }
  241. static bool decode(const Node& node, mamba::ChannelPriority& rhs)
  242. {
  243. if (!node.IsScalar())
  244. {
  245. return false;
  246. }
  247. auto str = node.as<std::string>();
  248. if (str == "strict")
  249. {
  250. rhs = mamba::ChannelPriority::kStrict;
  251. }
  252. else if ((str == "flexible") || (str == "true"))
  253. {
  254. rhs = mamba::ChannelPriority::kFlexible;
  255. }
  256. else if (str == "disabled")
  257. {
  258. rhs = mamba::ChannelPriority::kDisabled;
  259. }
  260. else
  261. {
  262. return false;
  263. }
  264. return true;
  265. }
  266. };
  267. template <>
  268. struct convert<fs::u8path>
  269. {
  270. static Node encode(const fs::u8path& rhs)
  271. {
  272. return Node(rhs.string());
  273. }
  274. static bool decode(const Node& node, fs::u8path& rhs)
  275. {
  276. if (!node.IsScalar())
  277. {
  278. return false;
  279. }
  280. rhs = fs::u8path(node.as<std::string>());
  281. return true;
  282. }
  283. };
  284. template <>
  285. struct convert<mamba::log_level>
  286. {
  287. private:
  288. inline static const std::array<std::string, 7> log_level_names = {
  289. "trace", "debug", "info", "warning", "error", "critical", "off"
  290. };
  291. public:
  292. static Node encode(const mamba::log_level& rhs)
  293. {
  294. return Node(log_level_names[static_cast<size_t>(rhs)]);
  295. }
  296. static bool decode(const Node& node, mamba::log_level& rhs)
  297. {
  298. auto name = node.as<std::string>();
  299. auto it = std::find(log_level_names.begin(), log_level_names.end(), name);
  300. if (it != log_level_names.end())
  301. {
  302. rhs = static_cast<mamba::log_level>(it - log_level_names.begin());
  303. return true;
  304. }
  305. LOG_ERROR << "Invalid log level, should be in {'critical', 'error', 'warning', 'info', 'debug', 'trace', 'off'} but is '"
  306. << name << "'";
  307. return false;
  308. }
  309. };
  310. }
  311. #endif