summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAaron Robinson <arobins@microsoft.com>2019-05-08 15:09:41 -0700
committerGitHub <noreply@github.com>2019-05-08 15:09:41 -0700
commite682619de596282b30c21d624716b7a3d1f2541b (patch)
tree87c40461c3c3cd1f9d59a6b899ba25a211058539
parent991b17eb68331e6118baa0e27a6f8d8ee9db36b2 (diff)
downloadcoreclr-e682619de596282b30c21d624716b7a3d1f2541b.tar.gz
coreclr-e682619de596282b30c21d624716b7a3d1f2541b.tar.bz2
coreclr-e682619de596282b30c21d624716b7a3d1f2541b.zip
Add support in SPCL to call into user supplied register and unregisteā€¦ (#24452)
* Add support in SPCL to call into user supplied register and unregister functions
-rw-r--r--src/System.Private.CoreLib/Resources/Strings.resx18
-rw-r--r--src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs199
-rw-r--r--tests/src/Interop/COM/Activator/Program.cs106
-rw-r--r--tests/src/Interop/COM/Activator/Servers/AssemblyA.cs118
-rw-r--r--tests/src/Interop/COM/Activator/Servers/AssemblyContracts.cs8
5 files changed, 430 insertions, 19 deletions
diff --git a/src/System.Private.CoreLib/Resources/Strings.resx b/src/System.Private.CoreLib/Resources/Strings.resx
index 9e2c84b4d1..3edc12165c 100644
--- a/src/System.Private.CoreLib/Resources/Strings.resx
+++ b/src/System.Private.CoreLib/Resources/Strings.resx
@@ -3754,4 +3754,22 @@
<data name="Argument_StartupHookAssemblyLoadFailed" xml:space="preserve">
<value>Startup hook assembly '{0}' failed to load. See inner exception for details.</value>
</data>
+ <data name="InvalidOperation_NonStaticComRegFunction" xml:space="preserve">
+ <value>COM register function must be static.</value>
+ </data>
+ <data name="InvalidOperation_NonStaticComUnRegFunction" xml:space="preserve">
+ <value>COM unregister function must be static.</value>
+ </data>
+ <data name="InvalidOperation_InvalidComRegFunctionSig" xml:space="preserve">
+ <value>COM register function must have a System.Type parameter and a void return type.</value>
+ </data>
+ <data name="InvalidOperation_InvalidComUnRegFunctionSig" xml:space="preserve">
+ <value>COM unregister function must have a System.Type parameter and a void return type.</value>
+ </data>
+ <data name="InvalidOperation_MultipleComRegFunctions" xml:space="preserve">
+ <value>Type '{0}' has more than one COM registration function.</value>
+ </data>
+ <data name="InvalidOperation_MultipleComUnRegFunctions" xml:space="preserve">
+ <value>Type '{0}' has more than one COM unregistration function.</value>
+ </data>
</root>
diff --git a/src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs b/src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs
index b8a8aac00c..d50dfcd7e5 100644
--- a/src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs
+++ b/src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs
@@ -71,16 +71,6 @@ namespace Internal.Runtime.InteropServices
}
[StructLayout(LayoutKind.Sequential)]
- public struct ComActivationContext
- {
- public Guid ClassId;
- public Guid InterfaceId;
- public string AssemblyPath;
- public string AssemblyName;
- public string TypeName;
- }
-
- [StructLayout(LayoutKind.Sequential)]
[CLSCompliant(false)]
public unsafe struct ComActivationContextInternal
{
@@ -92,6 +82,29 @@ namespace Internal.Runtime.InteropServices
public IntPtr ClassFactoryDest;
}
+ [StructLayout(LayoutKind.Sequential)]
+ public struct ComActivationContext
+ {
+ public Guid ClassId;
+ public Guid InterfaceId;
+ public string AssemblyPath;
+ public string AssemblyName;
+ public string TypeName;
+
+ [CLSCompliant(false)]
+ public unsafe static ComActivationContext Create(ref ComActivationContextInternal cxtInt)
+ {
+ return new ComActivationContext()
+ {
+ ClassId = cxtInt.ClassId,
+ InterfaceId = cxtInt.InterfaceId,
+ AssemblyPath = Marshal.PtrToStringUni(new IntPtr(cxtInt.AssemblyPathBuffer))!,
+ AssemblyName = Marshal.PtrToStringUni(new IntPtr(cxtInt.AssemblyNameBuffer))!,
+ TypeName = Marshal.PtrToStringUni(new IntPtr(cxtInt.TypeNameBuffer))!
+ };
+ }
+ }
+
public static class ComActivator
{
// Collection of all ALCs used for COM activation. In the event we want to support
@@ -126,6 +139,87 @@ namespace Internal.Runtime.InteropServices
}
/// <summary>
+ /// Entry point for unmanaged COM register/unregister API from managed code
+ /// </summary>
+ /// <param name="cxt">Reference to a <see cref="ComActivationContext"/> instance</param>
+ /// <param name="register">true if called for register or false to indicate unregister</param>
+ public static void ClassRegisterationScenarioForType(ComActivationContext cxt, bool register)
+ {
+ // Retrieve the attribute type to use to determine if a function is the requested user defined
+ // registration function.
+ string attributeName = register ? "ComRegisterFunctionAttribute" : "ComUnregisterFunctionAttribute";
+ Type? regFuncAttrType = Type.GetType($"System.Runtime.InteropServices.{attributeName}, System.Runtime.InteropServices", throwOnError: false);
+ if (regFuncAttrType == null)
+ {
+ // If the COM registration attributes can't be found then it is not on the type.
+ return;
+ }
+
+ if (!Path.IsPathRooted(cxt.AssemblyPath))
+ {
+ throw new ArgumentException();
+ }
+
+ Type classType = FindClassType(cxt.ClassId, cxt.AssemblyPath, cxt.AssemblyName, cxt.TypeName);
+
+ // Retrieve all the methods.
+ MethodInfo[] methods = classType.GetMethods(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static);
+
+ bool calledFunction = false;
+
+ // Go through all the methods and check for the custom attribute.
+ foreach (MethodInfo method in methods)
+ {
+ // Check to see if the method has the custom attribute.
+ if (method.GetCustomAttributes(regFuncAttrType!, inherit: true).Length == 0)
+ {
+ continue;
+ }
+
+ // Check to see if the method is static before we call it.
+ if (!method.IsStatic)
+ {
+ string msg = register ? SR.InvalidOperation_NonStaticComRegFunction : SR.InvalidOperation_NonStaticComUnRegFunction;
+ throw new InvalidOperationException(SR.Format(msg));
+ }
+
+ // Finally validate signature
+ ParameterInfo[] methParams = method.GetParameters();
+ if (method.ReturnType != typeof(void)
+ || methParams == null
+ || methParams.Length != 1
+ || (methParams[0].ParameterType != typeof(string) && methParams[0].ParameterType != typeof(Type)))
+ {
+ string msg = register ? SR.InvalidOperation_InvalidComRegFunctionSig : SR.InvalidOperation_InvalidComUnRegFunctionSig;
+ throw new InvalidOperationException(SR.Format(msg));
+ }
+
+ if (calledFunction)
+ {
+ string msg = register ? SR.InvalidOperation_MultipleComRegFunctions : SR.InvalidOperation_MultipleComUnRegFunctions;
+ throw new InvalidOperationException(SR.Format(msg));
+ }
+
+ // The function is valid so set up the arguments to call it.
+ var objs = new object[1];
+ if (methParams[0].ParameterType == typeof(string))
+ {
+ // We are dealing with the string overload of the function - provide the registry key - see comhost.dll implementation
+ objs[0] = $"HKEY_LOCAL_MACHINE\\SOFTWARE\\Classes\\CLSID\\{cxt.ClassId.ToString("B")}";
+ }
+ else
+ {
+ // We are dealing with the type overload of the function.
+ objs[0] = classType;
+ }
+
+ // Invoke the COM register function.
+ method.Invoke(null, objs);
+ calledFunction = true;
+ }
+ }
+
+ /// <summary>
/// Internal entry point for unmanaged COM activation API from native code
/// </summary>
/// <param name="cxtInt">Reference to a <see cref="ComActivationContextInternal"/> instance</param>
@@ -146,15 +240,7 @@ $@"{nameof(GetClassFactoryForTypeInternal)} arguments:
try
{
- var cxt = new ComActivationContext()
- {
- ClassId = cxtInt.ClassId,
- InterfaceId = cxtInt.InterfaceId,
- AssemblyPath = Marshal.PtrToStringUni(new IntPtr(cxtInt.AssemblyPathBuffer))!,
- AssemblyName = Marshal.PtrToStringUni(new IntPtr(cxtInt.AssemblyNameBuffer))!,
- TypeName = Marshal.PtrToStringUni(new IntPtr(cxtInt.TypeNameBuffer))!
- };
-
+ var cxt = ComActivationContext.Create(ref cxtInt);
object cf = GetClassFactoryForType(cxt);
IntPtr nativeIUnknown = Marshal.GetIUnknownForObject(cf);
Marshal.WriteIntPtr(cxtInt.ClassFactoryDest, nativeIUnknown);
@@ -167,6 +253,81 @@ $@"{nameof(GetClassFactoryForTypeInternal)} arguments:
return 0;
}
+ /// <summary>
+ /// Internal entry point for registering a managed COM server API from native code
+ /// </summary>
+ /// <param name="cxtInt">Reference to a <see cref="ComActivationContextInternal"/> instance</param>
+ [CLSCompliant(false)]
+ public unsafe static int RegisterClassForTypeInternal(ref ComActivationContextInternal cxtInt)
+ {
+ if (IsLoggingEnabled())
+ {
+ Log(
+$@"{nameof(RegisterClassForTypeInternal)} arguments:
+ {cxtInt.ClassId}
+ {cxtInt.InterfaceId}
+ 0x{(ulong)cxtInt.AssemblyPathBuffer:x}
+ 0x{(ulong)cxtInt.AssemblyNameBuffer:x}
+ 0x{(ulong)cxtInt.TypeNameBuffer:x}
+ 0x{cxtInt.ClassFactoryDest.ToInt64():x}");
+ }
+
+ if (cxtInt.InterfaceId != Guid.Empty
+ || cxtInt.ClassFactoryDest != IntPtr.Zero)
+ {
+ throw new ArgumentException();
+ }
+
+ try
+ {
+ var cxt = ComActivationContext.Create(ref cxtInt);
+ ClassRegisterationScenarioForType(cxt, register: true);
+ }
+ catch (Exception e)
+ {
+ return e.HResult;
+ }
+
+ return 0;
+ }
+
+ /// <summary>
+ /// Internal entry point for unregistering a managed COM server API from native code
+ /// </summary>
+ [CLSCompliant(false)]
+ public unsafe static int UnregisterClassForTypeInternal(ref ComActivationContextInternal cxtInt)
+ {
+ if (IsLoggingEnabled())
+ {
+ Log(
+$@"{nameof(UnregisterClassForTypeInternal)} arguments:
+ {cxtInt.ClassId}
+ {cxtInt.InterfaceId}
+ 0x{(ulong)cxtInt.AssemblyPathBuffer:x}
+ 0x{(ulong)cxtInt.AssemblyNameBuffer:x}
+ 0x{(ulong)cxtInt.TypeNameBuffer:x}
+ 0x{cxtInt.ClassFactoryDest.ToInt64():x}");
+ }
+
+ if (cxtInt.InterfaceId != Guid.Empty
+ || cxtInt.ClassFactoryDest != IntPtr.Zero)
+ {
+ throw new ArgumentException();
+ }
+
+ try
+ {
+ var cxt = ComActivationContext.Create(ref cxtInt);
+ ClassRegisterationScenarioForType(cxt, register: false);
+ }
+ catch (Exception e)
+ {
+ return e.HResult;
+ }
+
+ return 0;
+ }
+
private static bool IsLoggingEnabled()
{
#if COM_ACTIVATOR_DEBUG
diff --git a/tests/src/Interop/COM/Activator/Program.cs b/tests/src/Interop/COM/Activator/Program.cs
index 451ecc7755..eb925ad163 100644
--- a/tests/src/Interop/COM/Activator/Program.cs
+++ b/tests/src/Interop/COM/Activator/Program.cs
@@ -136,6 +136,111 @@ namespace Activator
Assert.AreNotEqual(typeCFromAssemblyA, typeCFromAssemblyB, "Types should be from different AssemblyLoadContexts");
}
+ static void ValidateUserDefinedRegistrationCallbacks()
+ {
+ Console.WriteLine($"Running {nameof(ValidateUserDefinedRegistrationCallbacks)}...");
+
+ string assemblySubPath = Path.Combine(Environment.CurrentDirectory, "Servers");
+ string assemblyAPath = Path.Combine(assemblySubPath, "AssemblyA.dll");
+ string assemblyBPath = Path.Combine(assemblySubPath, "AssemblyB.dll");
+ string assemblyCPath = Path.Combine(assemblySubPath, "AssemblyC.dll");
+ string assemblyPaths = $"{assemblyAPath}{Path.PathSeparator}{assemblyBPath}{Path.PathSeparator}{assemblyCPath}";
+
+ HostPolicyMock.Initialize(Environment.CurrentDirectory, null);
+
+ var CLSID_NotUsed = Guid.Empty; // During this phase of activation the GUID is not used.
+ Guid iid = typeof(IValidateRegistrationCallbacks).GUID;
+
+ using (HostPolicyMock.Mock_corehost_resolve_component_dependencies(
+ 0,
+ assemblyPaths,
+ string.Empty,
+ string.Empty))
+ {
+ foreach (string typename in new[] { "ValidRegistrationTypeCallbacks", "ValidRegistrationStringCallbacks" })
+ {
+ Console.WriteLine($"Validating {typename}...");
+
+ var cxt = new ComActivationContext()
+ {
+ ClassId = CLSID_NotUsed,
+ InterfaceId = typeof(IClassFactory).GUID,
+ AssemblyPath = assemblyAPath,
+ AssemblyName = "AssemblyA",
+ TypeName = typename
+ };
+
+ var factory = (IClassFactory)ComActivator.GetClassFactoryForType(cxt);
+
+ object svr;
+ factory.CreateInstance(null, ref iid, out svr);
+
+ var inst = (IValidateRegistrationCallbacks)svr;
+ Assert.IsFalse(inst.DidRegister());
+ Assert.IsFalse(inst.DidUnregister());
+
+ cxt.InterfaceId = Guid.Empty;
+ ComActivator.ClassRegisterationScenarioForType(cxt, register: true);
+ ComActivator.ClassRegisterationScenarioForType(cxt, register: false);
+
+ Assert.IsTrue(inst.DidRegister());
+ Assert.IsTrue(inst.DidUnregister());
+ }
+ }
+
+ using (HostPolicyMock.Mock_corehost_resolve_component_dependencies(
+ 0,
+ assemblyPaths,
+ string.Empty,
+ string.Empty))
+ {
+ foreach (string typename in new[] { "NoRegistrationCallbacks", "InvalidArgRegistrationCallbacks", "InvalidInstanceRegistrationCallbacks", "MultipleRegistrationCallbacks" })
+ {
+ Console.WriteLine($"Validating {typename}...");
+
+ var cxt = new ComActivationContext()
+ {
+ ClassId = CLSID_NotUsed,
+ InterfaceId = typeof(IClassFactory).GUID,
+ AssemblyPath = assemblyAPath,
+ AssemblyName = "AssemblyA",
+ TypeName = typename
+ };
+
+ var factory = (IClassFactory)ComActivator.GetClassFactoryForType(cxt);
+
+ object svr;
+ factory.CreateInstance(null, ref iid, out svr);
+
+ var inst = (IValidateRegistrationCallbacks)svr;
+ cxt.InterfaceId = Guid.Empty;
+ bool exceptionThrown = false;
+ try
+ {
+ ComActivator.ClassRegisterationScenarioForType(cxt, register: true);
+ }
+ catch
+ {
+ exceptionThrown = true;
+ }
+
+ Assert.IsTrue(exceptionThrown || !inst.DidRegister());
+
+ exceptionThrown = false;
+ try
+ {
+ ComActivator.ClassRegisterationScenarioForType(cxt, register: false);
+ }
+ catch
+ {
+ exceptionThrown = true;
+ }
+
+ Assert.IsTrue(exceptionThrown || !inst.DidUnregister());
+ }
+ }
+ }
+
static int Main(string[] doNotUse)
{
try
@@ -144,6 +249,7 @@ namespace Activator
ClassNotRegistered();
NonrootedAssemblyPath();
ValidateAssemblyIsolation();
+ ValidateUserDefinedRegistrationCallbacks();
}
catch (Exception e)
{
diff --git a/tests/src/Interop/COM/Activator/Servers/AssemblyA.cs b/tests/src/Interop/COM/Activator/Servers/AssemblyA.cs
index 606f064c92..8646a6a856 100644
--- a/tests/src/Interop/COM/Activator/Servers/AssemblyA.cs
+++ b/tests/src/Interop/COM/Activator/Servers/AssemblyA.cs
@@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.
using System;
+using System.Runtime.InteropServices;
public class ClassFromA : IGetTypeFromC
{
@@ -16,4 +17,121 @@ public class ClassFromA : IGetTypeFromC
{
return this._fromC.GetType();
}
+}
+
+public class ValidRegistrationTypeCallbacks : IValidateRegistrationCallbacks
+{
+ [ComRegisterFunctionAttribute]
+ public static void RegisterFunction(Type t) => s_didRegister = true;
+
+ [ComUnregisterFunctionAttribute]
+ public static void UnregisterFunction(Type t) => s_didUnregister = true;
+
+ private static bool s_didRegister = false;
+ private static bool s_didUnregister = false;
+
+ bool IValidateRegistrationCallbacks.DidRegister() => s_didRegister;
+
+ bool IValidateRegistrationCallbacks.DidUnregister() => s_didUnregister;
+
+ void IValidateRegistrationCallbacks.Reset()
+ {
+ s_didRegister = false;
+ s_didUnregister = false;
+ }
+}
+
+public class ValidRegistrationStringCallbacks : IValidateRegistrationCallbacks
+{
+ [ComRegisterFunctionAttribute]
+ public static void RegisterFunction(string t) => s_didRegister = true;
+
+ [ComUnregisterFunctionAttribute]
+ public static void UnregisterFunction(string t) => s_didUnregister = true;
+
+ private static bool s_didRegister = false;
+ private static bool s_didUnregister = false;
+
+ bool IValidateRegistrationCallbacks.DidRegister() => s_didRegister;
+
+ bool IValidateRegistrationCallbacks.DidUnregister() => s_didUnregister;
+
+ void IValidateRegistrationCallbacks.Reset()
+ {
+ s_didRegister = false;
+ s_didUnregister = false;
+ }
+}
+
+public class NoRegistrationCallbacks : IValidateRegistrationCallbacks
+{
+ // Not attributed function
+ public static void RegisterFunction(Type t) => s_didRegister = true;
+
+ // Not attributed function
+ public static void UnregisterFunction(Type t) => s_didRegister = true;
+
+ private static bool s_didRegister = false;
+ private static bool s_didUnregister = false;
+
+ bool IValidateRegistrationCallbacks.DidRegister() => s_didRegister;
+
+ bool IValidateRegistrationCallbacks.DidUnregister() => s_didUnregister;
+
+ void IValidateRegistrationCallbacks.Reset()
+ {
+ s_didRegister = false;
+ s_didUnregister = false;
+ }
+}
+
+public class InvalidArgRegistrationCallbacks : IValidateRegistrationCallbacks
+{
+ [ComRegisterFunctionAttribute]
+ public static void RegisterFunction(int i) => throw new Exception();
+
+ [ComUnregisterFunctionAttribute]
+ public static void UnregisterFunction(int i) => throw new Exception();
+
+ bool IValidateRegistrationCallbacks.DidRegister() => throw new NotImplementedException();
+
+ bool IValidateRegistrationCallbacks.DidUnregister() => throw new NotImplementedException();
+
+ void IValidateRegistrationCallbacks.Reset() => throw new NotImplementedException();
+}
+
+public class InvalidInstanceRegistrationCallbacks : IValidateRegistrationCallbacks
+{
+ [ComRegisterFunctionAttribute]
+ public void RegisterFunction(Type t) => throw new Exception();
+
+ [ComUnregisterFunctionAttribute]
+ public void UnregisterFunction(Type t) => throw new Exception();
+
+ bool IValidateRegistrationCallbacks.DidRegister() => throw new NotImplementedException();
+
+ bool IValidateRegistrationCallbacks.DidUnregister() => throw new NotImplementedException();
+
+ void IValidateRegistrationCallbacks.Reset() => throw new NotImplementedException();
+}
+
+public class MultipleRegistrationCallbacks : IValidateRegistrationCallbacks
+{
+ [ComRegisterFunctionAttribute]
+ public static void RegisterFunction(string t) { }
+
+ [ComUnregisterFunctionAttribute]
+ public static void UnregisterFunction(string t) { }
+
+ [ComRegisterFunctionAttribute]
+ public static void RegisterFunction2(string t) { }
+
+ [ComUnregisterFunctionAttribute]
+ public static void UnregisterFunction2(string t) { }
+
+ bool IValidateRegistrationCallbacks.DidRegister() => throw new NotImplementedException();
+
+ bool IValidateRegistrationCallbacks.DidUnregister() => throw new NotImplementedException();
+
+ void IValidateRegistrationCallbacks.Reset() => throw new NotImplementedException();
} \ No newline at end of file
diff --git a/tests/src/Interop/COM/Activator/Servers/AssemblyContracts.cs b/tests/src/Interop/COM/Activator/Servers/AssemblyContracts.cs
index deb7b6312b..ebbfb90569 100644
--- a/tests/src/Interop/COM/Activator/Servers/AssemblyContracts.cs
+++ b/tests/src/Interop/COM/Activator/Servers/AssemblyContracts.cs
@@ -10,4 +10,12 @@ using System.Runtime.InteropServices;
public interface IGetTypeFromC
{
object GetTypeFromC();
+}
+
+[Guid("DA746E78-E1E8-44DD-8184-203AB57B3002")]
+public interface IValidateRegistrationCallbacks
+{
+ bool DidRegister();
+ bool DidUnregister();
+ void Reset();
} \ No newline at end of file