Skip to content

Commit c06a440

Browse files
authored
bump: version 0.3.0 new DeepTensor API (#33)
* standarding optimizer class * update * version 0.3.0
1 parent ecb0e82 commit c06a440

File tree

4 files changed

+93
-155
lines changed

4 files changed

+93
-155
lines changed

csrc/main.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,27 +196,32 @@ PYBIND11_MODULE(_core, m) {
196196

197197
// Optimzer class
198198
py::class_<Optimizer, std::shared_ptr<Optimizer>>(m, "Optimizer")
199-
.def("step", &Optimizer::step);
199+
.def("step", &Optimizer::step)
200+
.def("zero_grad", &Optimizer::zero_grad);
200201

201202
py::class_<SGD, std::shared_ptr<SGD>>(m, "SGD")
202203
.def(py::init<std::shared_ptr<Model>, double>())
203204
.def_readwrite("learning_rate", &SGD::learning_rate)
205+
.def("zero_grad", &SGD::zero_grad)
204206
.def("step", &SGD::step);
205207

206208
py::class_<Momentum, std::shared_ptr<Momentum>>(m, "Momentum")
207209
.def(py::init<std::shared_ptr<Model>, double, double>())
208210
.def_readwrite("learning_rate", &Momentum::learning_rate)
211+
.def("zero_grad", &Momentum::zero_grad)
209212
.def_readwrite("decay_factor", &Momentum::decay_factor)
210213
.def("step", &Momentum::step);
211214

212215
py::class_<AdaGrad, std::shared_ptr<AdaGrad>>(m, "AdaGrad")
213216
.def(py::init<std::shared_ptr<Model>, double>())
214217
.def_readwrite("learning_rate", &AdaGrad::learning_rate)
218+
.def("zero_grad", &AdaGrad::zero_grad)
215219
.def("step", &AdaGrad::step);
216220

217221
py::class_<RMSprop, std::shared_ptr<RMSprop>>(m, "RMSprop")
218222
.def(py::init<std::shared_ptr<Model>, double>())
219223
.def(py::init<std::shared_ptr<Model>, double, double>())
224+
.def("zero_grad", &RMSprop::zero_grad)
220225
.def_readwrite("learning_rate", &RMSprop::learning_rate)
221226
.def_readwrite("decay_factor", &RMSprop::decay_factor)
222227
.def("step", &RMSprop::step);
@@ -225,6 +230,7 @@ PYBIND11_MODULE(_core, m) {
225230
.def(py::init<std::shared_ptr<Model>, double>())
226231
.def(py::init<std::shared_ptr<Model>, double, double, double>())
227232
.def_readwrite("learning_rate", &Adam::learning_rate)
233+
.def("zero_grad", &Adam::zero_grad)
228234
.def_readwrite("beta1", &Adam::beta1)
229235
.def_readwrite("beta2", &Adam::beta2)
230236
.def("step", &Adam::step);

csrc/optimizer.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ class Optimizer {
99
public:
1010
virtual ~Optimizer() = default;
1111
virtual void step() = 0;
12+
virtual void zero_grad() = 0;
1213
};
1314

1415
// stochastic gradient descent
@@ -27,6 +28,10 @@ class SGD : public Optimizer {
2728
e->data = e->data - this->learning_rate * e->grad;
2829
}
2930
}
31+
32+
void zero_grad() override {
33+
m->zero_grad();
34+
}
3035
};
3136

3237
// SGD with Momentum
@@ -55,6 +60,10 @@ class Momentum : public Optimizer {
5560
m_para[i]->data = m_para[i]->data - this->learning_rate * velocity[i];
5661
}
5762
}
63+
64+
void zero_grad() override {
65+
m->zero_grad();
66+
}
5867
};
5968

6069
// Nesterov Accelerated Gradient (NAG) - we need to compute gradient at
@@ -88,6 +97,10 @@ class Momentum : public Optimizer {
8897
// m_para[i]->data = m_para[i]->data - this->learning_rate * velocity[i];
8998
// }
9099
// }
100+
101+
// void zero_grad() override {
102+
// m->zero_grad();
103+
// }
91104
// };
92105

93106
// AdaGrad (Adaptive Gradient Algorithm) - great for sparse datasets
@@ -114,6 +127,10 @@ class AdaGrad : public Optimizer {
114127
std::sqrt(prev_grad_square[i] + this->epsilon);
115128
}
116129
}
130+
131+
void zero_grad() override {
132+
m->zero_grad();
133+
}
117134
};
118135

119136
// RMSProp (Root Mean Square Propagation)
@@ -153,6 +170,10 @@ class RMSprop : public Optimizer {
153170
std::sqrt(prev_grad_square[i] + epsilon);
154171
}
155172
}
173+
174+
void zero_grad() override {
175+
m->zero_grad();
176+
}
156177
};
157178

158179
// ADAM (Adaptive Moment Estimation)
@@ -215,4 +236,8 @@ class Adam : public Optimizer {
215236
}
216237
this->time++;
217238
}
239+
240+
void zero_grad() override {
241+
m->zero_grad();
242+
}
218243
};

demo/new_model_api.ipynb

Lines changed: 60 additions & 153 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "scikit_build_core.build"
44

55
[project]
66
name = "deeptensor"
7-
version = "0.2.0" # new api
7+
version = "0.3.0" # new api
88
url = "https://github.com/deependujha/deeptensor"
99
readme = "README.md"
1010
authors = [

0 commit comments

Comments
 (0)