diff --git a/dlls/rpcrt4/tests/ndr_marshall.c b/dlls/rpcrt4/tests/ndr_marshall.c index 5356c43a7bb..cba135a3d47 100644 --- a/dlls/rpcrt4/tests/ndr_marshall.c +++ b/dlls/rpcrt4/tests/ndr_marshall.c @@ -2055,6 +2055,213 @@ static void test_NdrMapCommAndFaultStatus(void) } } +static void test_NdrGetUserMarshalInfo(void) +{ + RPC_STATUS status; + MIDL_STUB_MESSAGE stubmsg; + USER_MARSHAL_CB umcb; + NDR_USER_MARSHAL_INFO umi; + unsigned char buffer[16]; + void *rpc_channel_buffer = (void *)(ULONG_PTR)0xcafebabe; + RPC_MESSAGE rpc_msg; + RPC_STATUS (RPC_ENTRY *pNdrGetUserMarshalInfo)(ULONG *,ULONG,NDR_USER_MARSHAL_INFO *); + + pNdrGetUserMarshalInfo = (void *)GetProcAddress(GetModuleHandle("rpcrt4.dll"), "NdrGetUserMarshalInfo"); + if (!pNdrGetUserMarshalInfo) + { + skip("NdrGetUserMarshalInfo not exported\n"); + return; + } + + /* unmarshall */ + + memset(&rpc_msg, 0xcc, sizeof(rpc_msg)); + rpc_msg.Buffer = buffer; + rpc_msg.BufferLength = 16; + + memset(&stubmsg, 0xcc, sizeof(stubmsg)); + stubmsg.RpcMsg = &rpc_msg; + stubmsg.dwDestContext = MSHCTX_INPROC; + stubmsg.pvDestContext = NULL; + stubmsg.Buffer = buffer + 15; + stubmsg.BufferLength = 0; + stubmsg.BufferEnd = NULL; + stubmsg.pRpcChannelBuffer = rpc_channel_buffer; + stubmsg.StubDesc = NULL; + stubmsg.pfnAllocate = my_alloc; + stubmsg.pfnFree = my_free; + + memset(&umcb, 0xcc, sizeof(umcb)); + umcb.Flags = MAKELONG(MSHCTX_INPROC, NDR_LOCAL_DATA_REPRESENTATION); + umcb.pStubMsg = &stubmsg; + umcb.Signature = USER_MARSHAL_CB_SIGNATURE; + umcb.CBType = USER_MARSHAL_CB_UNMARSHALL; + + memset(&umi, 0xaa, sizeof(umi)); + + status = pNdrGetUserMarshalInfo(&umcb.Flags, 1, &umi); + ok(status == RPC_S_OK, "NdrGetUserMarshalInfo failed with error %d\n", status); + ok( umi.InformationLevel == 1, + "umi.InformationLevel was %u instead of 1\n", + umi.InformationLevel); + ok( U(umi.Level1).Buffer == buffer + 15, + "U(umi.Level1).Buffer was %p instead of %p\n", + U(umi.Level1).Buffer, buffer); + ok( U(umi.Level1).BufferSize == 1, + "U(umi.Level1).BufferSize was %u instead of 1\n", + U(umi.Level1).BufferSize); + ok( U(umi.Level1).pfnAllocate == my_alloc, + "U(umi.Level1).pfnAllocate was %p instead of %p\n", + U(umi.Level1).pfnAllocate, my_alloc); + ok( U(umi.Level1).pfnFree == my_free, + "U(umi.Level1).pfnFree was %p instead of %p\n", + U(umi.Level1).pfnFree, my_free); + ok( U(umi.Level1).pRpcChannelBuffer == rpc_channel_buffer, + "U(umi.Level1).pRpcChannelBuffer was %p instead of %p\n", + U(umi.Level1).pRpcChannelBuffer, rpc_channel_buffer); + + /* buffer size */ + + rpc_msg.Buffer = buffer; + rpc_msg.BufferLength = 16; + + stubmsg.Buffer = buffer; + stubmsg.BufferLength = 16; + stubmsg.BufferEnd = NULL; + + umcb.CBType = USER_MARSHAL_CB_BUFFER_SIZE; + + memset(&umi, 0xaa, sizeof(umi)); + + status = pNdrGetUserMarshalInfo(&umcb.Flags, 1, &umi); + ok(status == RPC_S_OK, "NdrGetUserMarshalInfo failed with error %d\n", status); + ok( umi.InformationLevel == 1, + "umi.InformationLevel was %u instead of 1\n", + umi.InformationLevel); + ok( U(umi.Level1).Buffer == NULL, + "U(umi.Level1).Buffer was %p instead of NULL\n", + U(umi.Level1).Buffer); + ok( U(umi.Level1).BufferSize == 0, + "U(umi.Level1).BufferSize was %u instead of 0\n", + U(umi.Level1).BufferSize); + ok( U(umi.Level1).pfnAllocate == my_alloc, + "U(umi.Level1).pfnAllocate was %p instead of %p\n", + U(umi.Level1).pfnAllocate, my_alloc); + ok( U(umi.Level1).pfnFree == my_free, + "U(umi.Level1).pfnFree was %p instead of %p\n", + U(umi.Level1).pfnFree, my_free); + ok( U(umi.Level1).pRpcChannelBuffer == rpc_channel_buffer, + "U(umi.Level1).pRpcChannelBuffer was %p instead of %p\n", + U(umi.Level1).pRpcChannelBuffer, rpc_channel_buffer); + + /* marshall */ + + rpc_msg.Buffer = buffer; + rpc_msg.BufferLength = 16; + + stubmsg.Buffer = buffer + 15; + stubmsg.BufferLength = 0; + stubmsg.BufferEnd = NULL; + + umcb.CBType = USER_MARSHAL_CB_MARSHALL; + + memset(&umi, 0xaa, sizeof(umi)); + + status = pNdrGetUserMarshalInfo(&umcb.Flags, 1, &umi); + ok(status == RPC_S_OK, "NdrGetUserMarshalInfo failed with error %d\n", status); + ok( umi.InformationLevel == 1, + "umi.InformationLevel was %u instead of 1\n", + umi.InformationLevel); + ok( U(umi.Level1).Buffer == buffer + 15, + "U(umi.Level1).Buffer was %p instead of %p\n", + U(umi.Level1).Buffer, buffer); + ok( U(umi.Level1).BufferSize == 1, + "U(umi.Level1).BufferSize was %u instead of 1\n", + U(umi.Level1).BufferSize); + ok( U(umi.Level1).pfnAllocate == my_alloc, + "U(umi.Level1).pfnAllocate was %p instead of %p\n", + U(umi.Level1).pfnAllocate, my_alloc); + ok( U(umi.Level1).pfnFree == my_free, + "U(umi.Level1).pfnFree was %p instead of %p\n", + U(umi.Level1).pfnFree, my_free); + ok( U(umi.Level1).pRpcChannelBuffer == rpc_channel_buffer, + "U(umi.Level1).pRpcChannelBuffer was %p instead of %p\n", + U(umi.Level1).pRpcChannelBuffer, rpc_channel_buffer); + + /* free */ + + rpc_msg.Buffer = buffer; + rpc_msg.BufferLength = 16; + + stubmsg.Buffer = buffer; + stubmsg.BufferLength = 16; + stubmsg.BufferEnd = NULL; + + umcb.CBType = USER_MARSHAL_CB_FREE; + + memset(&umi, 0xaa, sizeof(umi)); + + status = pNdrGetUserMarshalInfo(&umcb.Flags, 1, &umi); + ok(status == RPC_S_OK, "NdrGetUserMarshalInfo failed with error %d\n", status); + ok( umi.InformationLevel == 1, + "umi.InformationLevel was %u instead of 1\n", + umi.InformationLevel); + ok( U(umi.Level1).Buffer == NULL, + "U(umi.Level1).Buffer was %p instead of NULL\n", + U(umi.Level1).Buffer); + ok( U(umi.Level1).BufferSize == 0, + "U(umi.Level1).BufferSize was %u instead of 0\n", + U(umi.Level1).BufferSize); + ok( U(umi.Level1).pfnAllocate == my_alloc, + "U(umi.Level1).pfnAllocate was %p instead of %p\n", + U(umi.Level1).pfnAllocate, my_alloc); + ok( U(umi.Level1).pfnFree == my_free, + "U(umi.Level1).pfnFree was %p instead of %p\n", + U(umi.Level1).pfnFree, my_free); + ok( U(umi.Level1).pRpcChannelBuffer == rpc_channel_buffer, + "U(umi.Level1).pRpcChannelBuffer was %p instead of %p\n", + U(umi.Level1).pRpcChannelBuffer, rpc_channel_buffer); + + /* boundary test */ + + rpc_msg.Buffer = buffer; + rpc_msg.BufferLength = 15; + + stubmsg.Buffer = buffer + 15; + stubmsg.BufferLength = 0; + stubmsg.BufferEnd = NULL; + + umcb.CBType = USER_MARSHAL_CB_MARSHALL; + + status = pNdrGetUserMarshalInfo(&umcb.Flags, 1, &umi); + ok(status == RPC_S_OK, "NdrGetUserMarshalInfo failed with error %d\n", status); + ok( U(umi.Level1).BufferSize == 0, + "U(umi.Level1).BufferSize was %u instead of 0\n", + U(umi.Level1).BufferSize); + + /* error conditions */ + + rpc_msg.BufferLength = 14; + status = pNdrGetUserMarshalInfo(&umcb.Flags, 1, &umi); + ok(status == ERROR_INVALID_USER_BUFFER, + "NdrGetUserMarshalInfo should have failed with ERROR_INVALID_USER_BUFFER instead of %d\n", status); + + rpc_msg.BufferLength = 15; + status = pNdrGetUserMarshalInfo(&umcb.Flags, 9999, &umi); + ok(status == RPC_S_INVALID_ARG, + "NdrGetUserMarshalInfo should have failed with RPC_S_INVALID_ARG instead of %d\n", status); + + umcb.CBType = 9999; + status = pNdrGetUserMarshalInfo(&umcb.Flags, 1, &umi); + ok(status == RPC_S_OK, "NdrGetUserMarshalInfo failed with error %d\n", status); + + umcb.CBType = USER_MARSHAL_CB_MARSHALL; + umcb.Signature = 0; + status = pNdrGetUserMarshalInfo(&umcb.Flags, 1, &umi); + ok(status == RPC_S_INVALID_ARG, + "NdrGetUserMarshalInfo should have failed with RPC_S_INVALID_ARG instead of %d\n", status); +} + START_TEST( ndr_marshall ) { determine_pointer_marshalling_style(); @@ -2073,4 +2280,5 @@ START_TEST( ndr_marshall ) test_conf_complex_struct(); test_ndr_buffer(); test_NdrMapCommAndFaultStatus(); + test_NdrGetUserMarshalInfo(); }