summaryrefslogtreecommitdiff
path: root/aten
diff options
context:
space:
mode:
authorZachary DeVito <zdevito@gmail.com>2018-02-02 17:45:59 -0800
committerGitHub <noreply@github.com>2018-02-02 17:45:59 -0800
commitc308e03f3eadf8d214884eedf43428941171fa3d (patch)
treec05c7b0f8f5c75f428db233b5a8751096736c50b /aten
parent4ae05799fa08ff040be66ae8bba66c2a73eb4993 (diff)
downloadpytorch-c308e03f3eadf8d214884eedf43428941171fa3d.tar.gz
pytorch-c308e03f3eadf8d214884eedf43428941171fa3d.tar.bz2
pytorch-c308e03f3eadf8d214884eedf43428941171fa3d.zip
Initial GraphExecutor Implementation. (#4982)
This adds the initial implementation of graph executor for the new JIT design. It includes a few python tests ensuring that nograd, backward, and double-backward cases work for simple examples and some corner cases. More work needs to be done to performance optimize as there are many extra copies and places where we hold onto variables longer than we should. These are noted in the comments.
Diffstat (limited to 'aten')
-rw-r--r--aten/src/ATen/ATenAssert.h10
-rw-r--r--aten/src/ATen/ArrayRef.h10
-rw-r--r--aten/src/ATen/Utils.h6
3 files changed, 16 insertions, 10 deletions
diff --git a/aten/src/ATen/ATenAssert.h b/aten/src/ATen/ATenAssert.h
new file mode 100644
index 0000000000..3e840fe2f1
--- /dev/null
+++ b/aten/src/ATen/ATenAssert.h
@@ -0,0 +1,10 @@
+#include "ATenGeneral.h"
+
+namespace at {
+
+#define AT_ASSERT(cond, ...) if (! (cond) ) { at::runtime_error(__VA_ARGS__); }
+
+[[noreturn]]
+AT_API void runtime_error(const char *format, ...);
+
+}
diff --git a/aten/src/ATen/ArrayRef.h b/aten/src/ATen/ArrayRef.h
index ff6ae09dec..fe425adfa5 100644
--- a/aten/src/ATen/ArrayRef.h
+++ b/aten/src/ATen/ArrayRef.h
@@ -17,6 +17,7 @@
#include <assert.h>
#include <array>
#include <vector>
+#include "ATenAssert.h"
namespace at {
/// ArrayRef - Represent a constant reference to an array (0 or more elements
@@ -104,13 +105,13 @@ namespace at {
/// front - Get the first element.
const T &front() const {
- assert(!empty());
+ AT_ASSERT(!empty(), "Empty list!");
return Data[0];
}
/// back - Get the last element.
const T &back() const {
- assert(!empty());
+ AT_ASSERT(!empty(), "Empty list!");
return Data[Length-1];
}
@@ -124,7 +125,7 @@ namespace at {
/// slice(n, m) - Chop off the first N elements of the array, and keep M
/// elements in the array.
ArrayRef<T> slice(size_t N, size_t M) const {
- assert(N+M <= size() && "Invalid specifier");
+ AT_ASSERT(N+M <= size(), "Invalid specifier");
return ArrayRef<T>(data()+N, M);
}
@@ -135,13 +136,12 @@ namespace at {
/// @name Operator Overloads
/// @{
const T &operator[](size_t Index) const {
- assert(Index < Length && "Invalid index!");
return Data[Index];
}
/// Vector compatibility
const T &at(size_t Index) const {
- assert(Index < Length && "Invalid index!");
+ AT_ASSERT(Index < Length, "Invalid index!");
return Data[Index];
}
diff --git a/aten/src/ATen/Utils.h b/aten/src/ATen/Utils.h
index 9400741fdd..d87a46a5d6 100644
--- a/aten/src/ATen/Utils.h
+++ b/aten/src/ATen/Utils.h
@@ -6,14 +6,10 @@
#include <algorithm>
#include <sstream>
#include <typeinfo>
+#include "ATenAssert.h"
namespace at {
-#define AT_ASSERT(cond, ...) if (! (cond) ) { at::runtime_error(__VA_ARGS__); }
-
-[[noreturn]]
-AT_API void runtime_error(const char *format, ...);
-
template <typename T, typename Base>
static inline T* checked_cast_storage(Base* expr, const char * name, int pos) {
if (typeid(*expr) != typeid(T))