Skip to content

Commit

Permalink
Merge pull request #29 from unitaryai/updated_bias_model
Browse files Browse the repository at this point in the history
Updated bias model training code
  • Loading branch information
laurahanu committed Sep 3, 2021
2 parents 67cc43a + e0415cb commit bdc84bd
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 25 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
{
"name": "Jigsaw_RoBERTa_combined",
"n_gpu": 1,
"batch_size": 10,
"accumulate_grad_batches": 3,
"num_main_classes": 7,
"loss_weight": 0.75,
"arch": {
"type": "ROBERTA",
"args": {
"num_classes": 16,
"model_type": "roberta-base",
"model_name": "RobertaForSequenceClassification",
"tokenizer_name": "RobertaTokenizer"
}
},
"dataset": {
"type": "JigsawDataBias",
"args": {
"train_csv_file": [
"jigsaw_data/jigsaw-unintended-bias-in-toxicity-classification/train.csv",
"jigsaw_data/jigsaw-toxic-comment-classification-challenge/train.csv"
],
"test_csv_file": "jigsaw_data/jigsaw-unintended-bias-in-toxicity-classification/test_public_expanded.csv",
"val_fraction": null,
"create_val_set": false,
"loss_weight": 0.75,
"classes": [
"toxicity",
"severe_toxicity",
"obscene",
"identity_attack",
"insult",
"threat",
"sexual_explicit"
],
"identity_classes": [
"male",
"female",
"homosexual_gay_or_lesbian",
"christian",
"jewish",
"muslim",
"black",
"white",
"psychiatric_or_mental_illness"
]
}
},
"optimizer": {
"type": "Adam",
"args": {
"lr": 3e-5,
"weight_decay": 3e-6,
"amsgrad": true
}
}
}
57 changes: 57 additions & 0 deletions convert_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch
from collections import OrderedDict
import hashlib
import argparse


def main():
"""Converts saved checkpoint to the expected format for detoxify.
"""
checkpoint = torch.load(ARGS.checkpoint, map_location=ARGS.device)

new_state_dict = {
"state_dict": OrderedDict(),
"config": checkpoint["hyper_parameters"]["config"],
}
for k, v in checkpoint["state_dict"].items():
if k.startswith("model."):
k = k[6:] # remove `model.`
new_state_dict["state_dict"][k] = v

torch.save(new_state_dict, ARGS.save_to)

if ARGS.hash:
with open(ARGS.save_to, "rb") as f:
bytes = f.read() # read entire file as bytes
readable_hash = hashlib.sha256(bytes).hexdigest()
print("Hash: ", readable_hash)

torch.save(new_state_dict, ARGS.save_to[:-5] + f"-{readable_hash[:8]}.ckpt")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--ckeckpoint",
type=str,
help="path to model checkpoint",
)
parser.add_argument(
"--save_to",
type=str,
help="path to save the model to",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
help="device to load the checkpoint on",
)
parser.add_argument(
"--hash",
type=bool,
default=True,
help="option to save hash in name",
)
ARGS = parser.parse_args()
main()
2 changes: 1 addition & 1 deletion detoxify/detoxify.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

MODEL_URLS = {
"original": "https://github.com/unitaryai/detoxify/releases/download/v0.1-alpha/toxic_original-c1212f89.ckpt",
"unbiased": "https://github.com/unitaryai/detoxify/releases/download/v0.1-alpha/toxic_bias-4e693588.ckpt",
"unbiased": "https://github.com/unitaryai/detoxify/releases/download/v0.3-alpha/toxic_debiased-c7548aa0.ckpt",
"multilingual": "https://github.com/unitaryai/detoxify/releases/download/v0.1-alpha/toxic_multilingual-bbddc277.ckpt",
"original-small": "https://github.com/unitaryai/detoxify/releases/download/v0.1.2/original-albert-0e1d6498.ckpt",
"unbiased-small": "https://github.com/unitaryai/detoxify/releases/download/v0.1.2/unbiased-albert-c8519128.ckpt"
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

setup(
name="detoxify",
version="0.2.2",
version="0.3.0",
description="A python library for detecting toxic comments",
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down
63 changes: 41 additions & 22 deletions src/data_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,41 @@ def __len__(self):

def load_data(self, train_csv_file):
files = []
cols = ["id", "comment_text", "toxic"]
change_names = {
"target": "toxicity",
"toxic": "toxicity",
"identity_hate": "identity_attack",
"severe_toxic": "severe_toxicity",
}
for file in tqdm(train_csv_file):
file_df = pd.read_csv(file)
file_df = file_df[cols]
file_df = file_df.astype({"id": "string"}, {"toxic": "float64"})
chunks = []
for chunk in pd.read_csv(file, chunksize=100000):
chunks.append(chunk)

file_df = pd.concat(chunks, axis=0)
filtered_change_names = {
k: v for k, v in change_names.items() if k in file_df.columns
}
if len(filtered_change_names) > 0:
file_df.rename(columns=filtered_change_names, inplace=True)
file_df = file_df.astype({"id": "string"})
files.append(file_df)
train = pd.concat(files)

train = pd.concat(files, join="outer")
return train

def filter_entry_labels(self, entry, classes, threshold=0.5, soft_labels=False):
target = {
label: -1 if label not in entry or entry[label] is None else entry[label]
for label in classes
}
if not soft_labels:
target.update({label: 1 for label in target if target[label] >= threshold})
target.update(
{label: 0 for label in target if 0 <= target[label] < threshold}
)
return target


class JigsawDataOriginal(JigsawData):
"""Dataloader for the original Jigsaw Toxic Comment Classification Challenge.
Expand Down Expand Up @@ -128,10 +154,11 @@ def __init__(
loss_weight=0.75,
classes=["toxic"],
identity_classes=["female"],
soft_labels=False,
):

self.classes = classes

self.soft_labels = soft_labels
self.identity_classes = identity_classes

super().__init__(
Expand All @@ -156,21 +183,12 @@ def __getitem__(self, index):
text_id = entry["id"]
text = entry["comment_text"]

target_dict = {label: 1 if entry[label] >= 0.5 else 0 for label in self.classes}

identity_target = {
label: -1 if entry[label] is None else entry[label]
for label in self.identity_classes
}
identity_target.update(
{label: 1 for label in identity_target if identity_target[label] >= 0.5}
target_dict = self.filter_entry_labels(
entry,
self.classes + self.identity_classes,
threshold=0.5,
soft_labels=self.soft_labels,
)
identity_target.update(
{label: 0 for label in identity_target if 0 <= identity_target[label] < 0.5}
)

target_dict.update(identity_target)

meta["multi_target"] = torch.tensor(
list(target_dict.values()), dtype=torch.float32
)
Expand All @@ -192,15 +210,16 @@ def __getitem__(self, index):
return text, meta

def compute_weigths(self, train_df):
"""Inspired from 2nd solution.
"""Inspired from 2nd best solution.
Source: https://www.kaggle.com/c/jigsaw-unintended-bias-in-toxicity-classification/discussion/100661"""
subgroup_bool = (train_df[self.identity_classes].fillna(0) >= 0.5).sum(
axis=1
) > 0
positive_bool = train_df["toxicity"] >= 0.5
weights = np.ones(len(train_df)) * 0.25

# Backgroud Positive and Subgroup Negative
# Background Positive and Subgroup Negative
# i.e. weigh higher toxic comments that don't mention identity and non toxic ones that mention it
weights[
((~subgroup_bool) & (positive_bool)) | ((subgroup_bool) & (~positive_bool))
] += 0.25
Expand Down
1 change: 0 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ def cli_main():

if args.device is not None:
config["device"] = args.device
os.environ["CUDA_VISIBLE_DEVICES"] = args.device

# data
def get_instance(module, name, config, *args, **kwargs):
Expand Down

0 comments on commit bdc84bd

Please sign in to comment.