Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.commons.math3.random.Well1024a;
import org.apache.sysds.common.Opcodes;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.DataGenOp;
Expand Down Expand Up @@ -177,6 +178,7 @@ else if(method == Types.OpOpDG.SEQ) {
@Override
public void processInstruction(ExecutionContext ec) {
final OOCStream<IndexedMatrixValue> qOut = createWritableStream();
ec.getMatrixObject(output).setStreamHandle(qOut);

// process specific datagen operator
if (method == Types.OpOpDG.RAND) {
Expand All @@ -188,9 +190,6 @@ public void processInstruction(ExecutionContext ec) {
long lcols = ec.getScalarInput(cols).getLongValue();
checkValidDimensions(lrows, lcols);

if (!pdf.equalsIgnoreCase("uniform") || minValue != maxValue)
throw new NotImplementedException(); // TODO modified version of rng as in LibMatrixDatagen to handle blocks independently

OOCStream<MatrixIndexes> qIn = createWritableStream();
int nrb = (int)((lrows-1) / blen)+1;
int ncb = (int)((lcols-1) / blen)+1;
Expand All @@ -210,10 +209,37 @@ public void processInstruction(ExecutionContext ec) {
return;
}

if(sparsity == 1.0 && minValue == maxValue) {
mapOOC(qIn, qOut, idx -> {
long rlen = Math.min(blen, lrows - (idx.getRowIndex()-1) * blen);
long clen = Math.min(blen, lcols - (idx.getColumnIndex()-1) * blen);
return new IndexedMatrixValue(idx, new MatrixBlock((int)rlen, (int)clen, minValue));
});
return;
}

Well1024a bigrand = LibMatrixDatagen.setupSeedsForRand(lSeed);
int nb = nrb * ncb;
long[] seeds = new long[nb];
for(int i = 0; i < nb; i++) seeds[i] = bigrand.nextLong();

mapOOC(qIn, qOut, idx -> {
long rlen = Math.min(blen, lrows - (idx.getRowIndex()-1) * blen);
long clen = Math.min(blen, lcols - (idx.getColumnIndex()-1) * blen);
MatrixBlock mout = MatrixBlock.randOperations(getGenerator(rlen, clen), lSeed);

int r = (int) idx.getRowIndex()-1;
int c = (int) idx.getColumnIndex()-1;
long bSeed = seeds[r*ncb+c];

final long estnnz = ((minValue==0.0 && maxValue==0.0) ? 0 : (long)(sparsity * rlen * clen));
boolean lsparse = MatrixBlock.evalSparseFormatInMemory(rlen, clen, estnnz);

MatrixBlock mout = new MatrixBlock();
mout.reset((int) rlen, (int) clen, lsparse, estnnz);
mout.allocateBlock();

LibMatrixDatagen.genRandomNumbers(false, 0, 1, 0, 1, mout, getGenerator(rlen, clen), bSeed, null);
mout.recomputeNonZeros();
return new IndexedMatrixValue(idx, mout);
});
}
Expand Down Expand Up @@ -263,8 +289,6 @@ else if(method == Types.OpOpDG.SEQ) {
}
else
throw new NotImplementedException();

ec.getMatrixObject(output).setStreamHandle(qOut);
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ private static long[] sliceSeedsForCP(long[] seeds, int rl, int ru, int cl, int
return lseeds;
}

private static void genRandomNumbers(boolean invokedFromCP, int rl, int ru, int cl, int cu, MatrixBlock out, RandomMatrixGenerator rgen, long bSeed, long[] seeds) {
public static void genRandomNumbers(boolean invokedFromCP, int rl, int ru, int cl, int cu, MatrixBlock out, RandomMatrixGenerator rgen, long bSeed, long[] seeds) {
int rows = rgen._rows;
int cols = rgen._cols;
int blen = rgen._blocksize;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,10 @@ public void setUp() {
addTestConfiguration(TEST_NAME_2, config2);
}

// Actual rand operation not yet supported
/*@Test
@Test
public void testRand() {
runRandTest(TEST_NAME_1);
}*/
}

@Test
public void testConstInit() {
Expand Down
2 changes: 1 addition & 1 deletion src/test/scripts/functions/ooc/Rand1.dml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@
#
#-------------------------------------------------------------

res = rand(rows=1500, cols=1200, min=-1, max=1);
res = rand(rows=1500, cols=1200, min=-1, max=1, seed=42);

write(res, $2, format="binary");