From 3f13177fac78aecbef0930a64dc0fef7de3eada3 Mon Sep 17 00:00:00 2001 From: Jessica Priebe Date: Tue, 24 Feb 2026 13:47:05 +0100 Subject: [PATCH 1/3] added OOC cbind --- .../instructions/OOCInstructionParser.java | 3 + .../ooc/AppendOOCInstruction.java | 181 ++++++++++++++++++ .../instructions/ooc/OOCInstruction.java | 2 +- .../sysds/test/functions/ooc/CBindTest.java | 131 +++++++++++++ src/test/scripts/functions/ooc/CBindTest.dml | 26 +++ 5 files changed, 342 insertions(+), 1 deletion(-) create mode 100644 src/main/java/org/apache/sysds/runtime/instructions/ooc/AppendOOCInstruction.java create mode 100644 src/test/java/org/apache/sysds/test/functions/ooc/CBindTest.java create mode 100644 src/test/scripts/functions/ooc/CBindTest.dml diff --git a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java index 607acbb3a0c..973ef4be146 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java @@ -42,6 +42,7 @@ import org.apache.sysds.runtime.instructions.ooc.MapMMChainOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.ReorgOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.TeeOOCInstruction; +import org.apache.sysds.runtime.instructions.ooc.AppendOOCInstruction; public class OOCInstructionParser extends InstructionParser { protected static final Log LOG = LogFactory.getLog(OOCInstructionParser.class.getName()); @@ -106,6 +107,8 @@ else if(parts.length == 4) return IndexingOOCInstruction.parseInstruction(str); case Rand: return DataGenOOCInstruction.parseInstruction(str); + case Append: + return AppendOOCInstruction.parseInstruction(str); default: throw new DMLRuntimeException("Invalid OOC Instruction Type: " + ooctype); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AppendOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AppendOOCInstruction.java new file mode 100644 index 00000000000..7df3791f342 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AppendOOCInstruction.java @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.ooc; + +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.functionobjects.OffsetColumnIndex; +import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.MatrixIndexes; +import org.apache.sysds.runtime.matrix.operators.Operator; +import org.apache.sysds.runtime.matrix.operators.ReorgOperator; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Function; + +public class AppendOOCInstruction extends BinaryOOCInstruction { + + public enum AppendType { + CBIND + } + + protected final AppendType _type; + + protected AppendOOCInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, AppendType type, + String opcode, String istr) { + super(OOCType.Append, op, in1, in2, out, opcode, istr); + _type = type; + } + + public static AppendOOCInstruction parseInstruction(String str) { + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + InstructionUtils.checkNumFields(parts, 5, 4); + + String opcode = parts[0]; + CPOperand in1 = new CPOperand(parts[1]); + CPOperand in2 = new CPOperand(parts[2]); + CPOperand out = new CPOperand(parts[parts.length-2]); + boolean cbind = Boolean.parseBoolean(parts[parts.length-1]); + + if(in1.getDataType() != Types.DataType.MATRIX || in2.getDataType() != Types.DataType.MATRIX || !cbind){ + throw new DMLRuntimeException("Only matrix-matrix cbind is supported"); + } + AppendType type = AppendType.CBIND; + + Operator op = new ReorgOperator(OffsetColumnIndex.getOffsetColumnIndexFnObject(-1)); + return new AppendOOCInstruction(op, in1, in2, out, type, opcode, str); + } + + @Override + public void processInstruction(ExecutionContext ec) { + MatrixObject in1 = ec.getMatrixObject(input1); + MatrixObject in2 = ec.getMatrixObject(input2); + validateInput(in1, in2); + + OOCStream qIn1 = in1.getStreamHandle(); + OOCStream qIn2 = in2.getStreamHandle(); + + int blksize = in1.getBlocksize(); + int rem1 = (int) in1.getNumColumns()%blksize; + int rem2 = (int) in2.getNumColumns()%blksize; + int cblk1 = (int) in1.getDataCharacteristics().getNumColBlocks(); + int cblk2 = (int) in2.getDataCharacteristics().getNumColBlocks(); + + if(rem1+rem2 == 0){ + // no shifting needed + OOCStream out = new SubscribableTaskQueue<>(); + mapOOC(qIn2, out, imv -> new IndexedMatrixValue( + new MatrixIndexes(imv.getIndexes().getRowIndex(), cblk1+imv.getIndexes().getColumnIndex()), imv.getValue())); + + ec.getMatrixObject(output).setStreamHandle(mergeOOCStreams(List.of(qIn1, out))); + return; + } + + List> split1 = splitOOCStream(qIn1, imv -> imv.getIndexes().getColumnIndex()==cblk1? 1 : 0, 2); + List> split2 = splitOOCStream(qIn2, imv -> (int) imv.getIndexes().getColumnIndex()-1, cblk2); + + OOCStream head = split1.get(0); + OOCStream lastCol = split1.get(1); + OOCStream firstCol = split2.get(0); + + CachingStream firstColCache = new CachingStream(firstCol); + OOCStream firstColForCritical = firstColCache.getReadStream(); + OOCStream firstColForTail = firstColCache.getReadStream(); + + SubscribableTaskQueue out = new SubscribableTaskQueue<>(); + Function rowKey = imv -> new MatrixIndexes(imv.getIndexes().getRowIndex(), 1); + + // combine cols both matrices + joinOOC(lastCol, firstColForCritical, out, (left, right) -> { + MatrixBlock lb = (MatrixBlock) left.getValue(); + MatrixBlock rb = (MatrixBlock) right.getValue(); + int stop = cblk2>1? blksize-rem1 : rem2; + MatrixBlock combined = cbindBlocks(lb, sliceCols(rb, 0, stop)); + return new IndexedMatrixValue( + new MatrixIndexes(left.getIndexes().getRowIndex(), left.getIndexes().getColumnIndex()), combined); + }, rowKey); + + List> outStreams = new ArrayList<>(); + outStreams.add(head); + outStreams.add(out); + + // shift cols second matrix + OOCStream fst = firstColForTail; + OOCStream sec = null; + for(int i=0; i(); + CachingStream secCachingStream = new CachingStream(split2.get(i+1)); + sec = secCachingStream.getReadStream(); + + int finalI = i; + joinOOC(fst, sec, out, (left, right) -> { + MatrixBlock lb = (MatrixBlock) left.getValue(); + MatrixBlock rb = (MatrixBlock) right.getValue(); + int stop = finalI+2==cblk2 ? rem2 : blksize-rem1; + MatrixBlock combined = cbindBlocks(sliceCols(lb, blksize-rem1, blksize), sliceCols(rb, 0, stop)); + return new IndexedMatrixValue( + new MatrixIndexes(left.getIndexes().getRowIndex(), cblk1 + left.getIndexes().getColumnIndex()), + combined); + }, rowKey); + + fst = secCachingStream.getReadStream(); + outStreams.add(out); + } + + if(rem1+rem2 > blksize){ + // overflow + int remSize = (rem1+rem2)%blksize; + out = new SubscribableTaskQueue<>(); + mapOOC(fst, out, imv -> new IndexedMatrixValue( + new MatrixIndexes(imv.getIndexes().getRowIndex(), cblk1+imv.getIndexes().getColumnIndex()), + sliceCols((MatrixBlock) imv.getValue(), rem2-remSize, rem2))); + + outStreams.add(out); + } + ec.getMatrixObject(output).setStreamHandle(mergeOOCStreams(outStreams)); + } + + public AppendType getAppendType() { + return _type; + } + + private void validateInput(MatrixObject m1, MatrixObject m2) { + if(_type == AppendType.CBIND && m1.getNumRows() != m2.getNumRows()) { + throw new DMLRuntimeException( + "Append-cbind is not possible for input matrices " + input1.getName() + " and " + input2.getName() + + " with different number of rows: " + m1.getNumRows() + " vs " + m2.getNumRows()); + } + } + + private static MatrixBlock sliceCols(MatrixBlock in, int colStart, int colEndExclusive) { + // slice is inclusive + return in.slice(0, in.getNumRows()-1, colStart, colEndExclusive-1); + } + + private static MatrixBlock cbindBlocks(MatrixBlock left, MatrixBlock right) { + return left.append(right, new MatrixBlock()); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java index f7cefe635df..931d45e0f45 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java @@ -75,7 +75,7 @@ public abstract class OOCInstruction extends Instruction { public enum OOCType { Reblock, Tee, Binary, Ternary, Unary, AggregateUnary, AggregateBinary, AggregateTernary, MAPMM, MMTSJ, - MAPMMCHAIN, Reorg, CM, Ctable, MatrixIndexing, ParameterizedBuiltin, Rand + MAPMMCHAIN, Reorg, CM, Ctable, MatrixIndexing, ParameterizedBuiltin, Rand, Append } protected final OOCInstruction.OOCType _ooctype; diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/CBindTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/CBindTest.java new file mode 100644 index 00000000000..e91a3c3902f --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/ooc/CBindTest.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.ooc; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.instructions.Instruction; +import org.apache.sysds.runtime.io.MatrixWriter; +import org.apache.sysds.runtime.io.MatrixWriterFactory; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +public class CBindTest extends AutomatedTestBase { + + private static final String TEST_NAME = "CBindTest"; + private static final String TEST_DIR = "functions/ooc/"; + private static final String TEST_CLASS_DIR = TEST_DIR + CBindTest.class.getSimpleName() + "/"; + + private final static double eps = 1e-8; + private static final String INPUT_NAME_1 = "A"; + private static final String INPUT_NAME_2 = "B"; + private static final String OUTPUT_NAME = "res"; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME)); + } + + @Test + public void testCBindAppendBlock() { runCBindTest(1000, 1000, 1000, 1000);} + + @Test + public void testCBindPartialFillSingleRightBlock() { runCBindTest(1000, 1100, 1000, 100);} + + @Test + public void testCBindTotalFillSingleBlockEachSide() { runCBindTest(1000, 500, 1000, 500);} + + @Test + public void testCBindTotalFillTwoRightBlocks() { runCBindTest(1000, 3500, 1000, 1500);} + + @Test + public void testCBindPartialFillMultipleBlocksEachSide() { runCBindTest(1000, 3100, 1000, 3200);} + + @Test + public void testCBindTotalFillMultipleBlocksEachSide() { runCBindTest(1000, 3500, 1000, 3500);} + + @Test + public void testCBindOverflowSingleRightBlock() { runCBindTest(1000, 1600, 1000, 600);} + + @Test + public void testCBindOverflowMultipleRightBlocks() { runCBindTest(1000, 1600, 1000, 2600);} + + @Test + public void testCBindMultipleRows() { runCBindTest(2500, 1500, 2500, 1500);} + + @Test + public void testCBind() {runCBindTest(2300, 1655, 2300, 2542);} + + public void runCBindTest(int r1, int c1, int r2, int c2) { + Types.ExecMode platformOld = rtplatform; + rtplatform = Types.ExecMode.SINGLE_NODE; + + try { + getAndLoadTestConfiguration(TEST_NAME); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + + double[][] A = TestUtils.floor(getRandomMatrix(r1, c1, -1, 1, 1.0, 7)); + double[][] B = TestUtils.floor(getRandomMatrix(r2, c2, -1, 1, 1.0, 13)); + + MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY); + writer.writeMatrixToHDFS(DataConverter.convertToMatrixBlock(A), input(INPUT_NAME_1), r1, c1, 1000, r1*c1); + writer.writeMatrixToHDFS(DataConverter.convertToMatrixBlock(B), input(INPUT_NAME_2), r2, c2, 1000, r2*c2); + + HDFSTool.writeMetaDataFile(input(INPUT_NAME_1 + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(r1, c1, 1000, r1*c1), Types.FileFormat.BINARY); + HDFSTool.writeMetaDataFile(input(INPUT_NAME_2 + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(r2, c2, 1000, r2*c2), Types.FileFormat.BINARY); + + programArgs = new String[] {"-explain", "-stats", "-ooc", "-args", + input(INPUT_NAME_1), input(INPUT_NAME_2), output(OUTPUT_NAME)}; + runTest(true, false, null, -1); + + Assert.assertTrue("OOC wasn't used for cbind", + heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.APPEND)); + + // rerun without ooc flag + programArgs = new String[] {"-explain", "-stats", "-args", + input(INPUT_NAME_1), input(INPUT_NAME_2), output(OUTPUT_NAME + "_target")}; + runTest(true, false, null, -1); + + // compare results + MatrixBlock ret1 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME), + Types.FileFormat.BINARY, r1, c1+c2, 1000); + MatrixBlock ret2 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME + "_target"), + Types.FileFormat.BINARY, r1, c1+c2, 1000); + TestUtils.compareMatrices(ret1, ret2, eps); + } + catch(Exception ex) { + Assert.fail(ex.getMessage()); + } + finally { + resetExecMode(platformOld); + } + } +} diff --git a/src/test/scripts/functions/ooc/CBindTest.dml b/src/test/scripts/functions/ooc/CBindTest.dml new file mode 100644 index 00000000000..edfbddafc0f --- /dev/null +++ b/src/test/scripts/functions/ooc/CBindTest.dml @@ -0,0 +1,26 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A = read($1) +B = read($2) +res = cbind(A, B) + +write(res, $3, format="binary"); From 4a9c8aa5d08d5a4ee087e9d5eff1338855b5a0eb Mon Sep 17 00:00:00 2001 From: Jessica Priebe Date: Tue, 24 Feb 2026 18:32:38 +0100 Subject: [PATCH 2/3] minor fix + new test --- .../instructions/ooc/AppendOOCInstruction.java | 12 +++++++----- .../apache/sysds/test/functions/ooc/CBindTest.java | 10 ++++++++++ 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AppendOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AppendOOCInstruction.java index 7df3791f342..45c4cfcb6ab 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AppendOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AppendOOCInstruction.java @@ -83,8 +83,9 @@ public void processInstruction(ExecutionContext ec) { int rem2 = (int) in2.getNumColumns()%blksize; int cblk1 = (int) in1.getDataCharacteristics().getNumColBlocks(); int cblk2 = (int) in2.getDataCharacteristics().getNumColBlocks(); + int cblkRes = (int) Math.ceil((double)(in1.getNumColumns()+in2.getNumColumns())/blksize); - if(rem1+rem2 == 0){ + if(rem1==0){ // no shifting needed OOCStream out = new SubscribableTaskQueue<>(); mapOOC(qIn2, out, imv -> new IndexedMatrixValue( @@ -108,11 +109,12 @@ public void processInstruction(ExecutionContext ec) { SubscribableTaskQueue out = new SubscribableTaskQueue<>(); Function rowKey = imv -> new MatrixIndexes(imv.getIndexes().getRowIndex(), 1); + int fullRem2 = rem2==0? blksize : rem2; // combine cols both matrices joinOOC(lastCol, firstColForCritical, out, (left, right) -> { MatrixBlock lb = (MatrixBlock) left.getValue(); MatrixBlock rb = (MatrixBlock) right.getValue(); - int stop = cblk2>1? blksize-rem1 : rem2; + int stop = cblk2==1 && blksize-rem1>fullRem2? fullRem2 : blksize-rem1; MatrixBlock combined = cbindBlocks(lb, sliceCols(rb, 0, stop)); return new IndexedMatrixValue( new MatrixIndexes(left.getIndexes().getRowIndex(), left.getIndexes().getColumnIndex()), combined); @@ -134,7 +136,7 @@ public void processInstruction(ExecutionContext ec) { joinOOC(fst, sec, out, (left, right) -> { MatrixBlock lb = (MatrixBlock) left.getValue(); MatrixBlock rb = (MatrixBlock) right.getValue(); - int stop = finalI+2==cblk2 ? rem2 : blksize-rem1; + int stop = finalI+2==cblk2 && blksize-rem1>fullRem2? fullRem2 : blksize-rem1; MatrixBlock combined = cbindBlocks(sliceCols(lb, blksize-rem1, blksize), sliceCols(rb, 0, stop)); return new IndexedMatrixValue( new MatrixIndexes(left.getIndexes().getRowIndex(), cblk1 + left.getIndexes().getColumnIndex()), @@ -145,13 +147,13 @@ public void processInstruction(ExecutionContext ec) { outStreams.add(out); } - if(rem1+rem2 > blksize){ + if(cblk1+cblk2==cblkRes){ // overflow int remSize = (rem1+rem2)%blksize; out = new SubscribableTaskQueue<>(); mapOOC(fst, out, imv -> new IndexedMatrixValue( new MatrixIndexes(imv.getIndexes().getRowIndex(), cblk1+imv.getIndexes().getColumnIndex()), - sliceCols((MatrixBlock) imv.getValue(), rem2-remSize, rem2))); + sliceCols((MatrixBlock) imv.getValue(), fullRem2-remSize, fullRem2))); outStreams.add(out); } diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/CBindTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/CBindTest.java index e91a3c3902f..16a9107caab 100644 --- a/src/test/java/org/apache/sysds/test/functions/ooc/CBindTest.java +++ b/src/test/java/org/apache/sysds/test/functions/ooc/CBindTest.java @@ -54,6 +54,10 @@ public void setUp() { @Test public void testCBindAppendBlock() { runCBindTest(1000, 1000, 1000, 1000);} + // TODO: fix OOC internals + // @Test + public void testCBindAppendBlockTwoLeftBlocks() {runCBindTest(1000, 2000, 1000, 1000);} + @Test public void testCBindPartialFillSingleRightBlock() { runCBindTest(1000, 1100, 1000, 100);} @@ -75,6 +79,12 @@ public void setUp() { @Test public void testCBindOverflowMultipleRightBlocks() { runCBindTest(1000, 1600, 1000, 2600);} + @Test + public void testCBindOverflowTotalFilledSingleRightBlock() {runCBindTest(1000, 1100, 1000, 1000);} + + @Test + public void testCBindOverflowTotalFilledTwoRightBlocks() {runCBindTest(1000, 1100, 1000, 2000);} + @Test public void testCBindMultipleRows() { runCBindTest(2500, 1500, 2500, 1500);} From 2e5c22a6a2d3ae8fcb8df55478398414be96ae85 Mon Sep 17 00:00:00 2001 From: Jannik Lindemann Date: Wed, 25 Feb 2026 09:51:57 +0100 Subject: [PATCH 3/3] Fix: Handle GroupQueueCallbacks to Avoid Miscounts --- .../sysds/runtime/ooc/stream/MergedOOCStream.java | 11 +++++++++++ .../apache/sysds/test/functions/ooc/CBindTest.java | 3 +-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/ooc/stream/MergedOOCStream.java b/src/main/java/org/apache/sysds/runtime/ooc/stream/MergedOOCStream.java index 7d0a27932f1..0c036d16c20 100644 --- a/src/main/java/org/apache/sysds/runtime/ooc/stream/MergedOOCStream.java +++ b/src/main/java/org/apache/sysds/runtime/ooc/stream/MergedOOCStream.java @@ -85,6 +85,17 @@ public MergedOOCStream(List> sources) { if(_failed.get()) return; + if(cb instanceof OOCStream.GroupQueueCallback) { + OOCStream.GroupQueueCallback group = (OOCStream.GroupQueueCallback) cb; + for(int i = 0; i < group.size(); i++) { + OOCStream.QueueCallback sub = group.getCallback(i); + try(sub) { + _taskQueue.enqueue(sub.keepOpen()); + } + } + return; + } + _taskQueue.enqueue(cb.keepOpen()); } } diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/CBindTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/CBindTest.java index 16a9107caab..f645bbba23f 100644 --- a/src/test/java/org/apache/sysds/test/functions/ooc/CBindTest.java +++ b/src/test/java/org/apache/sysds/test/functions/ooc/CBindTest.java @@ -54,8 +54,7 @@ public void setUp() { @Test public void testCBindAppendBlock() { runCBindTest(1000, 1000, 1000, 1000);} - // TODO: fix OOC internals - // @Test + @Test public void testCBindAppendBlockTwoLeftBlocks() {runCBindTest(1000, 2000, 1000, 1000);} @Test