diff --git a/src/solvers/smt2_incremental/convert_expr_to_smt.cpp b/src/solvers/smt2_incremental/convert_expr_to_smt.cpp index d5910628470..d899b02ef59 100644 --- a/src/solvers/smt2_incremental/convert_expr_to_smt.cpp +++ b/src/solvers/smt2_incremental/convert_expr_to_smt.cpp @@ -971,12 +971,90 @@ static smt_termt convert_to_smt_shift( return factory(first_operand, second_operand); } } +// Helper function to convert rotation expressions to SMT +static smt_termt convert_rotation_to_smt( + const shift_exprt &rotate, + const sub_expression_mapt &converted, + bool is_left) +{ + const smt_termt &value = converted.at(rotate.op0()); + const smt_termt &distance = converted.at(rotate.op1()); + + const auto value_sort = value.get_sort().cast(); + INVARIANT(value_sort, "Rotation value must have bit-vector sort"); + + const auto bit_width = value_sort->bit_width(); + + // Try to extract constant distance for optimized SMT2 rotate operations + if( + const auto constant_distance = + expr_try_dynamic_cast(rotate.op1())) + { + const auto distance_value = numeric_cast_v(*constant_distance); + const auto normalized_distance = distance_value % bit_width; + + if(is_left) + return smt_bit_vector_theoryt::rotate_left(normalized_distance)(value); + else + return smt_bit_vector_theoryt::rotate_right(normalized_distance)(value); + } + + // For dynamic rotation, implement using shifts and or + // rotate_left(x, n) = (x << n) | (x >> (width - n)) + // rotate_right(x, n) = (x >> n) | (x << (width - n)) + + const auto distance_sort = distance.get_sort().cast(); + INVARIANT(distance_sort, "Rotation distance must have bit-vector sort"); + + const std::size_t distance_width = distance_sort->bit_width(); + + // Normalize distance to bit_width to match value's width if needed + smt_termt normalized_distance = distance; + if(distance_width < bit_width) + { + normalized_distance = + smt_bit_vector_theoryt::zero_extend(bit_width - distance_width)(distance); + } + else if(distance_width > bit_width) + { + normalized_distance = + smt_bit_vector_theoryt::extract(bit_width - 1, 0)(distance); + } + + // Calculate complementary distance: width - n + const auto width_constant = + smt_bit_vector_constant_termt{bit_width, bit_width}; + const auto complementary_distance = + smt_bit_vector_theoryt::subtract(width_constant, normalized_distance); + + smt_termt shifted_left; + smt_termt shifted_right; + + if(is_left) + { + // rotate_left: (x << n) | (x >> (width - n)) + shifted_left = + smt_bit_vector_theoryt::shift_left(value, normalized_distance); + shifted_right = smt_bit_vector_theoryt::logical_shift_right( + value, complementary_distance); + } + else + { + // rotate_right: (x >> n) | (x << (width - n)) + shifted_right = + smt_bit_vector_theoryt::logical_shift_right(value, normalized_distance); + shifted_left = + smt_bit_vector_theoryt::shift_left(value, complementary_distance); + } + + return smt_bit_vector_theoryt::make_or(shifted_left, shifted_right); +} static smt_termt convert_expr_to_smt( const shift_exprt &shift, const sub_expression_mapt &converted) { - // TODO: Dispatch for rotation expressions. A `shift_exprt` can be a rotation. + // Handle rotation expressions if(const auto left_shift = expr_try_dynamic_cast(shift)) { return convert_to_smt_shift( @@ -996,6 +1074,14 @@ static smt_termt convert_expr_to_smt( *right_arith_shift, converted); } + if(shift.id() == ID_rol) + { + return convert_rotation_to_smt(shift, converted, true); + } + if(shift.id() == ID_ror) + { + return convert_rotation_to_smt(shift, converted, false); + } UNIMPLEMENTED_FEATURE( "Generation of SMT formula for shift expression: " + shift.pretty()); } @@ -1441,27 +1527,155 @@ static smt_termt convert_expr_to_smt( const popcount_exprt &population_count, const sub_expression_mapt &converted) { - UNIMPLEMENTED_FEATURE( - "Generation of SMT formula for population count expression: " + - population_count.pretty()); + const auto operand = converted.at(population_count.op()); + const auto operand_sort = operand.get_sort().cast(); + INVARIANT(operand_sort, "Population count operand must have bit-vector sort"); + + const auto bit_width = operand_sort->bit_width(); + + // Build a sum of each bit in the operand + // For bit vector (_ BitVec n), extract each bit and sum them + smt_termt result = smt_bit_vector_constant_termt{0, bit_width}; + + for(std::size_t i = 0; i < bit_width; ++i) + { + // Extract bit i and zero-extend it to bit_width + const auto bit = smt_bit_vector_theoryt::extract(i, i)(operand); + const auto extended_bit = + smt_bit_vector_theoryt::zero_extend(bit_width - 1)(bit); + result = smt_bit_vector_theoryt::add(result, extended_bit); + } + + return result; } static smt_termt convert_expr_to_smt( const count_leading_zeros_exprt &count_leading_zeros, const sub_expression_mapt &converted) { - UNIMPLEMENTED_FEATURE( - "Generation of SMT formula for count leading zeros expression: " + - count_leading_zeros.pretty()); + const auto operand = converted.at(count_leading_zeros.op()); + const auto operand_sort = operand.get_sort().cast(); + INVARIANT( + operand_sort, "Count leading zeros operand must have bit-vector sort"); + + const auto bit_width = operand_sort->bit_width(); + + // Count leading zeros by checking each bit from MSB to LSB + // Result is: if operand[n-1] == 0 then (if operand[n-2] == 0 then ... else n-2) else n-1 + smt_termt result = smt_bit_vector_constant_termt{bit_width, bit_width}; + + for(std::size_t i = 0; i < bit_width; ++i) + { + const std::size_t bit_index = bit_width - 1 - i; + const auto bit = + smt_bit_vector_theoryt::extract(bit_index, bit_index)(operand); + const auto zero_bit = smt_bit_vector_constant_termt{0, 1}; + const auto bit_is_zero = smt_core_theoryt::equal(bit, zero_bit); + const auto count_value = smt_bit_vector_constant_termt{i, bit_width}; + result = smt_core_theoryt::if_then_else(bit_is_zero, result, count_value); + } + + return result; } static smt_termt convert_expr_to_smt( const count_trailing_zeros_exprt &count_trailing_zeros, const sub_expression_mapt &converted) { - UNIMPLEMENTED_FEATURE( - "Generation of SMT formula for count trailing zeros expression: " + - count_trailing_zeros.pretty()); + const auto operand = converted.at(count_trailing_zeros.op()); + const auto operand_sort = operand.get_sort().cast(); + INVARIANT( + operand_sort, "Count trailing zeros operand must have bit-vector sort"); + + const auto bit_width = operand_sort->bit_width(); + + // Count trailing zeros by checking each bit from LSB to MSB + // Result is: if operand[0] == 0 then (if operand[1] == 0 then ... else 1) else 0 + smt_termt result = smt_bit_vector_constant_termt{bit_width, bit_width}; + + for(std::size_t i = 0; i < bit_width; ++i) + { + const std::size_t bit_index = bit_width - 1 - i; + const auto bit = + smt_bit_vector_theoryt::extract(bit_index, bit_index)(operand); + const auto zero_bit = smt_bit_vector_constant_termt{0, 1}; + const auto bit_is_zero = smt_core_theoryt::equal(bit, zero_bit); + const auto count_value = + smt_bit_vector_constant_termt{bit_index, bit_width}; + result = smt_core_theoryt::if_then_else(bit_is_zero, result, count_value); + } + + return result; +} + +static smt_termt convert_expr_to_smt( + const find_first_set_exprt &find_first_set, + const sub_expression_mapt &converted) +{ + const auto operand = converted.at(find_first_set.operand()); + const auto operand_sort = operand.get_sort().cast(); + INVARIANT(operand_sort, "Find first set operand must have bit-vector sort"); + + const auto bit_width = operand_sort->bit_width(); + + // Find first set returns index of first set bit (1-indexed) or 0 if no bit is set + // Result is: if operand[0] == 1 then 1 else (if operand[1] == 1 then 2 else ...) + smt_termt result = smt_bit_vector_constant_termt{0, bit_width}; + + for(std::size_t i = 0; i < bit_width; ++i) + { + const auto bit = smt_bit_vector_theoryt::extract(i, i)(operand); + const auto one_bit = smt_bit_vector_constant_termt{1, 1}; + const auto bit_is_one = smt_core_theoryt::equal(bit, one_bit); + const auto index_value = smt_bit_vector_constant_termt{i + 1, bit_width}; + result = smt_core_theoryt::if_then_else(bit_is_one, index_value, result); + } + + return result; +} + +static smt_termt convert_expr_to_smt( + const bitreverse_exprt &bit_reverse, + const sub_expression_mapt &converted) +{ + const auto operand = converted.at(bit_reverse.operand()); + const auto operand_sort = operand.get_sort().cast(); + INVARIANT(operand_sort, "Bit reverse operand must have bit-vector sort"); + + const auto bit_width = operand_sort->bit_width(); + + // Reverse bits by extracting each bit and concatenating in reverse order + if(bit_width == 1) + return operand; + + smt_termt result = smt_bit_vector_theoryt::extract(0, 0)(operand); + + for(std::size_t i = 1; i < bit_width; ++i) + { + const auto bit = smt_bit_vector_theoryt::extract(i, i)(operand); + result = smt_bit_vector_theoryt::concat(result, bit); + } + + return result; +} + +static smt_termt convert_expr_to_smt( + const bitnand_exprt &bit_nand, + const sub_expression_mapt &converted) +{ + if(operands_are_of_type(bit_nand)) + { + // NAND is equivalent to NOT(AND(...)) + const auto bit_and = convert_multiary_operator_to_terms( + bit_nand, converted, smt_bit_vector_theoryt::make_and); + return smt_bit_vector_theoryt::make_not(bit_and); + } + else + { + UNIMPLEMENTED_FEATURE( + "Generation of SMT formula for bitwise nand expression: " + + bit_nand.pretty()); + } } static smt_termt convert_expr_to_smt( @@ -1842,6 +2056,20 @@ static smt_termt dispatch_expr_to_smt_conversion( { return convert_expr_to_smt(*prophecy_pointer_in_range, converted); } + if( + const auto find_first_set = + expr_try_dynamic_cast(expr)) + { + return convert_expr_to_smt(*find_first_set, converted); + } + if(const auto bit_reverse = expr_try_dynamic_cast(expr)) + { + return convert_expr_to_smt(*bit_reverse, converted); + } + if(const auto bit_nand = expr_try_dynamic_cast(expr)) + { + return convert_expr_to_smt(*bit_nand, converted); + } UNIMPLEMENTED_FEATURE( "Generation of SMT formula for unknown kind of expression: " +