-
Notifications
You must be signed in to change notification settings - Fork 29
Expand file tree
/
Copy pathtest_preprocess.py
More file actions
27 lines (22 loc) · 1011 Bytes
/
test_preprocess.py
File metadata and controls
27 lines (22 loc) · 1011 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import numpy as np
from gluonts.dataset.common import TrainDatasets
from gluonts.dataset.artificial import constant_dataset
from preprocess import MaxNormalize
def test_max_normalize():
info, train_ds, test_ds = constant_dataset()
datasets = TrainDatasets(info.metadata, train_ds, test_ds)
normalize = MaxNormalize(datasets).apply()
assert normalize.datasets.metadata == datasets.metadata
for i, train_data in enumerate(normalize.datasets.train):
train = train_data["target"]
if i == 0:
assert np.all(train == np.zeros(len(train), dtype=np.float32))
else:
assert np.all(train == np.ones(len(train), dtype=np.float32))
assert normalize.datasets.test is not None
for i, test_data in enumerate(normalize.datasets.test):
test = test_data["target"]
if i == 0:
assert np.all(test == np.zeros(len(test), dtype=np.float32))
else:
assert np.all(test == np.ones(len(test), dtype=np.float32))