Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 227 additions & 0 deletions src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,8 @@ pub enum SingleExpressionInner {
Call(Call),
/// Match expression.
Match(Match),
/// If expression.
If(If),
}

/// Call of a user-defined or of a builtin function.
Expand Down Expand Up @@ -403,6 +405,38 @@ impl MatchArm {
}
}

#[derive(Clone, Debug)]
pub struct If {
scrutinee: Arc<Expression>,
then_arm: Arc<Expression>,
else_arm: Arc<Expression>,
span: Span,
}

impl If {
/// Access the expression who's output is deconstructed in the `if`.
pub fn scrutinee(&self) -> &Expression {
&self.scrutinee
}

/// Access the branch that handles the `true` portion of the `if`.
pub fn then_arm(&self) -> &Expression {
&self.then_arm
}

/// Access the branch that handles the `false` or `else` portion of the `if`.
pub fn else_arm(&self) -> &Expression {
&self.else_arm
}

/// Access the span of the if statement.
pub fn span(&self) -> &Span {
&self.span
}
}

impl_eq_hash!(If; scrutinee, then_arm, else_arm);

/// Item when analyzing modules.
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub enum ModuleItem {
Expand Down Expand Up @@ -462,6 +496,7 @@ pub enum ExprTree<'a> {
Single(&'a SingleExpression),
Call(&'a Call),
Match(&'a Match),
If(&'a If),
}

impl TreeLike for ExprTree<'_> {
Expand Down Expand Up @@ -502,13 +537,19 @@ impl TreeLike for ExprTree<'_> {
}
S::Call(call) => Tree::Unary(Self::Call(call)),
S::Match(match_) => Tree::Unary(Self::Match(match_)),
S::If(if_) => Tree::Unary(Self::If(if_)),
},
Self::Call(call) => Tree::Nary(call.args().iter().map(Self::Expression).collect()),
Self::Match(match_) => Tree::Nary(Arc::new([
Self::Expression(match_.scrutinee()),
Self::Expression(match_.left().expression()),
Self::Expression(match_.right().expression()),
])),
Self::If(if_) => Tree::Nary(Arc::new([
Self::Expression(if_.scrutinee()),
Self::Expression(if_.then_arm()),
Self::Expression(if_.else_arm()),
])),
}
}
}
Expand Down Expand Up @@ -1059,6 +1100,9 @@ impl AbstractSyntaxTree for SingleExpression {
parse::SingleExpressionInner::Match(match_) => {
Match::analyze(match_, ty, scope).map(SingleExpressionInner::Match)?
}
parse::SingleExpressionInner::If(if_) => {
If::analyze(if_, ty, scope).map(SingleExpressionInner::If)?
}
};

Ok(Self {
Expand Down Expand Up @@ -1426,6 +1470,28 @@ impl AbstractSyntaxTree for Match {
}
}

impl AbstractSyntaxTree for If {
type From = parse::If;

fn analyze(from: &Self::From, ty: &ResolvedType, scope: &mut Scope) -> Result<Self, RichError> {
let scrutinee =
Expression::analyze(from.scrutinee(), &ResolvedType::boolean(), scope).map(Arc::new)?;
scope.push_scope();
let ast_then = Expression::analyze(from.then_arm(), ty, scope).map(Arc::new)?;
scope.pop_scope();
scope.push_scope();
let ast_else = Expression::analyze(from.else_arm(), ty, scope).map(Arc::new)?;
scope.pop_scope();

Ok(Self {
scrutinee,
then_arm: ast_then,
else_arm: ast_else,
span: *from.as_ref(),
})
}
}

fn analyze_named_module(
name: ModuleName,
from: &parse::ModuleProgram,
Expand Down Expand Up @@ -1559,6 +1625,12 @@ impl AsRef<Span> for Match {
}
}

impl AsRef<Span> for If {
fn as_ref(&self) -> &Span {
&self.span
}
}

impl AsRef<Span> for Module {
fn as_ref(&self) -> &Span {
&self.span
Expand All @@ -1570,3 +1642,158 @@ impl AsRef<Span> for ModuleAssignment {
&self.span
}
}

#[cfg(test)]
mod test {
use super::*;
use crate::parse::{self, ParseFromStr};
use crate::types::UIntType;

/// Helper to check if an expression is a constant, unwrapping blocks if needed
fn is_constant_expr(expr: &Expression) -> bool {
match expr.inner() {
ExpressionInner::Single(single) => {
matches!(single.inner(), SingleExpressionInner::Constant(_))
}
ExpressionInner::Block(_, Some(inner_expr)) => is_constant_expr(inner_expr),
_ => false,
}
}

/// Helper to check if an expression is a block with statements
fn is_block_with_statements(expr: &Expression) -> bool {
matches!(expr.inner(), ExpressionInner::Block(stmts, Some(_)) if !stmts.is_empty())
}

fn parse_if(input: &str) -> parse::If {
// Parse the if expression
let parsed_expr = parse::Expression::parse_from_str(input).expect("Failed to parse");

// Extract the parsed If from the expression
let parsed_if = match parsed_expr.inner() {
parse::ExpressionInner::Single(single) => match single.inner() {
parse::SingleExpressionInner::If(if_) => if_.clone(),
_ => panic!("Expected If expression"),
},
_ => panic!("Expected Single expression"),
};
parsed_if
}

#[test]
fn test_if_expression_analyze() {
let input = "if true { 0 } else { 1 }";

let parsed_if = &parse_if(input);

// Analyze the if expression with u8 as the expected type
let expected_type = ResolvedType::from(UIntType::U8);
let mut scope = Scope::default();
let ast_if = If::analyze(parsed_if, &expected_type, &mut scope)
.expect("Failed to analyze If expression");

// Verify the structure
assert_eq!(
ast_if.scrutinee().ty(),
&ResolvedType::boolean(),
"Scrutinee should be boolean type"
);
assert_eq!(
ast_if.then_arm().ty(),
&expected_type,
"Then arm should have u8 type"
);
assert_eq!(
ast_if.else_arm().ty(),
&expected_type,
"Else arm should have u8 type"
);

// Verify scrutinee is a boolean constant
match ast_if.scrutinee().inner() {
ExpressionInner::Single(single) => match single.inner() {
SingleExpressionInner::Constant(_) => {
// Boolean constant verified
}
_ => panic!("Expected boolean constant for scrutinee"),
},
_ => panic!("Expected single expression for scrutinee"),
}

// Verify both arms are constants (may be wrapped in blocks)
assert!(
is_constant_expr(ast_if.then_arm()),
"Then arm should be a constant"
);
assert!(
is_constant_expr(ast_if.else_arm()),
"Else arm should be a constant"
);
}

#[test]
fn test_if_expression_with_complex_arms() {
let input = "if false { let x: u8 = 5; x } else { 10 }";

let parsed_expr = parse::Expression::parse_from_str(input).expect("Failed to parse");
let expected_type = ResolvedType::from(UIntType::U8);

// Analyze the entire expression (which will handle the if internally)
let ast_expr = Expression::analyze_const(&parsed_expr, &expected_type)
.expect("Failed to analyze expression");

// Verify the expression is an If
match ast_expr.inner() {
ExpressionInner::Single(single) => match single.inner() {
SingleExpressionInner::If(ast_if) => {
assert_eq!(ast_if.scrutinee().ty(), &ResolvedType::boolean());
assert_eq!(ast_if.then_arm().ty(), &expected_type);
assert_eq!(ast_if.else_arm().ty(), &expected_type);

// Verify then arm is a block with statements and else arm is a constant
assert!(
is_block_with_statements(ast_if.then_arm()),
"Then arm should be a block with statements"
);
assert!(
is_constant_expr(ast_if.else_arm()),
"Else arm should be a constant"
);
}
_ => panic!("Expected If expression"),
},
_ => panic!("Expected Single expression"),
}
}

#[test]
fn test_if_valid_parse_but_invalid_ast() {
let input = "if false { let x: u8 = 5; } else { 10 }";

let parsed_if = &parse_if(input);
let expected_type = ResolvedType::from(UIntType::U8);
let mut scope = Scope::default();
let ast_if_result = If::analyze(parsed_if, &expected_type, &mut scope);

assert!(ast_if_result
.err()
.map(|e| matches!(e.error(), Error::ExpressionTypeMismatch(..)))
.unwrap());
}

#[test]
fn test_if_valid_parse_but_invalid_scrutinee() {
let input = "if (()) { 1 } else { 10 }";

let parsed_if = &parse_if(input);
let expected_type = ResolvedType::from(UIntType::U8);
let mut scope = Scope::default();
let ast_if_result = If::analyze(parsed_if, &expected_type, &mut scope);

// Expected type of scrutinee is `bool`
assert!(ast_if_result
.err()
.map(|e| matches!(e.error(), Error::ExpressionUnexpectedType(..)))
.unwrap());
}
}
77 changes: 76 additions & 1 deletion src/compile/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use simplicity::{types, Cmr, FailEntropy};
use self::builtins::array_fold;
use crate::array::{BTreeSlice, Partition};
use crate::ast::{
Call, CallName, Expression, ExpressionInner, Match, Program, SingleExpression,
Call, CallName, Expression, ExpressionInner, If, Match, Program, SingleExpression,
SingleExpressionInner, Statement,
};
use crate::debug::CallTracker;
Expand Down Expand Up @@ -355,6 +355,7 @@ impl SingleExpression {
}
SingleExpressionInner::Call(call) => call.compile(scope)?,
SingleExpressionInner::Match(match_) => match_.compile(scope)?,
SingleExpressionInner::If(if_) => if_.compile(scope)?,
};

scope
Expand Down Expand Up @@ -680,3 +681,77 @@ impl Match {
input.comp(&output).with_span(self)
}
}

impl If {
fn compile<'brand>(
&self,
scope: &mut Scope<'brand>,
) -> Result<PairBuilder<ProgNode<'brand>>, RichError> {
scope.push_scope();
scope.insert(Pattern::Ignore);
let then_arm = self.then_arm().compile(scope)?;
scope.pop_scope();
scope.push_scope();
scope.insert(Pattern::Ignore);
let else_arm = self.else_arm().compile(scope)?;
scope.pop_scope();

let scrutinee = self.scrutinee().compile(scope)?;
let input = scrutinee.pair(PairBuilder::iden(scope.ctx()));
// Left = false, right = true
let output = ProgNode::case(else_arm.as_ref(), then_arm.as_ref()).with_span(self)?;
input.comp(&output).with_span(self)
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use super::*;
use crate::parse::ParseFromStr;
use crate::witness::Arguments;
use crate::{ast, parse};

fn compile_program(
input: &str,
) -> Result<Arc<named::CommitNode<Elements>>, crate::error::RichError> {
let parse_program = parse::Program::parse_from_str(input).expect("Failed to parse");
let ast_program = ast::Program::analyze(&parse_program).expect("Failed to analyze");
ast_program.compile(Arguments::default(), false)
}

#[test]
fn match_equivalent_to_if_compiles() {
// The same logic expressed using `match`, which is known to compile correctly.
// Used as a baseline to confirm the test infrastructure works.
let input_match = r#"fn main() {
let x: u16 = 2;
let _s: (bool, u16) = match true {
true => jet::add_16(x, 2),
false => jet::add_16(x, 3),
};
}"#;
let match_node = compile_program(input_match).expect("Match expression should compile");
// Verifies that an if expression with non-unit arms compiles correctly.
//
// This works because in Simplicity types are binary tries: `1 × A = A`
// definitionally (unit contributes zero bits), so the bool scrutinee
// `Either<1, 1>` pairs correctly with the `case` combinator's type
// `(1+1) × Input`.
let input_if = r#"fn main() {
let x: u16 = 2;
let _u: (bool, u16) = if true {
jet::add_16(x, 2)
} else {
jet::add_16(x, 3)
};
}"#;
let if_node = compile_program(input_if).expect("If expression should compile");

assert_eq!(
match_node.display_expr().to_string(),
if_node.display_expr().to_string()
);
}
}
4 changes: 4 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,8 @@ pub enum Error {
ModuleRedefined(ModuleName),
ArgumentMissing(WitnessName),
ArgumentTypeMismatch(WitnessName, ResolvedType, ResolvedType),
IfThenArmMissingBraces,
IfElseArmMissingBraces,
}

#[rustfmt::skip]
Expand Down Expand Up @@ -582,6 +584,8 @@ impl fmt::Display for Error {
f,
"Parameter `{name}` was declared with type `{declared}` but its assigned argument is of type `{assigned}`"
),
Error::IfThenArmMissingBraces | Error::IfElseArmMissingBraces => write!(f, "If statement must have enclosing braces")

}
}
}
Expand Down
Loading