Skip to content

Commit

Permalink
Roll back fairseq1 masking implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
kauterry committed Aug 26, 2024
1 parent fa22158 commit 2960b28
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 236 deletions.
3 changes: 0 additions & 3 deletions src/fairseq2/models/wav2vec2/asr/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ class Wav2Vec2AsrConfig:
use_masking: bool = True
"""If ``True``, masks features as regularization."""

mask_codebase: str = "fairseq2"

temporal_mask_span_len: int = 10
"""The length of each temporal mask span that is applied over time steps."""

Expand Down Expand Up @@ -144,7 +142,6 @@ def build_masker(self) -> Wav2Vec2Masker | None:
return None

return Wav2Vec2Masker(
self._config.mask_codebase,
self._config.encoder_config.model_dim,
self._config.temporal_mask_span_len,
self._config.max_temporal_mask_prob,
Expand Down
3 changes: 0 additions & 3 deletions src/fairseq2/models/wav2vec2/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ class Wav2Vec2Config:
from the transformer. """

# Mask
mask_codebase: str = "fairseq2"

temporal_mask_span_len: int = 10
"""The length of each temporal mask span that is applied over time steps."""

Expand Down Expand Up @@ -299,7 +297,6 @@ def build_model(self) -> Wav2Vec2Model:
def build_masker(self) -> Wav2Vec2Masker:
"""Build a feature masker."""
return Wav2Vec2Masker(
self._config.mask_codebase,
self._config.encoder_config.model_dim,
self._config.temporal_mask_span_len,
self._config.max_temporal_mask_prob,
Expand Down
35 changes: 8 additions & 27 deletions src/fairseq2/models/wav2vec2/masker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from torch.nn import Module, Parameter

from fairseq2.nn.padding import PaddingMask
from fairseq2.nn.utils.fairseq1_mask import compute_mask_indices
from fairseq2.nn.utils.mask import compute_row_mask
from fairseq2.typing import DataType, Device

Expand All @@ -32,7 +31,6 @@ class Wav2Vec2Masker(Module):

def __init__(
self,
mask_codebase: str,
model_dim: int,
temporal_span_len: int = 10,
max_temporal_mask_prob: float = 0.65,
Expand Down Expand Up @@ -64,7 +62,6 @@ def __init__(
if max_temporal_mask_prob == 0.0:
raise ValueError("`max_temporal_mask_prob` must be greater than 0.")

self.mask_codebase = mask_codebase
self.temporal_span_len = temporal_span_len
self.max_temporal_mask_prob = max_temporal_mask_prob
self.min_num_temporal_mask_spans = min_num_temporal_mask_spans
Expand Down Expand Up @@ -105,30 +102,14 @@ def forward(
batch_size, seq_len, model_dim = seqs.shape

# Temporal mask over time steps.
if self.mask_codebase == "fairseq2":
temporal_mask = compute_row_mask(
shape=(batch_size, seq_len),
span_len=self.temporal_span_len,
max_mask_prob=self.max_temporal_mask_prob,
row_lens=padding_mask.seq_lens if padding_mask is not None else None,
min_num_spans=self.min_num_temporal_mask_spans,
device=seqs.device,
)
else:
mask_indices = compute_mask_indices(
(batch_size, seq_len),
None,
self.max_temporal_mask_prob,
self.temporal_span_len,
mask_type="static",
mask_other=0.0,
min_masks=self.min_num_temporal_mask_spans,
no_overlap=False,
min_space=1,
require_same_masks=True,
mask_dropout=0.0,
)
temporal_mask = torch.from_numpy(mask_indices).to(seqs.device)
temporal_mask = compute_row_mask(
shape=(batch_size, seq_len),
span_len=self.temporal_span_len,
max_mask_prob=self.max_temporal_mask_prob,
row_lens=padding_mask.seq_lens if padding_mask is not None else None,
min_num_spans=self.min_num_temporal_mask_spans,
device=seqs.device,
)

assert temporal_mask is not None

Expand Down
147 changes: 0 additions & 147 deletions src/fairseq2/nn/utils/fairseq1_mask.py

This file was deleted.

56 changes: 0 additions & 56 deletions src/fairseq2/recipes/wav2vec2/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,53 +185,10 @@ class Wav2Vec2TrainConfig:
wav2vec2_train_preset = wav2vec2_train_presets.decorator


@wav2vec2_train_preset("base_960h_fs2_mask")
def _base_960h_fs2_mask() -> Wav2Vec2TrainConfig:
config = Wav2Vec2TrainConfig()
config.model_config.encoder_config.first_pass_dropout_p = 0.1
return config


@wav2vec2_train_preset("base_960h")
def _base_960h() -> Wav2Vec2TrainConfig:
config = Wav2Vec2TrainConfig()

config.model_config.encoder_config.first_pass_dropout_p = 0.1

Check failure on line 191 in src/fairseq2/recipes/wav2vec2/train.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Item "DataClass" of "DataClass | None" has no attribute "encoder_config"

Check failure on line 191 in src/fairseq2/recipes/wav2vec2/train.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Item "None" of "DataClass | None" has no attribute "encoder_config"
config.model_config.max_temporal_mask_prob = 0.65
config.model_config.mask_codebase = "fairseq1"

return config


@wav2vec2_train_preset("base_960h_perf")
def _base_960h_perf() -> Wav2Vec2TrainConfig:
config = _base_960h()

assert isinstance(config.lr_scheduler_config, PolynomialDecayLRConfig)

config.max_num_steps = 10000
config.lr_scheduler_config.num_warmup_steps = 800
return config


@wav2vec2_train_preset("large_960h_fs2_mask")
def _large_960h_fs2_mask() -> Wav2Vec2TrainConfig:
config = Wav2Vec2TrainConfig()
config.model_arch = "large"

assert isinstance(config.optimizer_config, AdamWConfig)
assert isinstance(config.lr_scheduler_config, PolynomialDecayLRConfig)

model_config = wav2vec2_archs.get("large", return_empty=False)
model_config.encoder_config.first_pass_dropout_p = 0.1
config.model_config = model_config

config.publish_metrics_every_n_steps = 100
config.max_audio_len = 320_000
config.max_num_elements = 1_200_000
config.max_num_steps = 250_000
config.optimizer_config.lr = 3e-04
config.lr_scheduler_config.num_warmup_steps = 20_000
return config


Expand All @@ -245,8 +202,6 @@ def _large_960h() -> Wav2Vec2TrainConfig:

model_config = wav2vec2_archs.get("large", return_empty=False)
model_config.encoder_config.first_pass_dropout_p = 0.1
model_config.max_temporal_mask_prob = 0.65
model_config.mask_codebase = "fairseq1"
config.model_config = model_config

config.publish_metrics_every_n_steps = 100
Expand All @@ -258,17 +213,6 @@ def _large_960h() -> Wav2Vec2TrainConfig:
return config


@wav2vec2_train_preset("large_960h_perf")
def _large_960h_perf() -> Wav2Vec2TrainConfig:
config = _large_960h()

assert isinstance(config.lr_scheduler_config, PolynomialDecayLRConfig)

config.max_num_steps = 10000
config.lr_scheduler_config.num_warmup_steps = 800
return config


def load_wav2vec2_trainer(
config: Wav2Vec2TrainConfig, output_dir: Path
) -> Trainer[Tensor]:
Expand Down

0 comments on commit 2960b28

Please sign in to comment.