Skip to content

Latest commit

 

History

History
131 lines (98 loc) · 3.43 KB

trunc_op.md

File metadata and controls

131 lines (98 loc) · 3.43 KB

Truncates the values of one input data (Tensor) at a specified bitwidth and produces one output data (Tensor). Additionally, takes four float tensors as input, which define the scale, zero-point, input bit-width and output bit-width of the quantization. The attribute rounding_mode defines how truncated values are rounded.

Version

This operator is not part of the ONNX standard and is not currently versioned.

Attributes

rounding_mode : string (default is "FLOOR")
Defines how rounding should be applied during truncation. Currently available modes are: "ROUND", "CEIL" and "FLOOR". Here "ROUND" implies a round-to-even operation. Lowercase variants for the rounding mode string are also supported: "round", "ceil", "floor".

Inputs

X (differentiable) : tensor(float32)
input tensor to truncate
scale : float32
The scale factor
zeropt : float32
The zero-point
in_bitwidth : int32
The number of bits used at the input of the truncation
out_bitwidth : int32
The number of bits used at the output of the truncation

Outputs

Y (differentiable) : tensor(float32)
Output tensor

Examples

Trunc
from onnx import helper
import numpy as np

# Define node settings and input
x = np.random.randn(100).astype(np.float32)*10.
scale = np.array(1.)
zeropt = np.array(0.)
in_bitwidth = np.array(10)
out_bitwidth = np.array(4)
rounding_mode = "ROUND"

# Create node
node = helper.make_node(
    'Trunc',
    domain='finn.custom_op.general',
    inputs=['x', 'scale', 'zeropt', 'in_bitwidth', 'out_bitwidth'],
    outputs=['y'],
    rounding_mode=rounding_mode,
)

# Execute the same settings with the reference implementation (trunc)
# See the sample implementation for more details on trunc.
output_ref = trunc(inp_tensor, scale, zeropt, in_bitwidth, out_bitwidth, rounding_mode)

# Execute node and compare
expect(node, inputs=[x, scale, zeropt, bitwidth], outputs=[output_ref], name='test_trunc')

Sample Implementation

Trunc
# SPDX-License-Identifier: Apache-2.0

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import numpy as np

def trunc(inp_tensor, scale, zeropt, input_bit_width, output_bit_width, rounding_mode):
    # Port of TruncIntQuant class from Brevitas: https://bit.ly/3wzIpTR

    # Scaling
    y = inp_tensor / scale
    y = y + zeropt
    # Rounding
    y = np.round(y)
    # Truncate
    trunc_bit_width = input_bit_width - output_bit_width
    trunc_scale = 2.0 ** trunc_bit_width
    y = y / trunc_scale

    # To int
    rounding_fx = resolve_rounding_mode(rounding_mode)
    y = rounding_fx(y)

    # Rescale
    y = y - zeropt
    y = y * scale

    return y

def resolve_rounding_mode(mode_string):
    """Resolve the rounding mode string of Quant and Trunc ops
    to the corresponding numpy functions."""
    if mode_string == "ROUND":
        return np.round
    elif mode_string == "CEIL":
        return np.ceil
    elif mode_string == "FLOOR":
        return np.floor
    else:
        raise ValueError(f"Could not resolve rounding mode called: {mode_string}")