Skip to content

Commit 7331c31

Browse files
committed
Add Pytorch training code
1 parent 82d6e00 commit 7331c31

118 files changed

Lines changed: 11216 additions & 2979 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

demo.ipynb

Lines changed: 445 additions & 1 deletion
Large diffs are not rendered by default.

environment_comfortable_py10.yml

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# run: conda env create --file environment.yml
2+
name: py10
3+
channels:
4+
- defaults
5+
dependencies:
6+
- python=3.10
7+
- pip
8+
- cachetools
9+
- av<9
10+
- cython
11+
- ezc3d
12+
- ffmpeg
13+
- imageio
14+
- matplotlib
15+
- mkl
16+
- trimesh
17+
- numba
18+
- numpy<2.0
19+
- libiconv
20+
- pandas
21+
- Pillow
22+
- scikit-image
23+
- scikit-learn
24+
- scikit-sparse
25+
- tqdm
26+
- conda-forge::mayavi>=4.8
27+
- conda-forge::PySide6==6.7.1
28+
- conda-forge::ffmpeg-python
29+
- pip:
30+
- distutils
31+
- addict
32+
- tensorflow==2.15
33+
- tensorflow-hub
34+
- chumpy
35+
- embreex
36+
- einops
37+
- imageio-ffmpeg
38+
- importlib_resources
39+
- jpeg4py
40+
- more_itertools
41+
- opencv-python
42+
- pyrender
43+
- tetgen
44+
- pymeshfix
45+
- -e $CODE_DIR/fleras
46+
- -e $CODE_DIR/affine_combining_autoencoder
47+
- -e $CODE_DIR/cameralib
48+
- -e $CODE_DIR/boxlib
49+
- -e $CODE_DIR/poseviz
50+
- -e $CODE_DIR/smplfitter
51+
- -e $CODE_DIR/simplepyutils
52+
- -e $CODE_DIR/rlemasklib
53+
- -e $CODE_DIR/barecat3
54+
- -e $CODE_DIR/tensorflow-inputs
55+
- -e $CODE_DIR/posepile
56+
- -e $CODE_DIR/blendipose
57+
- -e $CODE_DIR/blendify
58+
- -e $CODE_DIR/nlf-pipeline
59+
# - git+https://github.com/isarandi/fleras.git
60+
# - git+https://github.com/isarandi/cameralib.git
61+
# - git+https://github.com/isarandi/boxlib.git
62+
# - git+https://github.com/isarandi/poseviz.git
63+
# - git+https://github.com/isarandi/smplfitter.git
64+
# - git+https://github.com/isarandi/simplepyutils.git
65+
# - git+https://github.com/isarandi/tf-parallel-map.git
66+
# - git+https://github.com/isarandi/BareCat.git
67+
# - git+https://github.com/isarandi/rlemasklib.git
68+
# - git+https://github.com/isarandi/tensorflow-inputs.git

install_dependencies.sh

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
#!/usr/bin/env bash
2+
set -euo pipefail
3+
4+
# This was tested on Ubuntu 18.04 and 22.04
5+
sudo apt update
6+
sudo apt install build-essential --yes wget curl gfortran git ncurses-dev libncursesw5-dev unzip tar libxcb-xinerama0
7+
8+
################
9+
# This part is ONLY needed if you'll work with the original Human3.6M files.
10+
# For this, we need to install the [CDF library](https://cdf.gsfc.nasa.gov/) since the annotations are in CDF files.
11+
# We will use [SpacePy](https://spacepy.github.io/) as a wrapper, which in turn depends on this CDF library.
12+
CDF_VERSION=39_0
13+
wget "https://spdf.gsfc.nasa.gov/pub/software/cdf/dist/cdf${CDF_VERSION}/linux/cdf${CDF_VERSION}-dist-cdf.tar.gz"
14+
tar xf "cdf${CDF_VERSION}-dist-cdf.tar.gz"
15+
rm "cdf${CDF_VERSION}-dist-cdf.tar.gz"
16+
pushd "cdf${CDF_VERSION}-dist"
17+
make OS=linux ENV=gnu CURSES=no FORTRAN=no UCOPTIONS=-O2 SHARED=yes -j"$(nproc)" all
18+
# If you have sudo rights, simply run `sudo make install`. If you have no `sudo` rights, set the env var
19+
# `CDF_LIB` to `cdf${CDF_VERSION}-dist/src/lib`
20+
# add the export line to ~/.bashrc for permanent effect, or use GNU Stow.
21+
# mv src ~/.local/stow/cdf-${CDF_VERSION}
22+
# stow cdf-${CDF_VERSION}
23+
# The following will work temporarily:
24+
export CDF_LIB=$PWD/src/lib
25+
popd
26+
####################
27+
28+
# Micromamba is the simplest way to install the dependencies
29+
# If you don't have it yet, install it as follows:
30+
export MAMBA_ROOT_PREFIX=$HOME/micromamba
31+
mkdir -p $MAMBA_ROOT_PREFIX
32+
curl -Ls https://micro.mamba.pm/api/micromamba/linux-64/latest | tar -xvj -C $MAMBA_ROOT_PREFIX bin/micromamba
33+
# Execute and add to ~/.bashrc the following two lines
34+
export MAMBA_ROOT_PREFIX=$HOME/micromamba
35+
eval "$($MAMBA_ROOT_PREFIX/bin/micromamba shell hook -s posix)"
36+
#
37+
38+
# Create a new environment and install the dependencies
39+
envsubst < environment_comfortable_py10.yml > env_subst.yml
40+
micromamba env create --name=nlf --file=env_subst.yml -y
41+
micromamba activate nlf
42+
pip install --no-build-isolation git+https://github.com/spacepy/spacepy
43+
44+
45+
micromamba install -y -c conda-forge \
46+
cachetools \
47+
cython \
48+
ezc3d \
49+
ffmpeg \
50+
imageio \
51+
matplotlib \
52+
mkl \
53+
trimesh \
54+
numba \
55+
"numpy<2.0" \
56+
libiconv \
57+
pandas \
58+
pillow \
59+
scikit-image \
60+
scikit-learn \
61+
scikit-sparse \
62+
tqdm \
63+
conda-forge::mayavi>=4.8 \
64+
conda-forge::PySide6==6.7.1 \
65+
conda-forge::ffmpeg-python
66+
67+
pip install \
68+
setuptools \
69+
addict \
70+
tensorflow==2.15 \
71+
tensorflow-hub \
72+
torch \
73+
torchvision \
74+
torchdata \
75+
chumpy \
76+
embreex \
77+
einops \
78+
imageio-ffmpeg \
79+
importlib_resources \
80+
jpeg4py \
81+
more_itertools \
82+
opencv-python \
83+
pyrender \
84+
tetgen \
85+
pymeshfix
86+
87+
88+
# Optional:
89+
# Install libjpeg-turbo for faster JPEG decoding.
90+
# wget https://sourceforge.net/projects/libjpeg-turbo/files/2.0.5/libjpeg-turbo-2.0.5.tar.gz
91+
# Then compile it.
92+
# Or use the repo:
93+
# git clone https://github.com/libjpeg-turbo/libjpeg-turbo.git
94+
# cd libjpeg-turbo
95+
# PACKAGE_NAME=libjpeg-turbo
96+
# TARGET=$HOME/.local
97+
# sudo apt install nasm
98+
# cmake -DCMAKE_INSTALL_PREFIX="$TARGET" -DCMAKE_POSITION_INDEPENDENT_CODE=ON -G"Unix Makefiles" .
99+
# TEMP_DESTDIR=$(mktemp --directory --tmpdir="$STOW_DIR")
100+
# make -j "$(nproc)" install DESTDIR="$TEMP_DESTDIR"
101+
# mv -T "$TEMP_DESTDIR/$TARGET" "$STOW_DIR/$PACKAGE_NAME"
102+
# rm -rf "$TEMP_DESTDIR"
103+
# stow "$PACKAGE_NAME" --target="$TARGET"
104+
105+
- posepile image barecat
106+
- anno barecats (4)
107+
- code projects
108+
- micromamba, env install
109+
- cuda, cudnn copy
110+
- set up project initializer bashrc command that sets envvars, activates env, cd to project
111+
- wacv23_models
112+
- stuff from $DATA_ROOT/cache
113+
- projects/localizerfields
114+
115+
116+
#sudo apt-get install libavformat-dev libavdevice-dev
117+
#pip install av --no-binary av
118+
# https://stackoverflow.com/questions/72604912/cant-show-image-with-opencv-when-importing-av
Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1+
import cv2
12
import numpy as np
23
from simplepyutils import FLAGS
34

4-
import nlf.tf.augmentation.color as coloraug
5-
from nlf.tf import improc, util
6-
from nlf.tf.augmentation import voc_loader
7-
from nlf.tf.augmentation.border import augment_border
8-
from nlf.tf.util import TRAIN
9-
import cv2
5+
import nlf.common.augmentation.color as coloraug
6+
from nlf.common import improc, util
7+
from nlf.common.util import TRAIN
8+
from nlf.common.augmentation import voc_loader
9+
from nlf.common.augmentation.border import augment_border
1010

1111

1212
def augment_appearance(im, learning_phase, occlude_prob, border_value, rng):
@@ -46,8 +46,9 @@ def augment_appearance(im, learning_phase, occlude_prob, border_value, rng):
4646
if FLAGS.jpeg_aug_prob:
4747
im = jpeg_artifact(im, jpeg_rng)
4848

49-
if FLAGS.augment_border and border_value is not None:
50-
im = augment_border(im, border_value, border_rng)
49+
if FLAGS.augment_border_prob and border_value is not None:
50+
if border_rng.uniform(0.0, 1.0) < FLAGS.augment_border_prob:
51+
im = augment_border(im, border_value, border_rng)
5152

5253
return im
5354

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
import functools
2+
import os.path as osp
23

4+
import barecat
35
import cameralib
46
import cv2
57
import numpy as np
68
import simplepyutils as spu
79
from posepile.paths import DATA_ROOT
810
from simplepyutils import FLAGS
9-
import barecat
10-
from nlf.tf import improc, util
11-
import os.path as osp
11+
12+
from nlf.common import improc, util
13+
1214

1315
@functools.lru_cache()
1416
def get_inria_holiday_background_paths():
Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import cv2
22
import numpy as np
3-
from simplepyutils import FLAGS
43

54

65
def augment_border(im, border_value, rng):
@@ -11,15 +10,15 @@ def augment_border(im, border_value, rng):
1110
size = 0.08
1211
angle = np.deg2rad(8)
1312

14-
def get_angle():
13+
def random_angle():
1514
if rng.uniform(0, 1) < 0.9:
1615
return rng.uniform(-angle, angle)
1716
else:
1817
return rng.uniform(-angle, angle) * 1.5
1918

2019
# top:
2120
if rng.uniform(0, 1) < p:
22-
alpha = get_angle()
21+
alpha = random_angle()
2322
d = rng.uniform(0, h * size)
2423
y1 = - np.tan(alpha) * h / 2 + d
2524
y2 = + np.tan(alpha) * h / 2 + d
@@ -30,7 +29,7 @@ def get_angle():
3029

3130
# bottom:
3231
if rng.uniform(0, 1) < p:
33-
alpha = get_angle()
32+
alpha = random_angle()
3433
d = rng.uniform(0, h * size)
3534
y1 = h - np.tan(alpha) * h / 2 - d
3635
y2 = h + np.tan(alpha) * h / 2 - d
@@ -41,7 +40,7 @@ def get_angle():
4140

4241
# left:
4342
if rng.uniform(0, 1) < p:
44-
alpha = get_angle()
43+
alpha = random_angle()
4544
d = rng.uniform(0, w * size)
4645
x1 = - np.tan(alpha) * w / 2 + d
4746
x2 = + np.tan(alpha) * w / 2 + d
@@ -52,7 +51,7 @@ def get_angle():
5251

5352
# right:
5453
if rng.uniform(0, 1) < p:
55-
alpha = get_angle()
54+
alpha = random_angle()
5655
d = rng.uniform(0, w * size)
5756
x1 = w - np.tan(alpha) * w / 2 - d
5857
x2 = w + np.tan(alpha) * w / 2 - d

nlf/common/augmentation/color.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import cv2
2+
import numpy as np
3+
import numba
4+
5+
6+
def augment_color(im, rng, out_dtype=None):
7+
if out_dtype is None:
8+
out_dtype = im.dtype
9+
10+
if im.dtype == np.uint8:
11+
im = cv2.divide(im, (255, 255, 255, 255), dtype=cv2.CV_32F)
12+
13+
augmentation_functions = [
14+
augment_brightness,
15+
augment_contrast,
16+
augment_hue,
17+
augment_saturation,
18+
]
19+
rng.shuffle(augmentation_functions)
20+
21+
colorspace = 'rgb'
22+
for fn in augmentation_functions:
23+
colorspace = fn(im, colorspace, rng)
24+
25+
if colorspace != 'rgb':
26+
cv2.cvtColor(im, cv2.COLOR_HSV2RGB, dst=im)
27+
28+
np.clip(im, 0, 1, out=im)
29+
30+
if out_dtype == np.uint8:
31+
return (im * 255).astype(np.uint8)
32+
else:
33+
return im
34+
35+
36+
def augment_brightness(im, in_colorspace, rng):
37+
if in_colorspace != 'rgb':
38+
cv2.cvtColor(im, cv2.COLOR_HSV2RGB, dst=im)
39+
40+
im += rng.uniform(-0.125, 0.125)
41+
return 'rgb'
42+
43+
44+
def augment_contrast(im, in_colorspace, rng):
45+
if in_colorspace != 'rgb':
46+
cv2.cvtColor(im, cv2.COLOR_HSV2RGB, dst=im)
47+
# im -= 0.5
48+
# im *= rng.uniform(0.5, 1.5)
49+
# im += 0.5
50+
_augment_contrast_nb(im, rng.uniform(0.5, 1.5))
51+
return 'rgb'
52+
53+
54+
@numba.njit(cache=True)
55+
def _augment_contrast_nb(im, factor):
56+
im_flat = im.reshape(-1)
57+
factor32 = np.float32(factor)
58+
offset = np.float32(-0.5) * factor32 + np.float32(0.5)
59+
for i in range(im_flat.shape[0]):
60+
im_flat[i] = im_flat[i] * factor32 + offset
61+
62+
63+
def augment_hue(im, in_colorspace, rng):
64+
if in_colorspace != 'hsv':
65+
np.clip(im, 0, 1, out=im)
66+
cv2.cvtColor(im, cv2.COLOR_RGB2HSV, dst=im)
67+
# hue = im[:, :, 0]
68+
# hue += rng.uniform(-72, 72)
69+
# hue[hue < 0] += 360
70+
# hue[hue > 360] -= 360
71+
_augment_hue_nb(im, rng.uniform(-72, 72))
72+
return 'hsv'
73+
74+
75+
@numba.njit(cache=True)
76+
def _augment_hue_nb(im, offset):
77+
im_flat = im.reshape(-1, 3)
78+
for i in range(im_flat.shape[0]):
79+
im_flat[i, 0] += offset
80+
if im_flat[i, 0] < 0:
81+
im_flat[i, 0] += 360
82+
elif im_flat[i, 0] > 360:
83+
im_flat[i, 0] -= 360
84+
85+
86+
def augment_saturation(im, in_colorspace, rng):
87+
if in_colorspace != 'hsv':
88+
np.clip(im, 0, 1, out=im)
89+
cv2.cvtColor(im, cv2.COLOR_RGB2HSV, dst=im)
90+
91+
# saturation = im[:, :, 1]
92+
# saturation *= rng.uniform(0.5, 1.5)
93+
# saturation[saturation > 1] = 1
94+
_augment_saturation_nb(im, rng.uniform(0.5, 1.5))
95+
return 'hsv'
96+
97+
98+
@numba.njit(cache=True)
99+
def _augment_saturation_nb(im, factor):
100+
im_flat = im.reshape(-1, 3)
101+
for i in range(im_flat.shape[0]):
102+
im_flat[i, 1] *= factor
103+
if im_flat[i, 1] > 1:
104+
im_flat[i, 1] = 1

0 commit comments

Comments
 (0)