diff options
author | Teng Li <tengli@fb.com> | 2018-11-14 01:26:44 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-11-14 01:34:22 -0800 |
commit | 97036d3c3043b99b9015bf44f9a58eb14f106a80 (patch) | |
tree | 799f8f171c979ba936e7453b450090f7f128792f /torch/lib | |
parent | e2a7d43dfd04ae098be155fb4ea7fa49c45bf80c (diff) | |
download | pytorch-97036d3c3043b99b9015bf44f9a58eb14f106a80.tar.gz pytorch-97036d3c3043b99b9015bf44f9a58eb14f106a80.tar.bz2 pytorch-97036d3c3043b99b9015bf44f9a58eb14f106a80.zip |
FileStore auto deletes file and FileStore::add bug fix (#13708)
Summary:
This addressed: https://github.com/pytorch/pytorch/issues/11874
and we will have the identical file init_method behavior as the previous THD file init.
Also the FileStore::add bug is pretty annoying.
Two bugs:
(1) Add doesn't append to the end of the file.
(2) Cache doesn't get updated.
Both are fixed and tests are covered.
I examined the /tmp to ensure that all temp files are auto deleted after test_c10d.py
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13708
Reviewed By: pietern
Differential Revision: D12972810
Pulled By: teng-li
fbshipit-source-id: 917255390aa52845f6b0ad0f283875a7a704da48
Diffstat (limited to 'torch/lib')
-rw-r--r-- | torch/lib/c10d/FileStore.cpp | 56 | ||||
-rw-r--r-- | torch/lib/c10d/FileStore.hpp | 8 | ||||
-rw-r--r-- | torch/lib/c10d/example/allreduce.cpp | 2 | ||||
-rw-r--r-- | torch/lib/c10d/test/FileStoreTest.cpp | 8 | ||||
-rw-r--r-- | torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp | 2 | ||||
-rw-r--r-- | torch/lib/c10d/test/ProcessGroupGlooTest.cpp | 4 | ||||
-rw-r--r-- | torch/lib/c10d/test/ProcessGroupNCCLTest.cpp | 2 |
7 files changed, 61 insertions, 21 deletions
diff --git a/torch/lib/c10d/FileStore.cpp b/torch/lib/c10d/FileStore.cpp index fb842125e5..ccd7744ce0 100644 --- a/torch/lib/c10d/FileStore.cpp +++ b/torch/lib/c10d/FileStore.cpp @@ -8,6 +8,7 @@ #include <unistd.h> #include <chrono> +#include <cstdio> #include <functional> #include <iostream> #include <limits> @@ -181,30 +182,53 @@ off_t refresh( pos = file.tell(); } } + file.seek(0, SEEK_SET); return pos; } } // namespace -FileStore::FileStore(const std::string& path) : Store(), path_(path), pos_(0) {} +FileStore::FileStore(const std::string& path, int numWorkers) + : Store(), + path_(path), + pos_(0), + numWorkers_(numWorkers), + cleanupKey_("cleanup/"), + regularPrefix_("/") { + if (numWorkers_ < 1) { + throw std::runtime_error( + "Number of workers for FileStore should be greater than zero"); + } +} -FileStore::~FileStore() {} +FileStore::~FileStore() { + // cleanup key will be different from all rest keys since all rest keys will + // have a regular prefix. + auto numFinishedWorker = addHelper(cleanupKey_, 1); + // The last worker cleans up the file + if (numFinishedWorker == numWorkers_) { + // Best effort removal without checking the return + std::remove(path_.c_str()); + } +} void FileStore::set(const std::string& key, const std::vector<uint8_t>& value) { + std::string regKey = regularPrefix_ + key; File file(path_, O_RDWR | O_CREAT); auto lock = file.lockExclusive(); file.seek(0, SEEK_END); - file.write(key); + file.write(regKey); file.write(value); } std::vector<uint8_t> FileStore::get(const std::string& key) { + std::string regKey = regularPrefix_ + key; const auto start = std::chrono::steady_clock::now(); - while (cache_.count(key) == 0) { + while (true) { File file(path_, O_RDONLY); auto lock = file.lockShared(); auto size = file.size(); - if (size == pos_) { + if (cache_.count(regKey) == 0 && size == pos_) { // No new entries; release the shared lock and sleep for a bit lock.unlock(); const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>( @@ -215,14 +239,18 @@ std::vector<uint8_t> FileStore::get(const std::string& key) { std::this_thread::sleep_for(std::chrono::milliseconds(10)); continue; } - + // Always refresh since even though the key exists in the cache, + // it might be outdated pos_ = refresh(file, pos_, cache_); + if (cache_.count(regKey) != 0) { + break; + } } - return cache_[key]; + return cache_[regKey]; } -int64_t FileStore::add(const std::string& key, int64_t i) { +int64_t FileStore::addHelper(const std::string& key, int64_t i) { File file(path_, O_RDWR | O_CREAT); auto lock = file.lockExclusive(); pos_ = refresh(file, pos_, cache_); @@ -234,22 +262,28 @@ int64_t FileStore::add(const std::string& key, int64_t i) { auto len = value.size(); ti += std::stoll(std::string(buf, len)); } - + // Always seek to the end to write + file.seek(0, SEEK_END); // File cursor is at the end of the file now, and we have an // exclusive lock, so we can write the new value. file.write(key); file.write(std::to_string(ti)); - return ti; } +int64_t FileStore::add(const std::string& key, int64_t i) { + std::string regKey = regularPrefix_ + key; + return addHelper(regKey, i); +} + bool FileStore::check(const std::vector<std::string>& keys) { File file(path_, O_RDONLY); auto lock = file.lockShared(); pos_ = refresh(file, pos_, cache_); for (const auto& key : keys) { - if (cache_.count(key) == 0) { + std::string regKey = regularPrefix_ + key; + if (cache_.count(regKey) == 0) { return false; } } diff --git a/torch/lib/c10d/FileStore.hpp b/torch/lib/c10d/FileStore.hpp index 9bbc25ef10..17b1f05e36 100644 --- a/torch/lib/c10d/FileStore.hpp +++ b/torch/lib/c10d/FileStore.hpp @@ -10,7 +10,7 @@ namespace c10d { class FileStore : public Store { public: - explicit FileStore(const std::string& path); + explicit FileStore(const std::string& path, int numWorkers); virtual ~FileStore(); @@ -29,9 +29,15 @@ class FileStore : public Store { const std::chrono::milliseconds& timeout) override; protected: + int64_t addHelper(const std::string& key, int64_t i); + std::string path_; off_t pos_; + int numWorkers_; + const std::string cleanupKey_; + const std::string regularPrefix_; + std::unordered_map<std::string, std::vector<uint8_t>> cache_; }; diff --git a/torch/lib/c10d/example/allreduce.cpp b/torch/lib/c10d/example/allreduce.cpp index 64c7258c1c..76d6a5588f 100644 --- a/torch/lib/c10d/example/allreduce.cpp +++ b/torch/lib/c10d/example/allreduce.cpp @@ -6,7 +6,7 @@ using namespace ::c10d; int main(int argc, char** argv) { int rank = atoi(getenv("RANK")); int size = atoi(getenv("SIZE")); - auto store = std::make_shared<FileStore>("/tmp/c10d_example"); + auto store = std::make_shared<FileStore>("/tmp/c10d_example", size); ProcessGroupGloo pg(store, rank, size); // Create some tensors diff --git a/torch/lib/c10d/test/FileStoreTest.cpp b/torch/lib/c10d/test/FileStoreTest.cpp index c34ab7a094..e2cc3926cb 100644 --- a/torch/lib/c10d/test/FileStoreTest.cpp +++ b/torch/lib/c10d/test/FileStoreTest.cpp @@ -34,7 +34,7 @@ void testHelper(const std::string prefix = "") { // Basic set/get { - c10d::FileStore fileStore(path); + c10d::FileStore fileStore(path, 2); c10d::PrefixStore store(prefix, fileStore); c10d::test::set(store, "key0", "value0"); c10d::test::set(store, "key1", "value1"); @@ -46,7 +46,7 @@ void testHelper(const std::string prefix = "") { // Perform get on new instance { - c10d::FileStore fileStore(path); + c10d::FileStore fileStore(path, 2); c10d::PrefixStore store(prefix, fileStore); c10d::test::check(store, "key0", "value0"); } @@ -58,7 +58,7 @@ void testHelper(const std::string prefix = "") { c10d::test::Semaphore sem1, sem2; for (auto i = 0; i < numThreads; i++) { threads.push_back(std::thread([&] { - c10d::FileStore fileStore(path); + c10d::FileStore fileStore(path, numThreads + 1); c10d::PrefixStore store(prefix, fileStore); sem1.post(); sem2.wait(); @@ -75,7 +75,7 @@ void testHelper(const std::string prefix = "") { // Check that the counter has the expected value { - c10d::FileStore fileStore(path); + c10d::FileStore fileStore(path, numThreads + 1); c10d::PrefixStore store(prefix, fileStore); std::string expected = std::to_string(numThreads * numIterations); c10d::test::check(store, "counter", expected); diff --git a/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp b/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp index b1284c0f3a..fcdd9aa7e8 100644 --- a/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp @@ -45,7 +45,7 @@ class AsyncTest { } void start(int rank, int size) { - auto store = std::make_shared<::c10d::FileStore>(path_); + auto store = std::make_shared<::c10d::FileStore>(path_, size); // Use tiny timeout to make this test run fast ::c10d::ProcessGroupGloo::Options options; diff --git a/torch/lib/c10d/test/ProcessGroupGlooTest.cpp b/torch/lib/c10d/test/ProcessGroupGlooTest.cpp index 360f87ab1d..fe2b292d8b 100644 --- a/torch/lib/c10d/test/ProcessGroupGlooTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupGlooTest.cpp @@ -37,7 +37,7 @@ class SignalTest { } std::shared_ptr<::c10d::ProcessGroup::Work> run(int rank, int size) { - auto store = std::make_shared<::c10d::FileStore>(path_); + auto store = std::make_shared<::c10d::FileStore>(path_, size); // Use tiny timeout to make this test run fast ::c10d::ProcessGroupGloo::Options options; @@ -120,7 +120,7 @@ class CollectiveTest { } void start(int rank, int size) { - auto store = std::make_shared<::c10d::FileStore>(path_); + auto store = std::make_shared<::c10d::FileStore>(path_, size); // Use tiny timeout to make this test run fast ::c10d::ProcessGroupGloo::Options options; diff --git a/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp b/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp index f9d8b8a98c..158eef19b5 100644 --- a/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp @@ -27,7 +27,7 @@ class NCCLTestBase { } void initialize(int rank, int size) { - auto store = std::make_shared<::c10d::FileStore>(path_); + auto store = std::make_shared<::c10d::FileStore>(path_, size); pg_ = std::unique_ptr<::c10d::ProcessGroupNCCL>( new ::c10d::ProcessGroupNCCL(store, rank, size)); |