Skip to content

Commit

Permalink
fix: avoid reloading datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
PJEstrada committed Sep 1, 2021
1 parent 0d85553 commit a13977c
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 10 deletions.
10 changes: 6 additions & 4 deletions sdk/diffgram/core/directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,20 +152,22 @@ def to_pytorch(self, transform = None):
Transforms the file list inside the dataset into a pytorch dataset.
:return:
"""
file_id_list = self.all_file_ids()
file_id_list = self.file_id_list
pytorch_dataset = DiffgramPytorchDataset(
project = self.client,
diffgram_file_id_list = file_id_list,
transform = transform
transform = transform,
validate_ids = False

)
return pytorch_dataset

def to_tensorflow(self):
file_id_list = self.all_file_ids()
file_id_list = self.file_id_list
diffgram_tensorflow_dataset = DiffgramTensorflowDataset(
project = self.client,
diffgram_file_id_list = file_id_list
diffgram_file_id_list = file_id_list,
validate_ids = False
)
return diffgram_tensorflow_dataset

Expand Down
6 changes: 4 additions & 2 deletions sdk/diffgram/core/sliced_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def to_pytorch(self, transform = None):
pytorch_dataset = DiffgramPytorchDataset(
project = self.client,
diffgram_file_id_list = self.file_id_list,
transform = transform
transform = transform,
validate_ids = False

)
return pytorch_dataset
Expand All @@ -59,6 +60,7 @@ def to_tensorflow(self):
file_id_list = self.all_file_ids()
diffgram_tensorflow_dataset = DiffgramTensorflowDataset(
project = self.client,
diffgram_file_id_list = file_id_list
diffgram_file_id_list = file_id_list,
validate_ids = False
)
return diffgram_tensorflow_dataset
4 changes: 2 additions & 2 deletions sdk/diffgram/pytorch_diffgram/diffgram_pytorch_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@

class DiffgramPytorchDataset(DiffgramDatasetIterator, Dataset):

def __init__(self, project, diffgram_file_id_list = None, transform = None):
def __init__(self, project, diffgram_file_id_list = None, transform = None, validate_ids = True):
"""
:param project (sdk.core.core.Project): A Project object from the Diffgram SDK
:param diffgram_file_list (list): An arbitrary number of file ID's from Diffgram.
:param transform (callable, optional): Optional transforms to be applied on a sample
"""
super(DiffgramPytorchDataset, self).__init__(project, diffgram_file_id_list)
super(DiffgramPytorchDataset, self).__init__(project, diffgram_file_id_list, validate_ids)

self.diffgram_file_id_list = diffgram_file_id_list

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@

class DiffgramTensorflowDataset(DiffgramDatasetIterator):

def __init__(self, project, diffgram_file_id_list):
def __init__(self, project, diffgram_file_id_list, validate_ids = True):
"""
:param project (sdk.core.core.Project): A Project object from the Diffgram SDK
:param diffgram_file_list (list): An arbitrary number of file ID's from Diffgram.
:param transform (callable, optional): Optional transforms to be applied on a sample
"""
super(DiffgramTensorflowDataset, self).__init__(project, diffgram_file_id_list)
super(DiffgramTensorflowDataset, self).__init__(project, diffgram_file_id_list, validate_ids)

self.diffgram_file_id_list = diffgram_file_id_list

Expand Down

0 comments on commit a13977c

Please sign in to comment.