diff options
author | efaust <efaust@devvm2775.prn3.facebook.com> | 2019-04-23 23:26:04 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-04-23 23:36:48 -0700 |
commit | 8273b9b3cbc30ba392cbe366fd2a729916de6a4f (patch) | |
tree | d20bb77aa457b1649616f2926b8b147b7a5ba94f | |
parent | 309c15e2df3ed300e0c09bdbb4fbfe2ba98267ad (diff) | |
download | pytorch-8273b9b3cbc30ba392cbe366fd2a729916de6a4f.tar.gz pytorch-8273b9b3cbc30ba392cbe366fd2a729916de6a4f.tar.bz2 pytorch-8273b9b3cbc30ba392cbe366fd2a729916de6a4f.zip |
Enforce consistent dict iteration order for trace inputs. (#19528)
Summary:
Stack:
:black_circle: **#19528 [pytorch] Enforce consistent dict iteration order for trace inputs.** [:yellow_heart:](https://our.intern.facebook.com/intern/diff/D15023656/)
Don't iterate down unordered_maps and expect ordering. Should fix test flakiness.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19528
Differential Revision: D15023656
Pulled By: efaust
fbshipit-source-id: 91c9a31a8652fcf93ae0e942bea4cec67bb490c9
-rw-r--r-- | aten/src/ATen/core/ivalue.cpp | 20 | ||||
-rw-r--r-- | aten/src/ATen/core/ivalue.h | 3 | ||||
-rw-r--r-- | torch/csrc/jit/pickler.cpp | 21 | ||||
-rw-r--r-- | torch/csrc/jit/register_prim_ops.cpp | 19 | ||||
-rw-r--r-- | torch/csrc/jit/tracer.cpp | 14 |
5 files changed, 42 insertions, 35 deletions
diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp index 7e8ab5e832..2dbcd69818 100644 --- a/aten/src/ATen/core/ivalue.cpp +++ b/aten/src/ATen/core/ivalue.cpp @@ -118,4 +118,24 @@ void ivalue::Object::resizeObject(size_t slot) { slots_.resize(type()->numAttributes()); } +static bool CompareIValue(const std::pair<IValue, IValue>& aWrap, + const std::pair<IValue, IValue>& bWrap) { + const auto a = aWrap.first; + const auto b = bWrap.first; + if (a.isString() && b.isString()) { + return a.toStringRef().compare(b.toStringRef()) < 0; + } else if (a.isInt() && b.isInt()) { + return a.toInt() < b.toInt(); + } else if (a.isDouble() && b.isDouble()) { + return a.toDouble() < b.toDouble(); + } + AT_ERROR("Illegal dict key"); +} + +const ivalue::GenericDict::IterationOrder ivalue::GenericDict::iterationOrder() const { + IterationOrder ordered(elements().begin(), elements().end()); + std::sort(ordered.begin(), ordered.end(), CompareIValue); + return ordered; +} + } // namespace c10 diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h index e374b49ab1..8cfac1b2b4 100644 --- a/aten/src/ATen/core/ivalue.h +++ b/aten/src/ATen/core/ivalue.h @@ -747,6 +747,9 @@ struct C10_EXPORT ivalue::GenericDict : c10::intrusive_ptr_target { operator UnorderedMap&() { return elements(); } + + using IterationOrder = std::vector<std::pair<IValue, IValue>>; + const IterationOrder iterationOrder() const; }; #undef TORCH_FORALL_TAGS diff --git a/torch/csrc/jit/pickler.cpp b/torch/csrc/jit/pickler.cpp index f75ef2aec5..c8090bc22c 100644 --- a/torch/csrc/jit/pickler.cpp +++ b/torch/csrc/jit/pickler.cpp @@ -207,23 +207,6 @@ void Pickler::pushDouble(const IValue& ivalue) { } } -using ivalue_pair = std::pair<IValue, IValue>; - -struct IValuePairComparator { - bool operator()(const ivalue_pair& lhs, const ivalue_pair& rhs) const { - if (lhs.first.isString()) { - return lhs.first.toStringRef() < rhs.first.toStringRef(); - } - if (lhs.first.isInt()) { - return lhs.first.toInt() < rhs.first.toInt(); - } - if (lhs.first.isDouble()) { - return lhs.first.toDouble() < rhs.first.toDouble(); - } - AT_ERROR("Uncomparable IValue types"); - } -}; - void Pickler::pushDict(const IValue& ivalue) { auto dict = ivalue.toGenericDictRef(); @@ -233,9 +216,7 @@ void Pickler::pushDict(const IValue& ivalue) { push<OpCode>(OpCode::MARK); // Sort the dict for deterministic keys - std::vector<std::pair<IValue, IValue>> dict_items(dict.begin(), dict.end()); - std::sort(dict_items.begin(), dict_items.end(), IValuePairComparator()); - + auto dict_items = ivalue.toGenericDict()->iterationOrder(); for (const auto& pair : dict_items) { addIValue(pair.first); addIValue(pair.second); diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index e15e58b3d1..3aa5f3daf2 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -1572,10 +1572,11 @@ int dictKeys(Stack& stack) { } template <typename Elem> -std::vector<Elem> makeListForDictValues(const c10::ivalue::UnorderedMap& dict) { +std::vector<Elem> makeListForDictValues( + const c10::ivalue::GenericDict::IterationOrder &order) { std::vector<Elem> values; - values.reserve(dict.size()); - for (auto item : dict) { + values.reserve(order.size()); + for (auto item : order) { values.push_back(item.second.to<Elem>()); } return values; @@ -1584,17 +1585,17 @@ std::vector<Elem> makeListForDictValues(const c10::ivalue::UnorderedMap& dict) { Operation dictValues(const Node* n) { auto outputType = n->output()->type()->expect<ListType>(); return [=](Stack& stack) -> int { - auto dict = pop(stack).toGenericDictRef(); + const auto &order = pop(stack).toGenericDict()->iterationOrder(); if (outputType->getElementType()->isSubtypeOf(TensorType::get())) { - push(stack, makeListForDictValues<at::Tensor>(dict)); + push(stack, makeListForDictValues<at::Tensor>(order)); } else if (outputType->getElementType() == IntType::get()) { - push(stack, makeListForDictValues<int64_t>(dict)); + push(stack, makeListForDictValues<int64_t>(order)); } else if (outputType->getElementType() == FloatType::get()) { - push(stack, makeListForDictValues<double>(dict)); + push(stack, makeListForDictValues<double>(order)); } else if (outputType->getElementType() == BoolType::get()) { - push(stack, makeListForDictValues<bool>(dict)); + push(stack, makeListForDictValues<bool>(order)); } else { - push(stack, makeListForDictValues<IValue>(dict)); + push(stack, makeListForDictValues<IValue>(order)); } return 0; }; diff --git a/torch/csrc/jit/tracer.cpp b/torch/csrc/jit/tracer.cpp index 785dd89f97..716eac03dc 100644 --- a/torch/csrc/jit/tracer.cpp +++ b/torch/csrc/jit/tracer.cpp @@ -235,20 +235,22 @@ std::pair<std::shared_ptr<TracingState>, Stack> enter(TypedStack inputs) { } return Tuple::create(std::move(elems)); } else if (auto dict_type = type->cast<DictType>()) { - auto elem_pairs = input.toGenericDict()->elements(); + auto dict = input.toGenericDict(); + auto dict_size = dict->elements().size(); auto unpack_to_list = state->graph->insert(aten::values, {value}); - auto list_unpack = state->graph->createListUnpack(unpack_to_list, elem_pairs.size()); + auto list_unpack = state->graph->createListUnpack(unpack_to_list, dict_size); auto unpack_node = state->graph->insertNode(list_unpack); auto elem_values = unpack_node->outputs(); - AT_ASSERT(elem_pairs.size() == elem_values.size()); + const auto order = dict->iterationOrder(); + AT_ASSERT(order.size() == elem_values.size()); size_t i = 0; - for (const auto &pair : elem_pairs) { - elem_pairs[pair.first] = add_input(pair.second, dict_type->getValueType(), elem_values[i++]); + for (const auto &pair : order) { + dict->elements()[pair.first] = add_input(pair.second, dict_type->getValueType(), elem_values[i++]); } - return c10::ivalue::GenericDict::create(std::move(elem_pairs)); + return c10::ivalue::GenericDict::create(std::move(dict->elements())); } else { AT_ERROR( "Only tensors or (possibly nested) dict or tuples of tensors can be " |