File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -33,11 +33,31 @@ DomainRobust includes the following datasets:
3333
3434## Quick start
3535
36+ To download the datasets:
37+ ``` sh
38+ python3 -m domainrobust.download \
39+ --data_dir=./domainrobust/data
40+ ```
41+
42+ We first generate a pretrained model using DANN (this pretrained model will be used by algorithms like SROUDA and DART):
43+
44+ ``` sh
45+ python3 -m domainrobust.scripts.train\
46+ --data_dir=/my/datasets/path\
47+ --output_dir=/my/pretrained/model/path\
48+ --algorithm DANN\
49+ --dataset DIGIT\
50+ --task domain_adaptation\
51+ --source_envs 0\
52+ --target_envs 2
53+ ```
54+
3655To train a single model:
3756
3857``` sh
39- python3 -m scripts.train\
58+ python3 -m domainrobust. scripts.train\
4059 --data_dir=/my/datasets/path\
60+ --output_dir=/output/path\
4161 --algorithm AT\
4262 --dataset DIGIT\
4363 --task domain_adaptation\
@@ -55,7 +75,7 @@ python3 -m scripts.train\
5575To launch a sweep (over a range of hyperparameters and possibly multiple algorithms and datasets):
5676
5777``` sh
58- python -m scripts.sweep launch\
78+ python -m domainrobust. scripts.sweep launch\
5979 --data_dir=/my/datasets/path\
6080 --output_dir=/my/sweep/output/path\
6181 --command_launcher MyLauncher\
@@ -100,4 +120,4 @@ python -m domainrobust.scripts.test\
100120 --atk_lr 0.004 \
101121 --atk_iter 20 \
102122 --attack pgd
103- ````
123+ ````
Original file line number Diff line number Diff line change @@ -54,10 +54,6 @@ def download_digit(data_dir):
5454 download_and_extract ("https://drive.google.com/uc?id=1wnY5M74Zj-lPuOQzcmVO9KQVLVhgK7Vu" ,
5555 os .path .join (data_dir , "DIGIT.zip" ))
5656
57- os .rename (os .path .join (data_dir , "kfold" ),
58- full_path )
59-
60-
6157
6258# PACS ########################################################################
6359
@@ -94,8 +90,6 @@ def download_visda(data_dir):
9490 download_and_extract ("https://drive.google.com/uc?id=1VrIlU6yrm-XTcwfpRIWuALdTAYulgdYp" ,
9591 os .path .join (data_dir , "VISDA.zip" ))
9692
97- os .rename (os .path .join (data_dir , "kfold" ),
98- full_path )
9993
10094if __name__ == "__main__" :
10195 parser = argparse .ArgumentParser (description = 'Download datasets' )
Original file line number Diff line number Diff line change 1- numpy == 1.20.3
1+ numpy == 1.21.6
22wilds == 1.2.2
3- imageio == 2.9.0
4- gdown == 3.13.0
5- torchvision == 0.8.2
6- torch == 1.7 .1
7- tqdm == 4.62 .2
3+ imageio == 2.31.2
4+ gdown == 4.7.3
5+ torchvision == 0.14.1
6+ torch == 1.13 .1
7+ tqdm == 4.66 .2
88backpack == 0.1
99parameterized == 0.8.1
1010Pillow == 8.3.2
You can’t perform that action at this time.
0 commit comments