Replaced use of CreateStreamOnHGlobal with a custom IStream

This commit is contained in:
Chuck Walbourn 2021-04-06 17:52:47 -07:00
parent 5f22c7cc39
commit 60fde45d28
3 changed files with 385 additions and 127 deletions

View File

@ -396,7 +396,11 @@ namespace DirectX
void *__cdecl GetBufferPointer() const noexcept { return m_buffer; } void *__cdecl GetBufferPointer() const noexcept { return m_buffer; }
size_t __cdecl GetBufferSize() const noexcept { return m_size; } size_t __cdecl GetBufferSize() const noexcept { return m_size; }
HRESULT __cdecl Resize(size_t size) noexcept;
// Reallocate for a new size
HRESULT __cdecl Trim(size_t size) noexcept; HRESULT __cdecl Trim(size_t size) noexcept;
// Shorten size without reallocation
private: private:
void* m_buffer; void* m_buffer;

View File

@ -1510,3 +1510,25 @@ HRESULT Blob::Trim(size_t size) noexcept
return S_OK; return S_OK;
} }
HRESULT Blob::Resize(size_t size) noexcept
{
if (!size)
return E_INVALIDARG;
if (!m_buffer || !m_size)
return E_UNEXPECTED;
void *tbuffer = _aligned_malloc(size, 16);
if (!tbuffer)
return E_OUTOFMEMORY;
memcpy(tbuffer, m_buffer, std::min(m_size, size));
Release();
m_buffer = tbuffer;
m_size = size;
return S_OK;
}

View File

@ -11,89 +11,14 @@
#include "DirectXTexP.h" #include "DirectXTexP.h"
//-------------------------------------------------------------------------------------
// IStream support for WIC Memory routines
//-------------------------------------------------------------------------------------
#if defined(WINAPI_FAMILY) && (WINAPI_FAMILY == WINAPI_FAMILY_APP) && (WINAPI_FAMILY != WINAPI_FAMILY_PHONE_APP)
#include <shcore.h>
#pragma comment(lib,"shcore.lib")
#ifdef __cplusplus_winrt
static inline HRESULT CreateMemoryStream(_Outptr_ IStream** stream)
{
auto randomAccessStream = ref new ::Windows::Storage::Streams::InMemoryRandomAccessStream();
return CreateStreamOverRandomAccessStream(randomAccessStream, IID_PPV_ARGS(stream));
}
#else
#pragma warning(push)
#pragma warning(disable : 4619 5038)
#include <wrl\client.h>
#include <wrl\wrappers\corewrappers.h>
#pragma warning(pop)
#pragma warning(push)
#pragma warning(disable : 4471 5204)
#include <windows.storage.streams.h>
#pragma warning(pop)
static inline HRESULT CreateMemoryStream(_Outptr_ IStream** stream)
{
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IRandomAccessStream> abiStream;
HRESULT hr = Windows::Foundation::ActivateInstance(
Microsoft::WRL::Wrappers::HStringReference(RuntimeClass_Windows_Storage_Streams_InMemoryRandomAccessStream).Get(),
abiStream.GetAddressOf());
if (SUCCEEDED(hr))
{
hr = CreateStreamOverRandomAccessStream(abiStream.Get(), IID_PPV_ARGS(stream));
}
return hr;
}
#endif // __cplusplus_winrt
#elif (!defined(WINAPI_FAMILY) || (WINAPI_FAMILY == WINAPI_FAMILY_DESKTOP_APP)) && (_WIN32_WINNT >= _WIN32_WINNT_WIN8)
#include <Shlwapi.h>
#pragma comment(lib,"shlwapi.lib")
static inline HRESULT CreateMemoryStream(_Outptr_ IStream** stream) noexcept
{
if (!stream)
return E_INVALIDARG;
*stream = SHCreateMemStream(nullptr, 0u);
if (!*stream)
return E_OUTOFMEMORY;
return S_OK;
}
#else
#pragma prefast(suppress:6387 28196, "a simple wrapper around an existing annotated function" );
static inline HRESULT CreateMemoryStream(_Outptr_ IStream** stream) noexcept
{
return CreateStreamOnHGlobal(nullptr, TRUE, stream);
}
#endif
using namespace DirectX; using namespace DirectX;
using Microsoft::WRL::ComPtr; using Microsoft::WRL::ComPtr;
namespace namespace
{ {
//------------------------------------------------------------------------------------- //-------------------------------------------------------------------------------------
// WIC Pixel Format nearest conversion table // WIC Pixel Format nearest conversion table
//------------------------------------------------------------------------------------- //-------------------------------------------------------------------------------------
struct WICConvert struct WICConvert
{ {
const GUID& source; const GUID& source;
@ -266,6 +191,331 @@ namespace
} }
//-------------------------------------------------------------------------------------
// IStream over a Blob for WIC in-memory write functions
//-------------------------------------------------------------------------------------
class MemoryStreamOnBlob : public IStream
{
MemoryStreamOnBlob(Blob& blob) noexcept :
mBlob(blob),
m_streamPosition(0),
m_streamEOF(0),
mRefCount(1)
{
assert(mBlob.GetBufferPointer() && mBlob.GetBufferSize() > 0);
}
public:
virtual ~MemoryStreamOnBlob() = default;
MemoryStreamOnBlob(MemoryStreamOnBlob&&) = delete;
MemoryStreamOnBlob& operator= (MemoryStreamOnBlob&&) = delete;
MemoryStreamOnBlob(MemoryStreamOnBlob const&) = delete;
MemoryStreamOnBlob& operator= (MemoryStreamOnBlob const&) = delete;
// IUnknown
HRESULT STDMETHODCALLTYPE QueryInterface(REFIID iid, void** ppvObject) override
{
if (iid == __uuidof(IUnknown)
|| iid == __uuidof(IStream)
|| iid == __uuidof(ISequentialStream))
{
*ppvObject = static_cast<IStream*>(this);
AddRef();
return S_OK;
}
else
return E_NOINTERFACE;
}
ULONG STDMETHODCALLTYPE AddRef() override
{
return InterlockedIncrement(&mRefCount);
}
ULONG STDMETHODCALLTYPE Release() override
{
ULONG res = InterlockedDecrement(&mRefCount);
if (res == 0)
{
delete this;
}
return res;
}
// ISequentialStream
HRESULT STDMETHODCALLTYPE Read(void* pv, ULONG cb, ULONG* pcbRead) override
{
size_t maxRead = m_streamEOF - m_streamPosition;
auto ptr = static_cast<const uint8_t*>(mBlob.GetBufferPointer());
if (cb > maxRead)
{
uint64_t pos = uint64_t(m_streamPosition) + uint64_t(maxRead);
if (pos > UINT32_MAX)
return HRESULT_E_ARITHMETIC_OVERFLOW;
memcpy(pv, &ptr[m_streamPosition], maxRead);
m_streamPosition = static_cast<size_t>(pos);
if (pcbRead)
{
*pcbRead = static_cast<ULONG>(maxRead);
}
return E_BOUNDS;
}
else
{
uint64_t pos = uint64_t(m_streamPosition) + uint64_t(cb);
if (pos > UINT32_MAX)
return HRESULT_E_ARITHMETIC_OVERFLOW;
memcpy(pv, &ptr[m_streamPosition], cb);
m_streamPosition = static_cast<size_t>(pos);
if (pcbRead)
{
*pcbRead = cb;
}
return S_OK;
}
}
HRESULT STDMETHODCALLTYPE Write(void const* pv, ULONG cb, ULONG* pcbWritten) override
{
size_t blobSize = mBlob.GetBufferSize();
size_t spaceAvailable = blobSize - m_streamPosition;
size_t growAmount = cb;
if (spaceAvailable > 0)
{
if (spaceAvailable >= growAmount)
{
growAmount = 0;
}
else
{
growAmount -= spaceAvailable;
}
}
if (growAmount > 0)
{
uint64_t newSize = uint64_t(blobSize);
uint64_t targetSize = uint64_t(blobSize) + growAmount;
HRESULT hr = ComputeGrowSize(newSize, targetSize);
if (FAILED(hr))
return hr;
hr = mBlob.Resize(static_cast<size_t>(newSize));
if (FAILED(hr))
return hr;
}
uint64_t pos = uint64_t(m_streamPosition) + uint64_t(cb);
if (pos > UINT32_MAX)
return HRESULT_E_ARITHMETIC_OVERFLOW;
auto ptr = static_cast<uint8_t*>(mBlob.GetBufferPointer());
memcpy(&ptr[m_streamPosition], pv, cb);
m_streamPosition = static_cast<size_t>(pos);
m_streamEOF = std::max(m_streamEOF, m_streamPosition);
if (pcbWritten)
{
*pcbWritten = cb;
}
return S_OK;
}
// IStream
HRESULT STDMETHODCALLTYPE SetSize(ULARGE_INTEGER size) override
{
if (size.HighPart > 0)
return E_OUTOFMEMORY;
size_t blobSize = mBlob.GetBufferSize();
if (blobSize >= size.LowPart)
{
auto ptr = static_cast<uint8_t*>(mBlob.GetBufferPointer());
if (m_streamEOF < size.LowPart)
{
memset(&ptr[m_streamEOF], 0, size.LowPart - m_streamEOF);
}
m_streamEOF = static_cast<size_t>(size.LowPart);
}
else
{
uint64_t newSize = uint64_t(blobSize);
uint64_t targetSize = uint64_t(size.QuadPart);
HRESULT hr = ComputeGrowSize(newSize, targetSize);
if (FAILED(hr))
return hr;
hr = mBlob.Resize(static_cast<size_t>(newSize));
if (FAILED(hr))
return hr;
blobSize = mBlob.GetBufferSize();
auto ptr = static_cast<uint8_t*>(mBlob.GetBufferPointer());
if (m_streamEOF < size.LowPart)
{
memset(&ptr[m_streamEOF], 0, size.LowPart - m_streamEOF);
}
m_streamEOF = static_cast<size_t>(size.LowPart);
}
if (m_streamPosition > m_streamEOF)
{
m_streamPosition = m_streamEOF;
}
return S_OK;
}
HRESULT STDMETHODCALLTYPE CopyTo(IStream*, ULARGE_INTEGER, ULARGE_INTEGER*, ULARGE_INTEGER*) override
{
return E_NOTIMPL;
}
HRESULT STDMETHODCALLTYPE Commit(DWORD) override
{
return E_NOTIMPL;
}
HRESULT STDMETHODCALLTYPE Revert(void) override
{
return E_NOTIMPL;
}
HRESULT STDMETHODCALLTYPE LockRegion(ULARGE_INTEGER, ULARGE_INTEGER, DWORD) override
{
return E_NOTIMPL;
}
HRESULT STDMETHODCALLTYPE UnlockRegion(ULARGE_INTEGER, ULARGE_INTEGER, DWORD) override
{
return E_NOTIMPL;
}
HRESULT STDMETHODCALLTYPE Clone(IStream**) override
{
return E_NOTIMPL;
}
HRESULT STDMETHODCALLTYPE Seek(LARGE_INTEGER liDistanceToMove, DWORD dwOrigin, ULARGE_INTEGER* lpNewFilePointer) override
{
LONGLONG newPosition = 0;
switch (dwOrigin)
{
case STREAM_SEEK_SET:
newPosition = liDistanceToMove.QuadPart;
break;
case STREAM_SEEK_CUR:
newPosition = static_cast<LONGLONG>(m_streamPosition) + liDistanceToMove.QuadPart;
break;
case STREAM_SEEK_END:
newPosition = static_cast<LONGLONG>(m_streamEOF) + liDistanceToMove.QuadPart;
break;
default:
return STG_E_INVALIDFUNCTION;
}
HRESULT result = S_OK;
if (newPosition > static_cast<LONGLONG>(m_streamEOF))
{
m_streamPosition = m_streamEOF;
result = E_BOUNDS;
}
else if (newPosition < 0)
{
m_streamPosition = 0;
result = E_BOUNDS;
}
else
{
m_streamPosition = static_cast<size_t>(newPosition);
}
if (lpNewFilePointer)
{
lpNewFilePointer->QuadPart = static_cast<ULONGLONG>(m_streamPosition);
}
return result;
}
HRESULT STDMETHODCALLTYPE Stat(STATSTG* pStatstg, DWORD) override
{
if (!pStatstg)
return E_INVALIDARG;
pStatstg->cbSize.QuadPart = static_cast<ULONGLONG>(m_streamEOF);
return S_OK;
}
HRESULT Finialize() noexcept
{
if (mRefCount > 1)
return E_FAIL;
return mBlob.Trim(m_streamEOF);
}
static HRESULT CreateMemoryStream(_Outptr_ MemoryStreamOnBlob** stream, Blob& blob) noexcept
{
if (!stream)
return E_INVALIDARG;
*stream = nullptr;
auto ptr = new (std::nothrow) MemoryStreamOnBlob(blob);
if (!ptr)
return E_OUTOFMEMORY;
*stream = ptr;
return S_OK;
}
private:
Blob& mBlob;
size_t m_streamPosition;
size_t m_streamEOF;
ULONG mRefCount;
static HRESULT ComputeGrowSize(uint64_t& newSize, uint64_t& targetSize) noexcept
{
// We grow by doubling until we hit 256MB, then we add 16MB at a time.
while (newSize < targetSize)
{
if (newSize < (256 * 1024 * 1024))
{
newSize <<= 1;
}
else
{
newSize += 16 * 1024 * 1024;
}
if (newSize > UINT32_MAX)
return E_OUTOFMEMORY;
}
return S_OK;
}
};
//------------------------------------------------------------------------------------- //-------------------------------------------------------------------------------------
// Determines metadata for image // Determines metadata for image
//------------------------------------------------------------------------------------- //-------------------------------------------------------------------------------------
@ -1198,42 +1448,33 @@ HRESULT DirectX::SaveToWICMemory(
if (!image.pixels) if (!image.pixels)
return E_POINTER; return E_POINTER;
blob.Release(); HRESULT hr = blob.Initialize(65535u);
ComPtr<IStream> stream;
HRESULT hr = CreateMemoryStream(stream.GetAddressOf());
if (FAILED(hr)) if (FAILED(hr))
return hr; return hr;
ComPtr<MemoryStreamOnBlob> stream;
hr = MemoryStreamOnBlob::CreateMemoryStream(&stream, blob);
if (FAILED(hr))
{
blob.Release();
return hr;
}
hr = EncodeSingleFrame(image, flags, containerFormat, stream.Get(), targetFormat, setCustomProps); hr = EncodeSingleFrame(image, flags, containerFormat, stream.Get(), targetFormat, setCustomProps);
if (FAILED(hr)) if (FAILED(hr))
{
blob.Release();
return hr; return hr;
}
// Copy stream data into blob hr = stream->Finialize();
STATSTG stat;
hr = stream->Stat(&stat, STATFLAG_NONAME);
if (FAILED(hr)) if (FAILED(hr))
{
blob.Release();
return hr; return hr;
}
if (stat.cbSize.HighPart > 0) stream.Reset();
return HRESULT_E_FILE_TOO_LARGE;
hr = blob.Initialize(stat.cbSize.LowPart);
if (FAILED(hr))
return hr;
LARGE_INTEGER li = {};
hr = stream->Seek(li, STREAM_SEEK_SET, nullptr);
if (FAILED(hr))
return hr;
DWORD bytesRead;
hr = stream->Read(blob.GetBufferPointer(), static_cast<ULONG>(blob.GetBufferSize()), &bytesRead);
if (FAILED(hr))
return hr;
if (bytesRead != blob.GetBufferSize())
return E_FAIL;
return S_OK; return S_OK;
} }
@ -1251,46 +1492,37 @@ HRESULT DirectX::SaveToWICMemory(
if (!images || nimages == 0) if (!images || nimages == 0)
return E_INVALIDARG; return E_INVALIDARG;
blob.Release(); HRESULT hr = blob.Initialize(65535u);
ComPtr<IStream> stream;
HRESULT hr = CreateMemoryStream(stream.GetAddressOf());
if (FAILED(hr)) if (FAILED(hr))
return hr; return hr;
ComPtr<MemoryStreamOnBlob> stream;
hr = MemoryStreamOnBlob::CreateMemoryStream(&stream, blob);
if (FAILED(hr))
{
blob.Release();
return hr;
}
if (nimages > 1) if (nimages > 1)
hr = EncodeMultiframe(images, nimages, flags, containerFormat, stream.Get(), targetFormat, setCustomProps); hr = EncodeMultiframe(images, nimages, flags, containerFormat, stream.Get(), targetFormat, setCustomProps);
else else
hr = EncodeSingleFrame(images[0], flags, containerFormat, stream.Get(), targetFormat, setCustomProps); hr = EncodeSingleFrame(images[0], flags, containerFormat, stream.Get(), targetFormat, setCustomProps);
if (FAILED(hr)) if (FAILED(hr))
{
blob.Release();
return hr; return hr;
}
// Copy stream data into blob hr = stream->Finialize();
STATSTG stat;
hr = stream->Stat(&stat, STATFLAG_NONAME);
if (FAILED(hr)) if (FAILED(hr))
{
blob.Release();
return hr; return hr;
}
if (stat.cbSize.HighPart > 0) stream.Reset();
return HRESULT_E_FILE_TOO_LARGE;
hr = blob.Initialize(stat.cbSize.LowPart);
if (FAILED(hr))
return hr;
LARGE_INTEGER li = {};
hr = stream->Seek(li, STREAM_SEEK_SET, nullptr);
if (FAILED(hr))
return hr;
DWORD bytesRead;
hr = stream->Read(blob.GetBufferPointer(), static_cast<ULONG>(blob.GetBufferSize()), &bytesRead);
if (FAILED(hr))
return hr;
if (bytesRead != blob.GetBufferSize())
return E_FAIL;
return S_OK; return S_OK;
} }