From d20ab01ed27e48e44fee8f2cb78a7804e78a26d4 Mon Sep 17 00:00:00 2001 From: Andrey Turkin Date: Wed, 14 Jan 2009 22:31:06 +0300 Subject: [PATCH] ole32: Do not crash in WriteClassStg if passed NULL pointer. --- dlls/ole32/storage32.c | 3 +++ dlls/ole32/tests/storage32.c | 51 ++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/dlls/ole32/storage32.c b/dlls/ole32/storage32.c index fe848046861..48a655f4a68 100644 --- a/dlls/ole32/storage32.c +++ b/dlls/ole32/storage32.c @@ -6281,6 +6281,9 @@ HRESULT WINAPI WriteClassStg(IStorage* pStg, REFCLSID rclsid) if(!pStg) return E_INVALIDARG; + if(!rclsid) + return STG_E_INVALIDPOINTER; + hRes = IStorage_SetClass(pStg, rclsid); return hRes; diff --git a/dlls/ole32/tests/storage32.c b/dlls/ole32/tests/storage32.c index f65af733d28..da2d0aa5a6e 100644 --- a/dlls/ole32/tests/storage32.c +++ b/dlls/ole32/tests/storage32.c @@ -803,6 +803,56 @@ static void test_storage_refcount(void) DeleteFileW(filename); } +static void test_writeclassstg(void) +{ + static const WCHAR szPrefix[] = { 's','t','g',0 }; + static const WCHAR szDot[] = { '.',0 }; + WCHAR filename[MAX_PATH]; + IStorage *stg = NULL; + HRESULT r; + CLSID temp_cls; + + if(!GetTempFileNameW(szDot, szPrefix, 0, filename)) + return; + + DeleteFileW(filename); + + /* create the file */ + r = StgCreateDocfile( filename, STGM_CREATE | STGM_SHARE_EXCLUSIVE | + STGM_READWRITE, 0, &stg); + ok(r==S_OK, "StgCreateDocfile failed\n"); + + r = ReadClassStg( NULL, NULL ); + ok(r == E_INVALIDARG, "ReadClassStg should return E_INVALIDARG instead of 0x%08X\n", r); + + r = ReadClassStg( stg, NULL ); + ok(r == E_INVALIDARG, "ReadClassStg should return E_INVALIDARG instead of 0x%08X\n", r); + + temp_cls.Data1 = 0xdeadbeef; + r = ReadClassStg( stg, &temp_cls ); + ok(r == S_OK, "ReadClassStg failed with 0x%08X\n", r); + + ok(IsEqualCLSID(&temp_cls, &CLSID_NULL), "ReadClassStg returned wrong clsid\n"); + + r = WriteClassStg( NULL, NULL ); + ok(r == E_INVALIDARG, "WriteClassStg should return E_INVALIDARG instead of 0x%08X\n", r); + + r = WriteClassStg( stg, NULL ); + ok(r == STG_E_INVALIDPOINTER, "WriteClassStg should return STG_E_INVALIDPOINTER instead of 0x%08X\n", r); + + r = WriteClassStg( stg, &test_stg_cls ); + ok( r == S_OK, "WriteClassStg failed with 0x%08X\n", r); + + r = ReadClassStg( stg, &temp_cls ); + ok( r == S_OK, "ReadClassStg failed with 0x%08X\n", r); + ok(IsEqualCLSID(&temp_cls, &test_stg_cls), "ReadClassStg returned wrong clsid\n"); + + r = IStorage_Release( stg ); + ok (r == 0, "storage not released\n"); + + DeleteFileW(filename); +} + static void test_streamenum(void) { static const WCHAR szPrefix[] = { 's','t','g',0 }; @@ -1229,4 +1279,5 @@ START_TEST(storage32) test_transact(); test_ReadClassStm(); test_access(); + test_writeclassstg(); }