cast.hpp 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. // Copyright (c) 2023, QuantStack and Mamba Contributors
  2. //
  3. // Distributed under the terms of the BSD 3-Clause License.
  4. //
  5. // The full license is in the file LICENSE, distributed with this software.
  6. #ifndef MAMBA_CORE_UTIL_CAST_HPP
  7. #define MAMBA_CORE_UTIL_CAST_HPP
  8. #include <limits>
  9. #include <stdexcept>
  10. #include <type_traits>
  11. #include <fmt/format.h>
  12. #include "mamba/util/compare.hpp"
  13. namespace mamba::util
  14. {
  15. /**
  16. * A safe cast between arithmetic types.
  17. *
  18. * If the conversion leads to an overflow, the cast will throw an ``std::overflow_error``.
  19. * If the conversion to a floating point type loses precision, the cast will throw a
  20. * ``std::runtime_error``.
  21. */
  22. template <typename To, typename From>
  23. constexpr auto safe_num_cast(const From& val) -> To;
  24. /********************
  25. * Implementation *
  26. ********************/
  27. namespace detail
  28. {
  29. template <typename To, typename From>
  30. constexpr auto make_overflow_error(const From& val)
  31. {
  32. return std::overflow_error{ fmt::format(
  33. "Value to cast ({}) is out of destination range ([{}, {}])",
  34. val,
  35. std::numeric_limits<To>::lowest(),
  36. std::numeric_limits<To>::max()
  37. ) };
  38. };
  39. }
  40. template <typename To, typename From>
  41. constexpr auto safe_num_cast(const From& val) -> To
  42. {
  43. static_assert(std::is_arithmetic_v<From>);
  44. static_assert(std::is_arithmetic_v<To>);
  45. constexpr auto to_lowest = std::numeric_limits<To>::lowest();
  46. constexpr auto to_max = std::numeric_limits<To>::max();
  47. constexpr auto from_lowest = std::numeric_limits<From>::lowest();
  48. constexpr auto from_max = std::numeric_limits<From>::max();
  49. if constexpr (std::is_same_v<From, To>)
  50. {
  51. return val;
  52. }
  53. else if constexpr (std::is_integral_v<From> && std::is_integral_v<To>)
  54. {
  55. if constexpr (cmp_less(from_lowest, to_lowest))
  56. {
  57. if (cmp_less(val, to_lowest))
  58. {
  59. throw detail::make_overflow_error<To>(val);
  60. }
  61. }
  62. if constexpr (cmp_greater(from_max, to_max))
  63. {
  64. if (cmp_greater(val, to_max))
  65. {
  66. throw detail::make_overflow_error<To>(val);
  67. }
  68. }
  69. return static_cast<To>(val);
  70. }
  71. else
  72. {
  73. using float_type = std::common_type_t<From, To>;
  74. constexpr auto float_cast = [](const auto& x) { return static_cast<float_type>(x); };
  75. if constexpr (float_cast(from_lowest) < float_cast(to_lowest))
  76. {
  77. if (float_cast(val) < float_cast(to_lowest))
  78. {
  79. throw detail::make_overflow_error<To>(val);
  80. }
  81. }
  82. if constexpr (float_cast(from_max) > float_cast(to_max))
  83. {
  84. if (float_cast(val) > float_cast(to_max))
  85. {
  86. throw detail::make_overflow_error<To>(val);
  87. }
  88. }
  89. To cast = static_cast<To>(val);
  90. From cast_back = static_cast<From>(cast);
  91. if (cast_back != val)
  92. {
  93. throw std::runtime_error{
  94. fmt::format("Casting from {} to {} loses precision", val, cast)
  95. };
  96. }
  97. return cast;
  98. }
  99. }
  100. }
  101. #endif