diff options
author | Junjie Bai <bai@in.tum.de> | 2018-12-05 18:35:21 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-12-05 18:37:32 -0800 |
commit | ba0ebe33c170a8e8aa9df8e8b8549b980b89551e (patch) | |
tree | dc1fd754947b0ae9905301cd3580351e8893e2a5 /c10 | |
parent | 252e9058d45445dc44d3f818eb8bd6fa95ac826a (diff) | |
download | pytorch-ba0ebe33c170a8e8aa9df8e8b8549b980b89551e.tar.gz pytorch-ba0ebe33c170a8e8aa9df8e8b8549b980b89551e.tar.bz2 pytorch-ba0ebe33c170a8e8aa9df8e8b8549b980b89551e.zip |
Unify device argument parsing between torch and c10
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/14786
Differential Revision: D13334501
Pulled By: bddppq
fbshipit-source-id: ae3536be1fe0dcd6a1552ec93629ecc9554c0d7c
Diffstat (limited to 'c10')
-rw-r--r-- | c10/Device.cpp | 30 | ||||
-rw-r--r-- | c10/Device.h | 10 |
2 files changed, 20 insertions, 20 deletions
diff --git a/c10/Device.cpp b/c10/Device.cpp index 44f14ea40a..c0bb507596 100644 --- a/c10/Device.cpp +++ b/c10/Device.cpp @@ -36,6 +36,13 @@ DeviceType parse_type(const std::string& device_string) { } } // namespace +void Device::validate() { + AT_CHECK(index_ == -1 || index_ >= 0, + "Device index must be -1 or non-negative, got ", index_); + AT_CHECK(!is_cpu() || index_ <= 0, + "CPU device index must be -1 or zero, got ", index_); +} + // `std::regex` is still in a very incomplete state in GCC 4.8.x, // so we have to do our own parsing, like peasants. // https://stackoverflow.com/questions/12530406/is-gcc-4-8-or-earlier-buggy-about-regular-expressions @@ -64,24 +71,23 @@ Device::Device(const std::string& device_string) : Device(Type::CPU) { int index = device_string.find(":"); if (index == std::string::npos) { type_ = parse_type(device_string); - return; } else { std::string s; s = device_string.substr(0, index); AT_CHECK(!s.empty(), "Device string must not be empty"); type_ = parse_type(s); + + std::string device_index = device_string.substr(index + 1); + try { + index_ = c10::stoi(device_index); + } catch (const std::exception &) { + AT_ERROR("Could not parse device index '", device_index, + "' in device string '", device_string, "'"); + } + AT_CHECK(index_ >= 0, + "Device index must be non-negative, got ", index_); } - std::string device_index = device_string.substr(index + 1); - try { - index_ = c10::stoi(device_index); - } catch (const std::exception&) { - AT_ERROR( - "Could not parse device index '", - device_index, - "' in device string '", - device_string, - "'"); - } + validate(); } std::ostream& operator<<(std::ostream& stream, const Device& device) { diff --git a/c10/Device.h b/c10/Device.h index 3c9fafaa8c..81c5cee8f4 100644 --- a/c10/Device.h +++ b/c10/Device.h @@ -34,14 +34,7 @@ struct C10_API Device final { /// index. /* implicit */ Device(DeviceType type, DeviceIndex index = -1) : type_(type), index_(index) { - AT_CHECK( - index == -1 || index >= 0, - "Device index must be -1 or non-negative, got ", - index); - AT_CHECK( - !is_cpu() || index <= 0, - "CPU device index must be -1 or zero, got ", - index); + validate(); } /// Constructs a `Device` from a string description, for convenience. @@ -96,6 +89,7 @@ struct C10_API Device final { private: DeviceType type_; DeviceIndex index_ = -1; + void validate(); }; C10_API std::ostream& operator<<( |