summaryrefslogtreecommitdiff
path: root/src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs
diff options
context:
space:
mode:
Diffstat (limited to 'src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs')
-rw-r--r--src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs199
1 files changed, 180 insertions, 19 deletions
diff --git a/src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs b/src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs
index b8a8aac00c..d50dfcd7e5 100644
--- a/src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs
+++ b/src/System.Private.CoreLib/src/Internal/Runtime/InteropServices/ComActivator.cs
@@ -71,16 +71,6 @@ namespace Internal.Runtime.InteropServices
}
[StructLayout(LayoutKind.Sequential)]
- public struct ComActivationContext
- {
- public Guid ClassId;
- public Guid InterfaceId;
- public string AssemblyPath;
- public string AssemblyName;
- public string TypeName;
- }
-
- [StructLayout(LayoutKind.Sequential)]
[CLSCompliant(false)]
public unsafe struct ComActivationContextInternal
{
@@ -92,6 +82,29 @@ namespace Internal.Runtime.InteropServices
public IntPtr ClassFactoryDest;
}
+ [StructLayout(LayoutKind.Sequential)]
+ public struct ComActivationContext
+ {
+ public Guid ClassId;
+ public Guid InterfaceId;
+ public string AssemblyPath;
+ public string AssemblyName;
+ public string TypeName;
+
+ [CLSCompliant(false)]
+ public unsafe static ComActivationContext Create(ref ComActivationContextInternal cxtInt)
+ {
+ return new ComActivationContext()
+ {
+ ClassId = cxtInt.ClassId,
+ InterfaceId = cxtInt.InterfaceId,
+ AssemblyPath = Marshal.PtrToStringUni(new IntPtr(cxtInt.AssemblyPathBuffer))!,
+ AssemblyName = Marshal.PtrToStringUni(new IntPtr(cxtInt.AssemblyNameBuffer))!,
+ TypeName = Marshal.PtrToStringUni(new IntPtr(cxtInt.TypeNameBuffer))!
+ };
+ }
+ }
+
public static class ComActivator
{
// Collection of all ALCs used for COM activation. In the event we want to support
@@ -126,6 +139,87 @@ namespace Internal.Runtime.InteropServices
}
/// <summary>
+ /// Entry point for unmanaged COM register/unregister API from managed code
+ /// </summary>
+ /// <param name="cxt">Reference to a <see cref="ComActivationContext"/> instance</param>
+ /// <param name="register">true if called for register or false to indicate unregister</param>
+ public static void ClassRegisterationScenarioForType(ComActivationContext cxt, bool register)
+ {
+ // Retrieve the attribute type to use to determine if a function is the requested user defined
+ // registration function.
+ string attributeName = register ? "ComRegisterFunctionAttribute" : "ComUnregisterFunctionAttribute";
+ Type? regFuncAttrType = Type.GetType($"System.Runtime.InteropServices.{attributeName}, System.Runtime.InteropServices", throwOnError: false);
+ if (regFuncAttrType == null)
+ {
+ // If the COM registration attributes can't be found then it is not on the type.
+ return;
+ }
+
+ if (!Path.IsPathRooted(cxt.AssemblyPath))
+ {
+ throw new ArgumentException();
+ }
+
+ Type classType = FindClassType(cxt.ClassId, cxt.AssemblyPath, cxt.AssemblyName, cxt.TypeName);
+
+ // Retrieve all the methods.
+ MethodInfo[] methods = classType.GetMethods(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static);
+
+ bool calledFunction = false;
+
+ // Go through all the methods and check for the custom attribute.
+ foreach (MethodInfo method in methods)
+ {
+ // Check to see if the method has the custom attribute.
+ if (method.GetCustomAttributes(regFuncAttrType!, inherit: true).Length == 0)
+ {
+ continue;
+ }
+
+ // Check to see if the method is static before we call it.
+ if (!method.IsStatic)
+ {
+ string msg = register ? SR.InvalidOperation_NonStaticComRegFunction : SR.InvalidOperation_NonStaticComUnRegFunction;
+ throw new InvalidOperationException(SR.Format(msg));
+ }
+
+ // Finally validate signature
+ ParameterInfo[] methParams = method.GetParameters();
+ if (method.ReturnType != typeof(void)
+ || methParams == null
+ || methParams.Length != 1
+ || (methParams[0].ParameterType != typeof(string) && methParams[0].ParameterType != typeof(Type)))
+ {
+ string msg = register ? SR.InvalidOperation_InvalidComRegFunctionSig : SR.InvalidOperation_InvalidComUnRegFunctionSig;
+ throw new InvalidOperationException(SR.Format(msg));
+ }
+
+ if (calledFunction)
+ {
+ string msg = register ? SR.InvalidOperation_MultipleComRegFunctions : SR.InvalidOperation_MultipleComUnRegFunctions;
+ throw new InvalidOperationException(SR.Format(msg));
+ }
+
+ // The function is valid so set up the arguments to call it.
+ var objs = new object[1];
+ if (methParams[0].ParameterType == typeof(string))
+ {
+ // We are dealing with the string overload of the function - provide the registry key - see comhost.dll implementation
+ objs[0] = $"HKEY_LOCAL_MACHINE\\SOFTWARE\\Classes\\CLSID\\{cxt.ClassId.ToString("B")}";
+ }
+ else
+ {
+ // We are dealing with the type overload of the function.
+ objs[0] = classType;
+ }
+
+ // Invoke the COM register function.
+ method.Invoke(null, objs);
+ calledFunction = true;
+ }
+ }
+
+ /// <summary>
/// Internal entry point for unmanaged COM activation API from native code
/// </summary>
/// <param name="cxtInt">Reference to a <see cref="ComActivationContextInternal"/> instance</param>
@@ -146,15 +240,7 @@ $@"{nameof(GetClassFactoryForTypeInternal)} arguments:
try
{
- var cxt = new ComActivationContext()
- {
- ClassId = cxtInt.ClassId,
- InterfaceId = cxtInt.InterfaceId,
- AssemblyPath = Marshal.PtrToStringUni(new IntPtr(cxtInt.AssemblyPathBuffer))!,
- AssemblyName = Marshal.PtrToStringUni(new IntPtr(cxtInt.AssemblyNameBuffer))!,
- TypeName = Marshal.PtrToStringUni(new IntPtr(cxtInt.TypeNameBuffer))!
- };
-
+ var cxt = ComActivationContext.Create(ref cxtInt);
object cf = GetClassFactoryForType(cxt);
IntPtr nativeIUnknown = Marshal.GetIUnknownForObject(cf);
Marshal.WriteIntPtr(cxtInt.ClassFactoryDest, nativeIUnknown);
@@ -167,6 +253,81 @@ $@"{nameof(GetClassFactoryForTypeInternal)} arguments:
return 0;
}
+ /// <summary>
+ /// Internal entry point for registering a managed COM server API from native code
+ /// </summary>
+ /// <param name="cxtInt">Reference to a <see cref="ComActivationContextInternal"/> instance</param>
+ [CLSCompliant(false)]
+ public unsafe static int RegisterClassForTypeInternal(ref ComActivationContextInternal cxtInt)
+ {
+ if (IsLoggingEnabled())
+ {
+ Log(
+$@"{nameof(RegisterClassForTypeInternal)} arguments:
+ {cxtInt.ClassId}
+ {cxtInt.InterfaceId}
+ 0x{(ulong)cxtInt.AssemblyPathBuffer:x}
+ 0x{(ulong)cxtInt.AssemblyNameBuffer:x}
+ 0x{(ulong)cxtInt.TypeNameBuffer:x}
+ 0x{cxtInt.ClassFactoryDest.ToInt64():x}");
+ }
+
+ if (cxtInt.InterfaceId != Guid.Empty
+ || cxtInt.ClassFactoryDest != IntPtr.Zero)
+ {
+ throw new ArgumentException();
+ }
+
+ try
+ {
+ var cxt = ComActivationContext.Create(ref cxtInt);
+ ClassRegisterationScenarioForType(cxt, register: true);
+ }
+ catch (Exception e)
+ {
+ return e.HResult;
+ }
+
+ return 0;
+ }
+
+ /// <summary>
+ /// Internal entry point for unregistering a managed COM server API from native code
+ /// </summary>
+ [CLSCompliant(false)]
+ public unsafe static int UnregisterClassForTypeInternal(ref ComActivationContextInternal cxtInt)
+ {
+ if (IsLoggingEnabled())
+ {
+ Log(
+$@"{nameof(UnregisterClassForTypeInternal)} arguments:
+ {cxtInt.ClassId}
+ {cxtInt.InterfaceId}
+ 0x{(ulong)cxtInt.AssemblyPathBuffer:x}
+ 0x{(ulong)cxtInt.AssemblyNameBuffer:x}
+ 0x{(ulong)cxtInt.TypeNameBuffer:x}
+ 0x{cxtInt.ClassFactoryDest.ToInt64():x}");
+ }
+
+ if (cxtInt.InterfaceId != Guid.Empty
+ || cxtInt.ClassFactoryDest != IntPtr.Zero)
+ {
+ throw new ArgumentException();
+ }
+
+ try
+ {
+ var cxt = ComActivationContext.Create(ref cxtInt);
+ ClassRegisterationScenarioForType(cxt, register: false);
+ }
+ catch (Exception e)
+ {
+ return e.HResult;
+ }
+
+ return 0;
+ }
+
private static bool IsLoggingEnabled()
{
#if COM_ACTIVATOR_DEBUG