Skip to content

Commit

Permalink
Do not force set path for cache/data
Browse files Browse the repository at this point in the history
  • Loading branch information
iamgroot42 committed Jan 10, 2024
1 parent f9e0e15 commit b02694a
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 33 deletions.
8 changes: 2 additions & 6 deletions .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,14 @@ jobs:
test:
runs-on: ubuntu-latest

strategy:
matrix:
python-version: [3.9, 3.10]

steps:
- name: Checkout Repository
uses: actions/checkout@v3

- name: Set up Python ${{ matrix.python-version }}
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
python-version: '3.9' # Specify your Python version

- name: Install pdoc
run: pip install pdoc3 # Install pdoc or pdoc3
Expand Down
62 changes: 40 additions & 22 deletions mimir/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,25 @@
from dataclasses import dataclass
from typing import Optional, List
from simple_parsing.helpers import Serializable, field
from mimir.utils import CACHE_PATH, DATA_SOURCE
from mimir.utils import get_cache_path, get_data_source


@dataclass
class ReferenceConfig(Serializable):
"""
Config for attacks that use reference models.
Config for attacks that use reference models.
"""

models: List[str]
"""Reference model names"""


@dataclass
class NeighborhoodConfig(Serializable):
"""
Config for neighborhood attack
Config for neighborhood attack
"""

model: str
"""Mask-filling model"""
n_perturbation_list: List[int] = field(default_factory=lambda: [1, 10])
Expand All @@ -35,7 +37,7 @@ class NeighborhoodConfig(Serializable):
"""Swap out token in original text with neighbor token, instead of re-generating text"""
pct_swap_bert: Optional[float] = 0.05
"""Percentage of tokens per neighbor that are different from the original text"""
neighbor_strategy: Optional[str] = 'deterministic'
neighbor_strategy: Optional[str] = "deterministic"
"""Strategy for generating neighbors. One of ['deterministic', 'random']. Deterministic uses only one-word neighbors"""
# T-5 specific hyper-parameters
span_length: Optional[int] = 2
Expand Down Expand Up @@ -65,39 +67,42 @@ def __post_init__(self):
@dataclass
class EnvironmentConfig(Serializable):
"""
Config for environment-specific parameters
Config for environment-specific parameters
"""

cache_dir: Optional[str] = None
"""Path to cache directory"""
data_source: Optional[str] = None
"""Path where data is stored"""
device: Optional[str] = 'cuda:1'
device: Optional[str] = "cuda:1"
"""Device (GPU) to load main model on"""
device_map: Optional[str] = None
"""Configuration for device map if needing to split model across gpus"""
device_aux: Optional[str] = 'cuda:0'
device_aux: Optional[str] = "cuda:0"
"""Device (GPU) to load any auxiliary model(s) on"""
compile: Optional[bool] = True
"""Compile models?"""
int8: Optional[bool] = False
"""Use int8 quantization?"""
half: Optional[bool] = False
"""Use half precision?"""
results: Optional[str] = 'results'
results: Optional[str] = "results"
"""Path for saving final results"""
tmp_results: Optional[str] = 'tmp_results'
tmp_results: Optional[str] = "tmp_results"

def __post_init__(self):
if self.cache_dir is None:
self.cache_dir = CACHE_PATH
self.cache_dir = get_cache_path()
if self.data_source is None:
self.data_source = DATA_SOURCE
self.data_source = get_data_source()


@dataclass
class OpenAIConfig(Serializable):
"""
Config for OpenAI calls
Config for OpenAI calls
"""

key: str
"""OpenAI API key"""
model: str
Expand All @@ -107,17 +112,19 @@ class OpenAIConfig(Serializable):
@dataclass
class ExtractionConfig(Serializable):
"""
Config for model-extraction
Config for model-extraction
"""

prompt_len: Optional[int] = 30
"""Prompt length"""


@dataclass
class ExperimentConfig(Serializable):
"""
Config for attacks
Config for attacks
"""

base_model: str
"""Base model name"""
dataset_member: str
Expand All @@ -132,25 +139,31 @@ class ExperimentConfig(Serializable):
"""Path to presampled dataset source for members"""
presampled_dataset_nonmember: Optional[str] = None
"""Path to presampled dataset source for mpmmembers"""
token_frequency_map: Optional[str] = None # TODO: Handling auxiliary data structures
token_frequency_map: Optional[
str
] = None # TODO: Handling auxiliary data structures
"""Path to a pre-computed token frequency map"""
dataset_key: Optional[str] = None
"""Dataset key"""
output_name: Optional[str] = None
"""Output name for sub-directory. Defaults to nothing"""
specific_source: Optional[str] = None
"""Specific sub-source to focus on. Only valid for the_pile"""
full_doc: Optional[bool] = False # TODO: refactor full_doc design?
full_doc: Optional[bool] = False # TODO: refactor full_doc design?
"""Determines whether MIA will be performed over entire doc or not"""
max_substrs: Optional[int] = 20
"""If full_doc, determines the maximum number of sample substrs to evaluate on"""
dump_cache: Optional[bool] = False
"Dump data to cache? Exits program after dumping"
load_from_cache: Optional[bool] = False
"""Load data from cache?"""
blackbox_attacks: Optional[List[str]] = field(default_factory=lambda: None) # Can replace with "default" attacks if we want
"""List of attacks to evaluate"""
baselines_only: Optional[bool] = False # TODO: to be removed after Neighborhood attack is implemented into blackbox attack flow
blackbox_attacks: Optional[List[str]] = field(
default_factory=lambda: None
) # Can replace with "default" attacks if we want
"""List of attacks to evaluate"""
baselines_only: Optional[
bool
] = False # TODO: to be removed after Neighborhood attack is implemented into blackbox attack flow
"""Evaluate only baselines?"""
tokenization_attack: Optional[bool] = False
"""Run tokenization attack?"""
Expand Down Expand Up @@ -204,7 +217,12 @@ class ExperimentConfig(Serializable):
def __post_init__(self):
if self.dump_cache and self.load_from_cache:
raise ValueError("Cannot dump and load cache at the same time")

if self.neighborhood_config:
if (self.neighborhood_config.dump_cache or self.neighborhood_config.load_from_cache) and not (self.load_from_cache or self.dump_cache):
raise ValueError("Using dump/load for neighborhood cache without dumping/loading main cache does not make sense")
if (
self.neighborhood_config.dump_cache
or self.neighborhood_config.load_from_cache
) and not (self.load_from_cache or self.dump_cache):
raise ValueError(
"Using dump/load for neighborhood cache without dumping/loading main cache does not make sense"
)
27 changes: 22 additions & 5 deletions mimir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,27 @@
import os


# Read environment variables
CACHE_PATH = os.environ.get('MIMIR_CACHE_PATH', None)
if CACHE_PATH is None:
raise ValueError('MIMIR_CACHE_PATH environment variable not set')

DATA_SOURCE = os.environ.get('MIMIR_DATA_SOURCE', None)
if DATA_SOURCE is None:
raise ValueError('MIMIR_DATA_SOURCE environment variable not set')


def get_cache_path():
"""
Get path to cache directory.
Returns:
str: path to cache directory
"""
if CACHE_PATH is None:
raise ValueError('MIMIR_CACHE_PATH environment variable not set')


def get_data_source():
"""
Get path to data source directory.
Returns:
str: path to data source directory
"""
if DATA_SOURCE is None:
raise ValueError('MIMIR_DATA_SOURCE environment variable not set')
return DATA_SOURCE
Empty file added tests/__init__.py
Empty file.
22 changes: 22 additions & 0 deletions tests/test_attacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""
Test attack implementations. Consists of basic execution tests to make sure attack works as expected and returns values as expected.
"""
import pytest


class TestAttack:
def test_attack_shape(self):
# Check 1 - attack accepts inputs in given shape, and works for both text-based and tokenized inputs
pass

def test_attack_scores_shape(self):
# Check 2 - scores returned match exepected shape
pass

def test_attack_score_range(self):
# Check 3 - scores match expected range
pass

def test_attack_auc(self):
# Check 4 (TODO) - Attack AUC is not horribly bad
pass

0 comments on commit b02694a

Please sign in to comment.