diff --git a/paddlenlp/transformers/artist/tokenizer.py b/paddlenlp/transformers/artist/tokenizer.py index 2a4074e2f114..94329201ef76 100644 --- a/paddlenlp/transformers/artist/tokenizer.py +++ b/paddlenlp/transformers/artist/tokenizer.py @@ -225,6 +225,7 @@ def __call__( return_offsets_mapping=False, add_special_tokens=True, pad_to_multiple_of=None, + padding_side=None, return_tensors=None, verbose: bool = True, **kwargs @@ -247,6 +248,7 @@ def __call__( return_offsets_mapping, add_special_tokens, pad_to_multiple_of, + padding_side, return_tensors, verbose, **kwargs, diff --git a/paddlenlp/transformers/bloom/tokenizer.py b/paddlenlp/transformers/bloom/tokenizer.py index 4ba02b9b9551..e95462d1bc9c 100644 --- a/paddlenlp/transformers/bloom/tokenizer.py +++ b/paddlenlp/transformers/bloom/tokenizer.py @@ -18,7 +18,7 @@ import os import shutil from functools import lru_cache -from typing import Dict, Optional, Union +from typing import Dict, Literal, Optional, Union import numpy as np from paddle.utils import try_import @@ -360,6 +360,7 @@ def _pad( max_length: Optional[int] = None, padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[Literal["right", "left"]] = None, return_attention_mask: Optional[bool] = None, ) -> dict: """ @@ -375,13 +376,16 @@ def _pad( - PaddingStrategy.LONGEST Pad to the longest sequence in the batch - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) - PaddingStrategy.DO_NOT_PAD: Do not pad - The tokenizer padding sides are defined in self.padding_side: + The tokenizer padding sides are defined in `padding_side` argument: - 'left': pads on the left of the sequences - 'right': pads on the right of the sequences pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability >= 7.5 (Volta). + padding_side: (optional) The side on which the model should have padding applied. + Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics) """ @@ -394,7 +398,7 @@ def _pad( required_input = encoded_inputs[self.model_input_names[0]] encoded_inputs = super()._pad( - encoded_inputs, max_length, padding_strategy, pad_to_multiple_of, return_attention_mask + encoded_inputs, max_length, padding_strategy, pad_to_multiple_of, padding_side, return_attention_mask ) if attention_mask is not None and len(np.shape(attention_mask)) > 2: encoded_inputs["attention_mask"] = attention_mask diff --git a/paddlenlp/transformers/chatglm/tokenizer.py b/paddlenlp/transformers/chatglm/tokenizer.py index 08b8ad9d4720..6f5222a7b7d9 100644 --- a/paddlenlp/transformers/chatglm/tokenizer.py +++ b/paddlenlp/transformers/chatglm/tokenizer.py @@ -14,7 +14,7 @@ """Tokenization classes for ChatGLM.""" import os -from typing import Dict, List, Optional, Union +from typing import Dict, List, Literal, Optional, Union import numpy as np import sentencepiece as spm @@ -218,13 +218,15 @@ def _pad( max_length: Optional[int] = None, padding_strategy=PaddingStrategy.DO_NOT_PAD, pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[Literal["right", "left"]] = None, return_attention_mask: Optional[bool] = None, ) -> dict: # Load from model defaults if return_attention_mask is None: return_attention_mask = "attention_mask" in self.model_input_names or "attention_mask" in encoded_inputs - assert self.padding_side == "left" + padding_side = padding_side if padding_side is not None else self.padding_side + assert padding_side == "left" required_input = encoded_inputs[self.model_input_names[0]] seq_length = len(required_input) diff --git a/paddlenlp/transformers/chatglm_v2/tokenizer.py b/paddlenlp/transformers/chatglm_v2/tokenizer.py index 6913418a0f04..2a966e358a85 100644 --- a/paddlenlp/transformers/chatglm_v2/tokenizer.py +++ b/paddlenlp/transformers/chatglm_v2/tokenizer.py @@ -15,7 +15,7 @@ import os import re -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union import numpy as np from sentencepiece import SentencePieceProcessor @@ -244,6 +244,7 @@ def _pad( max_length: Optional[int] = None, padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[Literal["right", "left"]] = None, return_attention_mask: Optional[bool] = None, ) -> dict: """ @@ -259,18 +260,22 @@ def _pad( - PaddingStrategy.LONGEST Pad to the longest sequence in the batch - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) - PaddingStrategy.DO_NOT_PAD: Do not pad - The tokenizer padding sides are defined in self.padding_side: + The tokenizer padding sides are defined in `padding_side` argument: - 'left': pads on the left of the sequences - 'right': pads on the right of the sequences pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability - `>= 7.5` (Volta). + >= 7.5 (Volta). + padding_side: (optional) The side on which the model should have padding applied. + Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics) """ # Load from model defaults - assert self.padding_side == "left" + padding_side = padding_side if padding_side is not None else self.padding_side + assert padding_side == "left" required_input = encoded_inputs[self.model_input_names[0]] seq_length = len(required_input) diff --git a/paddlenlp/transformers/dallebart/tokenizer.py b/paddlenlp/transformers/dallebart/tokenizer.py index c9d25946abe7..13335b6bc646 100644 --- a/paddlenlp/transformers/dallebart/tokenizer.py +++ b/paddlenlp/transformers/dallebart/tokenizer.py @@ -464,6 +464,7 @@ def __call__( return_offsets_mapping=False, add_special_tokens=True, pad_to_multiple_of=None, + padding_side=None, return_tensors=None, verbose: bool = True, **kwargs @@ -497,6 +498,7 @@ def __call__( return_offsets_mapping, add_special_tokens, pad_to_multiple_of, + padding_side, return_tensors, verbose, **kwargs, diff --git a/paddlenlp/transformers/gemma/tokenizer.py b/paddlenlp/transformers/gemma/tokenizer.py index 200be8345e36..e6790c958151 100644 --- a/paddlenlp/transformers/gemma/tokenizer.py +++ b/paddlenlp/transformers/gemma/tokenizer.py @@ -15,7 +15,7 @@ import os from shutil import copyfile -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Literal, Optional, Tuple, Union import numpy as np import sentencepiece as spm @@ -323,6 +323,7 @@ def _pad( max_length: Optional[int] = None, padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[Literal["right", "left"]] = None, return_attention_mask: Optional[bool] = None, ) -> dict: """ @@ -345,6 +346,9 @@ def _pad( pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability >= 7.5 (Volta). + padding_side: (optional) The side on which the model should have padding applied. + Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics) """ @@ -359,7 +363,7 @@ def _pad( required_input = encoded_inputs[self.model_input_names[0]] encoded_inputs = super()._pad( - encoded_inputs, max_length, padding_strategy, pad_to_multiple_of, return_attention_mask + encoded_inputs, max_length, padding_strategy, pad_to_multiple_of, padding_side, return_attention_mask ) if attention_mask is not None and len(np.shape(attention_mask)) > 2: encoded_inputs["attention_mask"] = attention_mask diff --git a/paddlenlp/transformers/gpt/tokenizer.py b/paddlenlp/transformers/gpt/tokenizer.py index bb0876e2dd74..5eed7b0b09ee 100644 --- a/paddlenlp/transformers/gpt/tokenizer.py +++ b/paddlenlp/transformers/gpt/tokenizer.py @@ -17,7 +17,7 @@ import os import shutil from functools import lru_cache -from typing import Dict, Optional, Union +from typing import Dict, Literal, Optional, Union import jieba import numpy as np @@ -584,6 +584,7 @@ def _pad( max_length: Optional[int] = None, padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[Literal["right", "left"]] = None, return_attention_mask: Optional[bool] = None, ) -> dict: """ @@ -599,13 +600,16 @@ def _pad( - PaddingStrategy.LONGEST Pad to the longest sequence in the batch - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) - PaddingStrategy.DO_NOT_PAD: Do not pad - The tokenizer padding sides are defined in self.padding_side: + The tokenizer padding sides are defined in `padding_side` argument: - 'left': pads on the left of the sequences - 'right': pads on the right of the sequences pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability >= 7.5 (Volta). + padding_side: (optional) The side on which the model should have padding applied. + Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics) """ @@ -620,7 +624,7 @@ def _pad( required_input = encoded_inputs[self.model_input_names[0]] encoded_inputs = super()._pad( - encoded_inputs, max_length, padding_strategy, pad_to_multiple_of, return_attention_mask + encoded_inputs, max_length, padding_strategy, pad_to_multiple_of, padding_side, return_attention_mask ) if attention_mask is not None and len(np.shape(attention_mask)) > 2: encoded_inputs["attention_mask"] = attention_mask diff --git a/paddlenlp/transformers/llama/tokenizer.py b/paddlenlp/transformers/llama/tokenizer.py index 2bae61e67b4e..9dab88e0bd2b 100644 --- a/paddlenlp/transformers/llama/tokenizer.py +++ b/paddlenlp/transformers/llama/tokenizer.py @@ -15,7 +15,7 @@ import os from shutil import copyfile -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Literal, Optional, Tuple, Union import numpy as np import sentencepiece as spm @@ -232,6 +232,7 @@ def _pad( max_length: Optional[int] = None, padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[Literal["right", "left"]] = None, return_attention_mask: Optional[bool] = None, ) -> dict: """ @@ -247,13 +248,16 @@ def _pad( - PaddingStrategy.LONGEST Pad to the longest sequence in the batch - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) - PaddingStrategy.DO_NOT_PAD: Do not pad - The tokenizer padding sides are defined in self.padding_side: + The tokenizer padding sides are defined in `padding_side` argument: - 'left': pads on the left of the sequences - 'right': pads on the right of the sequences pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability >= 7.5 (Volta). + padding_side: (optional) The side on which the model should have padding applied. + Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics) """ @@ -268,7 +272,7 @@ def _pad( required_input = encoded_inputs[self.model_input_names[0]] encoded_inputs = super()._pad( - encoded_inputs, max_length, padding_strategy, pad_to_multiple_of, return_attention_mask + encoded_inputs, max_length, padding_strategy, pad_to_multiple_of, padding_side, return_attention_mask ) if attention_mask is not None and len(np.shape(attention_mask)) > 2: encoded_inputs["attention_mask"] = attention_mask @@ -521,6 +525,7 @@ def _pad( max_length: Optional[int] = None, padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[Literal["right", "left"]] = None, return_attention_mask: Optional[bool] = None, ) -> dict: """ @@ -536,13 +541,16 @@ def _pad( - PaddingStrategy.LONGEST Pad to the longest sequence in the batch - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) - PaddingStrategy.DO_NOT_PAD: Do not pad - The tokenizer padding sides are defined in self.padding_side: + The tokenizer padding sides are defined in `padding_side` argument: - 'left': pads on the left of the sequences - 'right': pads on the right of the sequences pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability >= 7.5 (Volta). + padding_side: (optional) The side on which the model should have padding applied. + Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics) """ @@ -557,7 +565,7 @@ def _pad( required_input = encoded_inputs[self.model_input_names[0]] encoded_inputs = super()._pad( - encoded_inputs, max_length, padding_strategy, pad_to_multiple_of, return_attention_mask + encoded_inputs, max_length, padding_strategy, pad_to_multiple_of, padding_side, return_attention_mask ) if attention_mask is not None and len(np.shape(attention_mask)) > 2: encoded_inputs["attention_mask"] = attention_mask diff --git a/paddlenlp/transformers/mamba/tokenizer.py b/paddlenlp/transformers/mamba/tokenizer.py index 679a5e67c509..9d86b1084f91 100644 --- a/paddlenlp/transformers/mamba/tokenizer.py +++ b/paddlenlp/transformers/mamba/tokenizer.py @@ -18,7 +18,7 @@ import os import shutil from functools import lru_cache -from typing import Dict, Optional, Union +from typing import Dict, Literal, Optional, Union import numpy as np from paddle.utils import try_import @@ -302,6 +302,7 @@ def _pad( max_length: Optional[int] = None, padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[Literal["right", "left"]] = None, return_attention_mask: Optional[bool] = None, ) -> dict: """ @@ -317,13 +318,16 @@ def _pad( - PaddingStrategy.LONGEST Pad to the longest sequence in the batch - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) - PaddingStrategy.DO_NOT_PAD: Do not pad - The tokenizer padding sides are defined in self.padding_side: + The tokenizer padding sides are defined in `padding_side` argument: - 'left': pads on the left of the sequences - 'right': pads on the right of the sequences pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability >= 7.5 (Volta). + padding_side: (optional) The side on which the model should have padding applied. + Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics) """ @@ -338,7 +342,7 @@ def _pad( required_input = encoded_inputs[self.model_input_names[0]] encoded_inputs = super()._pad( - encoded_inputs, max_length, padding_strategy, pad_to_multiple_of, return_attention_mask + encoded_inputs, max_length, padding_strategy, pad_to_multiple_of, padding_side, return_attention_mask ) if attention_mask is not None and len(np.shape(attention_mask)) > 2: encoded_inputs["attention_mask"] = attention_mask diff --git a/paddlenlp/transformers/qwen/tokenizer.py b/paddlenlp/transformers/qwen/tokenizer.py index 16e881ef7831..a126541ac7b1 100644 --- a/paddlenlp/transformers/qwen/tokenizer.py +++ b/paddlenlp/transformers/qwen/tokenizer.py @@ -17,7 +17,7 @@ import base64 import os import unicodedata -from typing import Collection, Dict, List, Optional, Set, Tuple, Union +from typing import Collection, Dict, List, Literal, Optional, Set, Tuple, Union import numpy as np @@ -255,6 +255,7 @@ def _pad( max_length: Optional[int] = None, padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[Literal["right", "left"]] = None, return_attention_mask: Optional[bool] = None, ) -> dict: """ @@ -270,13 +271,16 @@ def _pad( - PaddingStrategy.LONGEST Pad to the longest sequence in the batch - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) - PaddingStrategy.DO_NOT_PAD: Do not pad - The tokenizer padding sides are defined in self.padding_side: + The tokenizer padding sides are defined in `padding_side` argument: - 'left': pads on the left of the sequences - 'right': pads on the right of the sequences pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability >= 7.5 (Volta). + padding_side: (optional) The side on which the model should have padding applied. + Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics) """ @@ -291,7 +295,7 @@ def _pad( required_input = encoded_inputs[self.model_input_names[0]] encoded_inputs = super()._pad( - encoded_inputs, max_length, padding_strategy, pad_to_multiple_of, return_attention_mask + encoded_inputs, max_length, padding_strategy, pad_to_multiple_of, padding_side, return_attention_mask ) if attention_mask is not None and len(np.shape(attention_mask)) > 2: encoded_inputs["attention_mask"] = attention_mask diff --git a/paddlenlp/transformers/tokenizer_utils.py b/paddlenlp/transformers/tokenizer_utils.py index 048c2fc40a75..ca320eb52289 100644 --- a/paddlenlp/transformers/tokenizer_utils.py +++ b/paddlenlp/transformers/tokenizer_utils.py @@ -25,7 +25,7 @@ import unicodedata from collections import OrderedDict from dataclasses import asdict, dataclass -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Literal, Optional, Tuple, Union import numpy import numpy as np @@ -1338,6 +1338,7 @@ def _encode_plus( stride: int = 0, is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[Literal["right", "left"]] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_position_ids: Optional[bool] = None, return_token_type_ids: Optional[bool] = None, @@ -1389,6 +1390,7 @@ def get_input_ids(text): max_length=max_length, stride=stride, pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, return_tensors=return_tensors, prepend_batch_axis=True, return_position_ids=return_position_ids, @@ -1419,6 +1421,7 @@ def _batch_encode_plus( stride: int = 0, is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[Literal["right", "left"]] = None, return_position_ids: Optional[bool] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_token_type_ids: Optional[bool] = None, @@ -1487,6 +1490,7 @@ def get_input_ids(text): max_length=max_length, stride=stride, pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, return_position_ids=return_position_ids, return_attention_mask=return_attention_mask, return_token_type_ids=return_token_type_ids, @@ -1511,6 +1515,7 @@ def _batch_prepare_for_model( max_length: Optional[int] = None, stride: int = 0, pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[Literal["right", "left"]] = None, return_position_ids: Optional[bool] = None, return_tensors: Optional[str] = None, return_token_type_ids: Optional[bool] = None, @@ -1623,6 +1628,7 @@ def _batch_prepare_for_model( max_length=max_length, stride=stride, pad_to_multiple_of=None, # we pad in batch afterward + padding_side=padding_side, # we pad in batch afterward return_position_ids=return_position_ids, # we pad in batch afterward return_attention_mask=False, # we pad in batch afterward return_token_type_ids=return_token_type_ids, @@ -1645,6 +1651,7 @@ def _batch_prepare_for_model( padding=padding_strategy.value, max_length=max_length, pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, return_attention_mask=return_attention_mask, ) if return_dict: diff --git a/paddlenlp/transformers/tokenizer_utils_base.py b/paddlenlp/transformers/tokenizer_utils_base.py index 6af5cc29e5d4..c9cd9f54d832 100644 --- a/paddlenlp/transformers/tokenizer_utils_base.py +++ b/paddlenlp/transformers/tokenizer_utils_base.py @@ -25,7 +25,17 @@ from collections import UserDict from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union +from typing import ( + Any, + Dict, + List, + Literal, + NamedTuple, + Optional, + Sequence, + Tuple, + Union, +) import aistudio_sdk import numpy as np @@ -2110,6 +2120,7 @@ def __call__( return_offsets_mapping: bool = False, add_special_tokens: bool = True, pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[Literal["right", "left"]] = None, return_tensors: Optional[Union[str, TensorType]] = None, verbose: bool = True, **kwargs @@ -2219,6 +2230,9 @@ def __call__( If set will pad the sequence to a multiple of the provided value. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta). Defaults to `None`. + padding_side (`str`, *optional*): + The side on which the model should have padding applied. Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. return_tensors (str or [TensorType], optional): If set, will return tensors instead of list of python integers. Acceptable values are: @@ -2331,6 +2345,7 @@ def _is_valid_text_input(t): return_offsets_mapping=return_offsets_mapping, add_special_tokens=add_special_tokens, pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, return_tensors=return_tensors, verbose=verbose, **kwargs, @@ -2353,6 +2368,7 @@ def _is_valid_text_input(t): return_offsets_mapping=return_offsets_mapping, add_special_tokens=add_special_tokens, pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, return_tensors=return_tensors, verbose=verbose, **kwargs, @@ -2369,6 +2385,7 @@ def encode( stride: int = 0, is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[Literal["right", "left"]] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = None, @@ -2423,6 +2440,7 @@ def encode( stride=stride, is_split_into_words=is_split_into_words, pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, return_tensors=return_tensors, return_position_ids=return_position_ids, return_token_type_ids=return_token_type_ids, @@ -2445,6 +2463,7 @@ def encode_plus( max_length: Optional[int] = None, stride: int = 0, is_split_into_words: bool = False, + padding_side: Optional[Literal["right", "left"]] = None, pad_to_multiple_of: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_token_type_ids: Optional[bool] = None, @@ -2496,6 +2515,7 @@ def encode_plus( stride=stride, is_split_into_words=is_split_into_words, pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, return_tensors=return_tensors, return_token_type_ids=return_token_type_ids, return_attention_mask=return_attention_mask, @@ -2518,6 +2538,7 @@ def _encode_plus( stride: int = 0, is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[Literal["right", "left"]] = None, return_position_ids: Optional[bool] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_token_type_ids: Optional[bool] = None, @@ -2557,6 +2578,7 @@ def batch_encode( return_offsets_mapping=False, add_special_tokens=True, pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[Literal["right", "left"]] = None, return_tensors: Optional[Union[str, TensorType]] = None, verbose: bool = True, **kwargs @@ -2607,6 +2629,7 @@ def batch_encode( stride=stride, is_split_into_words=is_split_into_words, pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, return_tensors=return_tensors, return_position_ids=return_position_ids, return_token_type_ids=return_token_type_ids, @@ -2637,6 +2660,7 @@ def _batch_encode_plus( stride: int = 0, is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[Literal["right", "left"]] = None, return_position_ids: Optional[bool] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_token_type_ids: Optional[bool] = None, @@ -2662,6 +2686,7 @@ def pad( ], padding: Union[bool, str, PaddingStrategy] = True, max_length: Optional[int] = None, + padding_side: Optional[Literal["right", "left"]] = None, pad_to_multiple_of: Optional[int] = None, return_attention_mask: Optional[bool] = None, return_tensors: Optional[Union[str, TensorType]] = None, @@ -2706,6 +2731,9 @@ def pad( This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta). + padding_side (`str`, *optional*): + The side on which the model should have padding applied. Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. return_attention_mask (`bool`, *optional*): Whether to return the attention mask. If left to the default, will return the attention mask according to the specific tokenizer's default, defined by the `return_outputs` attribute. @@ -2772,6 +2800,7 @@ def pad( max_length=max_length, padding_strategy=padding_strategy, pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, return_attention_mask=return_attention_mask, ) return BatchEncoding(encoded_inputs, tensor_type=return_tensors) @@ -2792,6 +2821,7 @@ def pad( inputs, max_length=max_length, padding_strategy=padding_strategy, + padding_side=padding_side, pad_to_multiple_of=pad_to_multiple_of, return_attention_mask=return_attention_mask, ) @@ -2872,6 +2902,7 @@ def prepare_for_model( max_length: Optional[int] = None, stride: int = 0, pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[Literal["right", "left"]] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_position_ids=None, return_token_type_ids: Optional[bool] = None, @@ -3002,6 +3033,7 @@ def prepare_for_model( max_length=max_length, padding=padding_strategy.value, pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, return_attention_mask=return_attention_mask, ) @@ -3141,6 +3173,7 @@ def _pad( max_length: Optional[int] = None, padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[Literal["right", "left"]] = None, return_attention_mask: Optional[bool] = None, ) -> dict: """ @@ -3156,13 +3189,16 @@ def _pad( - PaddingStrategy.LONGEST Pad to the longest sequence in the batch - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) - PaddingStrategy.DO_NOT_PAD: Do not pad - The tokenizer padding sides are defined in self.padding_side: + The tokenizer padding sides are defined in `padding_side` argument: - 'left': pads on the left of the sequences - 'right': pads on the right of the sequences pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability >= 7.5 (Volta). + padding_side: (optional) The side on which the model should have padding applied. + Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics) """ @@ -3186,8 +3222,9 @@ def _pad( if needs_to_be_padded: difference = max_length - len(required_input) + padding_side = padding_side if padding_side is not None else self.padding_side - if self.padding_side == "right": + if padding_side == "right": if return_attention_mask: encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference @@ -3207,7 +3244,7 @@ def _pad( if "end_positions" in encoded_inputs and isinstance(encoded_inputs["end_positions"], list): encoded_inputs["end_positions"] = encoded_inputs["end_positions"] + [0] * difference encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference - elif self.padding_side == "left": + elif padding_side == "left": if return_attention_mask: encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"] if "token_type_ids" in encoded_inputs: @@ -3226,7 +3263,7 @@ def _pad( encoded_inputs["end_positions"] = [0] * difference + encoded_inputs["end_positions"] encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input else: - raise ValueError("Invalid padding strategy:" + str(self.padding_side)) + raise ValueError("Invalid padding strategy:" + str(padding_side)) return encoded_inputs diff --git a/paddlenlp/transformers/tokenizer_utils_fast.py b/paddlenlp/transformers/tokenizer_utils_fast.py index d6a854fdd667..6d49ac7fde71 100644 --- a/paddlenlp/transformers/tokenizer_utils_fast.py +++ b/paddlenlp/transformers/tokenizer_utils_fast.py @@ -22,7 +22,7 @@ import json import os from collections import defaultdict -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Literal, Optional, Tuple, Union import tokenizers.pre_tokenizers as pre_tokenizers_fast from tokenizers import Encoding as EncodingFast @@ -398,6 +398,7 @@ def set_truncation_and_padding( max_length: int, stride: int, pad_to_multiple_of: Optional[int], + padding_side: Optional[Literal["right", "left"]], ): """ Define the truncation and the padding strategies for fast tokenizers (provided by PaddleNLP's fast_tokenizer @@ -419,6 +420,9 @@ def set_truncation_and_padding( pad_to_multiple_of (`int`, *optional*): If set will pad the sequence to a multiple of the provided value. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta). + padding_side (`str`, *optional*): + The side on which the model should have padding applied. Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. """ _truncation = self._tokenizer.truncation _padding = self._tokenizer.padding @@ -453,7 +457,7 @@ def set_truncation_and_padding( length = max_length if padding_strategy == PaddingStrategy.MAX_LENGTH else None target = { "length": length, - "direction": self.padding_side, + "direction": padding_side if padding_side is not None else self.padding_side, "pad_id": self.pad_token_id, "pad_token": self.pad_token, "pad_type_id": self.pad_token_type_id, @@ -479,6 +483,7 @@ def _batch_encode_plus( stride: int = 0, is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[bool] = None, return_tensors: Optional[str] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = None, @@ -504,6 +509,7 @@ def _batch_encode_plus( max_length=max_length, stride=stride, pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, ) if self._tokenizer.encode_special_tokens != split_special_tokens: @@ -571,6 +577,7 @@ def _encode_plus( stride: int = 0, is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[Literal["right", "left"]] = None, return_position_ids: Optional[bool] = None, return_tensors: Optional[bool] = None, return_token_type_ids: Optional[bool] = None, @@ -593,6 +600,7 @@ def _encode_plus( max_length=max_length, stride=stride, pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, return_position_ids=return_position_ids, return_tensors=return_tensors, return_token_type_ids=return_token_type_ids, diff --git a/tests/transformers/chatglm/test_tokenizer.py b/tests/transformers/chatglm/test_tokenizer.py index 4017a8290c25..14b7b63482fc 100644 --- a/tests/transformers/chatglm/test_tokenizer.py +++ b/tests/transformers/chatglm/test_tokenizer.py @@ -16,6 +16,7 @@ import unittest import numpy as np +from parameterized import parameterized from paddlenlp.transformers import ChatGLMTokenizer from paddlenlp.transformers.tokenizer_utils import PretrainedTokenizer @@ -217,7 +218,8 @@ def test_pretrained_model_lists(self): self.assertGreaterEqual(len(self.tokenizer_class.pretrained_resource_files_map), 1) self.assertGreaterEqual(len(list(self.tokenizer_class.pretrained_resource_files_map.values())[0]), 1) - def test_encode_plus_with_padding(self): + @parameterized.expand([(True,), (False,)]) + def test_encode_plus_with_padding(self, use_padding_as_call_kwarg: bool): tokenizers = self.get_tokenizers(do_lower_case=False) for tokenizer in tokenizers: with self.subTest(f"{tokenizer.__class__.__name__}"): @@ -233,14 +235,32 @@ def test_encode_plus_with_padding(self): special_tokens_mask = encoded_sequence["special_tokens_mask"] sequence_length = len(input_ids) + # Test right padding + tokenizer_kwargs_right = { + "max_length": sequence_length + padding_size, + "padding": "max_length", + "return_special_tokens_mask": True, + } + + if not use_padding_as_call_kwarg: + tokenizer.padding_side = "right" + else: + tokenizer_kwargs_right["padding_side"] = "right" + self.assertRaises(AssertionError, lambda: tokenizer.encode_plus(sequence, **tokenizer_kwargs_right)) + # Test left padding - tokenizer.padding_side = "left" - left_padded_sequence = tokenizer.encode( - sequence, - max_length=sequence_length + padding_size, - padding="max_length", - return_special_tokens_mask=True, - ) + tokenizer_kwargs_left = { + "max_length": sequence_length + padding_size, + "padding": "max_length", + "return_special_tokens_mask": True, + } + + if not use_padding_as_call_kwarg: + tokenizer.padding_side = "left" + else: + tokenizer_kwargs_left["padding_side"] = "left" + + left_padded_sequence = tokenizer.encode_plus(sequence, **tokenizer_kwargs_left) left_padded_input_ids = left_padded_sequence["input_ids"] left_padded_special_tokens_mask = left_padded_sequence["special_tokens_mask"] left_padded_sequence_length = len(left_padded_input_ids) diff --git a/tests/transformers/test_tokenizer_common.py b/tests/transformers/test_tokenizer_common.py index 7d78bfb09e0f..f3596afa6abe 100644 --- a/tests/transformers/test_tokenizer_common.py +++ b/tests/transformers/test_tokenizer_common.py @@ -27,6 +27,8 @@ from pathlib import Path from typing import Any, Dict, List, Tuple +from parameterized import parameterized + from paddlenlp.transformers import PretrainedTokenizer from paddlenlp.transformers.tokenizer_utils import AddedToken, Trie from paddlenlp.transformers.tokenizer_utils_base import PretrainedTokenizerBase @@ -1487,7 +1489,15 @@ def test_padding_with_attention_mask(self): else: self.assertListEqual(padded_features["attention_mask"], [[1, 1, 1, 1, 1, 0], [0, 0, 0, 1, 1, 0]]) - def test_encode_plus_with_padding(self): + @parameterized.expand([(True,), (False,)]) + def test_encode_plus_with_padding(self, use_padding_as_call_kwarg: bool): + """ + This test checks that padding works as expected when tokenizing a sequence. + Padding is expected to have no effect when the input is a single sequence and + the padding-strategy is not `max_length`. Otherwise it pads to the specified max-length + using tokenizer classes `padding_side` attribute. Also, we check that passing `padding_side` + as call time kwarg works same way as when one sets `tokenizer.padding_side` attribute. + """ tokenizers = self.get_tokenizers(do_lower_case=False) for tokenizer in tokenizers: with self.subTest(f"{tokenizer.__class__.__name__}"): @@ -1506,7 +1516,6 @@ def test_encode_plus_with_padding(self): sequence_length = len(input_ids) # Test 'longest' and 'no_padding' don't do anything - tokenizer.padding_side = "right" not_padded_sequence = tokenizer.encode( sequence, @@ -1537,14 +1546,18 @@ def test_encode_plus_with_padding(self): self.assertEqual(special_tokens_mask, not_padded_special_tokens_mask) # Test right padding - tokenizer.padding_side = "right" + tokenizer_kwargs_right = { + "max_length": sequence_length + padding_size, + "padding": "max_length", + "return_special_tokens_mask": True, + } + + if not use_padding_as_call_kwarg: + tokenizer.padding_side = "right" + else: + tokenizer_kwargs_right["padding_side"] = "right" - right_padded_sequence = tokenizer.encode( - sequence, - max_length=sequence_length + padding_size, - padding="max_length", - return_special_tokens_mask=True, - ) + right_padded_sequence = tokenizer.encode_plus(sequence, **tokenizer_kwargs_right) right_padded_input_ids = right_padded_sequence["input_ids"] right_padded_special_tokens_mask = right_padded_sequence["special_tokens_mask"] @@ -1555,13 +1568,18 @@ def test_encode_plus_with_padding(self): self.assertEqual(special_tokens_mask + [1] * padding_size, right_padded_special_tokens_mask) # Test left padding - tokenizer.padding_side = "left" - left_padded_sequence = tokenizer.encode( - sequence, - max_length=sequence_length + padding_size, - padding="max_length", - return_special_tokens_mask=True, - ) + tokenizer_kwargs_left = { + "max_length": sequence_length + padding_size, + "padding": "max_length", + "return_special_tokens_mask": True, + } + + if not use_padding_as_call_kwarg: + tokenizer.padding_side = "left" + else: + tokenizer_kwargs_left["padding_side"] = "left" + + left_padded_sequence = tokenizer.encode_plus(sequence, **tokenizer_kwargs_left) left_padded_input_ids = left_padded_sequence["input_ids"] left_padded_special_tokens_mask = left_padded_sequence["special_tokens_mask"] left_padded_sequence_length = len(left_padded_input_ids)