Skip to content

Commit ab54804

Browse files
authored
Merge pull request #1 from yunjuanwang/fix-readme-download-requirements
Fix README, requirements, download files
2 parents 2907f7b + 69e97a8 commit ab54804

3 files changed

Lines changed: 29 additions & 15 deletions

File tree

README.md

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,31 @@ DomainRobust includes the following datasets:
3333

3434
## Quick start
3535

36+
To download the datasets:
37+
```sh
38+
python3 -m domainrobust.download \
39+
--data_dir=./domainrobust/data
40+
```
41+
42+
We first generate a pretrained model using DANN (this pretrained model will be used by algorithms like SROUDA and DART):
43+
44+
```sh
45+
python3 -m domainrobust.scripts.train\
46+
--data_dir=/my/datasets/path\
47+
--output_dir=/my/pretrained/model/path\
48+
--algorithm DANN\
49+
--dataset DIGIT\
50+
--task domain_adaptation\
51+
--source_envs 0\
52+
--target_envs 2
53+
```
54+
3655
To train a single model:
3756

3857
```sh
39-
python3 -m scripts.train\
58+
python3 -m domainrobust.scripts.train\
4059
--data_dir=/my/datasets/path\
60+
--output_dir=/output/path\
4161
--algorithm AT\
4262
--dataset DIGIT\
4363
--task domain_adaptation\
@@ -55,7 +75,7 @@ python3 -m scripts.train\
5575
To launch a sweep (over a range of hyperparameters and possibly multiple algorithms and datasets):
5676

5777
```sh
58-
python -m scripts.sweep launch\
78+
python -m domainrobust.scripts.sweep launch\
5979
--data_dir=/my/datasets/path\
6080
--output_dir=/my/sweep/output/path\
6181
--command_launcher MyLauncher\
@@ -100,4 +120,4 @@ python -m domainrobust.scripts.test\
100120
--atk_lr 0.004 \
101121
--atk_iter 20 \
102122
--attack pgd
103-
````
123+
````

domainrobust/download.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,6 @@ def download_digit(data_dir):
5454
download_and_extract("https://drive.google.com/uc?id=1wnY5M74Zj-lPuOQzcmVO9KQVLVhgK7Vu",
5555
os.path.join(data_dir, "DIGIT.zip"))
5656

57-
os.rename(os.path.join(data_dir, "kfold"),
58-
full_path)
59-
60-
6157

6258
# PACS ########################################################################
6359

@@ -94,8 +90,6 @@ def download_visda(data_dir):
9490
download_and_extract("https://drive.google.com/uc?id=1VrIlU6yrm-XTcwfpRIWuALdTAYulgdYp",
9591
os.path.join(data_dir, "VISDA.zip"))
9692

97-
os.rename(os.path.join(data_dir, "kfold"),
98-
full_path)
9993

10094
if __name__ == "__main__":
10195
parser = argparse.ArgumentParser(description='Download datasets')

requirements.txt

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
numpy==1.20.3
1+
numpy==1.21.6
22
wilds==1.2.2
3-
imageio==2.9.0
4-
gdown==3.13.0
5-
torchvision==0.8.2
6-
torch==1.7.1
7-
tqdm==4.62.2
3+
imageio==2.31.2
4+
gdown==4.7.3
5+
torchvision==0.14.1
6+
torch==1.13.1
7+
tqdm==4.66.2
88
backpack==0.1
99
parameterized==0.8.1
1010
Pillow==8.3.2

0 commit comments

Comments
 (0)