Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
338 changes: 338 additions & 0 deletions src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
Original file line number Diff line number Diff line change
Expand Up @@ -4515,4 +4515,342 @@ private static int eulerTotient(int[] primes, int[] exponents, int[] iExponents,
}
return count;
}







private static long[] getStridesForPermutation(int[] dims) {
long[] strides = new long[dims.length];
long stride = 1;
for( int i = dims.length - 1; i >= 0; i-- ) {
strides[i] = stride;
stride *= dims[i];
}
return strides;
}

public static MatrixBlock permute(MatrixBlock in, int[] inDims, int[] perm) {
return permute(in, inDims, perm, 1);
}

public static MatrixBlock permute(MatrixBlock in, int[] inDims, int[] perm, int k) {
int rank = inDims.length;

boolean isIdentity = true;
for( int i = 0; i < rank; i++ ) {
if( perm[i] != i ) {
isIdentity = false;
break;
}
}

if( isIdentity ) {
return new MatrixBlock(in);
}

int[] outDims = new int[rank];
for( int i = 0; i < rank; i++ ) {
outDims[i] = inDims[perm[i]];
}

long length = 1;
for( int d : outDims ) {
length *= d;
}

MatrixBlock out = new MatrixBlock(1, (int)length, false);
out.allocateDenseBlock();

DenseBlock inDB = in.getDenseBlock();
DenseBlock outDB = out.getDenseBlock();

long[] inStrides = getStridesForPermutation(inDims);
long[] outStrides = getStridesForPermutation(outDims);

long[] permutedStrides = new long[rank];
for( int i = 0; i < rank; i++ ) {
permutedStrides[i] = outStrides[perm[i]];
}

boolean useParallel = (k > 1 || k == -1) && length >= PAR_NUMCELL_THRESHOLD;
int numThreads = k == -1 ? Runtime.getRuntime().availableProcessors() : k;

if( inDB.numBlocks() == 1 && outDB.numBlocks() == 1 ) {
double[] inData = inDB.valuesAt(0);
double[] outData = outDB.valuesAt(0);

if( useParallel && rank > 0 ) {
permuteSingleBlockParallel(inData, outData, inDims, inStrides,
permutedStrides, numThreads, length);
} else {
permuteSingleBlock(inData, outData, inDims, inStrides,
permutedStrides, 0, 0, 0);
}
} else {
if( useParallel && rank > 0 ) {
permuteMultiBlockParallel(inDB, outDB, inDims, inStrides,
permutedStrides, numThreads, length);
} else {
permuteMultiBlock(inDB, outDB, inDims, inStrides,
permutedStrides, 0, 0L, 0L);
}
}
return out;
}

private static void permuteSingleBlock(
double[] inData, double[] outData,
int[] inDims, long[] inStrides, long[] permutedStrides,
int dim, int inOffset, int outOffset) {

if( dim == inDims.length - 1 ) {
int len = inDims[dim];
int outStride = (int) permutedStrides[dim];

if( outStride == 1 ) {
System.arraycopy(inData, inOffset, outData, outOffset, len);
} else {
transposeRow(inData, outData, inOffset, outOffset, outStride, len);
}
return;
}

int dimSize = inDims[dim];
long inStep = inStrides[dim];
long outStep = permutedStrides[dim];

final int BLOCK_SIZE = 128;
for( int bi = 0; bi < dimSize; bi += BLOCK_SIZE ) {
int bimin = Math.min(bi + BLOCK_SIZE, dimSize);
for( int i = bi; i < bimin; i++ ) {
permuteSingleBlock(
inData, outData, inDims, inStrides, permutedStrides,
dim + 1,
inOffset + (int)(i * inStep),
outOffset + (int)(i * outStep)
);
}
}
}

private static void permuteSingleBlockParallel(
double[] inData, double[] outData,
int[] inDims, long[] inStrides, long[] permutedStrides,
int k, long totalElements) {

final long elementsPerThread = Math.max(1024, (totalElements + k - 1) / k);
final int actualThreads = (int) Math.min(k, (totalElements + elementsPerThread - 1) / elementsPerThread);

final ExecutorService pool = CommonThreadPool.get(actualThreads);
try {
final ArrayList<PermuteSingleBlockTask> tasks = new ArrayList<>();

for( int t = 0; t < actualThreads; t++ ) {
final long start = t * elementsPerThread;
final long end = Math.min(start + elementsPerThread, totalElements);

if( start >= totalElements ) {
break;
}

tasks.add(new PermuteSingleBlockTask(inData, outData, inDims,
inStrides, permutedStrides, start, end));
}

for( Future<Object> task : pool.invokeAll(tasks) ) {
task.get();
}
} catch (Exception ex) {
throw new DMLRuntimeException(ex);
} finally {
pool.shutdown();
}
}

private static void permuteMultiBlock(
DenseBlock inDB, DenseBlock outDB,
int[] inDims, long[] inStrides, long[] permutedStrides,
int dim, long inOffset, long outOffset) {

if( dim == inDims.length - 1 ) {
int len = inDims[dim];
long outStride = permutedStrides[dim];

int inBlockSize = inDB.blockSize();
int outBlockSize = outDB.blockSize();

for( int i = 0; i < len; i++ ) {
long currentInAbs = inOffset + i * inStrides[dim];
long currentOutAbs = outOffset + i * outStride;

int inBlockIdx = (int) (currentInAbs / inBlockSize);
int inRelIdx = (int) (currentInAbs % inBlockSize);

int outBlockIdx = (int) (currentOutAbs / outBlockSize);
int outRelIdx = (int) (currentOutAbs % outBlockSize);

double[] inArr = inDB.valuesAt(inBlockIdx);
double[] outArr = outDB.valuesAt(outBlockIdx);

if( inArr != null && outArr != null &&
inRelIdx < inArr.length && outRelIdx < outArr.length ) {
outArr[outRelIdx] = inArr[inRelIdx];
}
}
return;
}

int dimSize = inDims[dim];
long inStep = inStrides[dim];
long outStep = permutedStrides[dim];

final int BLOCK_SIZE = 128;
for( int bi = 0; bi < dimSize; bi += BLOCK_SIZE ) {
int bimin = Math.min(bi + BLOCK_SIZE, dimSize);
for( int i = bi; i < bimin; i++ ) {
permuteMultiBlock(
inDB, outDB, inDims, inStrides, permutedStrides,
dim + 1,
inOffset + i * inStep,
outOffset + i * outStep
);
}
}
}

private static void permuteMultiBlockParallel(
DenseBlock inDB, DenseBlock outDB,
int[] inDims, long[] inStrides, long[] permutedStrides,
int k, long totalElements) {

final long elementsPerThread = Math.max(1024, (totalElements + k - 1) / k);
final int actualThreads = (int) Math.min(k, (totalElements + elementsPerThread - 1) / elementsPerThread);

final ExecutorService pool = CommonThreadPool.get(actualThreads);
try {
final ArrayList<PermuteMultiBlockTask> tasks = new ArrayList<>();

for( int t = 0; t < actualThreads; t++ ) {
final long start = t * elementsPerThread;
final long end = Math.min(start + elementsPerThread, totalElements);

if( start >= totalElements ) {
break;
}

tasks.add(new PermuteMultiBlockTask(inDB, outDB, inDims,
inStrides, permutedStrides, start, end));
}

for( Future<Object> task : pool.invokeAll(tasks) ) {
task.get();
}

} catch (Exception ex) {
throw new DMLRuntimeException(ex);
} finally {
pool.shutdown();
}
}

private static class PermuteSingleBlockTask implements Callable<Object> {
private final double[] inData;
private final double[] outData;
private final int[] inDims;
private final long[] inStrides;
private final long[] permutedStrides;
private final long start;
private final long end;

protected PermuteSingleBlockTask(double[] inData, double[] outData,
int[] inDims, long[] inStrides, long[] permutedStrides,
long start, long end) {
this.inData = inData;
this.outData = outData;
this.inDims = inDims;
this.inStrides = inStrides;
this.permutedStrides = permutedStrides;
this.start = start;
this.end = end;
}

@Override
public Object call() {
for( long idx = start; idx < end; idx++ ) {
long inIdx = 0;
long outIdx = 0;
long remaining = idx;

for( int d = 0; d < inDims.length; d++ ) {
long coord = remaining / inStrides[d];
remaining = remaining % inStrides[d];
inIdx += coord * inStrides[d];
outIdx += coord * permutedStrides[d];
}

outData[(int)outIdx] = inData[(int)inIdx];
}
return null;
}
}

private static class PermuteMultiBlockTask implements Callable<Object> {
private final DenseBlock inDB;
private final DenseBlock outDB;
private final int[] inDims;
private final long[] inStrides;
private final long[] permutedStrides;
private final long start;
private final long end;

protected PermuteMultiBlockTask(DenseBlock inDB, DenseBlock outDB,
int[] inDims, long[] inStrides, long[] permutedStrides,
long start, long end) {
this.inDB = inDB;
this.outDB = outDB;
this.inDims = inDims;
this.inStrides = inStrides;
this.permutedStrides = permutedStrides;
this.start = start;
this.end = end;
}

@Override
public Object call() {
int inBlockSize = inDB.blockSize();
int outBlockSize = outDB.blockSize();

for( long idx = start; idx < end; idx++ ) {
long inIdx = 0;
long outIdx = 0;
long remaining = idx;

for( int d = 0; d < inDims.length; d++ ) {
long coord = remaining / inStrides[d];
remaining = remaining % inStrides[d];
inIdx += coord * inStrides[d];
outIdx += coord * permutedStrides[d];
}

int inBlockIdx = (int) (inIdx / inBlockSize);
int inRelIdx = (int) (inIdx % inBlockSize);

int outBlockIdx = (int) (outIdx / outBlockSize);
int outRelIdx = (int) (outIdx % outBlockSize);

double[] inArr = inDB.valuesAt(inBlockIdx);
double[] outArr = outDB.valuesAt(outBlockIdx);

if( inArr != null && outArr != null &&
inRelIdx < inArr.length && outRelIdx < outArr.length ) {
outArr[outRelIdx] = inArr[inRelIdx];
}
}
return null;
}
}
}

Loading