summaryrefslogtreecommitdiff
path: root/c10
diff options
context:
space:
mode:
authorJunjie Bai <bai@in.tum.de>2018-12-05 18:35:21 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-12-05 18:37:32 -0800
commitba0ebe33c170a8e8aa9df8e8b8549b980b89551e (patch)
treedc1fd754947b0ae9905301cd3580351e8893e2a5 /c10
parent252e9058d45445dc44d3f818eb8bd6fa95ac826a (diff)
downloadpytorch-ba0ebe33c170a8e8aa9df8e8b8549b980b89551e.tar.gz
pytorch-ba0ebe33c170a8e8aa9df8e8b8549b980b89551e.tar.bz2
pytorch-ba0ebe33c170a8e8aa9df8e8b8549b980b89551e.zip
Unify device argument parsing between torch and c10
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/14786 Differential Revision: D13334501 Pulled By: bddppq fbshipit-source-id: ae3536be1fe0dcd6a1552ec93629ecc9554c0d7c
Diffstat (limited to 'c10')
-rw-r--r--c10/Device.cpp30
-rw-r--r--c10/Device.h10
2 files changed, 20 insertions, 20 deletions
diff --git a/c10/Device.cpp b/c10/Device.cpp
index 44f14ea40a..c0bb507596 100644
--- a/c10/Device.cpp
+++ b/c10/Device.cpp
@@ -36,6 +36,13 @@ DeviceType parse_type(const std::string& device_string) {
}
} // namespace
+void Device::validate() {
+ AT_CHECK(index_ == -1 || index_ >= 0,
+ "Device index must be -1 or non-negative, got ", index_);
+ AT_CHECK(!is_cpu() || index_ <= 0,
+ "CPU device index must be -1 or zero, got ", index_);
+}
+
// `std::regex` is still in a very incomplete state in GCC 4.8.x,
// so we have to do our own parsing, like peasants.
// https://stackoverflow.com/questions/12530406/is-gcc-4-8-or-earlier-buggy-about-regular-expressions
@@ -64,24 +71,23 @@ Device::Device(const std::string& device_string) : Device(Type::CPU) {
int index = device_string.find(":");
if (index == std::string::npos) {
type_ = parse_type(device_string);
- return;
} else {
std::string s;
s = device_string.substr(0, index);
AT_CHECK(!s.empty(), "Device string must not be empty");
type_ = parse_type(s);
+
+ std::string device_index = device_string.substr(index + 1);
+ try {
+ index_ = c10::stoi(device_index);
+ } catch (const std::exception &) {
+ AT_ERROR("Could not parse device index '", device_index,
+ "' in device string '", device_string, "'");
+ }
+ AT_CHECK(index_ >= 0,
+ "Device index must be non-negative, got ", index_);
}
- std::string device_index = device_string.substr(index + 1);
- try {
- index_ = c10::stoi(device_index);
- } catch (const std::exception&) {
- AT_ERROR(
- "Could not parse device index '",
- device_index,
- "' in device string '",
- device_string,
- "'");
- }
+ validate();
}
std::ostream& operator<<(std::ostream& stream, const Device& device) {
diff --git a/c10/Device.h b/c10/Device.h
index 3c9fafaa8c..81c5cee8f4 100644
--- a/c10/Device.h
+++ b/c10/Device.h
@@ -34,14 +34,7 @@ struct C10_API Device final {
/// index.
/* implicit */ Device(DeviceType type, DeviceIndex index = -1)
: type_(type), index_(index) {
- AT_CHECK(
- index == -1 || index >= 0,
- "Device index must be -1 or non-negative, got ",
- index);
- AT_CHECK(
- !is_cpu() || index <= 0,
- "CPU device index must be -1 or zero, got ",
- index);
+ validate();
}
/// Constructs a `Device` from a string description, for convenience.
@@ -96,6 +89,7 @@ struct C10_API Device final {
private:
DeviceType type_;
DeviceIndex index_ = -1;
+ void validate();
};
C10_API std::ostream& operator<<(