diff --git a/configs/Unintended_bias_toxic_comment_classification_RoBERTa_combined.json b/configs/Unintended_bias_toxic_comment_classification_RoBERTa_combined.json new file mode 100644 index 0000000..869cee4 --- /dev/null +++ b/configs/Unintended_bias_toxic_comment_classification_RoBERTa_combined.json @@ -0,0 +1,58 @@ +{ + "name": "Jigsaw_RoBERTa_combined", + "n_gpu": 1, + "batch_size": 10, + "accumulate_grad_batches": 3, + "num_main_classes": 7, + "loss_weight": 0.75, + "arch": { + "type": "ROBERTA", + "args": { + "num_classes": 16, + "model_type": "roberta-base", + "model_name": "RobertaForSequenceClassification", + "tokenizer_name": "RobertaTokenizer" + } + }, + "dataset": { + "type": "JigsawDataBias", + "args": { + "train_csv_file": [ + "jigsaw_data/jigsaw-unintended-bias-in-toxicity-classification/train.csv", + "jigsaw_data/jigsaw-toxic-comment-classification-challenge/train.csv" + ], + "test_csv_file": "jigsaw_data/jigsaw-unintended-bias-in-toxicity-classification/test_public_expanded.csv", + "val_fraction": null, + "create_val_set": false, + "loss_weight": 0.75, + "classes": [ + "toxicity", + "severe_toxicity", + "obscene", + "identity_attack", + "insult", + "threat", + "sexual_explicit" + ], + "identity_classes": [ + "male", + "female", + "homosexual_gay_or_lesbian", + "christian", + "jewish", + "muslim", + "black", + "white", + "psychiatric_or_mental_illness" + ] + } + }, + "optimizer": { + "type": "Adam", + "args": { + "lr": 3e-5, + "weight_decay": 3e-6, + "amsgrad": true + } + } +} \ No newline at end of file diff --git a/convert_weights.py b/convert_weights.py new file mode 100644 index 0000000..c8c3eff --- /dev/null +++ b/convert_weights.py @@ -0,0 +1,57 @@ +import torch +from collections import OrderedDict +import hashlib +import argparse + + +def main(): + """Converts saved checkpoint to the expected format for detoxify. + """ + checkpoint = torch.load(ARGS.checkpoint, map_location=ARGS.device) + + new_state_dict = { + "state_dict": OrderedDict(), + "config": checkpoint["hyper_parameters"]["config"], + } + for k, v in checkpoint["state_dict"].items(): + if k.startswith("model."): + k = k[6:] # remove `model.` + new_state_dict["state_dict"][k] = v + + torch.save(new_state_dict, ARGS.save_to) + + if ARGS.hash: + with open(ARGS.save_to, "rb") as f: + bytes = f.read() # read entire file as bytes + readable_hash = hashlib.sha256(bytes).hexdigest() + print("Hash: ", readable_hash) + + torch.save(new_state_dict, ARGS.save_to[:-5] + f"-{readable_hash[:8]}.ckpt") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--ckeckpoint", + type=str, + help="path to model checkpoint", + ) + parser.add_argument( + "--save_to", + type=str, + help="path to save the model to", + ) + parser.add_argument( + "--device", + type=str, + default="cpu", + help="device to load the checkpoint on", + ) + parser.add_argument( + "--hash", + type=bool, + default=True, + help="option to save hash in name", + ) + ARGS = parser.parse_args() + main() diff --git a/detoxify/detoxify.py b/detoxify/detoxify.py index 3380d72..56d6b64 100644 --- a/detoxify/detoxify.py +++ b/detoxify/detoxify.py @@ -3,7 +3,7 @@ MODEL_URLS = { "original": "https://github.com/unitaryai/detoxify/releases/download/v0.1-alpha/toxic_original-c1212f89.ckpt", - "unbiased": "https://github.com/unitaryai/detoxify/releases/download/v0.1-alpha/toxic_bias-4e693588.ckpt", + "unbiased": "https://github.com/unitaryai/detoxify/releases/download/v0.3-alpha/toxic_debiased-c7548aa0.ckpt", "multilingual": "https://github.com/unitaryai/detoxify/releases/download/v0.1-alpha/toxic_multilingual-bbddc277.ckpt", "original-small": "https://github.com/unitaryai/detoxify/releases/download/v0.1.2/original-albert-0e1d6498.ckpt", "unbiased-small": "https://github.com/unitaryai/detoxify/releases/download/v0.1.2/unbiased-albert-c8519128.ckpt" diff --git a/setup.py b/setup.py index 25996a3..e73b7d1 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ setup( name="detoxify", - version="0.2.2", + version="0.3.0", description="A python library for detecting toxic comments", long_description=long_description, long_description_content_type="text/markdown", diff --git a/src/data_loaders.py b/src/data_loaders.py index 9dda074..4b561ab 100644 --- a/src/data_loaders.py +++ b/src/data_loaders.py @@ -58,15 +58,41 @@ def __len__(self): def load_data(self, train_csv_file): files = [] - cols = ["id", "comment_text", "toxic"] + change_names = { + "target": "toxicity", + "toxic": "toxicity", + "identity_hate": "identity_attack", + "severe_toxic": "severe_toxicity", + } for file in tqdm(train_csv_file): - file_df = pd.read_csv(file) - file_df = file_df[cols] - file_df = file_df.astype({"id": "string"}, {"toxic": "float64"}) + chunks = [] + for chunk in pd.read_csv(file, chunksize=100000): + chunks.append(chunk) + + file_df = pd.concat(chunks, axis=0) + filtered_change_names = { + k: v for k, v in change_names.items() if k in file_df.columns + } + if len(filtered_change_names) > 0: + file_df.rename(columns=filtered_change_names, inplace=True) + file_df = file_df.astype({"id": "string"}) files.append(file_df) - train = pd.concat(files) + + train = pd.concat(files, join="outer") return train + def filter_entry_labels(self, entry, classes, threshold=0.5, soft_labels=False): + target = { + label: -1 if label not in entry or entry[label] is None else entry[label] + for label in classes + } + if not soft_labels: + target.update({label: 1 for label in target if target[label] >= threshold}) + target.update( + {label: 0 for label in target if 0 <= target[label] < threshold} + ) + return target + class JigsawDataOriginal(JigsawData): """Dataloader for the original Jigsaw Toxic Comment Classification Challenge. @@ -128,10 +154,11 @@ def __init__( loss_weight=0.75, classes=["toxic"], identity_classes=["female"], + soft_labels=False, ): self.classes = classes - + self.soft_labels = soft_labels self.identity_classes = identity_classes super().__init__( @@ -156,21 +183,12 @@ def __getitem__(self, index): text_id = entry["id"] text = entry["comment_text"] - target_dict = {label: 1 if entry[label] >= 0.5 else 0 for label in self.classes} - - identity_target = { - label: -1 if entry[label] is None else entry[label] - for label in self.identity_classes - } - identity_target.update( - {label: 1 for label in identity_target if identity_target[label] >= 0.5} + target_dict = self.filter_entry_labels( + entry, + self.classes + self.identity_classes, + threshold=0.5, + soft_labels=self.soft_labels, ) - identity_target.update( - {label: 0 for label in identity_target if 0 <= identity_target[label] < 0.5} - ) - - target_dict.update(identity_target) - meta["multi_target"] = torch.tensor( list(target_dict.values()), dtype=torch.float32 ) @@ -192,7 +210,7 @@ def __getitem__(self, index): return text, meta def compute_weigths(self, train_df): - """Inspired from 2nd solution. + """Inspired from 2nd best solution. Source: https://www.kaggle.com/c/jigsaw-unintended-bias-in-toxicity-classification/discussion/100661""" subgroup_bool = (train_df[self.identity_classes].fillna(0) >= 0.5).sum( axis=1 @@ -200,7 +218,8 @@ def compute_weigths(self, train_df): positive_bool = train_df["toxicity"] >= 0.5 weights = np.ones(len(train_df)) * 0.25 - # Backgroud Positive and Subgroup Negative + # Background Positive and Subgroup Negative + # i.e. weigh higher toxic comments that don't mention identity and non toxic ones that mention it weights[ ((~subgroup_bool) & (positive_bool)) | ((subgroup_bool) & (~positive_bool)) ] += 0.25 diff --git a/train.py b/train.py index 3091d1e..c666fce 100644 --- a/train.py +++ b/train.py @@ -179,7 +179,6 @@ def cli_main(): if args.device is not None: config["device"] = args.device - os.environ["CUDA_VISIBLE_DEVICES"] = args.device # data def get_instance(module, name, config, *args, **kwargs):