/*
* Copyright (c) 2015-2016 Intel Corporation, Inc.  All rights reserved.
*
* This software is available to you under a choice of one of two
* licenses.  You may choose to be licensed under the terms of the GNU
* General Public License (GPL) Version 2, available from the file
* COPYING in the main directory of this source tree, or the
* BSD license below:
*
*     Redistribution and use in source and binary forms, with or
*     without modification, are permitted provided that the following
*     conditions are met:
*
*      - Redistributions of source code must retain the above
*        copyright notice, this list of conditions and the following
*        disclaimer.
*
*      - Redistributions in binary form must reproduce the above
*        copyright notice, this list of conditions and the following
*        disclaimer in the documentation and/or other materials
*        provided with the distribution.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
* BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
* ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
* CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/

#ifdef _WIN32

#include <initguid.h>
#include <guiddef.h>

#include <ws2spi.h>
#include <cassert>
#include "ndspi.h"

#include "netdir.h"
#include "netdir_log.h"

#ifndef ofi_sizeofaddr
#define ofi_sizeofaddr(address)			\
	(address)->sa_family == AF_INET ?	\
		sizeof(struct sockaddr_in) :	\
		sizeof(struct sockaddr_in6)
#endif

#define FI_ND_PROTO_FLAG (XP1_GUARANTEED_DELIVERY | XP1_GUARANTEED_ORDER | \
			  XP1_MESSAGE_ORIENTED | XP1_CONNECT_DATA)

static int ofi_nd_startup_done = 0;

typedef HRESULT(*can_unload_now_t)(void);
typedef HRESULT(*get_class_object_t)(REFCLSID rclsid, REFIID rrid, LPVOID* ppv);

struct module_t {
	const wchar_t		*path;
	HMODULE			module;
	can_unload_now_t	can_unload_now;
	get_class_object_t	get_class_object;
};

struct factory_t {
	WSAPROTOCOL_INFOW	protocol;
	IClassFactory		*class_factory;
	IND2Provider		*provider;
	struct module_t		*module;
	SOCKET_ADDRESS_LIST	*addr_list;
};

struct adapter_t {
	union {
		struct sockaddr		addr;
		struct sockaddr_in	addr4;
		struct sockaddr_in6	addr6;
	} address;
	ND2_ADAPTER_INFO		info;
	IND2Adapter			*adapter;
	struct factory_t		*factory;
	const char			*name;
};

static struct ofi_nd_infra_t {
	struct modules_t {
		struct module_t		*modules;
		size_t			count;
	} providers;

	struct class_factory_t {
		struct factory_t	*factory;
		size_t			count;
	} class_factories;

	struct adapters_t {
		struct adapter_t	*adapter;
		size_t			count;
	} adapters;
} ofi_nd_infra = {0};

/* release all objects, do not free strings or arrays */
static inline void ofi_nd_release_infra()
{
	size_t i;

	if (ofi_nd_infra.adapters.count) {
		assert(ofi_nd_infra.adapters.adapter);
		for (i = 0; i < ofi_nd_infra.adapters.count; i++) {
			struct adapter_t *adapter = &ofi_nd_infra.adapters.adapter[i];
			if (adapter->adapter) {
				adapter->adapter->lpVtbl->Release(adapter->adapter);
				adapter->adapter = 0;
			}
		}
	}

	if (ofi_nd_infra.class_factories.count) {
		assert(ofi_nd_infra.class_factories.factory);
		for (i = 0; i < ofi_nd_infra.class_factories.count; i++) {
			struct factory_t *factory = &ofi_nd_infra.class_factories.factory[i];
			if (factory->provider) {
				factory->provider->lpVtbl->Release(factory->provider);
				factory->provider = 0;
			}
			if (factory->class_factory) {
				factory->class_factory->lpVtbl->Release(factory->class_factory);
				factory->class_factory = 0;
			}
		}
	}
}

static inline void ofi_nd_free_infra()
{
	size_t i;

	ofi_nd_release_infra();

	if (ofi_nd_infra.adapters.count) {
		assert(ofi_nd_infra.adapters.adapter);
		for (i = 0; i < ofi_nd_infra.adapters.count; i++) {
			struct adapter_t *adapter = &ofi_nd_infra.adapters.adapter[i];
			if (adapter->name) {
				free((void*)adapter->name);
				adapter->name = 0;
			}
		}
		free(ofi_nd_infra.adapters.adapter);
		ofi_nd_infra.adapters.adapter = 0;
		ofi_nd_infra.adapters.count = 0;
	}

	if (ofi_nd_infra.class_factories.count) {
		assert(ofi_nd_infra.class_factories.factory);
		for (i = 0; i < ofi_nd_infra.class_factories.count; i++) {
			struct factory_t *factory = &ofi_nd_infra.class_factories.factory[i];
			assert(factory->module);
			if (factory->addr_list) {
				free(factory->addr_list);
				factory->addr_list = 0;
			}
		}
		free(ofi_nd_infra.class_factories.factory);
		ofi_nd_infra.class_factories.factory = 0;
		ofi_nd_infra.class_factories.count = 0;
	}

	if (ofi_nd_infra.providers.count) {
		assert(ofi_nd_infra.providers.modules);
		for (i = 0; i < ofi_nd_infra.providers.count; i++) {
			struct module_t *module = &ofi_nd_infra.providers.modules[i];
			assert(module->path);
			free((void*)module->path);
		}
		free(ofi_nd_infra.providers.modules);
		ofi_nd_infra.providers.modules = 0;
		ofi_nd_infra.providers.count = 0;
	}
}

static inline HRESULT ofi_nd_alloc_infra(size_t cnt)
{
	memset(&ofi_nd_infra, 0, sizeof(*(&ofi_nd_infra)));

	ofi_nd_infra.providers.modules = (struct module_t*)malloc(cnt * sizeof(*ofi_nd_infra.providers.modules));
	if (!ofi_nd_infra.providers.modules)
		return ND_NO_MEMORY;

	ofi_nd_infra.class_factories.factory = (struct factory_t*)malloc(cnt * sizeof(*ofi_nd_infra.class_factories.factory));
	if (!ofi_nd_infra.class_factories.factory) {
		ofi_nd_free_infra();
		return ND_NO_MEMORY;
	}

	return S_OK;
}

static inline wchar_t *ofi_nd_get_provider_path(const WSAPROTOCOL_INFOW *proto)
{
	assert(proto);

	int len, lenex, err, res;
	wchar_t *prov, *provex;

	res = WSCGetProviderPath((GUID*)&proto->ProviderId, NULL, &len, &err);
	if (err != WSAEFAULT || !len)
		return NULL;

	prov = (wchar_t*)malloc(len * sizeof(*prov));
	if (!prov)
		return NULL;

	res = WSCGetProviderPath((GUID*)&proto->ProviderId, prov, &len, &err);
	if (res)
		goto fn1;

	lenex = ExpandEnvironmentStringsW(prov, NULL, 0);
	if (!lenex)
		goto fn1;

	provex = (wchar_t*)malloc(lenex * sizeof(*provex));
	if (!provex)
		goto fn1;

	lenex = ExpandEnvironmentStringsW(prov, provex, lenex);
	if (!lenex)
		goto fn2;

	free(prov);
	return provex;

fn2:
	free(provex);
fn1:
	free(prov);
	return NULL;
}

static inline int ofi_nd_is_valid_proto(const WSAPROTOCOL_INFOW *proto)
{
	assert(proto);

	if ((proto->dwServiceFlags1 & FI_ND_PROTO_FLAG) != FI_ND_PROTO_FLAG)
		return 0;

	if (!(proto->iAddressFamily == AF_INET ||
	      proto->iAddressFamily == AF_INET6))
		return 0;

	if (proto->iSocketType != -1)
		return 0;

	if (proto->iProtocol || proto->iProtocolMaxOffset)
		return 0;
	return 1;
}

static inline struct module_t *ofi_nd_search_module(const wchar_t* path)
{
	size_t i;
	size_t j;

	for (i = 0; i < ofi_nd_infra.providers.count; i++) {
		if (path && ofi_nd_file_exists(path) &&
		    !ofi_nd_is_directory(path)) {
			for (j = 0; j < ofi_nd_infra.providers.count; j++) {
				if (ofi_nd_is_same_file(path, ofi_nd_infra.providers.modules[j].path)) {
					return &ofi_nd_infra.providers.modules[j];
				}
			}
		}
	}
	return NULL;
}

static inline struct module_t *ofi_nd_create_module(const wchar_t* path)
{
	struct module_t *module;
	HMODULE hmodule;
	can_unload_now_t unload;
	get_class_object_t getclass;

	assert(ofi_nd_infra.providers.modules);

	module = ofi_nd_search_module(path);
	if (module)
		return module;

	/* ok, this is not duplicate. try to
	load it and get class factory*/
	hmodule = LoadLibraryW(path);
	if (!hmodule) {
		ND_LOG_WARN(FI_LOG_CORE,
			   "ofi_nd_create_module: provider : %S, failed to load: %s\n",
			   path, ofi_nd_strerror(GetLastError(), 0));
		return NULL;
	}

	unload = (can_unload_now_t)GetProcAddress(hmodule, "DllCanUnloadNow");
	getclass = (get_class_object_t)GetProcAddress(hmodule, "DllGetClassObject");
	if (!unload || !getclass) {
		ND_LOG_WARN(FI_LOG_CORE,
			   "ofi_nd_create_module: provider: %S, failed to import interface\n",
			   path);
		goto fn_noiface;
	}

	module = &ofi_nd_infra.providers.modules[ofi_nd_infra.providers.count];
	ofi_nd_infra.providers.count++;

	module->path = _wcsdup(path);
	module->module = hmodule;
	module->can_unload_now = unload;
	module->get_class_object = getclass;

	return module;

fn_noiface:
	FreeLibrary(hmodule);
	return NULL;
}

static inline HRESULT ofi_nd_create_factory(const WSAPROTOCOL_INFOW* proto)
{
	wchar_t *path;
	struct module_t *module;
	IClassFactory* factory;
	HRESULT hr;
	struct factory_t *ftr;

	assert(proto);
	assert(ofi_nd_is_valid_proto(proto));
	assert(ofi_nd_infra.class_factories.factory);

	path = ofi_nd_get_provider_path(proto);
	if (path)
		ND_LOG_INFO(FI_LOG_CORE,
			    "ofi_nd_create_factory: provider " FI_ND_GUID_FORMAT " path: %S \n",
			    FI_ND_GUID_ARG(proto->ProviderId), path);
	else /* can't get provider path. just return */
		return S_OK;

	module = ofi_nd_create_module(path);
	free(path);
	if (!module)
		return S_OK;

	assert(module->get_class_object);
	hr = module->get_class_object(&proto->ProviderId, &IID_IClassFactory,
				      (void**)&factory);
	if (FAILED(hr))
		return hr;

	ftr = &ofi_nd_infra.class_factories.factory[ofi_nd_infra.class_factories.count];
	ofi_nd_infra.class_factories.count++;
	ftr->class_factory = factory;
	ftr->module = module;
	ftr->protocol = *proto;

	return S_OK;
}

static int ofi_nd_adapter_cmp(const void *adapter1, const void *adapter2)
{
	return ofi_nd_addr_cmp(&((struct adapter_t*)adapter1)->address,
			       &((struct adapter_t*)adapter2)->address);
}

static HRESULT ofi_nd_create_adapter(void)
{
	size_t addr_count = 0;
	HRESULT hr;

	for (size_t i = 0; i < ofi_nd_infra.class_factories.count; i++) {
		struct factory_t *factory = &ofi_nd_infra.class_factories.factory[i];
		ULONG listsize = 0;

		assert(factory->class_factory);

		hr = factory->class_factory->lpVtbl->CreateInstance(factory->class_factory,
			NULL, &IID_IND2Provider, (void**)&factory->provider);
		if (FAILED(hr))
			return hr;
		hr = factory->provider->lpVtbl->QueryAddressList(factory->provider, NULL, &listsize);
		if (hr != ND_BUFFER_OVERFLOW)
			return hr;
		if (!listsize) {
			continue;
		}

		factory->addr_list = (SOCKET_ADDRESS_LIST*)malloc(listsize);
		if (!factory->addr_list)
			return ND_NO_MEMORY;

		hr = factory->provider->lpVtbl->QueryAddressList(factory->provider,
			factory->addr_list, &listsize);
		if (FAILED(hr))
			return hr;

		for (INT j = 0; j < factory->addr_list->iAddressCount; j++) {
			if (ofi_nd_is_valid_addr(factory->addr_list->Address[j].lpSockaddr))
				addr_count++;
		}
	}

	if (!addr_count)
		return E_NOINTERFACE;

	ofi_nd_infra.adapters.adapter = (struct adapter_t*)malloc(addr_count * sizeof(*ofi_nd_infra.adapters.adapter));
	if (!ofi_nd_infra.adapters.adapter)
		return ND_NO_MEMORY;

	/* put all available valid addresses into common array */
	for (size_t i = 0; i < ofi_nd_infra.class_factories.count; i++) {
		struct factory_t *factory = &ofi_nd_infra.class_factories.factory[i];
		for (INT j = 0; j < factory->addr_list->iAddressCount; j++) {
			if (ofi_nd_is_valid_addr(factory->addr_list->Address[j].lpSockaddr)) {
				struct adapter_t *adapter = &ofi_nd_infra.adapters.adapter[ofi_nd_infra.adapters.count];
				assert((int)sizeof(adapter->address) >= factory->addr_list->Address[j].iSockaddrLength);
				memcpy(&adapter->address,
				       factory->addr_list->Address[j].lpSockaddr,
				       factory->addr_list->Address[j].iSockaddrLength);
				adapter->factory = factory;
				ofi_nd_infra.adapters.count++;
			}
		}
	}

	if (!ofi_nd_infra.adapters.count)
		return E_NOINTERFACE;

	/* sort adapters by addresses to set IP4 addresses first. then remove
	   duplicates */
	qsort(ofi_nd_infra.adapters.adapter, ofi_nd_infra.adapters.count,
	      sizeof(struct adapter_t), ofi_nd_adapter_cmp);
	ofi_nd_infra.adapters.count = unique(ofi_nd_infra.adapters.adapter,
					     ofi_nd_infra.adapters.count,
					     sizeof(struct adapter_t), ofi_nd_adapter_cmp);

	for (size_t i = 0; i < ofi_nd_infra.adapters.count; i++) {
		struct adapter_t *adapter = &ofi_nd_infra.adapters.adapter[i];
		struct factory_t *factory = adapter->factory;
		wchar_t *saddr;
		DWORD addrlen = 0;
		UINT64 id;
		int res;

		assert(factory);
		assert(factory->provider);

		assert(adapter->address.addr.sa_family == AF_INET ||
		       adapter->address.addr.sa_family == AF_INET6);

		hr = factory->provider->lpVtbl->ResolveAddress(factory->provider,
			&adapter->address.addr,
			ofi_sizeofaddr(&adapter->address.addr), &id);

		if (FAILED(hr))
			return hr;

		hr = factory->provider->lpVtbl->OpenAdapter(factory->provider,
			&IID_IND2Adapter, id, (void**)&adapter->adapter);
		if (FAILED(hr))
			return hr;

		ULONG linfo = sizeof(adapter->info);
		adapter->info.InfoVersion = ND_VERSION_2;
		hr = adapter->adapter->lpVtbl->Query(adapter->adapter, &adapter->info, &linfo);
		if (FAILED(hr) && hr == ND_BUFFER_OVERFLOW) {
			ND2_ADAPTER_INFO *info = (ND2_ADAPTER_INFO*)malloc(linfo);
			if (!info)
				return ND_NO_MEMORY;
			info->InfoVersion = ND_VERSION_2;
			hr = adapter->adapter->lpVtbl->Query(adapter->adapter, info, &linfo);
			if (FAILED(hr))
				return hr;
			adapter->info = *info;
			free(info);
		} else if (FAILED(hr)) {
			return hr;
		}

		/* generate adapter's name */
		res = WSAAddressToStringW(&adapter->address.addr,
					     ofi_sizeofaddr(&adapter->address.addr),
					     NULL, NULL, &addrlen);
		if (res == SOCKET_ERROR && WSAGetLastError() == WSAEFAULT && addrlen) {
			saddr = (wchar_t*)malloc((addrlen + 1) * sizeof(*saddr));
			WSAAddressToStringW(&adapter->address.addr,
					    ofi_sizeofaddr(&adapter->address.addr),
					    NULL, saddr, &addrlen);
		}
		else {
			saddr = _wcsdup(L"unknown");
		}

		asprintf((char**)&adapter->name, "netdir-%S-%S-%p",
			 ofi_nd_filename(adapter->factory->module->path),
			 saddr, adapter);
		free(saddr);
	}

	return S_OK;
}

static HRESULT ofi_nd_init(ofi_nd_adapter_cb_t cb)
{
	DWORD proto_len = 0;
	HRESULT hr = ND_INTERNAL_ERROR;
	int i, protonum, err;
	size_t j, prov_count = 0;
	WSAPROTOCOL_INFOW *proto = 0;

	memset(&ofi_nd_infra, 0, sizeof(ofi_nd_infra));

	int ret = WSCEnumProtocols(NULL, NULL, &proto_len, &err);
	if (ret != SOCKET_ERROR || err != WSAENOBUFS) {
		hr = ND_NO_MEMORY;
		goto fn_exit;
	}

	proto = (WSAPROTOCOL_INFOW*)(malloc(proto_len));
	if (!proto) {
		hr = ND_NO_MEMORY;
		goto fn_exit;
	}

	protonum = WSCEnumProtocols(NULL, proto, &proto_len, &err);
	if (protonum == SOCKET_ERROR) {
		hr = ND_INTERNAL_ERROR;
		goto fn_protofail;
	}

	/* calculating number of valid protocols. this number is used
	   as maximum of existing providers and class factories */
	for (i = 0; i < protonum; i++) {
		if (ofi_nd_is_valid_proto(&proto[i]))
			prov_count++;
	}

	if (!prov_count) {
		hr = E_NOINTERFACE;
		goto fn_protofail;
	}

	hr = ofi_nd_alloc_infra(prov_count);
	if (hr != S_OK)
		goto fn_protofail;

	for (i = 0; i < protonum; i++) {
		if (ofi_nd_is_valid_proto(&proto[i]))
			ofi_nd_create_factory(&proto[i]);
	}

	free(proto);

	/* ok, factories are created, now list all available addresses, try to
	   create adapters & collect adapter's info */
	hr = ofi_nd_create_adapter();
	if (FAILED(hr))
		return hr;

	/* free all interfaces. we don't need it right now */
	ofi_nd_release_infra();

	/* now call cb function to create info's */
	for (j = 0; j < ofi_nd_infra.adapters.count; j++)
		cb(&ofi_nd_infra.adapters.adapter[j].info,
		   ofi_nd_infra.adapters.adapter[j].name);

	return hr;
fn_protofail:
	free(proto);
fn_exit:
	return hr;
}

/* we don't need here exclusive execution because this function
 * is called from OFI init routine which is single thread */
HRESULT ofi_nd_startup(ofi_nd_adapter_cb_t cb)
{ 
	WSADATA data;
	HRESULT hr;
	int ret;

	assert(cb);

	if (ofi_nd_startup_done)
		return S_OK;

	ND_LOG_INFO(FI_LOG_CORE, "ofi_nd_startup: starting initialization\n");

	ret = WSAStartup(MAKEWORD(2, 2), &data);
	if (ret)
		return HRESULT_FROM_WIN32(ret);

	ND_LOG_DEBUG(FI_LOG_CORE, "ofi_nd_startup: WSAStartup complete\n");

	hr = ofi_nd_init(cb);

	ofi_nd_startup_done = 1;

	return hr;
}

HRESULT ofi_nd_shutdown(void)
{
	if (!ofi_nd_startup_done)
		return S_OK;

	ND_LOG_INFO(FI_LOG_CORE, "ofi_nd_shutdown: shutdown WSA\n");

	ofi_nd_free_infra();
	ofi_nd_startup_done = 0;

	return HRESULT_FROM_WIN32(WSACleanup());
}

int ofi_nd_lookup_adapter(const char *name, IND2Adapter **adapter, struct sockaddr** addr)
{
	size_t i;

	assert(name);
	assert(adapter);

	if (!ofi_nd_startup_done)
		return -FI_EOPBADSTATE;

	for (i = 0; i < ofi_nd_infra.adapters.count; i++) {
		struct adapter_t *ada = &ofi_nd_infra.adapters.adapter[i];
		if (ada->name && !strcmp(ada->name, name)) {
			HRESULT hr;
			UINT64 adapter_id;
			IClassFactory* factory = NULL;
			IND2Provider *provider = NULL;

			/* ok, we found good adapter. try to initialize it */
			if (ada->adapter) {
				*adapter = ada->adapter;
				*addr = &ada->address.addr;
				ada->adapter->lpVtbl->AddRef(ada->adapter);
				return FI_SUCCESS;
			}

			assert(ada->factory);
			assert(ada->factory->module);
			assert(ada->factory->module->get_class_object);

			hr = ada->factory->module->get_class_object(
				&ada->factory->protocol.ProviderId,
				&IID_IClassFactory,
				(void**)&factory);
			if (FAILED(hr))
				return H2F(hr);
			assert(factory);

			hr = factory->lpVtbl->CreateInstance(factory, NULL,
							     &IID_IND2Provider,
							     (void**)&provider);
			factory->lpVtbl->Release(factory);
			if (FAILED(hr))
				return H2F(hr);
			assert(provider);

			hr = provider->lpVtbl->ResolveAddress(provider, &ada->address.addr,
							      ofi_sizeofaddr(&ada->address.addr),
							      &adapter_id);
			if (FAILED(hr)) {
				provider->lpVtbl->Release(provider);
				return H2F(hr);
			}

			hr = provider->lpVtbl->OpenAdapter(provider, &IID_IND2Adapter, adapter_id,
							   (void**)&ada->adapter);
			provider->lpVtbl->Release(provider);
			if (FAILED(hr))
				return H2F(hr);

			*adapter = ada->adapter;
			*addr = &ada->address.addr;
			ada->adapter->lpVtbl->AddRef(ada->adapter);

			return FI_SUCCESS;
		}
	}

	return -FI_EINVAL;
}

#endif /* _WIN32 */