diff options
author | Zachary DeVito <zdevito@gmail.com> | 2018-02-02 17:45:59 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-02-02 17:45:59 -0800 |
commit | c308e03f3eadf8d214884eedf43428941171fa3d (patch) | |
tree | c05c7b0f8f5c75f428db233b5a8751096736c50b /aten | |
parent | 4ae05799fa08ff040be66ae8bba66c2a73eb4993 (diff) | |
download | pytorch-c308e03f3eadf8d214884eedf43428941171fa3d.tar.gz pytorch-c308e03f3eadf8d214884eedf43428941171fa3d.tar.bz2 pytorch-c308e03f3eadf8d214884eedf43428941171fa3d.zip |
Initial GraphExecutor Implementation. (#4982)
This adds the initial implementation of graph executor for the new JIT design. It includes a few python tests ensuring that nograd, backward, and double-backward cases work for simple examples and some corner cases. More work needs to be done to performance optimize as there are many extra copies and places where we hold onto variables longer than we should. These are noted in the comments.
Diffstat (limited to 'aten')
-rw-r--r-- | aten/src/ATen/ATenAssert.h | 10 | ||||
-rw-r--r-- | aten/src/ATen/ArrayRef.h | 10 | ||||
-rw-r--r-- | aten/src/ATen/Utils.h | 6 |
3 files changed, 16 insertions, 10 deletions
diff --git a/aten/src/ATen/ATenAssert.h b/aten/src/ATen/ATenAssert.h new file mode 100644 index 0000000000..3e840fe2f1 --- /dev/null +++ b/aten/src/ATen/ATenAssert.h @@ -0,0 +1,10 @@ +#include "ATenGeneral.h" + +namespace at { + +#define AT_ASSERT(cond, ...) if (! (cond) ) { at::runtime_error(__VA_ARGS__); } + +[[noreturn]] +AT_API void runtime_error(const char *format, ...); + +} diff --git a/aten/src/ATen/ArrayRef.h b/aten/src/ATen/ArrayRef.h index ff6ae09dec..fe425adfa5 100644 --- a/aten/src/ATen/ArrayRef.h +++ b/aten/src/ATen/ArrayRef.h @@ -17,6 +17,7 @@ #include <assert.h> #include <array> #include <vector> +#include "ATenAssert.h" namespace at { /// ArrayRef - Represent a constant reference to an array (0 or more elements @@ -104,13 +105,13 @@ namespace at { /// front - Get the first element. const T &front() const { - assert(!empty()); + AT_ASSERT(!empty(), "Empty list!"); return Data[0]; } /// back - Get the last element. const T &back() const { - assert(!empty()); + AT_ASSERT(!empty(), "Empty list!"); return Data[Length-1]; } @@ -124,7 +125,7 @@ namespace at { /// slice(n, m) - Chop off the first N elements of the array, and keep M /// elements in the array. ArrayRef<T> slice(size_t N, size_t M) const { - assert(N+M <= size() && "Invalid specifier"); + AT_ASSERT(N+M <= size(), "Invalid specifier"); return ArrayRef<T>(data()+N, M); } @@ -135,13 +136,12 @@ namespace at { /// @name Operator Overloads /// @{ const T &operator[](size_t Index) const { - assert(Index < Length && "Invalid index!"); return Data[Index]; } /// Vector compatibility const T &at(size_t Index) const { - assert(Index < Length && "Invalid index!"); + AT_ASSERT(Index < Length, "Invalid index!"); return Data[Index]; } diff --git a/aten/src/ATen/Utils.h b/aten/src/ATen/Utils.h index 9400741fdd..d87a46a5d6 100644 --- a/aten/src/ATen/Utils.h +++ b/aten/src/ATen/Utils.h @@ -6,14 +6,10 @@ #include <algorithm> #include <sstream> #include <typeinfo> +#include "ATenAssert.h" namespace at { -#define AT_ASSERT(cond, ...) if (! (cond) ) { at::runtime_error(__VA_ARGS__); } - -[[noreturn]] -AT_API void runtime_error(const char *format, ...); - template <typename T, typename Base> static inline T* checked_cast_storage(Base* expr, const char * name, int pos) { if (typeid(*expr) != typeid(T)) |