diff --git a/dlls/rpcrt4/rpc_server.c b/dlls/rpcrt4/rpc_server.c index 84e8854194e..4778a5eb403 100644 --- a/dlls/rpcrt4/rpc_server.c +++ b/dlls/rpcrt4/rpc_server.c @@ -67,6 +67,7 @@ static RpcObjTypeMap *RpcObjTypeMaps; /* list of type RpcServerProtseq */ static struct list protseqs = LIST_INIT(protseqs); static struct list server_interfaces = LIST_INIT(server_interfaces); +static struct list server_registered_auth_info = LIST_INIT(server_registered_auth_info); static CRITICAL_SECTION server_cs; static CRITICAL_SECTION_DEBUG server_cs_debug = @@ -86,6 +87,15 @@ static CRITICAL_SECTION_DEBUG listen_cs_debug = }; static CRITICAL_SECTION listen_cs = { &listen_cs_debug, -1, 0, 0, 0, 0 }; +static CRITICAL_SECTION server_auth_info_cs; +static CRITICAL_SECTION_DEBUG server_auth_info_cs_debug = +{ + 0, 0, &server_auth_info_cs, + { &server_auth_info_cs_debug.ProcessLocksList, &server_auth_info_cs_debug.ProcessLocksList }, + 0, 0, { (DWORD_PTR)(__FILE__ ": server_auth_info_cs") } +}; +static CRITICAL_SECTION server_auth_info_cs = { &server_auth_info_cs_debug, -1, 0, 0, 0, 0 }; + /* whether the server is currently listening */ static BOOL std_listen; /* number of manual listeners (calls to RpcServerListen) */ @@ -1134,15 +1144,114 @@ RPC_STATUS WINAPI RpcObjectSetType( UUID* ObjUuid, UUID* TypeUuid ) return RPC_S_OK; } +struct rpc_server_registered_auth_info +{ + struct list entry; + TimeStamp exp; + CredHandle cred; + ULONG max_token; + USHORT auth_type; +}; + +RPC_STATUS RPCRT4_ServerGetRegisteredAuthInfo( + USHORT auth_type, CredHandle *cred, TimeStamp *exp, ULONG *max_token) +{ + RPC_STATUS status = RPC_S_UNKNOWN_AUTHN_SERVICE; + struct rpc_server_registered_auth_info *auth_info; + + EnterCriticalSection(&server_auth_info_cs); + LIST_FOR_EACH_ENTRY(auth_info, &server_registered_auth_info, struct rpc_server_registered_auth_info, entry) + { + if (auth_info->auth_type == auth_type) + { + *cred = auth_info->cred; + *exp = auth_info->exp; + *max_token = auth_info->max_token; + status = RPC_S_OK; + break; + } + } + LeaveCriticalSection(&server_auth_info_cs); + + return status; +} + +void RPCRT4_ServerFreeAllRegisteredAuthInfo(void) +{ + struct rpc_server_registered_auth_info *auth_info, *cursor2; + + EnterCriticalSection(&server_auth_info_cs); + LIST_FOR_EACH_ENTRY_SAFE(auth_info, cursor2, &server_registered_auth_info, struct rpc_server_registered_auth_info, entry) + { + FreeCredentialsHandle(&auth_info->cred); + HeapFree(GetProcessHeap(), 0, auth_info); + } + LeaveCriticalSection(&server_auth_info_cs); +} + /*********************************************************************** * RpcServerRegisterAuthInfoA (RPCRT4.@) */ RPC_STATUS WINAPI RpcServerRegisterAuthInfoA( RPC_CSTR ServerPrincName, ULONG AuthnSvc, RPC_AUTH_KEY_RETRIEVAL_FN GetKeyFn, LPVOID Arg ) { - FIXME( "(%s,%u,%p,%p): stub\n", ServerPrincName, AuthnSvc, GetKeyFn, Arg ); - - return RPC_S_UNKNOWN_AUTHN_SERVICE; /* We don't know any authentication services */ + SECURITY_STATUS sec_status; + CredHandle cred; + TimeStamp exp; + ULONG package_count; + ULONG i; + PSecPkgInfoA packages; + ULONG max_token; + struct rpc_server_registered_auth_info *auth_info; + + TRACE("(%s,%u,%p,%p)\n", ServerPrincName, AuthnSvc, GetKeyFn, Arg); + + sec_status = EnumerateSecurityPackagesA(&package_count, &packages); + if (sec_status != SEC_E_OK) + { + ERR("EnumerateSecurityPackagesA failed with error 0x%08x\n", + sec_status); + return RPC_S_SEC_PKG_ERROR; + } + + for (i = 0; i < package_count; i++) + if (packages[i].wRPCID == AuthnSvc) + break; + + if (i == package_count) + { + WARN("unsupported AuthnSvc %u\n", AuthnSvc); + FreeContextBuffer(packages); + return RPC_S_UNKNOWN_AUTHN_SERVICE; + } + TRACE("found package %s for service %u\n", packages[i].Name, + AuthnSvc); + sec_status = AcquireCredentialsHandleA((SEC_CHAR *)ServerPrincName, + packages[i].Name, + SECPKG_CRED_INBOUND, NULL, NULL, + NULL, NULL, &cred, &exp); + max_token = packages[i].cbMaxToken; + FreeContextBuffer(packages); + if (sec_status != SEC_E_OK) + return RPC_S_SEC_PKG_ERROR; + + auth_info = HeapAlloc(GetProcessHeap(), 0, sizeof(*auth_info)); + if (!auth_info) + { + FreeCredentialsHandle(&cred); + return RPC_S_OUT_OF_RESOURCES; + } + + auth_info->exp = exp; + auth_info->cred = cred; + auth_info->max_token = max_token; + auth_info->auth_type = AuthnSvc; + + EnterCriticalSection(&server_auth_info_cs); + list_add_tail(&server_registered_auth_info, &auth_info->entry); + LeaveCriticalSection(&server_auth_info_cs); + + return RPC_S_OK; } /*********************************************************************** @@ -1151,9 +1260,63 @@ RPC_STATUS WINAPI RpcServerRegisterAuthInfoA( RPC_CSTR ServerPrincName, ULONG Au RPC_STATUS WINAPI RpcServerRegisterAuthInfoW( RPC_WSTR ServerPrincName, ULONG AuthnSvc, RPC_AUTH_KEY_RETRIEVAL_FN GetKeyFn, LPVOID Arg ) { - FIXME( "(%s,%u,%p,%p): stub\n", debugstr_w( ServerPrincName ), AuthnSvc, GetKeyFn, Arg ); - - return RPC_S_UNKNOWN_AUTHN_SERVICE; /* We don't know any authentication services */ + SECURITY_STATUS sec_status; + CredHandle cred; + TimeStamp exp; + ULONG package_count; + ULONG i; + PSecPkgInfoW packages; + ULONG max_token; + struct rpc_server_registered_auth_info *auth_info; + + TRACE("(%s,%u,%p,%p)\n", debugstr_w(ServerPrincName), AuthnSvc, GetKeyFn, Arg); + + sec_status = EnumerateSecurityPackagesW(&package_count, &packages); + if (sec_status != SEC_E_OK) + { + ERR("EnumerateSecurityPackagesW failed with error 0x%08x\n", + sec_status); + return RPC_S_SEC_PKG_ERROR; + } + + for (i = 0; i < package_count; i++) + if (packages[i].wRPCID == AuthnSvc) + break; + + if (i == package_count) + { + WARN("unsupported AuthnSvc %u\n", AuthnSvc); + FreeContextBuffer(packages); + return RPC_S_UNKNOWN_AUTHN_SERVICE; + } + TRACE("found package %s for service %u\n", debugstr_w(packages[i].Name), + AuthnSvc); + sec_status = AcquireCredentialsHandleW((SEC_WCHAR *)ServerPrincName, + packages[i].Name, + SECPKG_CRED_INBOUND, NULL, NULL, + NULL, NULL, &cred, &exp); + max_token = packages[i].cbMaxToken; + FreeContextBuffer(packages); + if (sec_status != SEC_E_OK) + return RPC_S_SEC_PKG_ERROR; + + auth_info = HeapAlloc(GetProcessHeap(), 0, sizeof(*auth_info)); + if (!auth_info) + { + FreeCredentialsHandle(&cred); + return RPC_S_OUT_OF_RESOURCES; + } + + auth_info->exp = exp; + auth_info->cred = cred; + auth_info->max_token = max_token; + auth_info->auth_type = AuthnSvc; + + EnterCriticalSection(&server_auth_info_cs); + list_add_tail(&server_registered_auth_info, &auth_info->entry); + LeaveCriticalSection(&server_auth_info_cs); + + return RPC_S_OK; } /*********************************************************************** diff --git a/dlls/rpcrt4/rpc_server.h b/dlls/rpcrt4/rpc_server.h index cacce794e34..07fb6234cea 100644 --- a/dlls/rpcrt4/rpc_server.h +++ b/dlls/rpcrt4/rpc_server.h @@ -80,5 +80,6 @@ void RPCRT4_new_client(RpcConnection* conn); const struct protseq_ops *rpcrt4_get_protseq_ops(const char *protseq); void RPCRT4_destroy_all_protseqs(void); +void RPCRT4_ServerFreeAllRegisteredAuthInfo(void); #endif /* __WINE_RPC_SERVER_H */ diff --git a/dlls/rpcrt4/rpcrt4_main.c b/dlls/rpcrt4/rpcrt4_main.c index 67059e77a6f..a790ccbd568 100644 --- a/dlls/rpcrt4/rpcrt4_main.c +++ b/dlls/rpcrt4/rpcrt4_main.c @@ -127,6 +127,7 @@ BOOL WINAPI DllMain(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpvReserved) case DLL_PROCESS_DETACH: RPCRT4_destroy_all_protseqs(); + RPCRT4_ServerFreeAllRegisteredAuthInfo(); break; }