@@ -100,15 +100,29 @@ class APIListViewTestMixin(object):
100100 list_name = 'list'
101101 default_ids = []
102102 always_exclude = ['created' ]
103+ test_post_method = False
103104
104- def path (self , ids = None , fields = None , exclude = None , ** kwargs ):
105- query_params = {}
106- for query_arg , data in zip ([self .ids_param , 'fields' , 'exclude' ], [ids , fields , exclude ]) + kwargs .items ():
107- if data :
108- query_params [query_arg ] = ',' .join (data )
109- query_string = '?{}' .format (urlencode (query_params ))
105+ def path (self , query_data = None ):
106+ query_data = query_data or {}
107+ concat_query_data = {param : ',' .join (arg ) for param , arg in query_data .items () if arg }
108+ query_string = '?{}' .format (urlencode (concat_query_data )) if concat_query_data else ''
110109 return '/api/v0/{}/{}' .format (self .list_name , query_string )
111110
111+ def validated_request (self , ids = None , fields = None , exclude = None , ** extra_args ):
112+ params = [self .ids_param , 'fields' , 'exclude' ]
113+ args = [ids , fields , exclude ]
114+ data = {param : arg for param , arg in zip (params , args ) if arg }
115+ data .update (extra_args )
116+
117+ get_response = self .authenticated_get (self .path (data ))
118+ if self .test_post_method :
119+ post_response = self .authenticated_post (self .path (), data = data )
120+ self .assertEquals (get_response .status_code , post_response .status_code )
121+ if 200 <= get_response .status_code < 300 :
122+ self .assertEquals (get_response .data , post_response .data )
123+
124+ return get_response
125+
112126 def create_model (self , model_id , ** kwargs ):
113127 pass # implement in subclass
114128
@@ -134,19 +148,19 @@ def all_expected_results(self, ids=None, **kwargs):
134148
135149 def _test_all_items (self , ids ):
136150 self .generate_data ()
137- response = self .authenticated_get ( self . path ( ids = ids , exclude = self .always_exclude ) )
151+ response = self .validated_request ( ids = ids , exclude = self .always_exclude )
138152 self .assertEquals (response .status_code , 200 )
139153 self .assertItemsEqual (response .data , self .all_expected_results (ids = ids ))
140154
141155 def _test_one_item (self , item_id ):
142156 self .generate_data ()
143- response = self .authenticated_get ( self . path ( ids = [item_id ], exclude = self .always_exclude ) )
157+ response = self .validated_request ( ids = [item_id ], exclude = self .always_exclude )
144158 self .assertEquals (response .status_code , 200 )
145159 self .assertItemsEqual (response .data , [self .expected_result (item_id )])
146160
147161 def _test_fields (self , fields ):
148162 self .generate_data ()
149- response = self .authenticated_get ( self . path ( fields = fields ) )
163+ response = self .validated_request ( fields = fields )
150164 self .assertEquals (response .status_code , 200 )
151165
152166 # remove fields not requested from expected results
@@ -158,10 +172,10 @@ def _test_fields(self, fields):
158172 self .assertItemsEqual (response .data , expected_results )
159173
160174 def test_no_items (self ):
161- response = self .authenticated_get ( self . path () )
175+ response = self .validated_request ( )
162176 self .assertEquals (response .status_code , 404 )
163177
164178 def test_no_matching_items (self ):
165179 self .generate_data ()
166- response = self .authenticated_get ( self . path ( ids = ['no/items/found' ]) )
180+ response = self .validated_request ( ids = ['no/items/found' ])
167181 self .assertEquals (response .status_code , 404 )
0 commit comments