Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 238 additions & 10 deletions src/solvers/smt2_incremental/convert_expr_to_smt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -971,12 +971,90 @@ static smt_termt convert_to_smt_shift(
return factory(first_operand, second_operand);
}
}
// Helper function to convert rotation expressions to SMT
static smt_termt convert_rotation_to_smt(
const shift_exprt &rotate,
const sub_expression_mapt &converted,
bool is_left)
{
const smt_termt &value = converted.at(rotate.op0());
const smt_termt &distance = converted.at(rotate.op1());

const auto value_sort = value.get_sort().cast<smt_bit_vector_sortt>();
INVARIANT(value_sort, "Rotation value must have bit-vector sort");

const auto bit_width = value_sort->bit_width();

// Try to extract constant distance for optimized SMT2 rotate operations
if(
const auto constant_distance =
expr_try_dynamic_cast<constant_exprt>(rotate.op1()))
{
const auto distance_value = numeric_cast_v<std::size_t>(*constant_distance);
const auto normalized_distance = distance_value % bit_width;

if(is_left)
return smt_bit_vector_theoryt::rotate_left(normalized_distance)(value);
else
return smt_bit_vector_theoryt::rotate_right(normalized_distance)(value);
}

// For dynamic rotation, implement using shifts and or
// rotate_left(x, n) = (x << n) | (x >> (width - n))
// rotate_right(x, n) = (x >> n) | (x << (width - n))

const auto distance_sort = distance.get_sort().cast<smt_bit_vector_sortt>();
INVARIANT(distance_sort, "Rotation distance must have bit-vector sort");

const std::size_t distance_width = distance_sort->bit_width();

// Normalize distance to bit_width to match value's width if needed
smt_termt normalized_distance = distance;
if(distance_width < bit_width)
{
normalized_distance =
smt_bit_vector_theoryt::zero_extend(bit_width - distance_width)(distance);
}
else if(distance_width > bit_width)
{
normalized_distance =
smt_bit_vector_theoryt::extract(bit_width - 1, 0)(distance);
}

// Calculate complementary distance: width - n
const auto width_constant =
smt_bit_vector_constant_termt{bit_width, bit_width};
const auto complementary_distance =
smt_bit_vector_theoryt::subtract(width_constant, normalized_distance);

smt_termt shifted_left;
smt_termt shifted_right;

if(is_left)
{
// rotate_left: (x << n) | (x >> (width - n))
shifted_left =
smt_bit_vector_theoryt::shift_left(value, normalized_distance);
shifted_right = smt_bit_vector_theoryt::logical_shift_right(
value, complementary_distance);
}
else
{
// rotate_right: (x >> n) | (x << (width - n))
shifted_right =
smt_bit_vector_theoryt::logical_shift_right(value, normalized_distance);
shifted_left =
smt_bit_vector_theoryt::shift_left(value, complementary_distance);
}

return smt_bit_vector_theoryt::make_or(shifted_left, shifted_right);
}

static smt_termt convert_expr_to_smt(
const shift_exprt &shift,
const sub_expression_mapt &converted)
{
// TODO: Dispatch for rotation expressions. A `shift_exprt` can be a rotation.
// Handle rotation expressions
if(const auto left_shift = expr_try_dynamic_cast<shl_exprt>(shift))
{
return convert_to_smt_shift(
Expand All @@ -996,6 +1074,14 @@ static smt_termt convert_expr_to_smt(
*right_arith_shift,
converted);
}
if(shift.id() == ID_rol)
{
return convert_rotation_to_smt(shift, converted, true);
}
if(shift.id() == ID_ror)
{
return convert_rotation_to_smt(shift, converted, false);
}
UNIMPLEMENTED_FEATURE(
"Generation of SMT formula for shift expression: " + shift.pretty());
}
Expand Down Expand Up @@ -1441,27 +1527,155 @@ static smt_termt convert_expr_to_smt(
const popcount_exprt &population_count,
const sub_expression_mapt &converted)
{
UNIMPLEMENTED_FEATURE(
"Generation of SMT formula for population count expression: " +
population_count.pretty());
const auto operand = converted.at(population_count.op());
const auto operand_sort = operand.get_sort().cast<smt_bit_vector_sortt>();
INVARIANT(operand_sort, "Population count operand must have bit-vector sort");

const auto bit_width = operand_sort->bit_width();

// Build a sum of each bit in the operand
// For bit vector (_ BitVec n), extract each bit and sum them
smt_termt result = smt_bit_vector_constant_termt{0, bit_width};

for(std::size_t i = 0; i < bit_width; ++i)
{
// Extract bit i and zero-extend it to bit_width
const auto bit = smt_bit_vector_theoryt::extract(i, i)(operand);
const auto extended_bit =
smt_bit_vector_theoryt::zero_extend(bit_width - 1)(bit);
result = smt_bit_vector_theoryt::add(result, extended_bit);
}

return result;
}

static smt_termt convert_expr_to_smt(
const count_leading_zeros_exprt &count_leading_zeros,
const sub_expression_mapt &converted)
{
UNIMPLEMENTED_FEATURE(
"Generation of SMT formula for count leading zeros expression: " +
count_leading_zeros.pretty());
const auto operand = converted.at(count_leading_zeros.op());
const auto operand_sort = operand.get_sort().cast<smt_bit_vector_sortt>();
INVARIANT(
operand_sort, "Count leading zeros operand must have bit-vector sort");

const auto bit_width = operand_sort->bit_width();

// Count leading zeros by checking each bit from MSB to LSB
// Result is: if operand[n-1] == 0 then (if operand[n-2] == 0 then ... else n-2) else n-1
smt_termt result = smt_bit_vector_constant_termt{bit_width, bit_width};

for(std::size_t i = 0; i < bit_width; ++i)
{
const std::size_t bit_index = bit_width - 1 - i;
const auto bit =
smt_bit_vector_theoryt::extract(bit_index, bit_index)(operand);
const auto zero_bit = smt_bit_vector_constant_termt{0, 1};
const auto bit_is_zero = smt_core_theoryt::equal(bit, zero_bit);
const auto count_value = smt_bit_vector_constant_termt{i, bit_width};
result = smt_core_theoryt::if_then_else(bit_is_zero, result, count_value);
}

return result;
}

static smt_termt convert_expr_to_smt(
const count_trailing_zeros_exprt &count_trailing_zeros,
const sub_expression_mapt &converted)
{
UNIMPLEMENTED_FEATURE(
"Generation of SMT formula for count trailing zeros expression: " +
count_trailing_zeros.pretty());
const auto operand = converted.at(count_trailing_zeros.op());
const auto operand_sort = operand.get_sort().cast<smt_bit_vector_sortt>();
INVARIANT(
operand_sort, "Count trailing zeros operand must have bit-vector sort");

const auto bit_width = operand_sort->bit_width();

// Count trailing zeros by checking each bit from LSB to MSB
// Result is: if operand[0] == 0 then (if operand[1] == 0 then ... else 1) else 0
smt_termt result = smt_bit_vector_constant_termt{bit_width, bit_width};

for(std::size_t i = 0; i < bit_width; ++i)
{
const std::size_t bit_index = bit_width - 1 - i;
const auto bit =
smt_bit_vector_theoryt::extract(bit_index, bit_index)(operand);
const auto zero_bit = smt_bit_vector_constant_termt{0, 1};
const auto bit_is_zero = smt_core_theoryt::equal(bit, zero_bit);
const auto count_value =
smt_bit_vector_constant_termt{bit_index, bit_width};
result = smt_core_theoryt::if_then_else(bit_is_zero, result, count_value);
}

return result;
}

static smt_termt convert_expr_to_smt(
const find_first_set_exprt &find_first_set,
const sub_expression_mapt &converted)
{
const auto operand = converted.at(find_first_set.operand());
const auto operand_sort = operand.get_sort().cast<smt_bit_vector_sortt>();
INVARIANT(operand_sort, "Find first set operand must have bit-vector sort");

const auto bit_width = operand_sort->bit_width();

// Find first set returns index of first set bit (1-indexed) or 0 if no bit is set
// Result is: if operand[0] == 1 then 1 else (if operand[1] == 1 then 2 else ...)
smt_termt result = smt_bit_vector_constant_termt{0, bit_width};

for(std::size_t i = 0; i < bit_width; ++i)
{
const auto bit = smt_bit_vector_theoryt::extract(i, i)(operand);
const auto one_bit = smt_bit_vector_constant_termt{1, 1};
const auto bit_is_one = smt_core_theoryt::equal(bit, one_bit);
const auto index_value = smt_bit_vector_constant_termt{i + 1, bit_width};
result = smt_core_theoryt::if_then_else(bit_is_one, index_value, result);
}

return result;
}

static smt_termt convert_expr_to_smt(
const bitreverse_exprt &bit_reverse,
const sub_expression_mapt &converted)
{
const auto operand = converted.at(bit_reverse.operand());
const auto operand_sort = operand.get_sort().cast<smt_bit_vector_sortt>();
INVARIANT(operand_sort, "Bit reverse operand must have bit-vector sort");

const auto bit_width = operand_sort->bit_width();

// Reverse bits by extracting each bit and concatenating in reverse order
if(bit_width == 1)
return operand;

smt_termt result = smt_bit_vector_theoryt::extract(0, 0)(operand);

for(std::size_t i = 1; i < bit_width; ++i)
{
const auto bit = smt_bit_vector_theoryt::extract(i, i)(operand);
result = smt_bit_vector_theoryt::concat(result, bit);
}

return result;
}

static smt_termt convert_expr_to_smt(
const bitnand_exprt &bit_nand,
const sub_expression_mapt &converted)
{
if(operands_are_of_type<bitvector_typet>(bit_nand))
{
// NAND is equivalent to NOT(AND(...))
const auto bit_and = convert_multiary_operator_to_terms(
bit_nand, converted, smt_bit_vector_theoryt::make_and);
return smt_bit_vector_theoryt::make_not(bit_and);
}
else
{
UNIMPLEMENTED_FEATURE(
"Generation of SMT formula for bitwise nand expression: " +
bit_nand.pretty());
}
}

static smt_termt convert_expr_to_smt(
Expand Down Expand Up @@ -1842,6 +2056,20 @@ static smt_termt dispatch_expr_to_smt_conversion(
{
return convert_expr_to_smt(*prophecy_pointer_in_range, converted);
}
if(
const auto find_first_set =
expr_try_dynamic_cast<find_first_set_exprt>(expr))
{
return convert_expr_to_smt(*find_first_set, converted);
}
if(const auto bit_reverse = expr_try_dynamic_cast<bitreverse_exprt>(expr))
{
return convert_expr_to_smt(*bit_reverse, converted);
}
if(const auto bit_nand = expr_try_dynamic_cast<bitnand_exprt>(expr))
{
return convert_expr_to_smt(*bit_nand, converted);
}

UNIMPLEMENTED_FEATURE(
"Generation of SMT formula for unknown kind of expression: " +
Expand Down
Loading