diff --git a/frontend/Interfaces/buddy/DAP/DSP/WhisperPreprocess.h b/frontend/Interfaces/buddy/DAP/DSP/WhisperPreprocess.h index a642243e7..cf997f518 100644 --- a/frontend/Interfaces/buddy/DAP/DSP/WhisperPreprocess.h +++ b/frontend/Interfaces/buddy/DAP/DSP/WhisperPreprocess.h @@ -29,15 +29,24 @@ namespace dap { namespace detail { // Declare the whisper preprocess C interface. extern "C" { -void _mlir_ciface_buddy_whisperPreprocess(MemRef *inputRawSpeech, - MemRef *outputFeatures); +// The original MLIR function: +// ```mlir +// func.func @buddy_whisperPreprocess(%in : memref) -> +// memref<1x80x3000xf32> +// ``` +// +// After applying the '-llvm-request-c-wrappers' pass: +// The result of the function (memref<1x80x3000xf32>) is modified to be the +// first operand of the function. +void _mlir_ciface_buddy_whisperPreprocess(MemRef *outputFeatures, + MemRef *inputRawSpeech); } } // namespace detail // Function for Whisper preprocess void whisperPreprocess(MemRef *inputRawSpeech, MemRef *outputFeatures) { - detail::_mlir_ciface_buddy_whisperPreprocess(inputRawSpeech, outputFeatures); + detail::_mlir_ciface_buddy_whisperPreprocess(outputFeatures, inputRawSpeech); } } // namespace dap diff --git a/frontend/Interfaces/lib/DAP-extend.mlir b/frontend/Interfaces/lib/DAP-extend.mlir index 8a8455f50..c77fe3873 100644 --- a/frontend/Interfaces/lib/DAP-extend.mlir +++ b/frontend/Interfaces/lib/DAP-extend.mlir @@ -1,4 +1,4 @@ -func.func @buddy_whisperPreprocess(%in : memref, %out : memref<1x80x3000xf32>) -> () { - dap.whisper_preprocess %in, %out : memref, memref<1x80x3000xf32> - return +func.func @buddy_whisperPreprocess(%in : memref) -> memref<1x80x3000xf32> { + %out = dap.whisper_preprocess %in : memref to memref<1x80x3000xf32> + return %out : memref<1x80x3000xf32> } diff --git a/midend/include/Dialect/DAP/DAPOps.td b/midend/include/Dialect/DAP/DAPOps.td index 98bac167d..0855fdaad 100644 --- a/midend/include/Dialect/DAP/DAPOps.td +++ b/midend/include/Dialect/DAP/DAPOps.td @@ -50,8 +50,7 @@ def DAP_FirOp : DAP_Op<"fir"> { }]; } -def DAP_BiquadOp : DAP_Op<"biquad"> -{ +def DAP_BiquadOp : DAP_Op<"biquad"> { let summary = [{Biquad filter, a infinite impulse response (IIR) filter. ```mlir @@ -95,22 +94,25 @@ def DAP_IirOp : DAP_Op<"iir"> { } def DAP_WhisperPreprocessOp : DAP_Op<"whisper_preprocess"> { - let summary = [{Preprocessor for Whisper model, do features extraction for input audio. - Input MemRef stores the raw speech data, Output MemRef contains computed features with - shape memref<1x80x3000xf32>. + let summary = "preprocessor for Whisper model"; + let description = [{ + Preprocessor for Whisper model, do features extraction for input audio. + Input MemRef stores the raw speech data, Output MemRef contains computed + features with shape memref<1x80x3000xf32>. - ```mlir - dap.whisper_preprocess %input, %output : memref, memref<1x80x3000xf32> - ``` + Example: + + ```mlir + %output = dap.whisper_preprocess %input : memref to memref<1x80x3000xf32> + ``` }]; let arguments = (ins Arg:$memrefI, - Arg:$memrefO); - + [MemRead]>:$memrefI); + let results = (outs Res:$memrefO); let assemblyFormat = [{ - $memrefI `,` $memrefO attr-dict `:` type($memrefI) `,` type($memrefO) + $memrefI attr-dict `:` type($memrefI) `to` type($memrefO) }]; } diff --git a/midend/lib/Conversion/ExtendDAP/ExtendDAPPass.cpp b/midend/lib/Conversion/ExtendDAP/ExtendDAPPass.cpp index 85aad217b..8981ea769 100644 --- a/midend/lib/Conversion/ExtendDAP/ExtendDAPPass.cpp +++ b/midend/lib/Conversion/ExtendDAP/ExtendDAPPass.cpp @@ -1283,7 +1283,7 @@ Value spectrogram(PatternRewriter &rewriter, Location loc, Value f0, Value c0, RankedTensorType tensorTy0 = RankedTensorType::get({400}, f64Ty); MemRefType mTp = MemRefType::get({400}, f64Ty); - // mulf_trait for linalg generic operation + // #mulf_trait for 'linalg.generic' operation AffineMap mulFIdMap = AffineMap::getMultiDimIdentityMap(1, rewriter.getContext()); SmallVector mulFIndexingMaps = {mulFIdMap, mulFIdMap, mulFIdMap}; @@ -1299,7 +1299,7 @@ Value spectrogram(PatternRewriter &rewriter, Location loc, Value f0, Value c0, Value buffer = rewriter.create(loc, tensorTy0, buffer400); - // linalg.generic operation use mulf_trait + // 'linalg.generic' operation use #mulf_trait auto mulfOp = rewriter.create( loc, /*resultTensorTypes=*/tensorTy0, /*inputs=*/ValueRange{buffer, window}, @@ -1368,14 +1368,14 @@ Value spectrogram(PatternRewriter &rewriter, Location loc, Value f0, Value c0, /*outputs=*/ValueRange{melFloor}); Value spectrogramMax = linalgMaxOp.getResultTensors()[0]; - // log10_trait for linalg generic operation + // #log10_trait for 'linalg.generic' operation AffineMap log10IdMap = AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()); SmallVector log10IndexingMaps = {log10IdMap, log10IdMap}; SmallVector log10IteratorTypes = { utils::IteratorType::parallel, utils::IteratorType::parallel}; - // linalg.generic operation use log10_trait + // 'linalg.generic' operation use #log10_trait auto log10Op = rewriter.create( loc, /*resultTensorTypes=*/tensorTy1, /*inputs=*/ValueRange{spectrogramMax}, @@ -1402,9 +1402,7 @@ class DAPWhisperPreprocessLowering PatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto ctx = op->getContext(); - Value input = op->getOperand(0); - Value output = op->getOperand(1); Value c0 = rewriter.create(loc, 0); Value c1 = rewriter.create(loc, 1); @@ -1498,14 +1496,14 @@ class DAPWhisperPreprocessLowering Value InputFeaturesF32 = rewriter.create(loc, resultTy, f0F32); - // tail_processing_trait for linalg generic operation + // #tail_processing_trait for 'linalg.generic' operation AffineMap IdMap = AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()); SmallVector IndexingMaps = {IdMap, IdMap}; SmallVector IteratorTypes = { utils::IteratorType::parallel, utils::IteratorType::parallel}; - // linalg.generic operation use tail_processing_trait + // 'linalg.generic' operation use #tail_processing_trait auto tailProcessOp = rewriter.create( loc, /*resultTensorTypes=*/resultTy, /*inputs=*/ValueRange{logSpecMax}, @@ -1538,9 +1536,9 @@ class DAPWhisperPreprocessLowering Value resultMemRef = rewriter.create( loc, resultMemTp, resultExpand); - rewriter.create(loc, resultMemRef, output); - - rewriter.eraseOp(op); + // Replace this operation with the generated result. The replaced op is + // erased. + rewriter.replaceOp(op, resultMemRef); return success(); } }; @@ -1570,8 +1568,8 @@ class ExtendDAPPass void runOnOperation() override; void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); registry.insert(); + registry.insert(); registry.insert(); registry.insert(); registry.insert(); @@ -1590,11 +1588,17 @@ void ExtendDAPPass::runOnOperation() { ModuleOp module = getOperation(); ConversionTarget target(*context); - target.addLegalDialect(); + // Add legal dialects. + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + // Add legal operations. target.addLegalOp(); RewritePatternSet patterns(context);