summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorefaust <efaust@devvm2775.prn3.facebook.com>2019-04-23 23:26:04 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-04-23 23:36:48 -0700
commit8273b9b3cbc30ba392cbe366fd2a729916de6a4f (patch)
treed20bb77aa457b1649616f2926b8b147b7a5ba94f
parent309c15e2df3ed300e0c09bdbb4fbfe2ba98267ad (diff)
downloadpytorch-8273b9b3cbc30ba392cbe366fd2a729916de6a4f.tar.gz
pytorch-8273b9b3cbc30ba392cbe366fd2a729916de6a4f.tar.bz2
pytorch-8273b9b3cbc30ba392cbe366fd2a729916de6a4f.zip
Enforce consistent dict iteration order for trace inputs. (#19528)
Summary: Stack: &nbsp;&nbsp;&nbsp;&nbsp;:black_circle:&nbsp; **#19528 [pytorch] Enforce consistent dict iteration order for trace inputs.**&nbsp;&nbsp;[:yellow_heart:](https://our.intern.facebook.com/intern/diff/D15023656/) Don't iterate down unordered_maps and expect ordering. Should fix test flakiness. Pull Request resolved: https://github.com/pytorch/pytorch/pull/19528 Differential Revision: D15023656 Pulled By: efaust fbshipit-source-id: 91c9a31a8652fcf93ae0e942bea4cec67bb490c9
-rw-r--r--aten/src/ATen/core/ivalue.cpp20
-rw-r--r--aten/src/ATen/core/ivalue.h3
-rw-r--r--torch/csrc/jit/pickler.cpp21
-rw-r--r--torch/csrc/jit/register_prim_ops.cpp19
-rw-r--r--torch/csrc/jit/tracer.cpp14
5 files changed, 42 insertions, 35 deletions
diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp
index 7e8ab5e832..2dbcd69818 100644
--- a/aten/src/ATen/core/ivalue.cpp
+++ b/aten/src/ATen/core/ivalue.cpp
@@ -118,4 +118,24 @@ void ivalue::Object::resizeObject(size_t slot) {
slots_.resize(type()->numAttributes());
}
+static bool CompareIValue(const std::pair<IValue, IValue>& aWrap,
+ const std::pair<IValue, IValue>& bWrap) {
+ const auto a = aWrap.first;
+ const auto b = bWrap.first;
+ if (a.isString() && b.isString()) {
+ return a.toStringRef().compare(b.toStringRef()) < 0;
+ } else if (a.isInt() && b.isInt()) {
+ return a.toInt() < b.toInt();
+ } else if (a.isDouble() && b.isDouble()) {
+ return a.toDouble() < b.toDouble();
+ }
+ AT_ERROR("Illegal dict key");
+}
+
+const ivalue::GenericDict::IterationOrder ivalue::GenericDict::iterationOrder() const {
+ IterationOrder ordered(elements().begin(), elements().end());
+ std::sort(ordered.begin(), ordered.end(), CompareIValue);
+ return ordered;
+}
+
} // namespace c10
diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h
index e374b49ab1..8cfac1b2b4 100644
--- a/aten/src/ATen/core/ivalue.h
+++ b/aten/src/ATen/core/ivalue.h
@@ -747,6 +747,9 @@ struct C10_EXPORT ivalue::GenericDict : c10::intrusive_ptr_target {
operator UnorderedMap&() {
return elements();
}
+
+ using IterationOrder = std::vector<std::pair<IValue, IValue>>;
+ const IterationOrder iterationOrder() const;
};
#undef TORCH_FORALL_TAGS
diff --git a/torch/csrc/jit/pickler.cpp b/torch/csrc/jit/pickler.cpp
index f75ef2aec5..c8090bc22c 100644
--- a/torch/csrc/jit/pickler.cpp
+++ b/torch/csrc/jit/pickler.cpp
@@ -207,23 +207,6 @@ void Pickler::pushDouble(const IValue& ivalue) {
}
}
-using ivalue_pair = std::pair<IValue, IValue>;
-
-struct IValuePairComparator {
- bool operator()(const ivalue_pair& lhs, const ivalue_pair& rhs) const {
- if (lhs.first.isString()) {
- return lhs.first.toStringRef() < rhs.first.toStringRef();
- }
- if (lhs.first.isInt()) {
- return lhs.first.toInt() < rhs.first.toInt();
- }
- if (lhs.first.isDouble()) {
- return lhs.first.toDouble() < rhs.first.toDouble();
- }
- AT_ERROR("Uncomparable IValue types");
- }
-};
-
void Pickler::pushDict(const IValue& ivalue) {
auto dict = ivalue.toGenericDictRef();
@@ -233,9 +216,7 @@ void Pickler::pushDict(const IValue& ivalue) {
push<OpCode>(OpCode::MARK);
// Sort the dict for deterministic keys
- std::vector<std::pair<IValue, IValue>> dict_items(dict.begin(), dict.end());
- std::sort(dict_items.begin(), dict_items.end(), IValuePairComparator());
-
+ auto dict_items = ivalue.toGenericDict()->iterationOrder();
for (const auto& pair : dict_items) {
addIValue(pair.first);
addIValue(pair.second);
diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp
index e15e58b3d1..3aa5f3daf2 100644
--- a/torch/csrc/jit/register_prim_ops.cpp
+++ b/torch/csrc/jit/register_prim_ops.cpp
@@ -1572,10 +1572,11 @@ int dictKeys(Stack& stack) {
}
template <typename Elem>
-std::vector<Elem> makeListForDictValues(const c10::ivalue::UnorderedMap& dict) {
+std::vector<Elem> makeListForDictValues(
+ const c10::ivalue::GenericDict::IterationOrder &order) {
std::vector<Elem> values;
- values.reserve(dict.size());
- for (auto item : dict) {
+ values.reserve(order.size());
+ for (auto item : order) {
values.push_back(item.second.to<Elem>());
}
return values;
@@ -1584,17 +1585,17 @@ std::vector<Elem> makeListForDictValues(const c10::ivalue::UnorderedMap& dict) {
Operation dictValues(const Node* n) {
auto outputType = n->output()->type()->expect<ListType>();
return [=](Stack& stack) -> int {
- auto dict = pop(stack).toGenericDictRef();
+ const auto &order = pop(stack).toGenericDict()->iterationOrder();
if (outputType->getElementType()->isSubtypeOf(TensorType::get())) {
- push(stack, makeListForDictValues<at::Tensor>(dict));
+ push(stack, makeListForDictValues<at::Tensor>(order));
} else if (outputType->getElementType() == IntType::get()) {
- push(stack, makeListForDictValues<int64_t>(dict));
+ push(stack, makeListForDictValues<int64_t>(order));
} else if (outputType->getElementType() == FloatType::get()) {
- push(stack, makeListForDictValues<double>(dict));
+ push(stack, makeListForDictValues<double>(order));
} else if (outputType->getElementType() == BoolType::get()) {
- push(stack, makeListForDictValues<bool>(dict));
+ push(stack, makeListForDictValues<bool>(order));
} else {
- push(stack, makeListForDictValues<IValue>(dict));
+ push(stack, makeListForDictValues<IValue>(order));
}
return 0;
};
diff --git a/torch/csrc/jit/tracer.cpp b/torch/csrc/jit/tracer.cpp
index 785dd89f97..716eac03dc 100644
--- a/torch/csrc/jit/tracer.cpp
+++ b/torch/csrc/jit/tracer.cpp
@@ -235,20 +235,22 @@ std::pair<std::shared_ptr<TracingState>, Stack> enter(TypedStack inputs) {
}
return Tuple::create(std::move(elems));
} else if (auto dict_type = type->cast<DictType>()) {
- auto elem_pairs = input.toGenericDict()->elements();
+ auto dict = input.toGenericDict();
+ auto dict_size = dict->elements().size();
auto unpack_to_list = state->graph->insert(aten::values, {value});
- auto list_unpack = state->graph->createListUnpack(unpack_to_list, elem_pairs.size());
+ auto list_unpack = state->graph->createListUnpack(unpack_to_list, dict_size);
auto unpack_node = state->graph->insertNode(list_unpack);
auto elem_values = unpack_node->outputs();
- AT_ASSERT(elem_pairs.size() == elem_values.size());
+ const auto order = dict->iterationOrder();
+ AT_ASSERT(order.size() == elem_values.size());
size_t i = 0;
- for (const auto &pair : elem_pairs) {
- elem_pairs[pair.first] = add_input(pair.second, dict_type->getValueType(), elem_values[i++]);
+ for (const auto &pair : order) {
+ dict->elements()[pair.first] = add_input(pair.second, dict_type->getValueType(), elem_values[i++]);
}
- return c10::ivalue::GenericDict::create(std::move(elem_pairs));
+ return c10::ivalue::GenericDict::create(std::move(dict->elements()));
} else {
AT_ERROR(
"Only tensors or (possibly nested) dict or tuples of tensors can be "