Skip to content

Commit

Permalink
[DAP/Whisper] Remove 'memref.copy' operation in 'dap.whisper_preproce…
Browse files Browse the repository at this point in the history
…ss'.
  • Loading branch information
taiqzheng committed Aug 30, 2024
1 parent ace1f45 commit c8e44e3
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 36 deletions.
15 changes: 12 additions & 3 deletions frontend/Interfaces/buddy/DAP/DSP/WhisperPreprocess.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,24 @@ namespace dap {
namespace detail {
// Declare the whisper preprocess C interface.
extern "C" {
void _mlir_ciface_buddy_whisperPreprocess(MemRef<double, 1> *inputRawSpeech,
MemRef<float, 3> *outputFeatures);
// The original MLIR function:
// ```mlir
// func.func @buddy_whisperPreprocess(%in : memref<?xf64>) ->
// memref<1x80x3000xf32>
// ```
//
// After applying the '-llvm-request-c-wrappers' pass:
// The result of the function (memref<1x80x3000xf32>) is modified to be the
// first operand of the function.
void _mlir_ciface_buddy_whisperPreprocess(MemRef<float, 3> *outputFeatures,
MemRef<double, 1> *inputRawSpeech);
}
} // namespace detail

// Function for Whisper preprocess
void whisperPreprocess(MemRef<double, 1> *inputRawSpeech,
MemRef<float, 3> *outputFeatures) {
detail::_mlir_ciface_buddy_whisperPreprocess(inputRawSpeech, outputFeatures);
detail::_mlir_ciface_buddy_whisperPreprocess(outputFeatures, inputRawSpeech);
}

} // namespace dap
Expand Down
6 changes: 3 additions & 3 deletions frontend/Interfaces/lib/DAP-extend.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
func.func @buddy_whisperPreprocess(%in : memref<?xf64>, %out : memref<1x80x3000xf32>) -> () {
dap.whisper_preprocess %in, %out : memref<?xf64>, memref<1x80x3000xf32>
return
func.func @buddy_whisperPreprocess(%in : memref<?xf64>) -> memref<1x80x3000xf32> {
%out = dap.whisper_preprocess %in : memref<?xf64> to memref<1x80x3000xf32>
return %out : memref<1x80x3000xf32>
}
28 changes: 15 additions & 13 deletions midend/include/Dialect/DAP/DAPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ def DAP_FirOp : DAP_Op<"fir"> {
}];
}

def DAP_BiquadOp : DAP_Op<"biquad">
{
def DAP_BiquadOp : DAP_Op<"biquad"> {
let summary = [{Biquad filter, a infinite impulse response (IIR) filter.

```mlir
Expand Down Expand Up @@ -95,22 +94,25 @@ def DAP_IirOp : DAP_Op<"iir"> {
}

def DAP_WhisperPreprocessOp : DAP_Op<"whisper_preprocess"> {
let summary = [{Preprocessor for Whisper model, do features extraction for input audio.
Input MemRef stores the raw speech data, Output MemRef contains computed features with
shape memref<1x80x3000xf32>.
let summary = "preprocessor for Whisper model";
let description = [{
Preprocessor for Whisper model, do features extraction for input audio.
Input MemRef stores the raw speech data, Output MemRef contains computed
features with shape memref<1x80x3000xf32>.

```mlir
dap.whisper_preprocess %input, %output : memref<?xf64>, memref<1x80x3000xf32>
```
Example:

```mlir
%output = dap.whisper_preprocess %input : memref<?xf64> to memref<1x80x3000xf32>
```
}];

let arguments = (ins Arg<AnyRankedOrUnrankedMemRef, "inputMemref",
[MemRead]>:$memrefI,
Arg<AnyRankedOrUnrankedMemRef, "outputMemref",
[MemRead]>:$memrefO);

[MemRead]>:$memrefI);
let results = (outs Res<AnyRankedOrUnrankedMemRef, "outputMemref",
[MemAlloc]>:$memrefO);
let assemblyFormat = [{
$memrefI `,` $memrefO attr-dict `:` type($memrefI) `,` type($memrefO)
$memrefI attr-dict `:` type($memrefI) `to` type($memrefO)
}];
}

Expand Down
38 changes: 21 additions & 17 deletions midend/lib/Conversion/ExtendDAP/ExtendDAPPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1283,7 +1283,7 @@ Value spectrogram(PatternRewriter &rewriter, Location loc, Value f0, Value c0,
RankedTensorType tensorTy0 = RankedTensorType::get({400}, f64Ty);
MemRefType mTp = MemRefType::get({400}, f64Ty);

// mulf_trait for linalg generic operation
// #mulf_trait for 'linalg.generic' operation
AffineMap mulFIdMap =
AffineMap::getMultiDimIdentityMap(1, rewriter.getContext());
SmallVector<AffineMap> mulFIndexingMaps = {mulFIdMap, mulFIdMap, mulFIdMap};
Expand All @@ -1299,7 +1299,7 @@ Value spectrogram(PatternRewriter &rewriter, Location loc, Value f0, Value c0,
Value buffer =
rewriter.create<tensor::CastOp>(loc, tensorTy0, buffer400);

// linalg.generic operation use mulf_trait
// 'linalg.generic' operation use #mulf_trait
auto mulfOp = rewriter.create<linalg::GenericOp>(
loc, /*resultTensorTypes=*/tensorTy0,
/*inputs=*/ValueRange{buffer, window},
Expand Down Expand Up @@ -1368,14 +1368,14 @@ Value spectrogram(PatternRewriter &rewriter, Location loc, Value f0, Value c0,
/*outputs=*/ValueRange{melFloor});
Value spectrogramMax = linalgMaxOp.getResultTensors()[0];

// log10_trait for linalg generic operation
// #log10_trait for 'linalg.generic' operation
AffineMap log10IdMap =
AffineMap::getMultiDimIdentityMap(2, rewriter.getContext());
SmallVector<AffineMap> log10IndexingMaps = {log10IdMap, log10IdMap};
SmallVector<utils::IteratorType> log10IteratorTypes = {
utils::IteratorType::parallel, utils::IteratorType::parallel};

// linalg.generic operation use log10_trait
// 'linalg.generic' operation use #log10_trait
auto log10Op = rewriter.create<linalg::GenericOp>(
loc, /*resultTensorTypes=*/tensorTy1,
/*inputs=*/ValueRange{spectrogramMax},
Expand All @@ -1402,9 +1402,7 @@ class DAPWhisperPreprocessLowering
PatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto ctx = op->getContext();

Value input = op->getOperand(0);
Value output = op->getOperand(1);

Value c0 = rewriter.create<ConstantIndexOp>(loc, 0);
Value c1 = rewriter.create<ConstantIndexOp>(loc, 1);
Expand Down Expand Up @@ -1498,14 +1496,14 @@ class DAPWhisperPreprocessLowering
Value InputFeaturesF32 =
rewriter.create<tensor::SplatOp>(loc, resultTy, f0F32);

// tail_processing_trait for linalg generic operation
// #tail_processing_trait for 'linalg.generic' operation
AffineMap IdMap =
AffineMap::getMultiDimIdentityMap(2, rewriter.getContext());
SmallVector<AffineMap> IndexingMaps = {IdMap, IdMap};
SmallVector<utils::IteratorType> IteratorTypes = {
utils::IteratorType::parallel, utils::IteratorType::parallel};

// linalg.generic operation use tail_processing_trait
// 'linalg.generic' operation use #tail_processing_trait
auto tailProcessOp = rewriter.create<linalg::GenericOp>(
loc, /*resultTensorTypes=*/resultTy,
/*inputs=*/ValueRange{logSpecMax},
Expand Down Expand Up @@ -1538,9 +1536,9 @@ class DAPWhisperPreprocessLowering
Value resultMemRef = rewriter.create<bufferization::ToMemrefOp>(
loc, resultMemTp, resultExpand);

rewriter.create<memref::CopyOp>(loc, resultMemRef, output);

rewriter.eraseOp(op);
// Replace this operation with the generated result. The replaced op is
// erased.
rewriter.replaceOp(op, resultMemRef);
return success();
}
};
Expand Down Expand Up @@ -1570,8 +1568,8 @@ class ExtendDAPPass
void runOnOperation() override;

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<arith::ArithDialect>();
registry.insert<affine::AffineDialect>();
registry.insert<arith::ArithDialect>();
registry.insert<bufferization::BufferizationDialect>();
registry.insert<func::FuncDialect>();
registry.insert<linalg::LinalgDialect>();
Expand All @@ -1590,11 +1588,17 @@ void ExtendDAPPass::runOnOperation() {
ModuleOp module = getOperation();

ConversionTarget target(*context);
target.addLegalDialect<affine::AffineDialect, scf::SCFDialect,
func::FuncDialect, math::MathDialect,
memref::MemRefDialect, arith::ArithDialect,
linalg::LinalgDialect, tensor::TensorDialect,
bufferization::BufferizationDialect>();
// Add legal dialects.
target.addLegalDialect<affine::AffineDialect>();
target.addLegalDialect<arith::ArithDialect>();
target.addLegalDialect<bufferization::BufferizationDialect>();
target.addLegalDialect<func::FuncDialect>();
target.addLegalDialect<linalg::LinalgDialect>();
target.addLegalDialect<math::MathDialect>();
target.addLegalDialect<memref::MemRefDialect>();
target.addLegalDialect<scf::SCFDialect>();
target.addLegalDialect<tensor::TensorDialect>();
// Add legal operations.
target.addLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>();

RewritePatternSet patterns(context);
Expand Down

0 comments on commit c8e44e3

Please sign in to comment.