diff options
author | Aaron Robinson <arobins@microsoft.com> | 2019-05-08 15:09:41 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-05-08 15:09:41 -0700 |
commit | e682619de596282b30c21d624716b7a3d1f2541b (patch) | |
tree | 87c40461c3c3cd1f9d59a6b899ba25a211058539 | |
parent | 991b17eb68331e6118baa0e27a6f8d8ee9db36b2 (diff) | |
download | coreclr-e682619de596282b30c21d624716b7a3d1f2541b.tar.gz coreclr-e682619de596282b30c21d624716b7a3d1f2541b.tar.bz2 coreclr-e682619de596282b30c21d624716b7a3d1f2541b.zip |
Add support in SPCL to call into user supplied register and unregisteā¦ (#24452)
* Add support in SPCL to call into user supplied register and unregister functions
5 files changed, 430 insertions, 19 deletions
diff --git a/src/System.Private.CoreLib/Resources/Strings.resx b/src/System.Private.CoreLib/Resources/Strings.resx index 9e2c84b4d1..3edc12165c 100644 --- a/src/System.Private.CoreLib/Resources/Strings.resx +++ b/src/System.Private.CoreLib/Resources/Strings.resx @@ -3754,4 +3754,22 @@ <data name="Argument_StartupHookAssemblyLoadFailed" xml:space="preserve"> <value>Startup hook assembly '{0}' failed to load. See inner exception for details.</value> </data> + <data name="InvalidOperation_NonStaticComRegFunction" xml:space="preserve"> + <value>COM register function must be static.</value> + </data> + <data name="InvalidOperation_NonStaticComUnRegFunction" xml:space="preserve"> + <value>COM unregister function must be static.</value> + </data> + <data name="InvalidOperation_InvalidComRegFunctionSig" xml:space="preserve"> + <value>COM register function must have a System.Type parameter and a void return type.</value> + </data> + <data name="InvalidOperation_InvalidComUnRegFunctionSig" xml:space="preserve"> + <value>COM unregister function must have a System.Type parameter and a void return type.</value> + </data> + <data name="InvalidOperation_MultipleComRegFunctions" xml:space="preserve"> + <value>Type '{0}' has more than one COM registration function.</value> + </data> + <data name="InvalidOperation_MultipleComUnRegFunctions" xml:space="preserve"> + <value>Type '{0}' has more than one COM unregistration function.</value> + </data> </root> diff --git a/src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs b/src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs index b8a8aac00c..d50dfcd7e5 100644 --- a/src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs +++ b/src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs @@ -71,16 +71,6 @@ namespace Internal.Runtime.InteropServices } [StructLayout(LayoutKind.Sequential)] - public struct ComActivationContext - { - public Guid ClassId; - public Guid InterfaceId; - public string AssemblyPath; - public string AssemblyName; - public string TypeName; - } - - [StructLayout(LayoutKind.Sequential)] [CLSCompliant(false)] public unsafe struct ComActivationContextInternal { @@ -92,6 +82,29 @@ namespace Internal.Runtime.InteropServices public IntPtr ClassFactoryDest; } + [StructLayout(LayoutKind.Sequential)] + public struct ComActivationContext + { + public Guid ClassId; + public Guid InterfaceId; + public string AssemblyPath; + public string AssemblyName; + public string TypeName; + + [CLSCompliant(false)] + public unsafe static ComActivationContext Create(ref ComActivationContextInternal cxtInt) + { + return new ComActivationContext() + { + ClassId = cxtInt.ClassId, + InterfaceId = cxtInt.InterfaceId, + AssemblyPath = Marshal.PtrToStringUni(new IntPtr(cxtInt.AssemblyPathBuffer))!, + AssemblyName = Marshal.PtrToStringUni(new IntPtr(cxtInt.AssemblyNameBuffer))!, + TypeName = Marshal.PtrToStringUni(new IntPtr(cxtInt.TypeNameBuffer))! + }; + } + } + public static class ComActivator { // Collection of all ALCs used for COM activation. In the event we want to support @@ -126,6 +139,87 @@ namespace Internal.Runtime.InteropServices } /// <summary> + /// Entry point for unmanaged COM register/unregister API from managed code + /// </summary> + /// <param name="cxt">Reference to a <see cref="ComActivationContext"/> instance</param> + /// <param name="register">true if called for register or false to indicate unregister</param> + public static void ClassRegisterationScenarioForType(ComActivationContext cxt, bool register) + { + // Retrieve the attribute type to use to determine if a function is the requested user defined + // registration function. + string attributeName = register ? "ComRegisterFunctionAttribute" : "ComUnregisterFunctionAttribute"; + Type? regFuncAttrType = Type.GetType($"System.Runtime.InteropServices.{attributeName}, System.Runtime.InteropServices", throwOnError: false); + if (regFuncAttrType == null) + { + // If the COM registration attributes can't be found then it is not on the type. + return; + } + + if (!Path.IsPathRooted(cxt.AssemblyPath)) + { + throw new ArgumentException(); + } + + Type classType = FindClassType(cxt.ClassId, cxt.AssemblyPath, cxt.AssemblyName, cxt.TypeName); + + // Retrieve all the methods. + MethodInfo[] methods = classType.GetMethods(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static); + + bool calledFunction = false; + + // Go through all the methods and check for the custom attribute. + foreach (MethodInfo method in methods) + { + // Check to see if the method has the custom attribute. + if (method.GetCustomAttributes(regFuncAttrType!, inherit: true).Length == 0) + { + continue; + } + + // Check to see if the method is static before we call it. + if (!method.IsStatic) + { + string msg = register ? SR.InvalidOperation_NonStaticComRegFunction : SR.InvalidOperation_NonStaticComUnRegFunction; + throw new InvalidOperationException(SR.Format(msg)); + } + + // Finally validate signature + ParameterInfo[] methParams = method.GetParameters(); + if (method.ReturnType != typeof(void) + || methParams == null + || methParams.Length != 1 + || (methParams[0].ParameterType != typeof(string) && methParams[0].ParameterType != typeof(Type))) + { + string msg = register ? SR.InvalidOperation_InvalidComRegFunctionSig : SR.InvalidOperation_InvalidComUnRegFunctionSig; + throw new InvalidOperationException(SR.Format(msg)); + } + + if (calledFunction) + { + string msg = register ? SR.InvalidOperation_MultipleComRegFunctions : SR.InvalidOperation_MultipleComUnRegFunctions; + throw new InvalidOperationException(SR.Format(msg)); + } + + // The function is valid so set up the arguments to call it. + var objs = new object[1]; + if (methParams[0].ParameterType == typeof(string)) + { + // We are dealing with the string overload of the function - provide the registry key - see comhost.dll implementation + objs[0] = $"HKEY_LOCAL_MACHINE\\SOFTWARE\\Classes\\CLSID\\{cxt.ClassId.ToString("B")}"; + } + else + { + // We are dealing with the type overload of the function. + objs[0] = classType; + } + + // Invoke the COM register function. + method.Invoke(null, objs); + calledFunction = true; + } + } + + /// <summary> /// Internal entry point for unmanaged COM activation API from native code /// </summary> /// <param name="cxtInt">Reference to a <see cref="ComActivationContextInternal"/> instance</param> @@ -146,15 +240,7 @@ $@"{nameof(GetClassFactoryForTypeInternal)} arguments: try { - var cxt = new ComActivationContext() - { - ClassId = cxtInt.ClassId, - InterfaceId = cxtInt.InterfaceId, - AssemblyPath = Marshal.PtrToStringUni(new IntPtr(cxtInt.AssemblyPathBuffer))!, - AssemblyName = Marshal.PtrToStringUni(new IntPtr(cxtInt.AssemblyNameBuffer))!, - TypeName = Marshal.PtrToStringUni(new IntPtr(cxtInt.TypeNameBuffer))! - }; - + var cxt = ComActivationContext.Create(ref cxtInt); object cf = GetClassFactoryForType(cxt); IntPtr nativeIUnknown = Marshal.GetIUnknownForObject(cf); Marshal.WriteIntPtr(cxtInt.ClassFactoryDest, nativeIUnknown); @@ -167,6 +253,81 @@ $@"{nameof(GetClassFactoryForTypeInternal)} arguments: return 0; } + /// <summary> + /// Internal entry point for registering a managed COM server API from native code + /// </summary> + /// <param name="cxtInt">Reference to a <see cref="ComActivationContextInternal"/> instance</param> + [CLSCompliant(false)] + public unsafe static int RegisterClassForTypeInternal(ref ComActivationContextInternal cxtInt) + { + if (IsLoggingEnabled()) + { + Log( +$@"{nameof(RegisterClassForTypeInternal)} arguments: + {cxtInt.ClassId} + {cxtInt.InterfaceId} + 0x{(ulong)cxtInt.AssemblyPathBuffer:x} + 0x{(ulong)cxtInt.AssemblyNameBuffer:x} + 0x{(ulong)cxtInt.TypeNameBuffer:x} + 0x{cxtInt.ClassFactoryDest.ToInt64():x}"); + } + + if (cxtInt.InterfaceId != Guid.Empty + || cxtInt.ClassFactoryDest != IntPtr.Zero) + { + throw new ArgumentException(); + } + + try + { + var cxt = ComActivationContext.Create(ref cxtInt); + ClassRegisterationScenarioForType(cxt, register: true); + } + catch (Exception e) + { + return e.HResult; + } + + return 0; + } + + /// <summary> + /// Internal entry point for unregistering a managed COM server API from native code + /// </summary> + [CLSCompliant(false)] + public unsafe static int UnregisterClassForTypeInternal(ref ComActivationContextInternal cxtInt) + { + if (IsLoggingEnabled()) + { + Log( +$@"{nameof(UnregisterClassForTypeInternal)} arguments: + {cxtInt.ClassId} + {cxtInt.InterfaceId} + 0x{(ulong)cxtInt.AssemblyPathBuffer:x} + 0x{(ulong)cxtInt.AssemblyNameBuffer:x} + 0x{(ulong)cxtInt.TypeNameBuffer:x} + 0x{cxtInt.ClassFactoryDest.ToInt64():x}"); + } + + if (cxtInt.InterfaceId != Guid.Empty + || cxtInt.ClassFactoryDest != IntPtr.Zero) + { + throw new ArgumentException(); + } + + try + { + var cxt = ComActivationContext.Create(ref cxtInt); + ClassRegisterationScenarioForType(cxt, register: false); + } + catch (Exception e) + { + return e.HResult; + } + + return 0; + } + private static bool IsLoggingEnabled() { #if COM_ACTIVATOR_DEBUG diff --git a/tests/src/Interop/COM/Activator/Program.cs b/tests/src/Interop/COM/Activator/Program.cs index 451ecc7755..eb925ad163 100644 --- a/tests/src/Interop/COM/Activator/Program.cs +++ b/tests/src/Interop/COM/Activator/Program.cs @@ -136,6 +136,111 @@ namespace Activator Assert.AreNotEqual(typeCFromAssemblyA, typeCFromAssemblyB, "Types should be from different AssemblyLoadContexts"); } + static void ValidateUserDefinedRegistrationCallbacks() + { + Console.WriteLine($"Running {nameof(ValidateUserDefinedRegistrationCallbacks)}..."); + + string assemblySubPath = Path.Combine(Environment.CurrentDirectory, "Servers"); + string assemblyAPath = Path.Combine(assemblySubPath, "AssemblyA.dll"); + string assemblyBPath = Path.Combine(assemblySubPath, "AssemblyB.dll"); + string assemblyCPath = Path.Combine(assemblySubPath, "AssemblyC.dll"); + string assemblyPaths = $"{assemblyAPath}{Path.PathSeparator}{assemblyBPath}{Path.PathSeparator}{assemblyCPath}"; + + HostPolicyMock.Initialize(Environment.CurrentDirectory, null); + + var CLSID_NotUsed = Guid.Empty; // During this phase of activation the GUID is not used. + Guid iid = typeof(IValidateRegistrationCallbacks).GUID; + + using (HostPolicyMock.Mock_corehost_resolve_component_dependencies( + 0, + assemblyPaths, + string.Empty, + string.Empty)) + { + foreach (string typename in new[] { "ValidRegistrationTypeCallbacks", "ValidRegistrationStringCallbacks" }) + { + Console.WriteLine($"Validating {typename}..."); + + var cxt = new ComActivationContext() + { + ClassId = CLSID_NotUsed, + InterfaceId = typeof(IClassFactory).GUID, + AssemblyPath = assemblyAPath, + AssemblyName = "AssemblyA", + TypeName = typename + }; + + var factory = (IClassFactory)ComActivator.GetClassFactoryForType(cxt); + + object svr; + factory.CreateInstance(null, ref iid, out svr); + + var inst = (IValidateRegistrationCallbacks)svr; + Assert.IsFalse(inst.DidRegister()); + Assert.IsFalse(inst.DidUnregister()); + + cxt.InterfaceId = Guid.Empty; + ComActivator.ClassRegisterationScenarioForType(cxt, register: true); + ComActivator.ClassRegisterationScenarioForType(cxt, register: false); + + Assert.IsTrue(inst.DidRegister()); + Assert.IsTrue(inst.DidUnregister()); + } + } + + using (HostPolicyMock.Mock_corehost_resolve_component_dependencies( + 0, + assemblyPaths, + string.Empty, + string.Empty)) + { + foreach (string typename in new[] { "NoRegistrationCallbacks", "InvalidArgRegistrationCallbacks", "InvalidInstanceRegistrationCallbacks", "MultipleRegistrationCallbacks" }) + { + Console.WriteLine($"Validating {typename}..."); + + var cxt = new ComActivationContext() + { + ClassId = CLSID_NotUsed, + InterfaceId = typeof(IClassFactory).GUID, + AssemblyPath = assemblyAPath, + AssemblyName = "AssemblyA", + TypeName = typename + }; + + var factory = (IClassFactory)ComActivator.GetClassFactoryForType(cxt); + + object svr; + factory.CreateInstance(null, ref iid, out svr); + + var inst = (IValidateRegistrationCallbacks)svr; + cxt.InterfaceId = Guid.Empty; + bool exceptionThrown = false; + try + { + ComActivator.ClassRegisterationScenarioForType(cxt, register: true); + } + catch + { + exceptionThrown = true; + } + + Assert.IsTrue(exceptionThrown || !inst.DidRegister()); + + exceptionThrown = false; + try + { + ComActivator.ClassRegisterationScenarioForType(cxt, register: false); + } + catch + { + exceptionThrown = true; + } + + Assert.IsTrue(exceptionThrown || !inst.DidUnregister()); + } + } + } + static int Main(string[] doNotUse) { try @@ -144,6 +249,7 @@ namespace Activator ClassNotRegistered(); NonrootedAssemblyPath(); ValidateAssemblyIsolation(); + ValidateUserDefinedRegistrationCallbacks(); } catch (Exception e) { diff --git a/tests/src/Interop/COM/Activator/Servers/AssemblyA.cs b/tests/src/Interop/COM/Activator/Servers/AssemblyA.cs index 606f064c92..8646a6a856 100644 --- a/tests/src/Interop/COM/Activator/Servers/AssemblyA.cs +++ b/tests/src/Interop/COM/Activator/Servers/AssemblyA.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Runtime.InteropServices; public class ClassFromA : IGetTypeFromC { @@ -16,4 +17,121 @@ public class ClassFromA : IGetTypeFromC { return this._fromC.GetType(); } +} + +public class ValidRegistrationTypeCallbacks : IValidateRegistrationCallbacks +{ + [ComRegisterFunctionAttribute] + public static void RegisterFunction(Type t) => s_didRegister = true; + + [ComUnregisterFunctionAttribute] + public static void UnregisterFunction(Type t) => s_didUnregister = true; + + private static bool s_didRegister = false; + private static bool s_didUnregister = false; + + bool IValidateRegistrationCallbacks.DidRegister() => s_didRegister; + + bool IValidateRegistrationCallbacks.DidUnregister() => s_didUnregister; + + void IValidateRegistrationCallbacks.Reset() + { + s_didRegister = false; + s_didUnregister = false; + } +} + +public class ValidRegistrationStringCallbacks : IValidateRegistrationCallbacks +{ + [ComRegisterFunctionAttribute] + public static void RegisterFunction(string t) => s_didRegister = true; + + [ComUnregisterFunctionAttribute] + public static void UnregisterFunction(string t) => s_didUnregister = true; + + private static bool s_didRegister = false; + private static bool s_didUnregister = false; + + bool IValidateRegistrationCallbacks.DidRegister() => s_didRegister; + + bool IValidateRegistrationCallbacks.DidUnregister() => s_didUnregister; + + void IValidateRegistrationCallbacks.Reset() + { + s_didRegister = false; + s_didUnregister = false; + } +} + +public class NoRegistrationCallbacks : IValidateRegistrationCallbacks +{ + // Not attributed function + public static void RegisterFunction(Type t) => s_didRegister = true; + + // Not attributed function + public static void UnregisterFunction(Type t) => s_didRegister = true; + + private static bool s_didRegister = false; + private static bool s_didUnregister = false; + + bool IValidateRegistrationCallbacks.DidRegister() => s_didRegister; + + bool IValidateRegistrationCallbacks.DidUnregister() => s_didUnregister; + + void IValidateRegistrationCallbacks.Reset() + { + s_didRegister = false; + s_didUnregister = false; + } +} + +public class InvalidArgRegistrationCallbacks : IValidateRegistrationCallbacks +{ + [ComRegisterFunctionAttribute] + public static void RegisterFunction(int i) => throw new Exception(); + + [ComUnregisterFunctionAttribute] + public static void UnregisterFunction(int i) => throw new Exception(); + + bool IValidateRegistrationCallbacks.DidRegister() => throw new NotImplementedException(); + + bool IValidateRegistrationCallbacks.DidUnregister() => throw new NotImplementedException(); + + void IValidateRegistrationCallbacks.Reset() => throw new NotImplementedException(); +} + +public class InvalidInstanceRegistrationCallbacks : IValidateRegistrationCallbacks +{ + [ComRegisterFunctionAttribute] + public void RegisterFunction(Type t) => throw new Exception(); + + [ComUnregisterFunctionAttribute] + public void UnregisterFunction(Type t) => throw new Exception(); + + bool IValidateRegistrationCallbacks.DidRegister() => throw new NotImplementedException(); + + bool IValidateRegistrationCallbacks.DidUnregister() => throw new NotImplementedException(); + + void IValidateRegistrationCallbacks.Reset() => throw new NotImplementedException(); +} + +public class MultipleRegistrationCallbacks : IValidateRegistrationCallbacks +{ + [ComRegisterFunctionAttribute] + public static void RegisterFunction(string t) { } + + [ComUnregisterFunctionAttribute] + public static void UnregisterFunction(string t) { } + + [ComRegisterFunctionAttribute] + public static void RegisterFunction2(string t) { } + + [ComUnregisterFunctionAttribute] + public static void UnregisterFunction2(string t) { } + + bool IValidateRegistrationCallbacks.DidRegister() => throw new NotImplementedException(); + + bool IValidateRegistrationCallbacks.DidUnregister() => throw new NotImplementedException(); + + void IValidateRegistrationCallbacks.Reset() => throw new NotImplementedException(); }
\ No newline at end of file diff --git a/tests/src/Interop/COM/Activator/Servers/AssemblyContracts.cs b/tests/src/Interop/COM/Activator/Servers/AssemblyContracts.cs index deb7b6312b..ebbfb90569 100644 --- a/tests/src/Interop/COM/Activator/Servers/AssemblyContracts.cs +++ b/tests/src/Interop/COM/Activator/Servers/AssemblyContracts.cs @@ -10,4 +10,12 @@ using System.Runtime.InteropServices; public interface IGetTypeFromC { object GetTypeFromC(); +} + +[Guid("DA746E78-E1E8-44DD-8184-203AB57B3002")] +public interface IValidateRegistrationCallbacks +{ + bool DidRegister(); + bool DidUnregister(); + void Reset(); }
\ No newline at end of file |