diff options
Diffstat (limited to 'src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs')
-rw-r--r-- | src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs | 199 |
1 files changed, 180 insertions, 19 deletions
diff --git a/src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs b/src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs index b8a8aac00c..d50dfcd7e5 100644 --- a/src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs +++ b/src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs @@ -71,16 +71,6 @@ namespace Internal.Runtime.InteropServices } [StructLayout(LayoutKind.Sequential)] - public struct ComActivationContext - { - public Guid ClassId; - public Guid InterfaceId; - public string AssemblyPath; - public string AssemblyName; - public string TypeName; - } - - [StructLayout(LayoutKind.Sequential)] [CLSCompliant(false)] public unsafe struct ComActivationContextInternal { @@ -92,6 +82,29 @@ namespace Internal.Runtime.InteropServices public IntPtr ClassFactoryDest; } + [StructLayout(LayoutKind.Sequential)] + public struct ComActivationContext + { + public Guid ClassId; + public Guid InterfaceId; + public string AssemblyPath; + public string AssemblyName; + public string TypeName; + + [CLSCompliant(false)] + public unsafe static ComActivationContext Create(ref ComActivationContextInternal cxtInt) + { + return new ComActivationContext() + { + ClassId = cxtInt.ClassId, + InterfaceId = cxtInt.InterfaceId, + AssemblyPath = Marshal.PtrToStringUni(new IntPtr(cxtInt.AssemblyPathBuffer))!, + AssemblyName = Marshal.PtrToStringUni(new IntPtr(cxtInt.AssemblyNameBuffer))!, + TypeName = Marshal.PtrToStringUni(new IntPtr(cxtInt.TypeNameBuffer))! + }; + } + } + public static class ComActivator { // Collection of all ALCs used for COM activation. In the event we want to support @@ -126,6 +139,87 @@ namespace Internal.Runtime.InteropServices } /// <summary> + /// Entry point for unmanaged COM register/unregister API from managed code + /// </summary> + /// <param name="cxt">Reference to a <see cref="ComActivationContext"/> instance</param> + /// <param name="register">true if called for register or false to indicate unregister</param> + public static void ClassRegisterationScenarioForType(ComActivationContext cxt, bool register) + { + // Retrieve the attribute type to use to determine if a function is the requested user defined + // registration function. + string attributeName = register ? "ComRegisterFunctionAttribute" : "ComUnregisterFunctionAttribute"; + Type? regFuncAttrType = Type.GetType($"System.Runtime.InteropServices.{attributeName}, System.Runtime.InteropServices", throwOnError: false); + if (regFuncAttrType == null) + { + // If the COM registration attributes can't be found then it is not on the type. + return; + } + + if (!Path.IsPathRooted(cxt.AssemblyPath)) + { + throw new ArgumentException(); + } + + Type classType = FindClassType(cxt.ClassId, cxt.AssemblyPath, cxt.AssemblyName, cxt.TypeName); + + // Retrieve all the methods. + MethodInfo[] methods = classType.GetMethods(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static); + + bool calledFunction = false; + + // Go through all the methods and check for the custom attribute. + foreach (MethodInfo method in methods) + { + // Check to see if the method has the custom attribute. + if (method.GetCustomAttributes(regFuncAttrType!, inherit: true).Length == 0) + { + continue; + } + + // Check to see if the method is static before we call it. + if (!method.IsStatic) + { + string msg = register ? SR.InvalidOperation_NonStaticComRegFunction : SR.InvalidOperation_NonStaticComUnRegFunction; + throw new InvalidOperationException(SR.Format(msg)); + } + + // Finally validate signature + ParameterInfo[] methParams = method.GetParameters(); + if (method.ReturnType != typeof(void) + || methParams == null + || methParams.Length != 1 + || (methParams[0].ParameterType != typeof(string) && methParams[0].ParameterType != typeof(Type))) + { + string msg = register ? SR.InvalidOperation_InvalidComRegFunctionSig : SR.InvalidOperation_InvalidComUnRegFunctionSig; + throw new InvalidOperationException(SR.Format(msg)); + } + + if (calledFunction) + { + string msg = register ? SR.InvalidOperation_MultipleComRegFunctions : SR.InvalidOperation_MultipleComUnRegFunctions; + throw new InvalidOperationException(SR.Format(msg)); + } + + // The function is valid so set up the arguments to call it. + var objs = new object[1]; + if (methParams[0].ParameterType == typeof(string)) + { + // We are dealing with the string overload of the function - provide the registry key - see comhost.dll implementation + objs[0] = $"HKEY_LOCAL_MACHINE\\SOFTWARE\\Classes\\CLSID\\{cxt.ClassId.ToString("B")}"; + } + else + { + // We are dealing with the type overload of the function. + objs[0] = classType; + } + + // Invoke the COM register function. + method.Invoke(null, objs); + calledFunction = true; + } + } + + /// <summary> /// Internal entry point for unmanaged COM activation API from native code /// </summary> /// <param name="cxtInt">Reference to a <see cref="ComActivationContextInternal"/> instance</param> @@ -146,15 +240,7 @@ $@"{nameof(GetClassFactoryForTypeInternal)} arguments: try { - var cxt = new ComActivationContext() - { - ClassId = cxtInt.ClassId, - InterfaceId = cxtInt.InterfaceId, - AssemblyPath = Marshal.PtrToStringUni(new IntPtr(cxtInt.AssemblyPathBuffer))!, - AssemblyName = Marshal.PtrToStringUni(new IntPtr(cxtInt.AssemblyNameBuffer))!, - TypeName = Marshal.PtrToStringUni(new IntPtr(cxtInt.TypeNameBuffer))! - }; - + var cxt = ComActivationContext.Create(ref cxtInt); object cf = GetClassFactoryForType(cxt); IntPtr nativeIUnknown = Marshal.GetIUnknownForObject(cf); Marshal.WriteIntPtr(cxtInt.ClassFactoryDest, nativeIUnknown); @@ -167,6 +253,81 @@ $@"{nameof(GetClassFactoryForTypeInternal)} arguments: return 0; } + /// <summary> + /// Internal entry point for registering a managed COM server API from native code + /// </summary> + /// <param name="cxtInt">Reference to a <see cref="ComActivationContextInternal"/> instance</param> + [CLSCompliant(false)] + public unsafe static int RegisterClassForTypeInternal(ref ComActivationContextInternal cxtInt) + { + if (IsLoggingEnabled()) + { + Log( +$@"{nameof(RegisterClassForTypeInternal)} arguments: + {cxtInt.ClassId} + {cxtInt.InterfaceId} + 0x{(ulong)cxtInt.AssemblyPathBuffer:x} + 0x{(ulong)cxtInt.AssemblyNameBuffer:x} + 0x{(ulong)cxtInt.TypeNameBuffer:x} + 0x{cxtInt.ClassFactoryDest.ToInt64():x}"); + } + + if (cxtInt.InterfaceId != Guid.Empty + || cxtInt.ClassFactoryDest != IntPtr.Zero) + { + throw new ArgumentException(); + } + + try + { + var cxt = ComActivationContext.Create(ref cxtInt); + ClassRegisterationScenarioForType(cxt, register: true); + } + catch (Exception e) + { + return e.HResult; + } + + return 0; + } + + /// <summary> + /// Internal entry point for unregistering a managed COM server API from native code + /// </summary> + [CLSCompliant(false)] + public unsafe static int UnregisterClassForTypeInternal(ref ComActivationContextInternal cxtInt) + { + if (IsLoggingEnabled()) + { + Log( +$@"{nameof(UnregisterClassForTypeInternal)} arguments: + {cxtInt.ClassId} + {cxtInt.InterfaceId} + 0x{(ulong)cxtInt.AssemblyPathBuffer:x} + 0x{(ulong)cxtInt.AssemblyNameBuffer:x} + 0x{(ulong)cxtInt.TypeNameBuffer:x} + 0x{cxtInt.ClassFactoryDest.ToInt64():x}"); + } + + if (cxtInt.InterfaceId != Guid.Empty + || cxtInt.ClassFactoryDest != IntPtr.Zero) + { + throw new ArgumentException(); + } + + try + { + var cxt = ComActivationContext.Create(ref cxtInt); + ClassRegisterationScenarioForType(cxt, register: false); + } + catch (Exception e) + { + return e.HResult; + } + + return 0; + } + private static bool IsLoggingEnabled() { #if COM_ACTIVATOR_DEBUG |