diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java index 90ea445be8d..2e2a874741d 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java @@ -4515,4 +4515,342 @@ private static int eulerTotient(int[] primes, int[] exponents, int[] iExponents, } return count; } + + + + + + + + private static long[] getStridesForPermutation(int[] dims) { + long[] strides = new long[dims.length]; + long stride = 1; + for( int i = dims.length - 1; i >= 0; i-- ) { + strides[i] = stride; + stride *= dims[i]; + } + return strides; + } + + public static MatrixBlock permute(MatrixBlock in, int[] inDims, int[] perm) { + return permute(in, inDims, perm, 1); + } + + public static MatrixBlock permute(MatrixBlock in, int[] inDims, int[] perm, int k) { + int rank = inDims.length; + + boolean isIdentity = true; + for( int i = 0; i < rank; i++ ) { + if( perm[i] != i ) { + isIdentity = false; + break; + } + } + + if( isIdentity ) { + return new MatrixBlock(in); + } + + int[] outDims = new int[rank]; + for( int i = 0; i < rank; i++ ) { + outDims[i] = inDims[perm[i]]; + } + + long length = 1; + for( int d : outDims ) { + length *= d; + } + + MatrixBlock out = new MatrixBlock(1, (int)length, false); + out.allocateDenseBlock(); + + DenseBlock inDB = in.getDenseBlock(); + DenseBlock outDB = out.getDenseBlock(); + + long[] inStrides = getStridesForPermutation(inDims); + long[] outStrides = getStridesForPermutation(outDims); + + long[] permutedStrides = new long[rank]; + for( int i = 0; i < rank; i++ ) { + permutedStrides[i] = outStrides[perm[i]]; + } + + boolean useParallel = (k > 1 || k == -1) && length >= PAR_NUMCELL_THRESHOLD; + int numThreads = k == -1 ? Runtime.getRuntime().availableProcessors() : k; + + if( inDB.numBlocks() == 1 && outDB.numBlocks() == 1 ) { + double[] inData = inDB.valuesAt(0); + double[] outData = outDB.valuesAt(0); + + if( useParallel && rank > 0 ) { + permuteSingleBlockParallel(inData, outData, inDims, inStrides, + permutedStrides, numThreads, length); + } else { + permuteSingleBlock(inData, outData, inDims, inStrides, + permutedStrides, 0, 0, 0); + } + } else { + if( useParallel && rank > 0 ) { + permuteMultiBlockParallel(inDB, outDB, inDims, inStrides, + permutedStrides, numThreads, length); + } else { + permuteMultiBlock(inDB, outDB, inDims, inStrides, + permutedStrides, 0, 0L, 0L); + } + } + return out; + } + + private static void permuteSingleBlock( + double[] inData, double[] outData, + int[] inDims, long[] inStrides, long[] permutedStrides, + int dim, int inOffset, int outOffset) { + + if( dim == inDims.length - 1 ) { + int len = inDims[dim]; + int outStride = (int) permutedStrides[dim]; + + if( outStride == 1 ) { + System.arraycopy(inData, inOffset, outData, outOffset, len); + } else { + transposeRow(inData, outData, inOffset, outOffset, outStride, len); + } + return; + } + + int dimSize = inDims[dim]; + long inStep = inStrides[dim]; + long outStep = permutedStrides[dim]; + + final int BLOCK_SIZE = 128; + for( int bi = 0; bi < dimSize; bi += BLOCK_SIZE ) { + int bimin = Math.min(bi + BLOCK_SIZE, dimSize); + for( int i = bi; i < bimin; i++ ) { + permuteSingleBlock( + inData, outData, inDims, inStrides, permutedStrides, + dim + 1, + inOffset + (int)(i * inStep), + outOffset + (int)(i * outStep) + ); + } + } + } + + private static void permuteSingleBlockParallel( + double[] inData, double[] outData, + int[] inDims, long[] inStrides, long[] permutedStrides, + int k, long totalElements) { + + final long elementsPerThread = Math.max(1024, (totalElements + k - 1) / k); + final int actualThreads = (int) Math.min(k, (totalElements + elementsPerThread - 1) / elementsPerThread); + + final ExecutorService pool = CommonThreadPool.get(actualThreads); + try { + final ArrayList tasks = new ArrayList<>(); + + for( int t = 0; t < actualThreads; t++ ) { + final long start = t * elementsPerThread; + final long end = Math.min(start + elementsPerThread, totalElements); + + if( start >= totalElements ) { + break; + } + + tasks.add(new PermuteSingleBlockTask(inData, outData, inDims, + inStrides, permutedStrides, start, end)); + } + + for( Future task : pool.invokeAll(tasks) ) { + task.get(); + } + } catch (Exception ex) { + throw new DMLRuntimeException(ex); + } finally { + pool.shutdown(); + } + } + + private static void permuteMultiBlock( + DenseBlock inDB, DenseBlock outDB, + int[] inDims, long[] inStrides, long[] permutedStrides, + int dim, long inOffset, long outOffset) { + + if( dim == inDims.length - 1 ) { + int len = inDims[dim]; + long outStride = permutedStrides[dim]; + + int inBlockSize = inDB.blockSize(); + int outBlockSize = outDB.blockSize(); + + for( int i = 0; i < len; i++ ) { + long currentInAbs = inOffset + i * inStrides[dim]; + long currentOutAbs = outOffset + i * outStride; + + int inBlockIdx = (int) (currentInAbs / inBlockSize); + int inRelIdx = (int) (currentInAbs % inBlockSize); + + int outBlockIdx = (int) (currentOutAbs / outBlockSize); + int outRelIdx = (int) (currentOutAbs % outBlockSize); + + double[] inArr = inDB.valuesAt(inBlockIdx); + double[] outArr = outDB.valuesAt(outBlockIdx); + + if( inArr != null && outArr != null && + inRelIdx < inArr.length && outRelIdx < outArr.length ) { + outArr[outRelIdx] = inArr[inRelIdx]; + } + } + return; + } + + int dimSize = inDims[dim]; + long inStep = inStrides[dim]; + long outStep = permutedStrides[dim]; + + final int BLOCK_SIZE = 128; + for( int bi = 0; bi < dimSize; bi += BLOCK_SIZE ) { + int bimin = Math.min(bi + BLOCK_SIZE, dimSize); + for( int i = bi; i < bimin; i++ ) { + permuteMultiBlock( + inDB, outDB, inDims, inStrides, permutedStrides, + dim + 1, + inOffset + i * inStep, + outOffset + i * outStep + ); + } + } + } + + private static void permuteMultiBlockParallel( + DenseBlock inDB, DenseBlock outDB, + int[] inDims, long[] inStrides, long[] permutedStrides, + int k, long totalElements) { + + final long elementsPerThread = Math.max(1024, (totalElements + k - 1) / k); + final int actualThreads = (int) Math.min(k, (totalElements + elementsPerThread - 1) / elementsPerThread); + + final ExecutorService pool = CommonThreadPool.get(actualThreads); + try { + final ArrayList tasks = new ArrayList<>(); + + for( int t = 0; t < actualThreads; t++ ) { + final long start = t * elementsPerThread; + final long end = Math.min(start + elementsPerThread, totalElements); + + if( start >= totalElements ) { + break; + } + + tasks.add(new PermuteMultiBlockTask(inDB, outDB, inDims, + inStrides, permutedStrides, start, end)); + } + + for( Future task : pool.invokeAll(tasks) ) { + task.get(); + } + + } catch (Exception ex) { + throw new DMLRuntimeException(ex); + } finally { + pool.shutdown(); + } + } + + private static class PermuteSingleBlockTask implements Callable { + private final double[] inData; + private final double[] outData; + private final int[] inDims; + private final long[] inStrides; + private final long[] permutedStrides; + private final long start; + private final long end; + + protected PermuteSingleBlockTask(double[] inData, double[] outData, + int[] inDims, long[] inStrides, long[] permutedStrides, + long start, long end) { + this.inData = inData; + this.outData = outData; + this.inDims = inDims; + this.inStrides = inStrides; + this.permutedStrides = permutedStrides; + this.start = start; + this.end = end; + } + + @Override + public Object call() { + for( long idx = start; idx < end; idx++ ) { + long inIdx = 0; + long outIdx = 0; + long remaining = idx; + + for( int d = 0; d < inDims.length; d++ ) { + long coord = remaining / inStrides[d]; + remaining = remaining % inStrides[d]; + inIdx += coord * inStrides[d]; + outIdx += coord * permutedStrides[d]; + } + + outData[(int)outIdx] = inData[(int)inIdx]; + } + return null; + } + } + + private static class PermuteMultiBlockTask implements Callable { + private final DenseBlock inDB; + private final DenseBlock outDB; + private final int[] inDims; + private final long[] inStrides; + private final long[] permutedStrides; + private final long start; + private final long end; + + protected PermuteMultiBlockTask(DenseBlock inDB, DenseBlock outDB, + int[] inDims, long[] inStrides, long[] permutedStrides, + long start, long end) { + this.inDB = inDB; + this.outDB = outDB; + this.inDims = inDims; + this.inStrides = inStrides; + this.permutedStrides = permutedStrides; + this.start = start; + this.end = end; + } + + @Override + public Object call() { + int inBlockSize = inDB.blockSize(); + int outBlockSize = outDB.blockSize(); + + for( long idx = start; idx < end; idx++ ) { + long inIdx = 0; + long outIdx = 0; + long remaining = idx; + + for( int d = 0; d < inDims.length; d++ ) { + long coord = remaining / inStrides[d]; + remaining = remaining % inStrides[d]; + inIdx += coord * inStrides[d]; + outIdx += coord * permutedStrides[d]; + } + + int inBlockIdx = (int) (inIdx / inBlockSize); + int inRelIdx = (int) (inIdx % inBlockSize); + + int outBlockIdx = (int) (outIdx / outBlockSize); + int outRelIdx = (int) (outIdx % outBlockSize); + + double[] inArr = inDB.valuesAt(inBlockIdx); + double[] outArr = outDB.valuesAt(outBlockIdx); + + if( inArr != null && outArr != null && + inRelIdx < inArr.length && outRelIdx < outArr.length ) { + outArr[outRelIdx] = inArr[inRelIdx]; + } + } + return null; + } + } } + diff --git a/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/PermuteTest.java b/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/PermuteTest.java new file mode 100644 index 00000000000..2b09c0ca0ae --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/PermuteTest.java @@ -0,0 +1,418 @@ +package org.apache.sysds.test.component.matrix.libMatrixReorg; + +import org.junit.Assert; +import org.junit.Test; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.LibMatrixReorg; +import org.apache.sysds.runtime.data.DenseBlock; +import org.mockito.Mockito; +import java.util.Arrays; + +public class PermuteTest { + + @Test + public void testBasicPermute() { + int[] shape = {2, 3, 4}; + MatrixBlock tensor = generateMatrixBlock(shape); + + Assert.assertEquals(24, tensor.getNumRows() * tensor.getNumColumns()); + + double[] data = tensor.getDenseBlockValues(); + Assert.assertEquals(23.0, data[1 * 4 * 3 + 2 * 4 + 3], 0.001); + Assert.assertEquals(0.0, data[0 * 4 * 3 + 0 * 4 + 0], 0.001); + + int[] permutation = {1, 0, 2}; + MatrixBlock outTensor = LibMatrixReorg.permute(tensor, shape, permutation); + + double[] outData = outTensor.getDenseBlockValues(); + Assert.assertEquals(24, outData.length); + Assert.assertEquals(4.0, outData[8], 0.001); + Assert.assertEquals(15.0, outData[7], 0.001); + } + + @Test + public void testPermute2D_Transpose() { + int[] shape = {10, 5}; + int[] perm = {1, 0}; + + MatrixBlock in = generateMatrixBlock(shape); + MatrixBlock out = LibMatrixReorg.permute(in, shape, perm); + + verifyPermutation(in, out, shape, perm); + } + + @Test + public void testPermute3D_Simple() { + int[] shape = {2, 3, 4}; + int[] perm = {1, 0, 2}; + + MatrixBlock in = generateMatrixBlock(shape); + MatrixBlock out = LibMatrixReorg.permute(in, shape, perm); + + verifyPermutation(in, out, shape, perm); + } + + @Test + public void testPermute3D_Identity() { + int[] shape = {5, 5, 5}; + int[] perm = {0, 1, 2}; + + MatrixBlock in = generateMatrixBlock(shape); + MatrixBlock out = LibMatrixReorg.permute(in, shape, perm); + + verifyPermutation(in, out, shape, perm); + } + + @Test + public void testPermute4D_Reverse() { + int[] shape = {2, 3, 4, 5}; + int[] perm = {3, 2, 1, 0}; + + MatrixBlock in = generateMatrixBlock(shape); + MatrixBlock out = LibMatrixReorg.permute(in, shape, perm); + + verifyPermutation(in, out, shape, perm); + } + + @Test + public void testPermuteHighRank() { + int[] shape = {2, 2, 2, 2, 2, 2}; + int[] perm = {5, 0, 4, 1, 3, 2}; + + MatrixBlock in = generateMatrixBlock(shape); + MatrixBlock out = LibMatrixReorg.permute(in, shape, perm); + + verifyPermutation(in, out, shape, perm); + } + + @Test + public void testLargeBlockLogic_Mocked() { + int[] shape = {10, 10, 10}; + int[] perm = {2, 0, 1}; + + MatrixBlock in = generateMatrixBlock(shape); + DenseBlock originalDB = in.getDenseBlock(); + DenseBlock spyDB = Mockito.spy(originalDB); + Mockito.when(spyDB.numBlocks()).thenReturn(2); + in.setDenseBlock(spyDB); + + MatrixBlock out = LibMatrixReorg.permute(in, shape, perm); + + MatrixBlock originalIn = generateMatrixBlock(shape); + verifyPermutation(originalIn, out, shape, perm); + } + + @Test + public void testLargeBlockLogic_Mocked_InputAndOutput() { + int[] shape = {4, 4, 4}; + int[] perm = {2, 1, 0}; + + MatrixBlock in = generateMatrixBlock(shape); + DenseBlock spyIn = Mockito.spy(in.getDenseBlock()); + Mockito.when(spyIn.numBlocks()).thenReturn(5); + in.setDenseBlock(spyIn); + + MatrixBlock out = LibMatrixReorg.permute(in, shape, perm); + + MatrixBlock originalIn = generateMatrixBlock(shape); + verifyPermutation(originalIn, out, shape, perm); + } + + @Test + public void testPermute3D_Parallel() { + int[] shape = {100, 100, 100}; + int[] perm = {2, 0, 1}; + + MatrixBlock in = generateMatrixBlock(shape); + MatrixBlock out = LibMatrixReorg.permute(in, shape, perm, -1); + + verifyPermutation(in, out, shape, perm); + } + + @Test + public void testPerformance_SingleVsMultiThreaded() { + int size = 100; + int[] shape = {size, size, size}; + int[] perm = {2, 0, 1}; + + MatrixBlock in = generateMatrixBlock(shape); + + long startSingle = System.nanoTime(); + MatrixBlock outSingle = LibMatrixReorg.permute(in, shape, perm, 1); + long timeSingle = System.nanoTime() - startSingle; + + long startMulti = System.nanoTime(); + MatrixBlock outMulti = LibMatrixReorg.permute(in, shape, perm, -1); + long timeMulti = System.nanoTime() - startMulti; + + verifyPermutation(in, outSingle, shape, perm); + verifyPermutation(in, outMulti, shape, perm); + + System.out.println("Large Matrix (" + size + "x" + size + "x" + size + "):"); + System.out.println("Single-threaded: " + timeSingle / 1_000_000 + " ms"); + System.out.println("Multi-threaded: " + timeMulti / 1_000_000 + " ms"); + System.out.println("Speedup: " + String.format("%.2fx", (double)timeSingle / timeMulti)); + + Assert.assertTrue("Multi-threaded should be faster for large matrices", timeMulti < timeSingle); + } + + @Test + public void testPerformance_LargeMatrix_SingleVsMulti() { + int[] shape = {1, 10000, 10000}; + int[] perm = {0, 2, 1}; + + MatrixBlock in = generateMatrixBlock(shape); + + long startSingle = System.nanoTime(); + MatrixBlock outSingle = LibMatrixReorg.permute(in, shape, perm, 1); + long timeSingle = System.nanoTime() - startSingle; + + long startMulti = System.nanoTime(); + MatrixBlock outMulti = LibMatrixReorg.permute(in, shape, perm, -1); + long timeMulti = System.nanoTime() - startMulti; + + System.out.println("Large Matrix (" + 1 + "x" + 10000 + "x" + 100000 + "):"); + System.out.println("Single-threaded: " + timeSingle / 1_000_000 + " ms"); + System.out.println("Multi-threaded: " + timeMulti / 1_000_000 + " ms"); + System.out.println("Speedup: " + String.format("%.2fx", (double)timeSingle / timeMulti)); + + Assert.assertTrue("Multi-threaded should be faster for large matrices", timeMulti < timeSingle); + } + + @Test + public void testPerformance_PermuteVsNativeTranspose() { + int size = 1000; + MatrixBlock in = new MatrixBlock(size, size, false); + in.allocateDenseBlock(); + double[] data = in.getDenseBlockValues(); + for (int i = 0; i < size; i++) { + for (int j = 0; j < size; j++) { + data[i * size + j] = i * size + j; + } + } + + int[] shape = {size, size}; + int[] perm = {1, 0}; + + long startPermute = System.nanoTime(); + MatrixBlock outPermute = LibMatrixReorg.permute(in, shape, perm, -1); + long timePermute = System.nanoTime() - startPermute; + + long startTranspose = System.nanoTime(); + MatrixBlock outTranspose = LibMatrixReorg.transpose(in); + long timeTranspose = System.nanoTime() - startTranspose; + + System.out.println("Transpose Performance (" + size + "x" + size + "):"); + System.out.println("Permute function: " + timePermute / 1_000_000 + " ms"); + System.out.println("Native transpose: " + timeTranspose / 1_000_000 + " ms"); + System.out.println("Ratio: " + String.format("%.2fx", (double)timePermute / timeTranspose)); + + double[] permuteData = outPermute.getDenseBlockValues(); + + for (int i = 0; i < size; i++) { + for (int j = 0; j < size; j++) { + double expected = in.get(j, i); + double actual = permuteData[i * size + j]; + Assert.assertEquals("Mismatch at (" + i + "," + j + ")", expected, actual, 0.0001); + } + } + } + + @Test + public void testEdgeCase_SingleElement() { + int[] shape = {1, 1, 1}; + int[] perm = {2, 1, 0}; + + MatrixBlock in = generateMatrixBlock(shape); + MatrixBlock out = LibMatrixReorg.permute(in, shape, perm); + + verifyPermutation(in, out, shape, perm); + } + + @Test + public void testEdgeCase_OneDimensionOne() { + int[] shape = {5, 1, 10}; + int[] perm = {2, 0, 1}; + + MatrixBlock in = generateMatrixBlock(shape); + MatrixBlock out = LibMatrixReorg.permute(in, shape, perm); + + verifyPermutation(in, out, shape, perm); + } + + @Test + public void testEdgeCase_TwoDimensionsOne() { + int[] shape = {1, 1, 100}; + int[] perm = {2, 1, 0}; + + MatrixBlock in = generateMatrixBlock(shape); + MatrixBlock out = LibMatrixReorg.permute(in, shape, perm); + + verifyPermutation(in, out, shape, perm); + } + + @Test + public void testConsecutivePermutations() { + int[] shape = {3, 4, 5}; + int[] perm1 = {1, 0, 2}; + int[] perm2 = {2, 0, 1}; + + MatrixBlock in = generateMatrixBlock(shape); + MatrixBlock temp = LibMatrixReorg.permute(in, shape, perm1); + + int[] tempShape = {shape[perm1[0]], shape[perm1[1]], shape[perm1[2]]}; + MatrixBlock out = LibMatrixReorg.permute(temp, tempShape, perm2); + + int[] finalShape = {tempShape[perm2[0]], tempShape[perm2[1]], tempShape[perm2[2]]}; + + verifyPermutation(temp, out, tempShape, perm2); + } + + @Test + public void testDifferentThreadCounts() { + int[] shape = {50, 50, 50}; + int[] perm = {2, 0, 1}; + + MatrixBlock in = generateMatrixBlock(shape); + + MatrixBlock out1 = LibMatrixReorg.permute(in, shape, perm, 1); + MatrixBlock out2 = LibMatrixReorg.permute(in, shape, perm, 2); + MatrixBlock out4 = LibMatrixReorg.permute(in, shape, perm, 4); + MatrixBlock out8 = LibMatrixReorg.permute(in, shape, perm, 8); + + double[] data1 = out1.getDenseBlockValues(); + double[] data2 = out2.getDenseBlockValues(); + double[] data4 = out4.getDenseBlockValues(); + double[] data8 = out8.getDenseBlockValues(); + + for (int i = 0; i < data1.length; i++) { + Assert.assertEquals(data1[i], data2[i], 0.0001); + Assert.assertEquals(data1[i], data4[i], 0.0001); + Assert.assertEquals(data1[i], data8[i], 0.0001); + } + } + + @Test + public void testPermute_AllDimensionsCyclic() { + int[] shape = {3, 4, 5, 2}; + int[] perm = {1, 2, 3, 0}; + + MatrixBlock in = generateMatrixBlock(shape); + MatrixBlock out = LibMatrixReorg.permute(in, shape, perm); + + verifyPermutation(in, out, shape, perm); + } + + @Test + public void testPermute_NonContiguousStrides() { + int[] shape = {7, 11, 13}; + int[] perm = {2, 0, 1}; + + MatrixBlock in = generateMatrixBlock(shape); + MatrixBlock out = LibMatrixReorg.permute(in, shape, perm); + + verifyPermutation(in, out, shape, perm); + } + + @Test + public void testPermute_LargePrimeStrides() { + int[] shape = {17, 19}; + int[] perm = {1, 0}; + + MatrixBlock in = generateMatrixBlock(shape); + MatrixBlock out = LibMatrixReorg.permute(in, shape, perm); + + verifyPermutation(in, out, shape, perm); + } + + private MatrixBlock generateMatrixBlock(int[] shape) { + long len = 1; + for (int d : shape) len *= d; + + MatrixBlock mb = new MatrixBlock(1, (int)len, false); + mb.allocateDenseBlock(); + double[] data = mb.getDenseBlockValues(); + for (int i = 0; i < data.length; i++) { + data[i] = (double) i; + } + return mb; + } + + private void verifyPermutation(MatrixBlock in, MatrixBlock out, int[] inShape, int[] perm) { + double[] inData = new double[(int)(in.getNumRows() * in.getNumColumns())]; + double[] outData = new double[(int)(out.getNumRows() * out.getNumColumns())]; + + DenseBlock inDB = in.getDenseBlock(); + DenseBlock outDB = out.getDenseBlock(); + + if (inDB != null) { + int inBlockSize = inDB.blockSize(); + for (int i = 0; i < inDB.numBlocks(); i++) { + double[] block = inDB.valuesAt(i); + int offset = i * inBlockSize; + int len = Math.min(inBlockSize, inData.length - offset); + System.arraycopy(block, 0, inData, offset, len); + } + } + + if (outDB != null) { + int outBlockSize = outDB.blockSize(); + for (int i = 0; i < outDB.numBlocks(); i++) { + double[] block = outDB.valuesAt(i); + int offset = i * outBlockSize; + int len = Math.min(outBlockSize, outData.length - offset); + System.arraycopy(block, 0, outData, offset, len); + } + } + + int rank = inShape.length; + int[] outShape = new int[rank]; + for (int i = 0; i < rank; i++) + outShape[i] = inShape[perm[i]]; + + long[] outStrides = getStrides(outShape); + long[] inStrides = getStrides(inShape); + + long len = 1; + for (int d : outShape) len *= d; + + for (long i = 0; i < len; i++) { + int[] outCoords = new int[rank]; + long temp = i; + for (int d = 0; d < rank; d++) { + outCoords[d] = (int)(temp / outStrides[d]); + temp = temp % outStrides[d]; + } + + int[] inCoords = new int[rank]; + for (int d = 0; d < rank; d++) { + inCoords[perm[d]] = outCoords[d]; + } + + long inIndex = 0; + for (int d = 0; d < rank; d++) { + inIndex += inCoords[d] * inStrides[d]; + } + + double expectedValue = inData[(int)inIndex]; + double actualValue = outData[(int)i]; + + if (Math.abs(expectedValue - actualValue) > 0.0001) { + Assert.fail("Mismatch at linear output index " + i + + ". Output coords " + Arrays.toString(outCoords) + + ". Input coords " + Arrays.toString(inCoords) + + ". Expected " + expectedValue + " but got " + actualValue); + } + } + } + + private long[] getStrides(int[] dims) { + long[] strides = new long[dims.length]; + long stride = 1; + for (int i = dims.length - 1; i >= 0; i--) { + strides[i] = stride; + stride *= dims[i]; + } + return strides; + } +} \ No newline at end of file diff --git a/src/test/java/org/apache/sysds/test/component/tensor/TransposeLinDataTest.java b/src/test/java/org/apache/sysds/test/component/tensor/TransposeLinDataTest.java new file mode 100644 index 00000000000..d7e13a8b562 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/tensor/TransposeLinDataTest.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.tensor; + +import org.junit.Assert; +import org.junit.Test; + +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.data.TensorBlock; + import java.util.Arrays; + +public class TransposeLinDataTest { + + @Test + public void Testrightelem(){ + int[] shape = {2, 3, 4}; + TensorBlock tensor = TensorUtils.createArangeTensor(shape); + + Assert.assertArrayEquals(new int[]{2, 3, 4}, tensor.getDims()); + Assert.assertEquals(0.0, tensor.get(new int[]{0, 0, 0})); + Assert.assertEquals(23.0, tensor.get(new int[]{1, 2, 3})); + Assert.assertEquals(6.0, tensor.get(new int[]{0, 1, 2})); + Assert.assertEquals(12.0, tensor.get(new int[]{1, 0, 0})); + printTensor(tensor); + + + int[] permutation = {1, 0, 2}; + TensorBlock outTensor = PermuteIt.permute(tensor, permutation); + printTensor(outTensor); + + Assert.assertArrayEquals(new int[]{3, 2, 4}, outTensor.getDims()); + Assert.assertEquals(0.0, outTensor.get(new int[]{0,0,0})); + Assert.assertEquals(23.0, outTensor.get(new int[]{2, 1, 3})); + Assert.assertEquals(12.0, outTensor.get(new int[]{0, 1, 0})); + Assert.assertEquals(17.0, outTensor.get(new int[]{1, 1, 1})); + + + int[] second_permutation = {2, 1, 0}; + TensorBlock perm2Block = PermuteIt.permute(tensor, second_permutation); + printTensor(perm2Block); + + Assert.assertArrayEquals(new int[]{4, 3, 2}, perm2Block.getDims()); + Assert.assertEquals(0.0, perm2Block.get(new int[]{0, 0, 0})); + Assert.assertEquals(12.0, perm2Block.get(new int[]{0, 0, 1})); + Assert.assertEquals(11.0, perm2Block.get(new int[]{3, 2, 0})); + Assert.assertEquals(23.0, perm2Block.get(new int[]{3, 2, 1})); + + } + + + + + public class TensorUtils { + + public static TensorBlock createArangeTensor(int[] shape) { + TensorBlock tb = new TensorBlock(ValueType.FP64, shape); + tb.allocateBlock(); + double[] counter = { 0.0 }; + int[] currentIndices = new int[shape.length]; + + fillRecursively(tb, shape, 0, currentIndices, counter); + + return tb; + } + + private static void fillRecursively(TensorBlock tb, int[] shape, int dim, int[] currentIndices, double[] counter) { + if (dim == shape.length) { + tb.set(currentIndices, counter[0]); + counter[0]++; + return; + } + + for (int i = 0; i < shape[dim]; i++) { + currentIndices[dim] = i; + + fillRecursively(tb, shape, dim + 1, currentIndices, counter); + } + } + } + + + + public class PermuteIt { + + + public static TensorBlock permute(TensorBlock tensor, int[] permute_dims) { + + int anz_dims = tensor.getNumDims(); + int[] dims = tensor.getDims(); + ValueType tensorType = tensor.getValueType(); + + int[] out_shape = new int[anz_dims]; + + for (int idx = 0; idx < anz_dims; idx++){ + out_shape[idx] = dims[permute_dims[idx]]; + } + + TensorBlock outTensor = new TensorBlock(tensorType, out_shape); + outTensor.allocateBlock(); + + int[] inIndex = new int[anz_dims]; + int[] outIndex = new int[anz_dims]; + + rekursion(tensor, outTensor, permute_dims, dims, 0, inIndex, outIndex); + return outTensor; + } + + public static void rekursion(TensorBlock inTensor, + TensorBlock outTensor, + int[] permutation, + int[] inShape, + int dim, + int[] inIndex, + int[]outIndex + ){ + + if (dim == inShape.length) { + for(int idx = 0; idx < permutation.length; idx++){ + outIndex[idx] = inIndex[permutation[idx]]; + } + double val = (double) inTensor.get(inIndex); + outTensor.set(outIndex, val); + return; + } + + for(int idx = 0; idx < inShape[dim]; idx++){ + inIndex[dim] = idx; + rekursion(inTensor, outTensor, permutation, inShape, dim+1, inIndex, outIndex); + } + + } + + } + + + public static void printTensor(TensorBlock tb) { + StringBuilder sb = new StringBuilder(); + int[] shape = tb.getDims(); + int[] currentIndices = new int[shape.length]; + + sb.append("Tensor(").append(Arrays.toString(shape)).append("):\n"); + printRecursive(tb, shape, 0, currentIndices, sb, 0); + + System.out.println(sb.toString()); + } + + private static void printRecursive(TensorBlock tb, int[] shape, int dim, int[] indices, StringBuilder sb, int indent) { + for (int k = 0; k < indent; k++) sb.append(" "); + + sb.append("["); + + if (dim == shape.length - 1) { + for (int i = 0; i < shape[dim]; i++) { + indices[dim] = i; + double val = (double) tb.get(indices); + sb.append(String.format("%.1f", val)); + if (i < shape[dim] - 1) sb.append(", "); + } + sb.append("]"); + } + + else { + sb.append("\n"); + for (int i = 0; i < shape[dim]; i++) { + indices[dim] = i; + printRecursive(tb, shape, dim + 1, indices, sb, indent + 2); + + if (i < shape[dim] - 1) { + sb.append(","); + sb.append("\n"); + if (shape.length - dim > 2) sb.append("\n"); + } + } + sb.append("\n"); + for (int k = 0; k < indent; k++) sb.append(" "); + sb.append("]"); + } + } + +} \ No newline at end of file