Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions prism/prism.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
#ifndef PRISM_H
#define PRISM_H

#ifdef __cplusplus
extern "C" {
#endif

#include "prism/defines.h"
#include "prism/util/pm_buffer.h"
#include "prism/util/pm_char.h"
Expand Down Expand Up @@ -403,4 +407,8 @@ PRISM_EXPORTED_FUNCTION pm_string_query_t pm_string_query_method_name(const uint
* ```
*/

#ifdef __cplusplus
}
#endif

#endif
28 changes: 16 additions & 12 deletions string.c
Original file line number Diff line number Diff line change
Expand Up @@ -6226,9 +6226,11 @@ rb_str_sub_bang(int argc, VALUE *argv, VALUE str)
}
else {
repl = argv[1];
hash = rb_check_hash_type(argv[1]);
if (NIL_P(hash)) {
StringValue(repl);
if (!RB_TYPE_P(repl, T_STRING)) {
hash = rb_check_hash_type(repl);
if (NIL_P(hash)) {
StringValue(repl);
}
}
}

Expand Down Expand Up @@ -6356,15 +6358,17 @@ str_gsub(int argc, VALUE *argv, VALUE str, int bang)
break;
case 2:
repl = argv[1];
hash = rb_check_hash_type(argv[1]);
if (NIL_P(hash)) {
StringValue(repl);
}
else if (rb_hash_default_unredefined(hash) && !FL_TEST_RAW(hash, RHASH_PROC_DEFAULT)) {
mode = FAST_MAP;
}
else {
mode = MAP;
if (!RB_TYPE_P(repl, T_STRING)) {
hash = rb_check_hash_type(repl);
if (NIL_P(hash)) {
StringValue(repl);
}
else if (rb_hash_default_unredefined(hash) && !FL_TEST_RAW(hash, RHASH_PROC_DEFAULT)) {
mode = FAST_MAP;
}
else {
mode = MAP;
}
}
break;
default:
Expand Down
33 changes: 24 additions & 9 deletions zjit/src/backend/lir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2220,9 +2220,8 @@ impl Assembler

/// Compile Target::SideExit and convert it into Target::CodePtr for all instructions
pub fn compile_exits(&mut self) {
/// Compile the main side-exit code. This function takes only SideExit so
/// that it can be safely deduplicated by using SideExit as a dedup key.
fn compile_exit(asm: &mut Assembler, exit: &SideExit) {
/// Restore VM state (cfp->pc, cfp->sp, stack, locals) for the side exit.
fn compile_exit_save_state(asm: &mut Assembler, exit: &SideExit) {
let SideExit { pc, stack, locals } = exit;

// Side exit blocks are not part of the CFG at the moment,
Expand All @@ -2249,12 +2248,22 @@ impl Assembler
asm.store(Opnd::mem(64, SP, (-local_size_and_idx_to_ep_offset(locals.len(), idx) - 1) * SIZEOF_VALUE_I32), opnd);
}
}
}

/// Tear down the JIT frame and return to the interpreter.
fn compile_exit_return(asm: &mut Assembler) {
asm_comment!(asm, "exit to the interpreter");
asm.frame_teardown(&[]); // matching the setup in gen_entry_point()
asm.cret(Opnd::UImm(Qundef.as_u64()));
}

/// Compile the main side-exit code. This function takes only SideExit so
/// that it can be safely deduplicated by using SideExit as a dedup key.
fn compile_exit(asm: &mut Assembler, exit: &SideExit) {
compile_exit_save_state(asm, exit);
compile_exit_return(asm);
}

fn join_opnds(opnds: &Vec<Opnd>, delimiter: &str) -> String {
opnds.iter().map(|opnd| format!("{opnd}")).collect::<Vec<_>>().join(delimiter)
}
Expand Down Expand Up @@ -2310,13 +2319,19 @@ impl Assembler
}

if should_record_exit {
// Save VM state before the ccall so that
// rb_profile_frames sees valid cfp->pc and the
// ccall doesn't clobber caller-saved registers
// holding stack/local operands.
compile_exit_save_state(self, &exit);
asm_ccall!(self, rb_zjit_record_exit_stack, pc);
}

// If the side exit has already been compiled, jump to it.
// Otherwise, let it fall through and compile the exit next.
if let Some(&exit_label) = compiled_exits.get(&exit) {
self.jmp(Target::Label(exit_label));
compile_exit_return(self);
} else {
// If the side exit has already been compiled, jump to it.
// Otherwise, let it fall through and compile the exit next.
if let Some(&exit_label) = compiled_exits.get(&exit) {
self.jmp(Target::Label(exit_label));
}
}
Some(counted_exit)
} else {
Expand Down
10 changes: 10 additions & 0 deletions zjit/src/hir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4500,6 +4500,16 @@ impl Function {
_ => None,
})
}
Insn::FixnumMod { left, right, .. } => {
self.fold_fixnum_bop(insn_id, left, right, |l, r| match (l, r) {
(Some(l), Some(r)) if r != 0 => {
let l_obj = VALUE::fixnum_from_isize(l as isize);
let r_obj = VALUE::fixnum_from_isize(r as isize);
Some(unsafe { rb_jit_fix_mod_fix(l_obj, r_obj) }.as_fixnum())
},
_ => None,
})
}
Insn::FixnumEq { left, right, .. } => {
self.fold_fixnum_pred(insn_id, left, right, |l, r| match (l, r) {
(Some(l), Some(r)) => Some(l == r),
Expand Down
190 changes: 190 additions & 0 deletions zjit/src/hir/opt_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,196 @@ mod hir_opt_tests {
");
}


#[test]
fn test_fold_fixnum_mod_zero_by_zero() {
eval("
def test
0 % 0
end
");
assert_snapshot!(hir_string("test"), @r"
fn test@<compiled>:3:
bb0():
EntryPoint interpreter
v1:BasicObject = LoadSelf
Jump bb2(v1)
bb1(v4:BasicObject):
EntryPoint JIT(0)
Jump bb2(v4)
bb2(v6:BasicObject):
v10:Fixnum[0] = Const Value(0)
v12:Fixnum[0] = Const Value(0)
PatchPoint MethodRedefined(Integer@0x1000, %@0x1008, cme:0x1010)
v22:Fixnum = FixnumMod v10, v12
IncrCounter inline_cfunc_optimized_send_count
CheckInterrupts
Return v22
");
}

#[test]
fn test_fold_fixnum_mod_non_zero_by_zero() {
eval("
def test
11 % 0
end
");
assert_snapshot!(hir_string("test"), @r"
fn test@<compiled>:3:
bb0():
EntryPoint interpreter
v1:BasicObject = LoadSelf
Jump bb2(v1)
bb1(v4:BasicObject):
EntryPoint JIT(0)
Jump bb2(v4)
bb2(v6:BasicObject):
v10:Fixnum[11] = Const Value(11)
v12:Fixnum[0] = Const Value(0)
PatchPoint MethodRedefined(Integer@0x1000, %@0x1008, cme:0x1010)
v22:Fixnum = FixnumMod v10, v12
IncrCounter inline_cfunc_optimized_send_count
CheckInterrupts
Return v22
");
}

#[test]
fn test_fold_fixnum_mod_zero_by_non_zero() {
eval("
def test
0 % 11
end
");
assert_snapshot!(hir_string("test"), @r"
fn test@<compiled>:3:
bb0():
EntryPoint interpreter
v1:BasicObject = LoadSelf
Jump bb2(v1)
bb1(v4:BasicObject):
EntryPoint JIT(0)
Jump bb2(v4)
bb2(v6:BasicObject):
v10:Fixnum[0] = Const Value(0)
v12:Fixnum[11] = Const Value(11)
PatchPoint MethodRedefined(Integer@0x1000, %@0x1008, cme:0x1010)
v24:Fixnum[0] = Const Value(0)
IncrCounter inline_cfunc_optimized_send_count
CheckInterrupts
Return v24
");
}

#[test]
fn test_fold_fixnum_mod() {
eval("
def test
11 % 3
end
");
assert_snapshot!(hir_string("test"), @r"
fn test@<compiled>:3:
bb0():
EntryPoint interpreter
v1:BasicObject = LoadSelf
Jump bb2(v1)
bb1(v4:BasicObject):
EntryPoint JIT(0)
Jump bb2(v4)
bb2(v6:BasicObject):
v10:Fixnum[11] = Const Value(11)
v12:Fixnum[3] = Const Value(3)
PatchPoint MethodRedefined(Integer@0x1000, %@0x1008, cme:0x1010)
v24:Fixnum[2] = Const Value(2)
IncrCounter inline_cfunc_optimized_send_count
CheckInterrupts
Return v24
");
}

#[test]
fn test_fold_fixnum_mod_negative_numerator() {
eval("
def test
-7 % 3
end
");
assert_snapshot!(hir_string("test"), @r"
fn test@<compiled>:3:
bb0():
EntryPoint interpreter
v1:BasicObject = LoadSelf
Jump bb2(v1)
bb1(v4:BasicObject):
EntryPoint JIT(0)
Jump bb2(v4)
bb2(v6:BasicObject):
v10:Fixnum[-7] = Const Value(-7)
v12:Fixnum[3] = Const Value(3)
PatchPoint MethodRedefined(Integer@0x1000, %@0x1008, cme:0x1010)
v24:Fixnum[2] = Const Value(2)
IncrCounter inline_cfunc_optimized_send_count
CheckInterrupts
Return v24
");
}

#[test]
fn test_fold_fixnum_mod_negative_denominator() {
eval("
def test
7 % -3
end
");
assert_snapshot!(hir_string("test"), @r"
fn test@<compiled>:3:
bb0():
EntryPoint interpreter
v1:BasicObject = LoadSelf
Jump bb2(v1)
bb1(v4:BasicObject):
EntryPoint JIT(0)
Jump bb2(v4)
bb2(v6:BasicObject):
v10:Fixnum[7] = Const Value(7)
v12:Fixnum[-3] = Const Value(-3)
PatchPoint MethodRedefined(Integer@0x1000, %@0x1008, cme:0x1010)
v24:Fixnum[-2] = Const Value(-2)
IncrCounter inline_cfunc_optimized_send_count
CheckInterrupts
Return v24
");
}

#[test]
fn test_fold_fixnum_mod_negative() {
eval("
def test
-7 % -3
end
");
assert_snapshot!(hir_string("test"), @r"
fn test@<compiled>:3:
bb0():
EntryPoint interpreter
v1:BasicObject = LoadSelf
Jump bb2(v1)
bb1(v4:BasicObject):
EntryPoint JIT(0)
Jump bb2(v4)
bb2(v6:BasicObject):
v10:Fixnum[-7] = Const Value(-7)
v12:Fixnum[-3] = Const Value(-3)
PatchPoint MethodRedefined(Integer@0x1000, %@0x1008, cme:0x1010)
v24:Fixnum[-1] = Const Value(-1)
IncrCounter inline_cfunc_optimized_send_count
CheckInterrupts
Return v24
");
}

#[test]
fn test_fold_fixnum_less() {
eval("
Expand Down