oleaut32: Correctly handle the case when the number of bytes in a BSTR is odd.

oldstable
Huw Davies 2006-05-10 12:41:32 +01:00 committed by Alexandre Julliard
parent 86d9457cdc
commit 40986cd729
2 changed files with 89 additions and 20 deletions

View File

@ -227,20 +227,20 @@ static void test_marshal_LPSAFEARRAY(void)
static void check_bstr(void *buffer, BSTR b)
{
DWORD *wireb = buffer;
DWORD len = SysStringLen(b);
DWORD len = SysStringByteLen(b);
ok(*wireb == len, "wv[0] %08lx\n", *wireb);
ok(*wireb == (len + 1) / 2, "wv[0] %08lx\n", *wireb);
wireb++;
if(len)
ok(*wireb == len * 2, "wv[1] %08lx\n", *wireb);
if(b)
ok(*wireb == len, "wv[1] %08lx\n", *wireb);
else
ok(*wireb == 0xffffffff, "wv[1] %08lx\n", *wireb);
wireb++;
ok(*wireb == len, "wv[2] %08lx\n", *wireb);
ok(*wireb == (len + 1) / 2, "wv[2] %08lx\n", *wireb);
if(len)
{
wireb++;
ok(!memcmp(wireb, b, len * 2), "strings differ\n");
ok(!memcmp(wireb, b, (len + 1) & ~1), "strings differ\n");
}
return;
}
@ -250,7 +250,7 @@ static void test_marshal_BSTR(void)
unsigned long size;
MIDL_STUB_MESSAGE stubMsg = { 0 };
USER_MARSHAL_CB umcb = { 0 };
unsigned char *buffer;
unsigned char *buffer, *next;
BSTR b, b2;
WCHAR str[] = {'m','a','r','s','h','a','l',' ','t','e','s','t','1',0};
DWORD len;
@ -271,14 +271,16 @@ static void test_marshal_BSTR(void)
ok(size == 38, "size %ld\n", size);
buffer = HeapAlloc(GetProcessHeap(), 0, size);
BSTR_UserMarshal(&umcb.Flags, buffer, &b);
next = BSTR_UserMarshal(&umcb.Flags, buffer, &b);
ok(next == buffer + size, "got %p expect %p\n", next, buffer + size);
check_bstr(buffer, b);
if (BSTR_UNMARSHAL_WORKS)
{
b2 = NULL;
BSTR_UserUnmarshal(&umcb.Flags, buffer, &b2);
ok(b2 != NULL, "NULL LPSAFEARRAY didn't unmarshal\n");
next = BSTR_UserUnmarshal(&umcb.Flags, buffer, &b2);
ok(next == buffer + size, "got %p expect %p\n", next, buffer + size);
ok(b2 != NULL, "BSTR didn't unmarshal\n");
ok(!memcmp(b, b2, (len + 1) * 2), "strings differ\n");
BSTR_UserFree(&umcb.Flags, &b2);
}
@ -291,11 +293,75 @@ static void test_marshal_BSTR(void)
ok(size == 12, "size %ld\n", size);
buffer = HeapAlloc(GetProcessHeap(), 0, size);
BSTR_UserMarshal(&umcb.Flags, buffer, &b);
next = BSTR_UserMarshal(&umcb.Flags, buffer, &b);
ok(next == buffer + size, "got %p expect %p\n", next, buffer + size);
check_bstr(buffer, b);
if (BSTR_UNMARSHAL_WORKS)
{
b2 = NULL;
next = BSTR_UserUnmarshal(&umcb.Flags, buffer, &b2);
ok(next == buffer + size, "got %p expect %p\n", next, buffer + size);
ok(b2 == NULL, "NULL BSTR didn't unmarshal\n");
BSTR_UserFree(&umcb.Flags, &b2);
}
HeapFree(GetProcessHeap(), 0, buffer);
b = SysAllocStringByteLen("abc", 3);
*(((char*)b) + 3) = 'd';
len = SysStringLen(b);
ok(len == 1, "get %ld\n", len);
len = SysStringByteLen(b);
ok(len == 3, "get %ld\n", len);
size = BSTR_UserSize(&umcb.Flags, 0, &b);
ok(size == 16, "size %ld\n", size);
buffer = HeapAlloc(GetProcessHeap(), 0, size);
memset(buffer, 0xcc, size);
next = BSTR_UserMarshal(&umcb.Flags, buffer, &b);
ok(next == buffer + size, "got %p expect %p\n", next, buffer + size);
check_bstr(buffer, b);
ok(buffer[15] == 'd', "buffer[15] %02x\n", buffer[15]);
if (BSTR_UNMARSHAL_WORKS)
{
b2 = NULL;
next = BSTR_UserUnmarshal(&umcb.Flags, buffer, &b2);
ok(next == buffer + size, "got %p expect %p\n", next, buffer + size);
ok(b2 != NULL, "BSTR didn't unmarshal\n");
ok(!memcmp(b, b2, len), "strings differ\n");
BSTR_UserFree(&umcb.Flags, &b2);
}
HeapFree(GetProcessHeap(), 0, buffer);
SysFreeString(b);
b = SysAllocStringByteLen("", 0);
len = SysStringLen(b);
ok(len == 0, "get %ld\n", len);
len = SysStringByteLen(b);
ok(len == 0, "get %ld\n", len);
size = BSTR_UserSize(&umcb.Flags, 0, &b);
ok(size == 12, "size %ld\n", size);
buffer = HeapAlloc(GetProcessHeap(), 0, size);
next = BSTR_UserMarshal(&umcb.Flags, buffer, &b);
ok(next == buffer + size, "got %p expect %p\n", next, buffer + size);
check_bstr(buffer, b);
if (BSTR_UNMARSHAL_WORKS)
{
b2 = NULL;
next = BSTR_UserUnmarshal(&umcb.Flags, buffer, &b2);
ok(next == buffer + size, "got %p expect %p\n", next, buffer + size);
ok(b2 != NULL, "NULL LPSAFEARRAY didn't unmarshal\n");
len = SysStringByteLen(b2);
ok(len == 0, "byte len %ld\n", len);
BSTR_UserFree(&umcb.Flags, &b2);
}
HeapFree(GetProcessHeap(), 0, buffer);
SysFreeString(b);
}
static void check_variant_header(DWORD *wirev, VARIANT *v, unsigned long size)

View File

@ -152,7 +152,7 @@ unsigned long WINAPI BSTR_UserSize(unsigned long *pFlags, unsigned long Start, B
TRACE("(%lx,%ld,%p) => %p\n", *pFlags, Start, pstr, *pstr);
if (*pstr) TRACE("string=%s\n", debugstr_w(*pstr));
ALIGN_LENGTH(Start, 3);
Start += sizeof(bstr_wire_t) + sizeof(OLECHAR) * (SysStringLen(*pstr));
Start += sizeof(bstr_wire_t) + ((SysStringByteLen(*pstr) + 1) & ~1);
TRACE("returning %ld\n", Start);
return Start;
}
@ -160,19 +160,21 @@ unsigned long WINAPI BSTR_UserSize(unsigned long *pFlags, unsigned long Start, B
unsigned char * WINAPI BSTR_UserMarshal(unsigned long *pFlags, unsigned char *Buffer, BSTR *pstr)
{
bstr_wire_t *header;
DWORD len = SysStringByteLen(*pstr);
TRACE("(%lx,%p,%p) => %p\n", *pFlags, Buffer, pstr, *pstr);
if (*pstr) TRACE("string=%s\n", debugstr_w(*pstr));
ALIGN_POINTER(Buffer, 3);
header = (bstr_wire_t*)Buffer;
header->len = header->len2 = SysStringLen(*pstr);
if (header->len)
header->len = header->len2 = (len + 1) / 2;
if (*pstr)
{
header->byte_len = header->len * sizeof(OLECHAR);
memcpy(header + 1, *pstr, header->byte_len);
header->byte_len = len;
memcpy(header + 1, *pstr, header->len * 2);
}
else
header->byte_len = 0xffffffff; /* special case for an empty string */
header->byte_len = 0xffffffff; /* special case for a null bstr */
return Buffer + sizeof(*header) + sizeof(OLECHAR) * header->len;
}
@ -187,14 +189,15 @@ unsigned char * WINAPI BSTR_UserUnmarshal(unsigned long *pFlags, unsigned char *
if(header->len != header->len2)
FIXME("len %08lx != len2 %08lx\n", header->len, header->len2);
if(header->len)
SysReAllocStringLen(pstr, (OLECHAR*)(header + 1), header->len);
else if (*pstr)
if(*pstr)
{
SysFreeString(*pstr);
*pstr = NULL;
}
if(header->byte_len != 0xffffffff)
*pstr = SysAllocStringByteLen((char*)(header + 1), header->byte_len);
if (*pstr) TRACE("string=%s\n", debugstr_w(*pstr));
return Buffer + sizeof(*header) + sizeof(OLECHAR) * header->len;
}