diff options
29 files changed, 1817 insertions, 314 deletions
diff --git a/src/System.Private.CoreLib/System.Private.CoreLib.csproj b/src/System.Private.CoreLib/System.Private.CoreLib.csproj index 5930aa0334..5e98274494 100644 --- a/src/System.Private.CoreLib/System.Private.CoreLib.csproj +++ b/src/System.Private.CoreLib/System.Private.CoreLib.csproj @@ -278,6 +278,7 @@ <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\ComEventsInfo.cs" Condition="'$(FeatureClassicCominterop)' == 'true'" /> <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\ComEventsMethod.cs" Condition="'$(FeatureClassicCominterop)' == 'true'" /> <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\ComEventsSink.cs" Condition="'$(FeatureClassicCominterop)' == 'true'" /> + <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\Variant.cs" Condition="'$(FeatureClassicCominterop)' == 'true'" /> <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\CustomMarshalers\ComDataHelpers.cs" /> <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\CustomMarshalers\EnumVariantViewOfEnumerator.cs" /> <Compile Include="$(BclSourcesRoot)\System\Runtime\InteropServices\CustomMarshalers\EnumerableToDispatchMarshaler.cs" /> @@ -344,6 +345,7 @@ </ItemGroup> <ItemGroup Condition="'$(TargetsWindows)' == 'true'"> <Compile Include="$(BclSourcesRoot)\System\DateTime.Windows.cs" /> + <Compile Include="$(BclSourcesRoot)\Interop\Windows\OleAut32\Interop.VariantClear.cs" /> <Compile Include="$(BclSourcesRoot)\System\ApplicationModel.Windows.cs" /> <Compile Include="$(BclSourcesRoot)\System\Globalization\GlobalizationMode.Windows.cs" /> <Compile Include="$(BclSourcesRoot)\System\Threading\ClrThreadPoolBoundHandle.Windows.cs" /> diff --git a/src/System.Private.CoreLib/src/Interop/Windows/OleAut32/Interop.VariantClear.cs b/src/System.Private.CoreLib/src/Interop/Windows/OleAut32/Interop.VariantClear.cs new file mode 100644 index 0000000000..ae95fe550c --- /dev/null +++ b/src/System.Private.CoreLib/src/Interop/Windows/OleAut32/Interop.VariantClear.cs @@ -0,0 +1,15 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Runtime.InteropServices; + +internal partial class Interop +{ + internal partial class OleAut32 + { + [DllImport(Libraries.OleAut32)] + internal static extern void VariantClear(IntPtr variant); + } +} diff --git a/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComEventsHelper.cs b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComEventsHelper.cs index 50e9ea6fee..a5d431dae2 100644 --- a/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComEventsHelper.cs +++ b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComEventsHelper.cs @@ -2,101 +2,85 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. - -/*============================================================ -** -** -** Purpose: ComEventHelpers APIs allow binding -** managed delegates to COM's connection point based events. -** -**/ // -// #ComEventsFeature -// -// code:#ComEventsFeature defines two public methods allowing to add/remove .NET delegates handling -// events from COM objects. Those methods are defined as part of code:ComEventsHelper static class -// * code:ComEventsHelper.Combine - will create/reuse-an-existing COM event sink and register the +// ComEventsFeature +// +// ComEventsFeature defines two public methods allowing to add/remove .NET delegates handling +// events from COM objects. Those methods are defined as part of ComEventsHelper static class +// * ComEventsHelper.Combine - will create/reuse-an-existing COM event sink and register the // specified delegate to be raised when corresponding COM event is raised -// * code:ComEventsHelper.Remove -// -// -// To bind an event handler to the COM object you need to provide the following data: -// * rcw - the instance of the COM object you want to bind to -// * iid - Guid of the source interface you want the sink to implement -// * dispid - dispatch identifier of the event on the source interface you are interested in -// * d - delegate to invoked when corresponding COM event is raised. -// -// #ComEventsArchitecture: -// In COM world, events are handled by so-called event sinks. What these are? COM-based Object Models -// (OMs) define "source" interfaces that need to be implemented by the COM clients to receive events. So, -// event sinks are COM objects implementing a source interfaces. Once an event sink is passed to the COM +// * ComEventsHelper.Remove +// +// ComEventsArchitecture: +// In COM world, events are handled by so-called event sinks. These are COM-based Object Models +// (OMs) that define "source" interfaces that need to be implemented by COM clients to receive events. So, +// event sinks are COM objects implementing source interfaces. Once an event sink is passed to the COM // server (through a mechanism known as 'binding/advising to connection point'), COM server will be -// calling source interface methods to "fire events" (advising, connection points, firing events etc. - -// is all COM jargon). -// +// calling source interface methods to "fire events". +// See https://docs.microsoft.com/cpp/mfc/connection-points +// // There are few interesting obervations about source interfaces. Usually source interfaces are defined // as 'dispinterface' - meaning that only late-bound invocations on this interface are allowed. Even // though it is not illegal to use early bound invocations on source interfaces - the practice is // discouraged because of versioning concerns. -// +// // Notice also that each COM server object might define multiple source interfaces and hence have // multiple connection points (each CP handles exactly one source interface). COM objects that want to -// fire events are required to implement IConnectionPointContainer interface which is used by the COM -// clients to discovery connection poitns - objects implementing IConnectionPoint interface. Once +// fire events are required to implement the IConnectionPointContainer interface which is used by COM +// clients to discovery connection points - objects implementing IConnectionPoint interface. Once a // connection point is found - clients can bind to it using IConnectionPoint::Advise (see -// code:ComEventsSink.Advise). -// -// The idea behind code:#ComEventsFeature is to write a "universal event sink" COM component that is +// ComEventsSink.Advise). +// +// The idea behind ComEventsFeature is to write a "universal event sink" COM component that is // generic enough to handle all late-bound event firings and invoke corresponding COM delegates (through // reflection). -// -// When delegate is registered (using code:ComEventsHelper.Combine) we will verify we have corresponding +// +// When delegate is registered (using ComEventsHelper.Combine) we will verify we have corresponding // event sink created and bound. -// -// But what happens when COM events are fired? code:ComEventsSink.Invoke implements IDispatch::Invoke method -// and this is the entry point that is called. Once our event sink is invoked, we need to find the -// corresponding delegate to invoke . We need to match the dispid of the call that is coming in to a -// dispid of .NET delegate that has been registered for this object. Once this is found we do call the -// delegates using reflection (code:ComEventsMethod.Invoke). -// -// #ComEventsArgsMarshalling +// +// When COM events are fired, ComEventsSink.Invoke implements IDispatch and the Invoke method +// is the entry point that is called. Once our event sink is invoked, we need to find the +// corresponding delegate to invoke. We need to match the dispid of the call that is coming in to a +// dispid of .NET delegate that has been registered for this object. Once this is found we call the +// delegates using reflection (see ComEventsMethod.Invoke). +// +// ComEventsArgsMarshalling // Notice, that we may not have a delegate registered against every method on the source interface. If we // were to marshal all the input parameters for methods that do not reach user code - we would end up // generatic RCWs that are not reachable for user code (the inconvenience it might create is there will // be RCWs that users can not call Marshal.ReleaseComObject on to explicitly manage the lifetime of these -// COM objects). The above behavior was one of the shortcoimings of legacy TLBIMP's implementation of COM +// COM objects). The above behavior was one of the shortcomings of legacy TLBIMP's implementation of COM // event sinking. In our code we will not marshal any data if there is no delegate registered to handle -// the event. (code:ComEventsMethod.Invoke) -// -// #ComEventsFinalization: +// the event. (see ComEventsMethod.Invoke) +// +// ComEventsFinalization: // Additional area of interest is when COM sink should be unadvised from the connection point. Legacy // TLBIMP's implementation of COM event sinks will unadvises the sink when corresponding RCW is GCed. // This is achieved by rooting the event sinks in a finalizable object stored in RCW's property bag // (using Marshal.SetComObjectData). Hence, once RCW is no longer reachable - the finalizer is called and // it would unadvise all the event sinks. We are employing the same strategy here. See storing an -// instance in the RCW at code:ComEventsInfo.FromObject and undadvsing the sinks at -// code:ComEventsInfo.~ComEventsInfo -// +// instance in the RCW at ComEventsInfo.FromObject and unadvising the sinks in ComEventsInfo.~ComEventsInfo +// // Classes of interest: -// * code:ComEventsHelpers - defines public methods but there are also a number of internal classes that -// implement the actual COM event sink: -// * code:ComEventsInfo - represents a finalizable container for all event sinks for a particular RCW. +// * ComEventsHelpers - defines public methods but there are also a number of internal classes that +// implement the actual COM event sink +// * ComEventsInfo - represents a finalizable container for all event sinks for a particular RCW. // Lifetime of this instance corresponds to the lifetime of the RCW object -// * code:ComEventsSink - represents a single event sink. Maintains an internal pointer to the next -// instance (in a singly linked list). A collection of code:ComEventsSink is stored at -// code:ComEventsInfo._sinks -// * code:ComEventsMethod - represents a single method from the source interface which has .NET delegates +// * ComEventsSink - represents a single event sink. Maintains an internal pointer to the next +// instance (in a singly linked list). A collection of ComEventsSink is stored at +// ComEventsInfo._sinks +// * ComEventsMethod - represents a single method from the source interface which has .NET delegates // attached to it. Maintains an internal pointer to the next instance (in a singly linked list). A -// collection of code:ComEventMethod is stored at code:ComEventsSink._methods -// -// #ComEventsRetValIssue: +// collection of ComEventMethod is stored at ComEventsSink._methods +// +// ComEventsRetValIssue: // Issue: normally, COM events would not return any value. However, it may happen as described in // http://support.microsoft.com/kb/810228. Such design might represent a problem for us - e.g. what is // the return value of a chain of delegates - is it the value of the last call in the chain or the the // first one? As the above KB article indicates, in cases where OM has events returning values, it is // suggested that people implement their event sink by explicitly implementing the source interface. This // means that the problem is already quite complex and we should not be dealing with it - see -// code:ComEventsMethod.Invoke +// ComEventsMethod.Invoke using System; @@ -115,10 +99,8 @@ namespace System.Runtime.InteropServices /// <param name="iid">identifier of the source interface used by COM object to fire events</param> /// <param name="dispid">dispatch identifier of the method on the source interface</param> /// <param name="d">delegate to invoke when specified COM event is fired</param> - public static void Combine(object rcw, Guid iid, int dispid, System.Delegate d) + public static void Combine(object rcw, Guid iid, int dispid, Delegate d) { - rcw = UnwrapIfTransparentProxy(rcw); - lock (rcw) { ComEventsInfo eventsInfo = ComEventsInfo.FromObject(rcw); @@ -129,7 +111,6 @@ namespace System.Runtime.InteropServices sink = eventsInfo.AddSink(ref iid); } - ComEventsMethod method = sink.FindMethod(dispid); if (method == null) { @@ -147,22 +128,27 @@ namespace System.Runtime.InteropServices /// <param name="iid">identifier of the source interface used by COM object to fire events</param> /// <param name="dispid">dispatch identifier of the method on the source interface</param> /// <param name="d">delegate to remove from the invocation list</param> - /// <returns></returns> - public static Delegate Remove(object rcw, Guid iid, int dispid, System.Delegate d) + public static Delegate Remove(object rcw, Guid iid, int dispid, Delegate d) { - rcw = UnwrapIfTransparentProxy(rcw); - lock (rcw) { ComEventsInfo eventsInfo = ComEventsInfo.Find(rcw); if (eventsInfo == null) + { return null; + } + ComEventsSink sink = eventsInfo.FindSink(ref iid); if (sink == null) + { return null; + } + ComEventsMethod method = sink.FindMethod(dispid); if (method == null) + { return null; + } method.RemoveDelegate(d); @@ -171,11 +157,13 @@ namespace System.Runtime.InteropServices // removed the last event handler for this dispid - need to remove dispid handler method = sink.RemoveMethod(method); } + if (method == null) { // removed last dispid handler for this sink - need to remove the sink sink = eventsInfo.RemoveSink(sink); } + if (sink == null) { // removed last sink for this rcw - need to remove all traces of event info @@ -186,10 +174,5 @@ namespace System.Runtime.InteropServices return d; } } - - internal static object UnwrapIfTransparentProxy(object rcw) - { - return rcw; - } } } diff --git a/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComEventsInfo.cs b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComEventsInfo.cs index 0fbe34db8d..8b47683e84 100644 --- a/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComEventsInfo.cs +++ b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComEventsInfo.cs @@ -2,33 +2,16 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. - -/*============================================================ -** -** -** Purpose: part of ComEventHelpers APIs which allow binding -** managed delegates to COM's connection point based events. -** -**/ +using System; +using ComTypes = System.Runtime.InteropServices.ComTypes; namespace System.Runtime.InteropServices { - using System; - using ComTypes = System.Runtime.InteropServices.ComTypes; - - // see code:ComEventsHelper#ComEventsArchitecture internal class ComEventsInfo { - #region fields - private ComEventsSink _sinks; private object _rcw; - #endregion - - - #region ctor/dtor - private ComEventsInfo(object rcw) { _rcw = rcw; @@ -36,22 +19,17 @@ namespace System.Runtime.InteropServices ~ComEventsInfo() { - // see code:ComEventsHelper#ComEventsFinalization + // see notes in ComEventsHelper.cs regarding ComEventsFinalization _sinks = ComEventsSink.RemoveAll(_sinks); } - #endregion - - - #region static methods - - internal static ComEventsInfo Find(object rcw) + public static ComEventsInfo Find(object rcw) { return (ComEventsInfo)Marshal.GetComObjectData(rcw, typeof(ComEventsInfo)); } // it is caller's responsibility to call this method under lock(rcw) - internal static ComEventsInfo FromObject(object rcw) + public static ComEventsInfo FromObject(object rcw) { ComEventsInfo eventsInfo = Find(rcw); if (eventsInfo == null) @@ -62,18 +40,13 @@ namespace System.Runtime.InteropServices return eventsInfo; } - #endregion - - - #region internal methods - - internal ComEventsSink FindSink(ref Guid iid) + public ComEventsSink FindSink(ref Guid iid) { return ComEventsSink.Find(_sinks, ref iid); } // it is caller's responsibility to call this method under lock(rcw) - internal ComEventsSink AddSink(ref Guid iid) + public ComEventsSink AddSink(ref Guid iid) { ComEventsSink sink = new ComEventsSink(_rcw, iid); _sinks = ComEventsSink.Add(_sinks, sink); @@ -87,7 +60,5 @@ namespace System.Runtime.InteropServices _sinks = ComEventsSink.Remove(_sinks, sink); return _sinks; } - - #endregion } } diff --git a/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComEventsMethod.cs b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComEventsMethod.cs index 8b1bcdcfab..ce36100d15 100644 --- a/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComEventsMethod.cs +++ b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComEventsMethod.cs @@ -2,96 +2,138 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. - -/*============================================================ -** -** -** Purpose: part of ComEventHelpers APIs which allow binding -** managed delegates to COM's connection point based events. -** -**/ - using System; using System.Collections.Generic; -using System.Text; using System.Diagnostics; +using System.Text; using System.Runtime.InteropServices; using System.Reflection; - namespace System.Runtime.InteropServices { - // see code:ComEventsHelper#ComEventsArchitecture + /// <summary> + /// Part of ComEventHelpers APIs which allow binding + /// managed delegates to COM's connection point based events. + /// </summary> internal class ComEventsMethod { - // This delegate wrapper class handles dynamic invocation of delegates. The reason for the wrapper's - // existence is that under certain circumstances we need to coerce arguments to types expected by the - // delegates signature. Normally, reflection (Delegate.DynamicInvoke) handles types coercion - // correctly but one known case is when the expected signature is 'ref Enum' - in this case - // reflection by design does not do the coercion. Since we need to be compatible with COM interop - // handling of this scenario - we are pre-processing delegate's signature by looking for 'ref enums' - // and cache the types required for such coercion. - internal class DelegateWrapper + /// <summary> + /// This delegate wrapper class handles dynamic invocation of delegates. The reason for the wrapper's + /// existence is that under certain circumstances we need to coerce arguments to types expected by the + /// delegates signature. Normally, reflection (Delegate.DynamicInvoke) handles type coercion + /// correctly but one known case is when the expected signature is 'ref Enum' - in this case + /// reflection by design does not do the coercion. Since we need to be compatible with COM interop + /// handling of this scenario - we are pre-processing delegate's signature by looking for 'ref enums' + /// and cache the types required for such coercion. + /// </summary> + public class DelegateWrapper { - private Delegate _d; + private bool _once = false; + private int _expectedParamsCount; + private Type[] _cachedTargetTypes; public DelegateWrapper(Delegate d) { - _d = d; + Delegate = d; } - public Delegate Delegate + public Delegate Delegate { get; set; } + + public object Invoke(object[] args) { - get { return _d; } - set { _d = value; } + if (Delegate == null) + { + return null; + } + + if (_once == false) + { + PreProcessSignature(); + _once = true; + } + + if (_cachedTargetTypes != null && _expectedParamsCount == args.Length) + { + for (int i = 0; i < _expectedParamsCount; i++) + { + if (_cachedTargetTypes[i] != null) + { + args[i] = Enum.ToObject(_cachedTargetTypes[i], args[i]); + } + } + } + + return Delegate.DynamicInvoke(args); } - } - #region private fields + private void PreProcessSignature() + { + ParameterInfo[] parameters = Delegate.Method.GetParameters(); + _expectedParamsCount = parameters.Length; + + bool needToHandleCoercion = false; + + var targetTypes = new List<Type>(); + foreach (ParameterInfo pi in parameters) + { + Type targetType = null; + + // recognize only 'ref Enum' signatures and cache + // both enum type and the underlying type. + if (pi.ParameterType.IsByRef + && pi.ParameterType.HasElementType + && pi.ParameterType.GetElementType().IsEnum) + { + needToHandleCoercion = true; + targetType = pi.ParameterType.GetElementType(); + } + + targetTypes.Add(targetType); + } + + if (needToHandleCoercion) + { + _cachedTargetTypes = targetTypes.ToArray(); + } + } + } /// <summary> /// Invoking ComEventsMethod means invoking a multi-cast delegate attached to it. /// Since multicast delegate's built-in chaining supports only chaining instances of the same type, /// we need to complement this design by using an explicit linked list data structure. /// </summary> - private DelegateWrapper[] _delegateWrappers; + private List<DelegateWrapper> _delegateWrappers = new List<DelegateWrapper>(); - private int _dispid; + private readonly int _dispid; private ComEventsMethod _next; - #endregion - - - #region ctor - - internal ComEventsMethod(int dispid) + public ComEventsMethod(int dispid) { - _delegateWrappers = null; _dispid = dispid; } - #endregion - - - #region internal static methods - - internal static ComEventsMethod Find(ComEventsMethod methods, int dispid) + public static ComEventsMethod Find(ComEventsMethod methods, int dispid) { while (methods != null && methods._dispid != dispid) { methods = methods._next; } + return methods; } - internal static ComEventsMethod Add(ComEventsMethod methods, ComEventsMethod method) + public static ComEventsMethod Add(ComEventsMethod methods, ComEventsMethod method) { method._next = methods; return method; } - internal static ComEventsMethod Remove(ComEventsMethod methods, ComEventsMethod method) + public static ComEventsMethod Remove(ComEventsMethod methods, ComEventsMethod method) { + Debug.Assert(methods != null, "removing method from empty methods collection"); + Debug.Assert(method != null, "specify method is null"); + if (methods == method) { methods = methods._next; @@ -100,99 +142,100 @@ namespace System.Runtime.InteropServices { ComEventsMethod current = methods; while (current != null && current._next != method) + { current = current._next; + } + if (current != null) + { current._next = method._next; + } } return methods; } - #endregion - #region public properties / methods - - internal bool Empty - { - get { return _delegateWrappers == null || _delegateWrappers.Length == 0; } - } - - internal void AddDelegate(Delegate d) + public bool Empty { - int count = 0; - if (_delegateWrappers != null) + get { - count = _delegateWrappers.Length; + lock (_delegateWrappers) + { + return _delegateWrappers.Count == 0; + } } + } - for (int i = 0; i < count; i++) + public void AddDelegate(Delegate d) + { + lock (_delegateWrappers) { - if (_delegateWrappers[i].Delegate.GetType() == d.GetType()) + // Update an existing delegate wrapper + foreach (DelegateWrapper wrapper in _delegateWrappers) { - _delegateWrappers[i].Delegate = Delegate.Combine(_delegateWrappers[i].Delegate, d); - return; + if (wrapper.Delegate.GetType() == d.GetType()) + { + wrapper.Delegate = Delegate.Combine(wrapper.Delegate, d); + return; + } } - } - DelegateWrapper[] newDelegateWrappers = new DelegateWrapper[count + 1]; - if (count > 0) - { - _delegateWrappers.CopyTo(newDelegateWrappers, 0); + var newWrapper = new DelegateWrapper(d); + _delegateWrappers.Add(newWrapper); } - - DelegateWrapper wrapper = new DelegateWrapper(d); - newDelegateWrappers[count] = wrapper; - - _delegateWrappers = newDelegateWrappers; } - internal void RemoveDelegate(Delegate d) + public void RemoveDelegate(Delegate d) { - int count = _delegateWrappers.Length; - int removeIdx = -1; - - for (int i = 0; i < count; i++) + lock (_delegateWrappers) { - if (_delegateWrappers[i].Delegate.GetType() == d.GetType()) + // Find delegate wrapper index + int removeIdx = -1; + DelegateWrapper wrapper = null; + for (int i = 0; i < _delegateWrappers.Count; i++) { - removeIdx = i; - break; + DelegateWrapper wrapperMaybe = _delegateWrappers[i]; + if (wrapperMaybe.Delegate.GetType() == d.GetType()) + { + removeIdx = i; + wrapper = wrapperMaybe; + break; + } } - } - if (removeIdx < 0) - return; + if (removeIdx < 0) + { + // Not present in collection + return; + } - Delegate newDelegate = Delegate.Remove(_delegateWrappers[removeIdx].Delegate, d); - if (newDelegate != null) - { - _delegateWrappers[removeIdx].Delegate = newDelegate; - return; + // Update wrapper or remove from collection + Delegate newDelegate = Delegate.Remove(wrapper.Delegate, d); + if (newDelegate != null) + { + wrapper.Delegate = newDelegate; + } + else + { + _delegateWrappers.RemoveAt(removeIdx); + } } + } - // now remove the found entry from the _delegates array - - if (count == 1) - { - _delegateWrappers = null; - return; - } + public object Invoke(object[] args) + { + Debug.Assert(!Empty); + object result = null; - DelegateWrapper[] newDelegateWrappers = new DelegateWrapper[count - 1]; - int j = 0; - while (j < removeIdx) + lock (_delegateWrappers) { - newDelegateWrappers[j] = _delegateWrappers[j]; - j++; - } - while (j < count - 1) - { - newDelegateWrappers[j] = _delegateWrappers[j + 1]; - j++; + foreach (DelegateWrapper wrapper in _delegateWrappers) + { + result = wrapper.Invoke(args); + } } - _delegateWrappers = newDelegateWrappers; + return result; } - - #endregion } } diff --git a/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComEventsSink.cs b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComEventsSink.cs index 9281d24d03..c5262a6558 100644 --- a/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComEventsSink.cs +++ b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComEventsSink.cs @@ -2,48 +2,32 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. - -/*============================================================ -** -** -** Purpose: part of ComEventHelpers APIs which allow binding -** managed delegates to COM's connection point based events. -** -**/ - using System; using System.Diagnostics; +using Variant = System.Runtime.InteropServices.Variant; + namespace System.Runtime.InteropServices { - // see code:ComEventsHelper#ComEventsArchitecture - internal class ComEventsSink : ICustomQueryInterface + /// <summary> + /// Part of ComEventHelpers APIs which allow binding + /// managed delegates to COM's connection point based events. + /// </summary> + internal class ComEventsSink : IDispatch, ICustomQueryInterface { - #region private fields - private Guid _iidSourceItf; private ComTypes.IConnectionPoint _connectionPoint; private int _cookie; private ComEventsMethod _methods; private ComEventsSink _next; - #endregion - - - #region ctor - - internal ComEventsSink(object rcw, Guid iid) + public ComEventsSink(object rcw, Guid iid) { _iidSourceItf = iid; this.Advise(rcw); } - #endregion - - - #region static members - - internal static ComEventsSink Find(ComEventsSink sinks, ref Guid iid) + public static ComEventsSink Find(ComEventsSink sinks, ref Guid iid) { ComEventsSink sink = sinks; while (sink != null && sink._iidSourceItf != iid) @@ -54,13 +38,13 @@ namespace System.Runtime.InteropServices return sink; } - internal static ComEventsSink Add(ComEventsSink sinks, ComEventsSink sink) + public static ComEventsSink Add(ComEventsSink sinks, ComEventsSink sink) { sink._next = sinks; return sink; } - internal static ComEventsSink RemoveAll(ComEventsSink sinks) + public static ComEventsSink RemoveAll(ComEventsSink sinks) { while (sinks != null) { @@ -71,7 +55,7 @@ namespace System.Runtime.InteropServices return null; } - internal static ComEventsSink Remove(ComEventsSink sinks, ComEventsSink sink) + public static ComEventsSink Remove(ComEventsSink sinks, ComEventsSink sink) { Debug.Assert(sinks != null, "removing event sink from empty sinks collection"); Debug.Assert(sink != null, "specify event sink is null"); @@ -84,7 +68,9 @@ namespace System.Runtime.InteropServices { ComEventsSink current = sinks; while (current != null && current._next != sink) + { current = current._next; + } if (current != null) { @@ -97,11 +83,6 @@ namespace System.Runtime.InteropServices return sinks; } - #endregion - - - #region public methods - public ComEventsMethod RemoveMethod(ComEventsMethod method) { _methods = ComEventsMethod.Remove(_methods, method); @@ -120,7 +101,136 @@ namespace System.Runtime.InteropServices return method; } - #endregion + int IDispatch.GetTypeInfoCount() + { + return 0; + } + + ComTypes.ITypeInfo IDispatch.GetTypeInfo(int iTInfo, int lcid) + { + throw new NotImplementedException(); + } + + void IDispatch.GetIDsOfNames(ref Guid iid, string[] names, int cNames, int lcid, int[] rgDispId) + { + throw new NotImplementedException(); + } + + private const VarEnum VT_BYREF_VARIANT = VarEnum.VT_BYREF | VarEnum.VT_VARIANT; + private const VarEnum VT_TYPEMASK = (VarEnum) 0x0fff; + private const VarEnum VT_BYREF_TYPEMASK = VT_TYPEMASK | VarEnum.VT_BYREF; + + private static unsafe ref Variant GetVariant(ref Variant pSrc) + { + if (pSrc.VariantType == VT_BYREF_VARIANT) + { + // For VB6 compatibility reasons, if the VARIANT is a VT_BYREF | VT_VARIANT that + // contains another VARIANT with VT_BYREF | VT_VARIANT, then we need to extract the + // inner VARIANT and use it instead of the outer one. Note that if the inner VARIANT + // is VT_BYREF | VT_VARIANT | VT_ARRAY, it will pass the below test too. + Span<Variant> pByRefVariant = new Span<Variant>(pSrc.AsByRefVariant.ToPointer(), 1); + if ((pByRefVariant[0].VariantType & VT_BYREF_TYPEMASK) == VT_BYREF_VARIANT) + { + return ref pByRefVariant[0]; + } + } + + return ref pSrc; + } + + unsafe void IDispatch.Invoke( + int dispid, + ref Guid riid, + int lcid, + InvokeFlags wFlags, + ref ComTypes.DISPPARAMS pDispParams, + IntPtr pVarResult, + IntPtr pExcepInfo, + IntPtr puArgErr) + { + ComEventsMethod method = FindMethod(dispid); + if (method == null) + { + return; + } + + // notice the unsafe pointers we are using. This is to avoid unnecessary + // arguments marshalling. see code:ComEventsHelper#ComEventsArgsMarshalling + + const int InvalidIdx = -1; + object [] args = new object[pDispParams.cArgs]; + int [] byrefsMap = new int[pDispParams.cArgs]; + bool [] usedArgs = new bool[pDispParams.cArgs]; + + int totalCount = pDispParams.cNamedArgs + pDispParams.cArgs; + var vars = new Span<Variant>(pDispParams.rgvarg.ToPointer(), totalCount); + var namedArgs = new Span<int>(pDispParams.rgdispidNamedArgs.ToPointer(), totalCount); + + // copy the named args (positional) as specified + int i; + int pos; + for (i = 0; i < pDispParams.cNamedArgs; i++) + { + pos = namedArgs[i]; + ref Variant pvar = ref GetVariant(ref vars[i]); + args[pos] = pvar.ToObject(); + usedArgs[pos] = true; + + int byrefIdx = InvalidIdx; + if (pvar.IsByRef) + { + byrefIdx = i; + } + + byrefsMap[pos] = byrefIdx; + } + + // copy the rest of the arguments in the reverse order + pos = 0; + for (; i < pDispParams.cArgs; i++) + { + // find the next unassigned argument + while (usedArgs[pos]) + { + pos++; + } + + ref Variant pvar = ref GetVariant(ref vars[pDispParams.cArgs - 1 - i]); + args[pos] = pvar.ToObject(); + + int byrefIdx = InvalidIdx; + if (pvar.IsByRef) + { + byrefIdx = pDispParams.cArgs - 1 - i; + } + + byrefsMap[pos] = byrefIdx; + + pos++; + } + + // Do the actual delegate invocation + object result = method.Invoke(args); + + // convert result to VARIANT + if (pVarResult != IntPtr.Zero) + { + Marshal.GetNativeVariantForObject(result, pVarResult); + } + + // Now we need to marshal all the byrefs back + for (i = 0; i < pDispParams.cArgs; i++) + { + int idxToPos = byrefsMap[i]; + if (idxToPos == InvalidIdx) + { + continue; + } + + ref Variant pvar = ref GetVariant(ref vars[idxToPos]); + pvar.CopyFromIndirect(args[i]); + } + } CustomQueryInterfaceResult ICustomQueryInterface.GetInterface(ref Guid iid, out IntPtr ppv) { @@ -134,9 +244,6 @@ namespace System.Runtime.InteropServices return CustomQueryInterfaceResult.NotHandled; } - #region private methods - - private void Advise(object rcw) { Debug.Assert(_connectionPoint == null, "comevent sink is already advised"); @@ -161,7 +268,7 @@ namespace System.Runtime.InteropServices _connectionPoint.Unadvise(_cookie); Marshal.ReleaseComObject(_connectionPoint); } - catch (System.Exception) + catch (Exception) { // swallow all exceptions on unadvise // the host may not be available at this point @@ -171,7 +278,5 @@ namespace System.Runtime.InteropServices _connectionPoint = null; } } - - #endregion }; } diff --git a/src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/EnumerableViewOfDispatch.cs b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/EnumerableViewOfDispatch.cs index 8de914ecef..67bb393c67 100644 --- a/src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/EnumerableViewOfDispatch.cs +++ b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/CustomMarshalers/EnumerableViewOfDispatch.cs @@ -5,6 +5,8 @@ using System.Collections; using System.Runtime.InteropServices.ComTypes; +using Variant = System.Runtime.InteropServices.Variant; + namespace System.Runtime.InteropServices.CustomMarshalers { internal class EnumerableViewOfDispatch : ICustomAdapter, System.Collections.IEnumerable @@ -23,19 +25,25 @@ namespace System.Runtime.InteropServices.CustomMarshalers public IEnumerator GetEnumerator() { - DISPPARAMS dispParams = new DISPPARAMS(); - Guid guid = Guid.Empty; - Dispatch.Invoke( - DISPID_NEWENUM, - ref guid, - LCID_DEFAULT, - InvokeFlags.DISPATCH_METHOD | InvokeFlags.DISPATCH_PROPERTYGET, - ref dispParams, - out object result, - IntPtr.Zero, - IntPtr.Zero); - - if (!(result is IEnumVARIANT enumVariant)) + Variant result; + unsafe + { + void *resultLocal = &result; + DISPPARAMS dispParams = new DISPPARAMS(); + Guid guid = Guid.Empty; + Dispatch.Invoke( + DISPID_NEWENUM, + ref guid, + LCID_DEFAULT, + InvokeFlags.DISPATCH_METHOD | InvokeFlags.DISPATCH_PROPERTYGET, + ref dispParams, + new IntPtr(resultLocal), + IntPtr.Zero, + IntPtr.Zero); + } + + object resultAsObject = result.ToObject(); + if (!(resultAsObject is IEnumVARIANT enumVariant)) { throw new InvalidOperationException(SR.InvalidOp_InvalidNewEnumVariant); } diff --git a/src/System.Private.CoreLib/src/System/Runtime/InteropServices/IDispatch.cs b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/IDispatch.cs index 183efa5f96..fbe70fed73 100644 --- a/src/System.Private.CoreLib/src/System/Runtime/InteropServices/IDispatch.cs +++ b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/IDispatch.cs @@ -4,7 +4,6 @@ using System; using System.Collections.Generic; -using System.Runtime.InteropServices.ComTypes; using System.Text; namespace System.Runtime.InteropServices @@ -14,6 +13,12 @@ namespace System.Runtime.InteropServices [InterfaceType(ComInterfaceType.InterfaceIsIUnknown)] internal interface IDispatch { + int GetTypeInfoCount(); + + ComTypes.ITypeInfo GetTypeInfo( + int iTInfo, + int lcid); + void GetIDsOfNames( ref Guid riid, [MarshalAs(UnmanagedType.LPArray, ArraySubType = UnmanagedType.LPWStr, SizeParamIndex = 2), In] @@ -22,21 +27,17 @@ namespace System.Runtime.InteropServices int lcid, [Out] int[] rgDispId); - ITypeInfo GetTypeInfo( - int iTInfo, - int lcid); - - int GetTypeInfoCount(); - + // The last 3 parameters of Invoke() are optional and must be defined + // as IntPtr in C#, since there is no language feature for optional ref/out. void Invoke( int dispIdMember, ref Guid riid, int lcid, InvokeFlags wFlags, - ref DISPPARAMS pDispParams, - out object pVarResult, - IntPtr pExcepInfo, - IntPtr puArgErr); + ref ComTypes.DISPPARAMS pDispParams, + /* out/optional */ IntPtr pVarResult, + /* out/optional */ IntPtr pExcepInfo, + /* out/optional */ IntPtr puArgErr); } [Flags] diff --git a/src/System.Private.CoreLib/src/System/Runtime/InteropServices/Variant.cs b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/Variant.cs new file mode 100644 index 0000000000..c726255c1a --- /dev/null +++ b/src/System.Private.CoreLib/src/System/Runtime/InteropServices/Variant.cs @@ -0,0 +1,710 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Diagnostics; + +namespace System.Runtime.InteropServices +{ + /// <summary> + /// Variant is the basic COM type for late-binding. It can contain any other COM data type. + /// This type definition precisely matches the unmanaged data layout so that the struct can be passed + /// to and from COM calls. + /// </summary> + [StructLayout(LayoutKind.Explicit)] + internal struct Variant + { +#if DEBUG + static Variant() + { + // Variant size is the size of 4 pointers (16 bytes) on a 32-bit processor, + // and 3 pointers (24 bytes) on a 64-bit processor. + int variantSize = Marshal.SizeOf(typeof(Variant)); + if (IntPtr.Size == 4) + { + Debug.Assert(variantSize == (4 * IntPtr.Size)); + } + else + { + Debug.Assert(IntPtr.Size == 8); + Debug.Assert(variantSize == (3 * IntPtr.Size)); + } + } +#endif + + // Most of the data types in the Variant are carried in _typeUnion + [FieldOffset(0)] private TypeUnion _typeUnion; + + // Decimal is the largest data type and it needs to use the space that is normally unused in TypeUnion._wReserved1, etc. + // Hence, it is declared to completely overlap with TypeUnion. A Decimal does not use the first two bytes, and so + // TypeUnion._vt can still be used to encode the type. + [FieldOffset(0)] private Decimal _decimal; + + [StructLayout(LayoutKind.Sequential)] + private struct TypeUnion + { + public ushort _vt; + public ushort _wReserved1; + public ushort _wReserved2; + public ushort _wReserved3; + + public UnionTypes _unionTypes; + } + + [StructLayout(LayoutKind.Sequential)] + private struct Record + { + public IntPtr _record; + public IntPtr _recordInfo; + } + + [StructLayout(LayoutKind.Explicit)] + private struct UnionTypes + { + [FieldOffset(0)] public sbyte _i1; + [FieldOffset(0)] public short _i2; + [FieldOffset(0)] public int _i4; + [FieldOffset(0)] public long _i8; + [FieldOffset(0)] public byte _ui1; + [FieldOffset(0)] public ushort _ui2; + [FieldOffset(0)] public uint _ui4; + [FieldOffset(0)] public ulong _ui8; + [FieldOffset(0)] public int _int; + [FieldOffset(0)] public uint _uint; + [FieldOffset(0)] public short _bool; + [FieldOffset(0)] public int _error; + [FieldOffset(0)] public float _r4; + [FieldOffset(0)] public double _r8; + [FieldOffset(0)] public long _cy; + [FieldOffset(0)] public double _date; + [FieldOffset(0)] public IntPtr _bstr; + [FieldOffset(0)] public IntPtr _unknown; + [FieldOffset(0)] public IntPtr _dispatch; + [FieldOffset(0)] public IntPtr _pvarVal; + [FieldOffset(0)] public IntPtr _byref; + [FieldOffset(0)] public Record _record; + } + + /// <summary> + /// Primitive types are the basic COM types. It includes valuetypes like ints, but also reference types + /// like BStrs. It does not include composite types like arrays and user-defined COM types (IUnknown/IDispatch). + /// </summary> + public static bool IsPrimitiveType(VarEnum varEnum) + { + switch(varEnum) + { + case VarEnum.VT_I1: + case VarEnum.VT_I2: + case VarEnum.VT_I4: + case VarEnum.VT_I8: + case VarEnum.VT_UI1: + case VarEnum.VT_UI2: + case VarEnum.VT_UI4: + case VarEnum.VT_UI8: + case VarEnum.VT_INT: + case VarEnum.VT_UINT: + case VarEnum.VT_BOOL: + case VarEnum.VT_R4: + case VarEnum.VT_R8: + case VarEnum.VT_DECIMAL: + case VarEnum.VT_DATE: + case VarEnum.VT_BSTR: + return true; + } + + return false; + } + + unsafe public void CopyFromIndirect(object value) + { + VarEnum vt = (VarEnum)(((int)this.VariantType) & ~((int)VarEnum.VT_BYREF)); + + if (value == null) + { + if (vt == VarEnum.VT_DISPATCH || vt == VarEnum.VT_UNKNOWN || vt == VarEnum.VT_BSTR) + { + *(IntPtr*)this._typeUnion._unionTypes._byref = IntPtr.Zero; + } + return; + } + + if ((vt & VarEnum.VT_ARRAY) != 0) + { + Variant vArray; + Marshal.GetNativeVariantForObject(value, (IntPtr)(void*)&vArray); + *(IntPtr*)this._typeUnion._unionTypes._byref = vArray._typeUnion._unionTypes._byref; + return; + } + + switch (vt) + { + case VarEnum.VT_I1: + *(sbyte*)this._typeUnion._unionTypes._byref = (sbyte)value; + break; + + case VarEnum.VT_UI1: + *(byte*)this._typeUnion._unionTypes._byref = (byte)value; + break; + + case VarEnum.VT_I2: + *(short*)this._typeUnion._unionTypes._byref = (short)value; + break; + + case VarEnum.VT_UI2: + *(ushort*)this._typeUnion._unionTypes._byref = (ushort)value; + break; + + case VarEnum.VT_BOOL: + // VARIANT_TRUE = -1 + // VARIANT_FALSE = 0 + *(short*)this._typeUnion._unionTypes._byref = (bool)value ? (short)-1 : (short)0; + break; + + case VarEnum.VT_I4: + case VarEnum.VT_INT: + *(int*)this._typeUnion._unionTypes._byref = (int)value; + break; + + case VarEnum.VT_UI4: + case VarEnum.VT_UINT: + *(uint*)this._typeUnion._unionTypes._byref = (uint)value; + break; + + case VarEnum.VT_ERROR: + *(int*)this._typeUnion._unionTypes._byref = ((ErrorWrapper)value).ErrorCode; + break; + + case VarEnum.VT_I8: + *(Int64*)this._typeUnion._unionTypes._byref = (Int64)value; + break; + + case VarEnum.VT_UI8: + *(UInt64*)this._typeUnion._unionTypes._byref = (UInt64)value; + break; + + case VarEnum.VT_R4: + *(float*)this._typeUnion._unionTypes._byref = (float)value; + break; + + case VarEnum.VT_R8: + *(double*)this._typeUnion._unionTypes._byref = (double)value; + break; + + case VarEnum.VT_DATE: + *(double*)this._typeUnion._unionTypes._byref = ((DateTime)value).ToOADate(); + break; + + case VarEnum.VT_UNKNOWN: + *(IntPtr*)this._typeUnion._unionTypes._byref = Marshal.GetIUnknownForObject(value); + break; + + case VarEnum.VT_DISPATCH: + *(IntPtr*)this._typeUnion._unionTypes._byref = Marshal.GetIDispatchForObject(value); + break; + + case VarEnum.VT_BSTR: + *(IntPtr*)this._typeUnion._unionTypes._byref = Marshal.StringToBSTR((string)value); + break; + + case VarEnum.VT_CY: + *(long*)this._typeUnion._unionTypes._byref = decimal.ToOACurrency((decimal)value); + break; + + case VarEnum.VT_DECIMAL: + *(decimal*)this._typeUnion._unionTypes._byref = (decimal)value; + break; + + case VarEnum.VT_VARIANT: + Marshal.GetNativeVariantForObject(value, this._typeUnion._unionTypes._byref); + break; + + default: + throw new ArgumentException(); + } + } + + /// <summary> + /// Get the managed object representing the Variant. + /// </summary> + /// <returns></returns> + public object ToObject() + { + // Check the simple case upfront + if (IsEmpty) + { + return null; + } + + switch (VariantType) + { + case VarEnum.VT_NULL: + return DBNull.Value; + + case VarEnum.VT_I1: return AsI1; + case VarEnum.VT_I2: return AsI2; + case VarEnum.VT_I4: return AsI4; + case VarEnum.VT_I8: return AsI8; + case VarEnum.VT_UI1: return AsUi1; + case VarEnum.VT_UI2: return AsUi2; + case VarEnum.VT_UI4: return AsUi4; + case VarEnum.VT_UI8: return AsUi8; + case VarEnum.VT_INT: return AsInt; + case VarEnum.VT_UINT: return AsUint; + case VarEnum.VT_BOOL: return AsBool; + case VarEnum.VT_ERROR: return AsError; + case VarEnum.VT_R4: return AsR4; + case VarEnum.VT_R8: return AsR8; + case VarEnum.VT_DECIMAL: return AsDecimal; + case VarEnum.VT_CY: return AsCy; + case VarEnum.VT_DATE: return AsDate; + case VarEnum.VT_BSTR: return AsBstr; + case VarEnum.VT_UNKNOWN: return AsUnknown; + case VarEnum.VT_DISPATCH: return AsDispatch; + + default: + unsafe + { + fixed (void* pThis = &this) + { + return Marshal.GetObjectForNativeVariant((System.IntPtr)pThis); + } + } + } + } + + /// <summary> + /// Release any unmanaged memory associated with the Variant + /// </summary> + /// <returns></returns> + public void Clear() + { + // We do not need to call OLE32's VariantClear for primitive types or ByRefs + // to save ourselves the cost of interop transition. + // ByRef indicates the memory is not owned by the VARIANT itself while + // primitive types do not have any resources to free up. + // Hence, only safearrays, BSTRs, interfaces and user types are + // handled differently. + VarEnum vt = VariantType; + if ((vt & VarEnum.VT_BYREF) != 0) + { + VariantType = VarEnum.VT_EMPTY; + } + else if (((vt & VarEnum.VT_ARRAY) != 0) + || ((vt) == VarEnum.VT_BSTR) + || ((vt) == VarEnum.VT_UNKNOWN) + || ((vt) == VarEnum.VT_DISPATCH) + || ((vt) == VarEnum.VT_VARIANT) + || ((vt) == VarEnum.VT_RECORD) + || ((vt) == VarEnum.VT_VARIANT)) + { + unsafe + { + fixed (void* pThis = &this) + { + Interop.OleAut32.VariantClear((IntPtr)pThis); + } + } + + Debug.Assert(IsEmpty); + } + else + { + VariantType = VarEnum.VT_EMPTY; + } + } + + public VarEnum VariantType + { + get => (VarEnum)_typeUnion._vt; + set => _typeUnion._vt = (ushort)value; + } + + public bool IsEmpty => _typeUnion._vt == ((ushort)VarEnum.VT_EMPTY); + + public bool IsByRef => (_typeUnion._vt & ((ushort)VarEnum.VT_BYREF)) != 0; + + public void SetAsNULL() + { + Debug.Assert(IsEmpty); + VariantType = VarEnum.VT_NULL; + } + + // VT_I1 + + public sbyte AsI1 + { + get + { + Debug.Assert(VariantType == VarEnum.VT_I1); + return _typeUnion._unionTypes._i1; + } + set + { + Debug.Assert(IsEmpty); + VariantType = VarEnum.VT_I1; + _typeUnion._unionTypes._i1 = value; + } + } + + // VT_I2 + + public short AsI2 + { + get + { + Debug.Assert(VariantType == VarEnum.VT_I2); + return _typeUnion._unionTypes._i2; + } + set + { + Debug.Assert(IsEmpty); + VariantType = VarEnum.VT_I2; + _typeUnion._unionTypes._i2 = value; + } + } + + // VT_I4 + + public int AsI4 + { + get + { + Debug.Assert(VariantType == VarEnum.VT_I4); + return _typeUnion._unionTypes._i4; + } + set + { + Debug.Assert(IsEmpty); + VariantType = VarEnum.VT_I4; + _typeUnion._unionTypes._i4 = value; + } + } + + // VT_I8 + + public long AsI8 + { + get + { + Debug.Assert(VariantType == VarEnum.VT_I8); + return _typeUnion._unionTypes._i8; + } + set + { + Debug.Assert(IsEmpty); + VariantType = VarEnum.VT_I8; + _typeUnion._unionTypes._i8 = value; + } + } + + // VT_UI1 + + public byte AsUi1 + { + get + { + Debug.Assert(VariantType == VarEnum.VT_UI1); + return _typeUnion._unionTypes._ui1; + } + set + { + Debug.Assert(IsEmpty); + VariantType = VarEnum.VT_UI1; + _typeUnion._unionTypes._ui1 = value; + } + } + + // VT_UI2 + + public ushort AsUi2 + { + get + { + Debug.Assert(VariantType == VarEnum.VT_UI2); + return _typeUnion._unionTypes._ui2; + } + set + { + Debug.Assert(IsEmpty); + VariantType = VarEnum.VT_UI2; + _typeUnion._unionTypes._ui2 = value; + } + } + + // VT_UI4 + + public uint AsUi4 + { + get + { + Debug.Assert(VariantType == VarEnum.VT_UI4); + return _typeUnion._unionTypes._ui4; + } + set + { + Debug.Assert(IsEmpty); + VariantType = VarEnum.VT_UI4; + _typeUnion._unionTypes._ui4 = value; + } + } + + // VT_UI8 + + public ulong AsUi8 + { + get + { + Debug.Assert(VariantType == VarEnum.VT_UI8); + return _typeUnion._unionTypes._ui8; + } + set + { + Debug.Assert(IsEmpty); + VariantType = VarEnum.VT_UI8; + _typeUnion._unionTypes._ui8 = value; + } + } + + // VT_INT + + public int AsInt + { + get + { + Debug.Assert(VariantType == VarEnum.VT_INT); + return _typeUnion._unionTypes._int; + } + set + { + Debug.Assert(IsEmpty); + VariantType = VarEnum.VT_INT; + _typeUnion._unionTypes._int = value; + } + } + + // VT_UINT + + public uint AsUint + { + get + { + Debug.Assert(VariantType == VarEnum.VT_UINT); + return _typeUnion._unionTypes._uint; + } + set + { + Debug.Assert(IsEmpty); + VariantType = VarEnum.VT_UINT; + _typeUnion._unionTypes._uint = value; + } + } + + // VT_BOOL + + public bool AsBool + { + get + { + Debug.Assert(VariantType == VarEnum.VT_BOOL); + return _typeUnion._unionTypes._bool != 0; + } + set + { + Debug.Assert(IsEmpty); + // VARIANT_TRUE = -1 + // VARIANT_FALSE = 0 + VariantType = VarEnum.VT_BOOL; + _typeUnion._unionTypes._bool = value ? (short)-1 : (short)0; + } + } + + // VT_ERROR + + public int AsError + { + get + { + Debug.Assert(VariantType == VarEnum.VT_ERROR); + return _typeUnion._unionTypes._error; + } + set + { + Debug.Assert(IsEmpty); + VariantType = VarEnum.VT_ERROR; + _typeUnion._unionTypes._error = value; + } + } + + // VT_R4 + + public float AsR4 + { + get + { + Debug.Assert(VariantType == VarEnum.VT_R4); + return _typeUnion._unionTypes._r4; + } + set + { + Debug.Assert(IsEmpty); + VariantType = VarEnum.VT_R4; + _typeUnion._unionTypes._r4 = value; + } + } + + // VT_R8 + + public double AsR8 + { + get + { + Debug.Assert(VariantType == VarEnum.VT_R8); + return _typeUnion._unionTypes._r8; + } + set + { + Debug.Assert(IsEmpty); + VariantType = VarEnum.VT_R8; + _typeUnion._unionTypes._r8 = value; + } + } + + // VT_DECIMAL + + public Decimal AsDecimal + { + get + { + Debug.Assert(VariantType == VarEnum.VT_DECIMAL); + // The first byte of Decimal is unused, but usually set to 0 + Variant v = this; + v._typeUnion._vt = 0; + return v._decimal; + } + set + { + Debug.Assert(IsEmpty); + VariantType = VarEnum.VT_DECIMAL; + _decimal = value; + // _vt overlaps with _decimal, and should be set after setting _decimal + _typeUnion._vt = (ushort)VarEnum.VT_DECIMAL; + } + } + + // VT_CY + + public Decimal AsCy + { + get + { + Debug.Assert(VariantType == VarEnum.VT_CY); + return Decimal.FromOACurrency(_typeUnion._unionTypes._cy); + } + set + { + Debug.Assert(IsEmpty); + VariantType = VarEnum.VT_CY; + _typeUnion._unionTypes._cy = Decimal.ToOACurrency(value); + } + } + + // VT_DATE + + public DateTime AsDate + { + get + { + Debug.Assert(VariantType == VarEnum.VT_DATE); + return DateTime.FromOADate(_typeUnion._unionTypes._date); + } + set + { + Debug.Assert(IsEmpty); + VariantType = VarEnum.VT_DATE; + _typeUnion._unionTypes._date = value.ToOADate(); + } + } + + // VT_BSTR + + public string AsBstr + { + get + { + Debug.Assert(VariantType == VarEnum.VT_BSTR); + return (string)Marshal.PtrToStringBSTR(this._typeUnion._unionTypes._bstr); + } + set + { + Debug.Assert(IsEmpty); + VariantType = VarEnum.VT_BSTR; + this._typeUnion._unionTypes._bstr = Marshal.StringToBSTR(value); + } + } + + // VT_UNKNOWN + + public object AsUnknown + { + get + { + Debug.Assert(VariantType == VarEnum.VT_UNKNOWN); + if (_typeUnion._unionTypes._unknown == IntPtr.Zero) + { + return null; + } + return Marshal.GetObjectForIUnknown(_typeUnion._unionTypes._unknown); + } + set + { + Debug.Assert(IsEmpty); + VariantType = VarEnum.VT_UNKNOWN; + if (value == null) + { + _typeUnion._unionTypes._unknown = IntPtr.Zero; + } + else + { + _typeUnion._unionTypes._unknown = Marshal.GetIUnknownForObject(value); + } + } + } + + // VT_DISPATCH + + public object AsDispatch + { + get + { + Debug.Assert(VariantType == VarEnum.VT_DISPATCH); + if (_typeUnion._unionTypes._dispatch == IntPtr.Zero) + { + return null; + } + return Marshal.GetObjectForIUnknown(_typeUnion._unionTypes._dispatch); + } + set + { + Debug.Assert(IsEmpty); + VariantType = VarEnum.VT_DISPATCH; + if (value == null) + { + _typeUnion._unionTypes._dispatch = IntPtr.Zero; + } + else + { + _typeUnion._unionTypes._dispatch = Marshal.GetIDispatchForObject(value); + } + } + } + + public IntPtr AsByRefVariant + { + get + { + Debug.Assert(VariantType == (VarEnum.VT_BYREF | VarEnum.VT_VARIANT)); + return _typeUnion._unionTypes._pvarVal; + } + } + } +} diff --git a/tests/src/Interop/COM/NETClients/Events/App.manifest b/tests/src/Interop/COM/NETClients/Events/App.manifest new file mode 100644 index 0000000000..833743f83d --- /dev/null +++ b/tests/src/Interop/COM/NETClients/Events/App.manifest @@ -0,0 +1,18 @@ +<?xml version="1.0" encoding="utf-8"?> +<assembly manifestVersion="1.0" xmlns="urn:schemas-microsoft-com:asm.v1"> + <assemblyIdentity + type="win32" + name="NetPrimitivesEvents" + version="1.0.0.0" /> + + <dependency> + <dependentAssembly> + <!-- RegFree COM --> + <assemblyIdentity + type="win32" + name="COMNativeServer.X" + version="1.0.0.0"/> + </dependentAssembly> + </dependency> + +</assembly> diff --git a/tests/src/Interop/COM/NETClients/Events/NETClientEvents.csproj b/tests/src/Interop/COM/NETClients/Events/NETClientEvents.csproj new file mode 100644 index 0000000000..fccbee9258 --- /dev/null +++ b/tests/src/Interop/COM/NETClients/Events/NETClientEvents.csproj @@ -0,0 +1,43 @@ +<?xml version="1.0" encoding="utf-8"?> +<Project ToolsVersion="12.0" DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> + <Import Project="$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory), dir.props))\dir.props" /> + <PropertyGroup> + <Configuration Condition=" '$(Configuration)' == '' ">Debug</Configuration> + <Platform Condition=" '$(Platform)' == '' ">AnyCPU</Platform> + <AssemblyName>NETClientEvents</AssemblyName> + <SchemaVersion>2.0</SchemaVersion> + <ProjectGuid>{85C57688-DA98-4DE3-AC9B-526E4747434C}</ProjectGuid> + <OutputType>Exe</OutputType> + <ProjectTypeGuids>{209912F9-0DA1-4184-9CC1-8D583BAF4A28};{87799F5D-CEBD-499D-BDBA-B2C6105CD766}</ProjectTypeGuids> + <ApplicationManifest>App.manifest</ApplicationManifest> + + <!-- Blocked on ILAsm supporting embedding resources. See https://github.com/dotnet/coreclr/issues/20819 --> + <IlrtTestKind>BuildOnly</IlrtTestKind> + + <!-- Blocked on CrossGen.exe supporting embedding resources. See https://github.com/dotnet/coreclr/issues/21006 --> + <CrossGenTest>false</CrossGenTest> + + <!-- Test unsupported outside of windows --> + <TestUnsupportedOutsideWindows>true</TestUnsupportedOutsideWindows> + <DisableProjectBuild Condition="'$(TargetsUnix)' == 'true'">true</DisableProjectBuild> + <!-- This test would require the runincontext.exe to include App.manifest describing the COM interfaces --> + <UnloadabilityIncompatible>true</UnloadabilityIncompatible> + </PropertyGroup> + <!-- Default configurations to help VS understand the configurations --> + <PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'Debug|x64'"> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)' == 'Release|x64'"> + </PropertyGroup> + <ItemGroup> + <Compile Include="Program.cs" /> + <Compile Include="../../ServerContracts/NativeServers.cs" /> + <Compile Include="../../ServerContracts/Server.Contracts.cs" /> + <Compile Include="../../ServerContracts/Server.Events.cs" /> + <Compile Include="../../ServerContracts/ServerGuids.cs" /> + </ItemGroup> + <ItemGroup> + <ProjectReference Include="../../NativeServer/CMakeLists.txt" /> + <ProjectReference Include="../../../../Common/CoreCLRTestLibrary/CoreCLRTestLibrary.csproj" /> + </ItemGroup> + <Import Project="$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory), dir.targets))\dir.targets" /> +</Project> diff --git a/tests/src/Interop/COM/NETClients/Events/Program.cs b/tests/src/Interop/COM/NETClients/Events/Program.cs new file mode 100644 index 0000000000..49359243d9 --- /dev/null +++ b/tests/src/Interop/COM/NETClients/Events/Program.cs @@ -0,0 +1,117 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace NetClient +{ + using System; + using System.Reflection; + using System.Runtime.InteropServices; + + using TestLibrary; + using Server.Contract; + using Server.Contract.Servers; + using Server.Contract.Events; + + class Program + { + static void Validate_BasicCOMEvent() + { + Console.WriteLine($"{nameof(Validate_BasicCOMEvent)}..."); + + var eventTesting = (EventTesting)new EventTestingClass(); + + // Verify event handler subscription + + // Add event + eventTesting.OnEvent += OnEventEventHandler; + + bool eventFired = false; + string message = string.Empty; + eventTesting.FireEvent(); + + Assert.IsTrue(eventFired, "Event didn't fire"); + Assert.AreEqual(nameof(EventTesting.FireEvent), message, "Event message is incorrect"); + + // Remove event + eventTesting.OnEvent -= OnEventEventHandler; + + // Verify event handler removed + + eventFired = false; + eventTesting.FireEvent(); + + Assert.IsFalse(eventFired, "Event shouldn't fire"); + + void OnEventEventHandler(string msg) + { + eventFired = true; + message = msg; + } + } + +#pragma warning disable 618 // Must test deprecated features + + // The ComAwareEventInfo is used by the compiler when PIAs + // containing COM Events are embedded. + static void Validate_COMEventViaComAwareEventInfo() + { + Console.WriteLine($"{nameof(Validate_COMEventViaComAwareEventInfo)}..."); + + var eventTesting = (EventTesting)new EventTestingClass(); + + // Verify event handler subscription + + // Add event + var comAwareEventInfo = new ComAwareEventInfo(typeof(TestingEvents_Event), nameof(TestingEvents_Event.OnEvent)); + var handler = new TestingEvents_OnEventEventHandler(OnEventEventHandler); + comAwareEventInfo.AddEventHandler(eventTesting, handler); + + bool eventFired = false; + string message = string.Empty; + eventTesting.FireEvent(); + + Assert.IsTrue(eventFired, "Event didn't fire"); + Assert.AreEqual(nameof(EventTesting.FireEvent), message, "Event message is incorrect"); + + comAwareEventInfo.RemoveEventHandler(eventTesting, handler); + + // Verify event handler removed + + eventFired = false; + eventTesting.FireEvent(); + + Assert.IsFalse(eventFired, "Event shouldn't fire"); + + void OnEventEventHandler(string msg) + { + eventFired = true; + message = msg; + } + } + +#pragma warning restore 618 // Must test deprecated features + + static int Main(string[] doNotUse) + { + // RegFree COM is not supported on Windows Nano + if (Utilities.IsWindowsNanoServer) + { + return 100; + } + + try + { + Validate_BasicCOMEvent(); + Validate_COMEventViaComAwareEventInfo(); + } + catch (Exception e) + { + Console.WriteLine($"Test Failure: {e}"); + return 101; + } + + return 100; + } + } +} diff --git a/tests/src/Interop/COM/NativeServer/ArrayTesting.h b/tests/src/Interop/COM/NativeServer/ArrayTesting.h index 8366a1c149..b653a5346f 100644 --- a/tests/src/Interop/COM/NativeServer/ArrayTesting.h +++ b/tests/src/Interop/COM/NativeServer/ArrayTesting.h @@ -340,7 +340,7 @@ public: // IUnknown /* [in] */ REFIID riid, /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject) { - return DoQueryInterface<ArrayTesting, IArrayTesting>(this, riid, ppvObject); + return DoQueryInterface(riid, ppvObject, static_cast<IArrayTesting *>(this)); } DEFINE_REF_COUNTING(); diff --git a/tests/src/Interop/COM/NativeServer/COMNativeServer.X.manifest b/tests/src/Interop/COM/NativeServer/COMNativeServer.X.manifest index 1569d52c9b..4509ee9d65 100644 --- a/tests/src/Interop/COM/NativeServer/COMNativeServer.X.manifest +++ b/tests/src/Interop/COM/NativeServer/COMNativeServer.X.manifest @@ -32,6 +32,11 @@ clsid="{0F8ACD0C-ECE0-4F2A-BD1B-6BFCA93A0726}" threadingModel="Both" /> + <!-- EventTesting --> + <comClass + clsid="{4DBD9B61-E372-499F-84DE-EFC70AA8A009}" + threadingModel="Both" /> + <!-- AggregationTesting --> <comClass clsid="{4CEFE36D-F377-4B6E-8C34-819A8BB9CB04}" diff --git a/tests/src/Interop/COM/NativeServer/ColorTesting.h b/tests/src/Interop/COM/NativeServer/ColorTesting.h index 5d6e1740df..6357ce2720 100644 --- a/tests/src/Interop/COM/NativeServer/ColorTesting.h +++ b/tests/src/Interop/COM/NativeServer/ColorTesting.h @@ -34,7 +34,7 @@ public: // IUnknown /* [in] */ REFIID riid, /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject) { - return DoQueryInterface<ColorTesting, IColorTesting>(this, riid, ppvObject); + return DoQueryInterface(riid, ppvObject, static_cast<IColorTesting *>(this)); } DEFINE_REF_COUNTING(); diff --git a/tests/src/Interop/COM/NativeServer/DispatchTesting.h b/tests/src/Interop/COM/NativeServer/DispatchTesting.h index af9cd6c957..166d6fa749 100644 --- a/tests/src/Interop/COM/NativeServer/DispatchTesting.h +++ b/tests/src/Interop/COM/NativeServer/DispatchTesting.h @@ -415,7 +415,7 @@ public: // IUnknown /* [in] */ REFIID riid, /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject) { - return DoQueryInterface<DispatchTesting, IDispatch, IDispatchTesting>(this, riid, ppvObject); + return DoQueryInterface(riid, ppvObject, static_cast<IDispatch *>(this), static_cast<IDispatchTesting *>(this)); } DEFINE_REF_COUNTING(); diff --git a/tests/src/Interop/COM/NativeServer/ErrorMarshalTesting.h b/tests/src/Interop/COM/NativeServer/ErrorMarshalTesting.h index 28c0011601..c28fa0f645 100644 --- a/tests/src/Interop/COM/NativeServer/ErrorMarshalTesting.h +++ b/tests/src/Interop/COM/NativeServer/ErrorMarshalTesting.h @@ -26,7 +26,7 @@ public: // IUnknown /* [in] */ REFIID riid, /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject) { - return DoQueryInterface<ErrorMarshalTesting, IErrorMarshalTesting>(this, riid, ppvObject); + return DoQueryInterface(riid, ppvObject, static_cast<IErrorMarshalTesting *>(this)); } DEFINE_REF_COUNTING(); diff --git a/tests/src/Interop/COM/NativeServer/EventTesting.h b/tests/src/Interop/COM/NativeServer/EventTesting.h new file mode 100644 index 0000000000..f700db8d5b --- /dev/null +++ b/tests/src/Interop/COM/NativeServer/EventTesting.h @@ -0,0 +1,237 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#pragma once + +#include "Servers.h" + +class EventTesting : + public UnknownImpl, + public IEventTesting, + public IConnectionPointContainer, + public IConnectionPoint +{ +private: // static + static const WCHAR * const Names[]; + static const int NamesCount; + +private: + IDispatch *_eventConnections[32]; + +public: + EventTesting() + { + // Ensure connections array is null + ::memset(_eventConnections, 0, sizeof(_eventConnections)); + } + +public: // IDispatch + virtual HRESULT STDMETHODCALLTYPE GetTypeInfoCount( + /* [out] */ __RPC__out UINT *pctinfo) + { + *pctinfo = 0; + return S_OK; + } + + virtual HRESULT STDMETHODCALLTYPE GetTypeInfo( + /* [in] */ UINT iTInfo, + /* [in] */ LCID lcid, + /* [out] */ __RPC__deref_out_opt ITypeInfo **ppTInfo) + { + return E_NOTIMPL; + } + + virtual HRESULT STDMETHODCALLTYPE GetIDsOfNames( + /* [in] */ __RPC__in REFIID, + /* [size_is][in] */ __RPC__in_ecount_full(cNames) LPOLESTR *rgszNames, + /* [range][in] */ __RPC__in_range(0,16384) UINT cNames, + /* [in] */ LCID, + /* [size_is][out] */ __RPC__out_ecount_full(cNames) DISPID *rgDispId) + { + bool containsUnknown = false; + DISPID *curr = rgDispId; + for (UINT i = 0; i < cNames; ++i) + { + *curr = DISPID_UNKNOWN; + LPOLESTR name = rgszNames[i]; + for (int j = 1; j < NamesCount; ++j) + { + const WCHAR *nameMaybe = Names[j]; + if (::TP_wcmp_s(name, nameMaybe) == 0) + { + *curr = DISPID{ j }; + break; + } + } + + containsUnknown &= (*curr == DISPID_UNKNOWN); + curr++; + } + + return (containsUnknown) ? DISP_E_UNKNOWNNAME : S_OK; + } + + virtual /* [local] */ HRESULT STDMETHODCALLTYPE Invoke( + /* [annotation][in] */ _In_ DISPID dispIdMember, + /* [annotation][in] */ _In_ REFIID riid, + /* [annotation][in] */ _In_ LCID lcid, + /* [annotation][in] */ _In_ WORD wFlags, + /* [annotation][out][in] */ _In_ DISPPARAMS *pDispParams, + /* [annotation][out] */ _Out_opt_ VARIANT *pVarResult, + /* [annotation][out] */ _Out_opt_ EXCEPINFO *pExcepInfo, + /* [annotation][out] */ _Out_opt_ UINT *puArgErr) + { + // + // Note that arguments are received in reverse order for IDispatch::Invoke() + // + + switch (dispIdMember) + { + case 1: + { + return FireEvent(); + } + } + + return E_NOTIMPL; + } + +public: // IEventTesting + virtual HRESULT STDMETHODCALLTYPE FireEvent() + { + return FireEvent_Impl(1 /* DISPID for the FireEvent function */); + } + +public: // IConnectionPointContainer + virtual HRESULT STDMETHODCALLTYPE EnumConnectionPoints( + /* [out] */ __RPC__deref_out_opt IEnumConnectionPoints **ppEnum) + { + return E_NOTIMPL; + } + virtual HRESULT STDMETHODCALLTYPE FindConnectionPoint( + /* [in] */ __RPC__in REFIID riid, + /* [out] */ __RPC__deref_out_opt IConnectionPoint **ppCP) + { + if (riid != IID_TestingEvents) + return CONNECT_E_NOCONNECTION; + + return QueryInterface(__uuidof(*ppCP), (void**)ppCP); + } + +public: // IConnectionPoint + virtual HRESULT STDMETHODCALLTYPE GetConnectionInterface( + /* [out] */ __RPC__out IID *pIID) + { + return E_NOTIMPL; + } + virtual HRESULT STDMETHODCALLTYPE GetConnectionPointContainer( + /* [out] */ __RPC__deref_out_opt IConnectionPointContainer **ppCPC) + { + return E_NOTIMPL; + } + virtual HRESULT STDMETHODCALLTYPE Advise( + /* [in] */ __RPC__in_opt IUnknown *pUnkSink, + /* [out] */ __RPC__out DWORD *pdwCookie) + { + if (pUnkSink == nullptr || pdwCookie == nullptr) + return E_POINTER; + + for (DWORD i = 0; i < ARRAYSIZE(_eventConnections); ++i) + { + if (_eventConnections[i] == nullptr) + { + IDispatch *handler; + HRESULT hr = pUnkSink->QueryInterface(IID_IDispatch, (void**)&handler); + if (hr != S_OK) + return CONNECT_E_CANNOTCONNECT; + + _eventConnections[i] = handler; + *pdwCookie = i; + return S_OK; + } + } + + return CONNECT_E_ADVISELIMIT; + } + virtual HRESULT STDMETHODCALLTYPE Unadvise( + /* [in] */ DWORD dwCookie) + { + if (0 <= dwCookie && dwCookie < ARRAYSIZE(_eventConnections)) + { + IDispatch *handler = _eventConnections[dwCookie]; + if (handler != nullptr) + { + _eventConnections[dwCookie] = nullptr; + handler->Release(); + return S_OK; + } + } + + return E_POINTER; + } + virtual HRESULT STDMETHODCALLTYPE EnumConnections( + /* [out] */ __RPC__deref_out_opt IEnumConnections **ppEnum) + { + return E_NOTIMPL; + } + +private: + HRESULT FireEvent_Impl(_In_ int dispId) + { + HRESULT hr = S_OK; + + VARIANTARG arg; + ::VariantInit(&arg); + + arg.vt = VT_BSTR; + arg.bstrVal = TP_SysAllocString(Names[dispId]); + + for (DWORD i = 0; i < ARRAYSIZE(_eventConnections); ++i) + { + IDispatch *handler = _eventConnections[i]; + if (handler != nullptr) + { + DISPPARAMS params{}; + params.rgvarg = &arg; + params.cArgs = 1; + hr = handler->Invoke( + DISPATCHTESTINGEVENTS_DISPID_ONEVENT, + IID_NULL, + 0, + DISPATCH_METHOD, + ¶ms, + nullptr, + nullptr, + nullptr); + + if (FAILED(hr)) + break; + } + } + + return ::VariantClear(&arg); + } + +public: // IUnknown + STDMETHOD(QueryInterface)( + /* [in] */ REFIID riid, + /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject) + { + return DoQueryInterface(riid, ppvObject, + static_cast<IDispatch *>(this), + static_cast<IEventTesting *>(this), + static_cast<IConnectionPointContainer *>(this), + static_cast<IConnectionPoint *>(this)); + } + + DEFINE_REF_COUNTING(); +}; + +const WCHAR * const EventTesting::Names[] = +{ + W("__RESERVED__"), + W("FireEvent"), +}; + +const int EventTesting::NamesCount = ARRAYSIZE(EventTesting::Names); diff --git a/tests/src/Interop/COM/NativeServer/NumericTesting.h b/tests/src/Interop/COM/NativeServer/NumericTesting.h index aa703be468..d30427aa3d 100644 --- a/tests/src/Interop/COM/NativeServer/NumericTesting.h +++ b/tests/src/Interop/COM/NativeServer/NumericTesting.h @@ -288,7 +288,7 @@ public: // IUnknown /* [in] */ REFIID riid, /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject) { - return DoQueryInterface<NumericTesting, INumericTesting>(this, riid, ppvObject); + return DoQueryInterface(riid, ppvObject, static_cast<INumericTesting *>(this)); } DEFINE_REF_COUNTING(); diff --git a/tests/src/Interop/COM/NativeServer/Servers.cpp b/tests/src/Interop/COM/NativeServer/Servers.cpp index 80da1b7bb0..3bf7072f81 100644 --- a/tests/src/Interop/COM/NativeServer/Servers.cpp +++ b/tests/src/Interop/COM/NativeServer/Servers.cpp @@ -165,6 +165,7 @@ STDAPI DllRegisterServer(void) RETURN_IF_FAILED(RegisterClsid(__uuidof(StringTesting), L"Both")); RETURN_IF_FAILED(RegisterClsid(__uuidof(ErrorMarshalTesting), L"Both")); RETURN_IF_FAILED(RegisterClsid(__uuidof(DispatchTesting), L"Both")); + RETURN_IF_FAILED(RegisterClsid(__uuidof(EventTesting), L"Both")); RETURN_IF_FAILED(RegisterClsid(__uuidof(AggregationTesting), L"Both")); RETURN_IF_FAILED(RegisterClsid(__uuidof(ColorTesting), L"Both")); @@ -180,6 +181,7 @@ STDAPI DllUnregisterServer(void) RETURN_IF_FAILED(RemoveClsid(__uuidof(StringTesting))); RETURN_IF_FAILED(RemoveClsid(__uuidof(ErrorMarshalTesting))); RETURN_IF_FAILED(RemoveClsid(__uuidof(DispatchTesting))); + RETURN_IF_FAILED(RemoveClsid(__uuidof(EventTesting))); RETURN_IF_FAILED(RemoveClsid(__uuidof(AggregationTesting))); RETURN_IF_FAILED(RemoveClsid(__uuidof(ColorTesting))); @@ -203,6 +205,9 @@ STDAPI DllGetClassObject(_In_ REFCLSID rclsid, _In_ REFIID riid, _Out_ LPVOID FA if (rclsid == __uuidof(DispatchTesting)) return ClassFactoryBasic<DispatchTesting>::Create(riid, ppv); + if (rclsid == __uuidof(EventTesting)) + return ClassFactoryBasic<EventTesting>::Create(riid, ppv); + if (rclsid == __uuidof(AggregationTesting)) return ClassFactoryAggregate<AggregationTesting>::Create(riid, ppv); diff --git a/tests/src/Interop/COM/NativeServer/Servers.h b/tests/src/Interop/COM/NativeServer/Servers.h index 38983e3f1f..7a2a1ff6a7 100644 --- a/tests/src/Interop/COM/NativeServer/Servers.h +++ b/tests/src/Interop/COM/NativeServer/Servers.h @@ -15,6 +15,7 @@ class DECLSPEC_UUID("B99ABE6A-DFF6-440F-BFB6-55179B8FE18E") ArrayTesting; class DECLSPEC_UUID("C73C83E8-51A2-47F8-9B5C-4284458E47A6") StringTesting; class DECLSPEC_UUID("71CF5C45-106C-4B32-B418-43A463C6041F") ErrorMarshalTesting; class DECLSPEC_UUID("0F8ACD0C-ECE0-4F2A-BD1B-6BFCA93A0726") DispatchTesting; +class DECLSPEC_UUID("4DBD9B61-E372-499F-84DE-EFC70AA8A009") EventTesting; class DECLSPEC_UUID("4CEFE36D-F377-4B6E-8C34-819A8BB9CB04") AggregationTesting; class DECLSPEC_UUID("C222F472-DA5A-4FC6-9321-92F4F7053A65") ColorTesting; @@ -23,6 +24,7 @@ class DECLSPEC_UUID("C222F472-DA5A-4FC6-9321-92F4F7053A65") ColorTesting; #define CLSID_StringTesting __uuidof(StringTesting) #define CLSID_ErrorMarshalTesting __uuidof(ErrorMarshalTesting) #define CLSID_DispatchTesting __uuidof(DispatchTesting) +#define CLSID_EventTesting __uuidof(EventTesting) #define CLSID_AggregationTesting __uuidof(AggregationTesting) #define CLSID_ColorTesting __uuidof(ColorTesting) @@ -31,6 +33,8 @@ class DECLSPEC_UUID("C222F472-DA5A-4FC6-9321-92F4F7053A65") ColorTesting; #define IID_IStringTesting __uuidof(IStringTesting) #define IID_IErrorMarshalTesting __uuidof(IErrorMarshalTesting) #define IID_IDispatchTesting __uuidof(IDispatchTesting) +#define IID_TestingEvents __uuidof(TestingEvents) +#define IID_IEventTesting __uuidof(IEventTesting) #define IID_IAggregationTesting __uuidof(IAggregationTesting) #define IID_IColorTesting __uuidof(IColorTesting) @@ -67,6 +71,7 @@ private: #include "StringTesting.h" #include "ErrorMarshalTesting.h" #include "DispatchTesting.h" + #include "EventTesting.h" #include "AggregationTesting.h" #include "ColorTesting.h" #endif diff --git a/tests/src/Interop/COM/NativeServer/StringTesting.h b/tests/src/Interop/COM/NativeServer/StringTesting.h index 06f13db2a4..cd01fbe317 100644 --- a/tests/src/Interop/COM/NativeServer/StringTesting.h +++ b/tests/src/Interop/COM/NativeServer/StringTesting.h @@ -317,7 +317,7 @@ public: // IUnknown /* [in] */ REFIID riid, /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject) { - return DoQueryInterface<StringTesting, IStringTesting>(this, riid, ppvObject); + return DoQueryInterface(riid, ppvObject, static_cast<IStringTesting *>(this)); } DEFINE_REF_COUNTING(); diff --git a/tests/src/Interop/COM/ServerContracts/NativeServers.cs b/tests/src/Interop/COM/ServerContracts/NativeServers.cs index 6f70bd10a0..e868345328 100644 --- a/tests/src/Interop/COM/ServerContracts/NativeServers.cs +++ b/tests/src/Interop/COM/ServerContracts/NativeServers.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. #pragma warning disable IDE1006 // Naming Styles +#pragma warning disable 618 // Must test deprecated features namespace Server.Contract.Servers { @@ -91,7 +92,7 @@ namespace Server.Contract.Servers [ComImport] [CoClass(typeof(DispatchTestingClass))] [Guid("a5e04c1c-474e-46d2-bbc0-769d04e12b54")] - internal interface DispatchTesting : Server.Contract.IDispatchTesting + internal interface DispatchTesting : Server.Contract.IDispatchTesting { } @@ -143,4 +144,5 @@ namespace Server.Contract.Servers } } +#pragma warning restore 618 // Must test deprecated features #pragma warning restore IDE1006 // Naming Styles diff --git a/tests/src/Interop/COM/ServerContracts/Server.Contracts.cs b/tests/src/Interop/COM/ServerContracts/Server.Contracts.cs index 3d2fb96593..f401c48176 100644 --- a/tests/src/Interop/COM/ServerContracts/Server.Contracts.cs +++ b/tests/src/Interop/COM/ServerContracts/Server.Contracts.cs @@ -234,6 +234,24 @@ namespace Server.Contract } [ComVisible(true)] + [Guid("83AFF8E4-C46A-45DB-9D91-2ADB5164545E")] + [InterfaceType(ComInterfaceType.InterfaceIsIDispatch)] + public interface IEventTesting + { + [DispId(1)] + void FireEvent(); + } + + [ComImport] + [Guid("28ea6635-42ab-4f5b-b458-4152e78b8e86")] + [InterfaceType(ComInterfaceType.InterfaceIsIDispatch)] + public interface TestingEvents + { + [DispId(100)] + void OnEvent([MarshalAs(UnmanagedType.BStr)] string msg); + }; + + [ComVisible(true)] [Guid("98cc27f0-d521-4f79-8b63-e980e3a92974")] [InterfaceType(ComInterfaceType.InterfaceIsIUnknown)] public interface IAggregationTesting diff --git a/tests/src/Interop/COM/ServerContracts/Server.Contracts.h b/tests/src/Interop/COM/ServerContracts/Server.Contracts.h index c9be5b0c16..25025679e6 100644 --- a/tests/src/Interop/COM/ServerContracts/Server.Contracts.h +++ b/tests/src/Interop/COM/ServerContracts/Server.Contracts.h @@ -19,6 +19,8 @@ struct __declspec(uuid("592386a5-6837-444d-9de3-250815d18556")) /* interface */ IErrorMarshalTesting; struct __declspec(uuid("a5e04c1c-474e-46d2-bbc0-769d04e12b54")) /* interface */ IDispatchTesting; +struct __declspec(uuid("83AFF8E4-C46A-45DB-9D91-2ADB5164545E")) +/* interface */ IEventTesting; struct __declspec(uuid("98cc27f0-d521-4f79-8b63-e980e3a92974")) /* interface */ IAggregationTesting; struct __declspec(uuid("E6D72BA7-0936-4396-8A69-3B76DA1108DA")) @@ -33,6 +35,7 @@ _COM_SMARTPTR_TYPEDEF(IArrayTesting, __uuidof(IArrayTesting)); _COM_SMARTPTR_TYPEDEF(IStringTesting, __uuidof(IStringTesting)); _COM_SMARTPTR_TYPEDEF(IErrorMarshalTesting, __uuidof(IErrorMarshalTesting)); _COM_SMARTPTR_TYPEDEF(IDispatchTesting, __uuidof(IDispatchTesting)); +_COM_SMARTPTR_TYPEDEF(IEventTesting, __uuidof(IEventTesting)); _COM_SMARTPTR_TYPEDEF(IAggregationTesting, __uuidof(IAggregationTesting)); _COM_SMARTPTR_TYPEDEF(IColorTesting, __uuidof(IColorTesting)); @@ -444,6 +447,19 @@ IDispatchTesting : IDispatch /*[out,retval]*/ HFA_4 *pRetVal) = 0; }; +struct __declspec(uuid("83AFF8E4-C46A-45DB-9D91-2ADB5164545E")) +IEventTesting : IDispatch +{ + virtual HRESULT STDMETHODCALLTYPE FireEvent() = 0; +}; + +struct __declspec(uuid("28ea6635-42ab-4f5b-b458-4152e78b8e86")) +TestingEvents : IDispatch +{ +#define DISPATCHTESTINGEVENTS_DISPID_ONEVENT 100 + // void OnEvent(_In_z_ BSTR t); +}; + struct __declspec(uuid("98cc27f0-d521-4f79-8b63-e980e3a92974")) IAggregationTesting : IUnknown { diff --git a/tests/src/Interop/COM/ServerContracts/Server.Events.cs b/tests/src/Interop/COM/ServerContracts/Server.Events.cs new file mode 100644 index 0000000000..16ed578d98 --- /dev/null +++ b/tests/src/Interop/COM/ServerContracts/Server.Events.cs @@ -0,0 +1,193 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#pragma warning disable 618 // Must test deprecated features + +namespace Server.Contract +{ + using System; + using System.Collections.Generic; + using System.Runtime.InteropServices; + + using IConnectionPoint = System.Runtime.InteropServices.ComTypes.IConnectionPoint; + using IConnectionPointContainer = System.Runtime.InteropServices.ComTypes.IConnectionPointContainer; + + namespace Servers + { + /// <summary> + /// Managed definition of CoClass + /// </summary> + [ComImport] + [CoClass(typeof(EventTestingClass))] + [Guid("83AFF8E4-C46A-45DB-9D91-2ADB5164545E")] + internal interface EventTesting : IEventTesting, Events.TestingEvents_Event + { + } + + /// <summary> + /// Managed activation for CoClass + /// </summary> + [ComImport] + [ComSourceInterfaces("Server.Contract.Events.TestingEvents\0")] + [Guid(Server.Contract.Guids.EventTesting)] + internal class EventTestingClass + { + } + } + + /// <summary> + /// Classes in the Events namespace are traditionally generated by the TlbImp tool. + /// </summary> + namespace Events + { + /// <summary> + /// Delegate used for event handler + /// </summary> + [ComVisible(false)] + public delegate void TestingEvents_OnEventEventHandler(string msg); + + /// <summary> + /// Event source interface + /// </summary> + /// <remarks> + /// Observe usage of the <see cref="ComEventInterfaceAttribute"/> attribute. + /// </remarks> + [ComVisible(false)] + [ComEventInterface(typeof(Contract.TestingEvents), typeof(TestingEvents_EventProvider))] + public interface TestingEvents_Event + { + event TestingEvents_OnEventEventHandler OnEvent; + } + + /// <summary> + /// Managed proxy for event subscription via IConnectionPointContainer and IConnectionPoint. + /// </summary> + public sealed class TestingEvents_EventProvider : TestingEvents_Event, IDisposable + { + private readonly WeakReference ConnectionPointContainer; + private readonly List<TestingEvents_SinkHelper> eventSinkHelpers = new List<TestingEvents_SinkHelper>(); + + private IConnectionPoint connectionPoint; + private bool isDisposed = false; + + public TestingEvents_EventProvider(object container) + { + this.ConnectionPointContainer = new WeakReference((IConnectionPointContainer)container, false); + } + + event TestingEvents_OnEventEventHandler TestingEvents_Event.OnEvent + { + add + { + lock (this.eventSinkHelpers) + { + if (this.connectionPoint == null) + { + this.Init(); + } + + var sinkHelper = new TestingEvents_SinkHelper(); + + int cookie; + this.connectionPoint.Advise(sinkHelper, out cookie); + + sinkHelper.Cookie = cookie; + sinkHelper.OnEventDelegate = value; + this.eventSinkHelpers.Add(sinkHelper); + } + } + remove + { + lock (this.eventSinkHelpers) + { + TestingEvents_SinkHelper sinkHelper = null; + int removeIdx = -1; + for (int i = 0; i < this.eventSinkHelpers.Count; ++i) + { + TestingEvents_SinkHelper sinkHelperMaybe = this.eventSinkHelpers[i]; + if (sinkHelperMaybe.OnEventDelegate.Equals(value)) + { + removeIdx = i; + sinkHelper = sinkHelperMaybe; + break; + } + } + + if (removeIdx < 0) + { + return; + } + + this.connectionPoint.Unadvise(sinkHelper.Cookie); + this.eventSinkHelpers.RemoveAt(removeIdx); + + if (this.eventSinkHelpers.Count == 0) + { + Marshal.ReleaseComObject(this.connectionPoint); + this.connectionPoint = null; + } + } + } + } + + void IDisposable.Dispose() + { + if (this.isDisposed) + { + return; + } + + lock (this.eventSinkHelpers) + { + foreach (TestingEvents_SinkHelper sinkHelper in this.eventSinkHelpers) + { + this.connectionPoint.Unadvise(sinkHelper.Cookie); + } + + this.eventSinkHelpers.Clear(); + } + + Marshal.ReleaseComObject(this.connectionPoint); + this.connectionPoint = null; + + this.isDisposed = true; + System.GC.SuppressFinalize(this); + } + + private void Init() + { + var container = (IConnectionPointContainer)this.ConnectionPointContainer.Target; + + Guid iid = typeof(Contract.TestingEvents).GUID; + IConnectionPoint connectionPoint; + container.FindConnectionPoint(ref iid, out connectionPoint); + + this.connectionPoint = connectionPoint; + } + } + + /// <summary> + /// Wrapper for event delegate. + /// </summary> + /// <remarks> + /// Observe usage of the <see cref="ClassInterfaceAttribute"/> attribute. + /// </remarks> + [ClassInterface(ClassInterfaceType.None)] + public class TestingEvents_SinkHelper : Contract.TestingEvents + { + public int Cookie { get; set; } + public TestingEvents_OnEventEventHandler OnEventDelegate { get; set; } + + public void OnEvent(string msg) + { + if (this.OnEventDelegate != null) + { + this.OnEventDelegate(msg); + } + } + } + } +} + +#pragma warning restore 618 // Must test deprecated features diff --git a/tests/src/Interop/COM/ServerContracts/ServerGuids.cs b/tests/src/Interop/COM/ServerContracts/ServerGuids.cs index 199f618aae..1269e6ab91 100644 --- a/tests/src/Interop/COM/ServerContracts/ServerGuids.cs +++ b/tests/src/Interop/COM/ServerContracts/ServerGuids.cs @@ -14,6 +14,7 @@ namespace Server.Contract public const string StringTesting = "C73C83E8-51A2-47F8-9B5C-4284458E47A6"; public const string ErrorMarshalTesting = "71CF5C45-106C-4B32-B418-43A463C6041F"; public const string DispatchTesting = "0F8ACD0C-ECE0-4F2A-BD1B-6BFCA93A0726"; + public const string EventTesting = "4DBD9B61-E372-499F-84DE-EFC70AA8A009"; public const string AggregationTesting = "4CEFE36D-F377-4B6E-8C34-819A8BB9CB04"; public const string ColorTesting = "C222F472-DA5A-4FC6-9321-92F4F7053A65"; } diff --git a/tests/src/Interop/PInvoke/IEnumerator/IEnumeratorNative.h b/tests/src/Interop/PInvoke/IEnumerator/IEnumeratorNative.h index 47456bf988..9e3f1f1500 100644 --- a/tests/src/Interop/PInvoke/IEnumerator/IEnumeratorNative.h +++ b/tests/src/Interop/PInvoke/IEnumerator/IEnumeratorNative.h @@ -63,7 +63,7 @@ public: REFIID riid, void** ppvObject) { - return DoQueryInterface<IntegerEnumerator, IEnumVARIANT>(this, riid, ppvObject); + return DoQueryInterface(riid, ppvObject, static_cast<IEnumVARIANT *>(this)); } DEFINE_REF_COUNTING(); @@ -146,7 +146,7 @@ public: REFIID riid, void** ppvObject) { - return DoQueryInterface<IntegerEnumerable, IDispatch>(this, riid, ppvObject); + return DoQueryInterface(riid, ppvObject, static_cast<IDispatch*>(this)); } DEFINE_REF_COUNTING(); diff --git a/tests/src/Interop/common/ComHelpers.h b/tests/src/Interop/common/ComHelpers.h index fd8963996f..c4d9e6cbdb 100644 --- a/tests/src/Interop/common/ComHelpers.h +++ b/tests/src/Interop/common/ComHelpers.h @@ -16,19 +16,15 @@ namespace Internal { - template<typename C, typename I> + template<typename I> HRESULT __QueryInterfaceImpl( - /* [in] */ C *obj, /* [in] */ REFIID riid, - /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject) + /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject, + /* [in] */ I obj) { if (riid == __uuidof(I)) { - *ppvObject = static_cast<I*>(obj); - } - else if (riid == __uuidof(IUnknown)) - { - *ppvObject = static_cast<IUnknown*>(obj); + *ppvObject = static_cast<I>(obj); } else { @@ -39,19 +35,20 @@ namespace Internal return S_OK; } - template<typename C, typename I1, typename I2, typename ...R> + template<typename I1, typename ...IR> HRESULT __QueryInterfaceImpl( - /* [in] */ C *obj, /* [in] */ REFIID riid, - /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject) + /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject, + /* [in] */ I1 i1, + /* [in] */ IR... remain) { if (riid == __uuidof(I1)) { - *ppvObject = static_cast<I1*>(obj); + *ppvObject = static_cast<I1>(i1); return S_OK; } - return __QueryInterfaceImpl<C, I2, R...>(obj, riid, ppvObject); + return __QueryInterfaceImpl(riid, ppvObject, remain...); } } @@ -68,21 +65,29 @@ public: UnknownImpl(UnknownImpl&&) = default; UnknownImpl& operator=(UnknownImpl&&) = default; - template<typename C, typename ...I> + template<typename I1, typename ...IR> HRESULT DoQueryInterface( - /* [in] */ C *derived, /* [in] */ REFIID riid, - /* [iid_is][out] */ _COM_Outptr_ void **ppvObject) + /* [iid_is][out] */ _COM_Outptr_ void **ppvObject, + /* [in] */ I1 i1, + /* [in] */ IR... remain) { - assert(derived != nullptr); if (ppvObject == nullptr) return E_POINTER; - HRESULT hr = Internal::__QueryInterfaceImpl<C, I...>(derived, riid, ppvObject); - if (hr == S_OK) - DoAddRef(); + if (riid == __uuidof(IUnknown)) + { + *ppvObject = static_cast<IUnknown *>(i1); + } + else + { + HRESULT hr = Internal::__QueryInterfaceImpl(riid, ppvObject, i1, remain...); + if (hr != S_OK) + return hr; + } - return hr; + DoAddRef(); + return S_OK; } ULONG DoAddRef() @@ -162,7 +167,7 @@ public: // IUnknown /* [in] */ REFIID riid, /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject) { - return DoQueryInterface<ClassFactoryBasic, IClassFactory>(this, riid, ppvObject); + return DoQueryInterface(riid, ppvObject, static_cast<IClassFactory *>(this)); } DEFINE_REF_COUNTING(); @@ -221,7 +226,7 @@ public: // IUnknown /* [in] */ REFIID riid, /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject) { - return DoQueryInterface<ClassFactoryAggregate, IClassFactory>(this, riid, ppvObject); + return DoQueryInterface(riid, ppvObject, static_cast<IClassFactory *>(this)); } DEFINE_REF_COUNTING(); |